Skip to content
Snippets Groups Projects
tools.py 38.74 KiB
import datetime
import json
import os
import re
import zlib
from collections import OrderedDict
from glob import glob
from multiprocessing.pool import ThreadPool
from os import environ, listdir, path
from os.path import isfile
from pathlib import Path
from queue import Queue
from tempfile import NamedTemporaryFile
from time import sleep
from typing import List, Optional, Tuple, Union
from urllib.parse import urljoin

import h5py
import ipykernel
import numpy as np
import requests
import yaml
import zmq
from extra_data import H5File, RunDirectory
from iCalibrationDB import ConstantMetaData, Versions
from notebook.notebookapp import list_running_servers

from .ana_tools import save_dict_to_hdf5


def parse_runs(runs, return_type=str):
    pruns = []
    if isinstance(runs, str):
        for rcomp in runs.split(","):
            if "-" in rcomp:
                start, end = rcomp.split("-")
                pruns += list(range(int(start), int(end)))
            else:
                pruns += [int(rcomp)]
    elif isinstance(runs, (list, tuple)):
        pruns = runs
    else:
        pruns = [runs, ]

    if return_type is str:
        return ["r{:04d}".format(r) for r in pruns]
    else:
        return pruns


def run_prop_seq_from_path(filename):
    run = re.findall(r'.*r([0-9]{4}).*', filename)
    run = run[0] if len(run) else None
    proposal = re.findall(r'.*p([0-9]{6}).*', filename)
    proposal = proposal[0] if len(proposal) else None
    sequence = re.findall(r'.*S([0-9]{5}).*', filename)
    sequence = sequence[0] if len(sequence) else None
    return run, proposal, sequence


def map_seq_files(
    run_folder: Path,
    karabo_das: List[str],
    sequences: Optional[List[int]] = None,
) -> Tuple[dict, int]:

    """Glob run_folder and match the files based on the selected
    detectors and sequence numbers.

    Returns:
        Dict: with karabo_das keys and the corresponding sequence files.
        Int: for number of all sequence files for all karabo_das to process.
    """

    if sequences == [-1]:
        sequences = None
    if sequences is not None:
        sequences = set(int(seq) for seq in sequences)

    seq_fn_pat = re.compile(r".*-(?P<da>.*?)-S(?P<seq>.*?)\.h5")

    mapped_files = {kda: [] for kda in karabo_das}
    total_files = 0

    for fn in run_folder.glob("*.h5"):
        if (match := seq_fn_pat.match(fn.name)) is not None:
            da = match.group("da")
            if da in mapped_files and (
                sequences is None or int(match.group("seq")) in sequences
            ):
                mapped_files[da].append(fn)
                total_files += 1

    # Return dict with sorted list of sequence files.
    for k in mapped_files:
        mapped_files[k].sort()

    return mapped_files, total_files


def map_modules_from_folder(in_folder, run, path_template, karabo_da,
                            sequences=None):
    """
    Prepare queues of files to process.
    Queues are stored in dictionary with module name Q{}M{} as a key

    :param in_folder: Input folder with raw data
    :param run: Run number
    :param path_template: Template for file name
                          e.g. `RAW-R{:04d}-{}-S{:05d}.h5`
    :param karabo_da: List of data aggregators e.g. [AGIPD00, AGIPD01]
    :param sequences: List of sequences to be considered
    :return: Dictionary of queues of files, dictionary of module indexes,
    total number of sequences, dictionary of number of sequences per module
    """
    module_files = OrderedDict()
    mod_ids = OrderedDict()
    total_sequences = 0
    total_file_size = 0
    sequences_qm = {}
    for inset in karabo_da:
        module_idx = int(inset[-2:])
        name = module_index_to_qm(module_idx)
        module_files[name] = Queue()
        sequences_qm[name] = 0
        mod_ids[name] = module_idx
        if sequences is None:
            fname = path_template.format(run, inset, 0).replace("S00000", "S*")
            abs_fname = "{}/r{:04d}/{}".format(in_folder, run, fname)

            for filename in glob(abs_fname):
                module_files[name].put(filename)
                total_sequences += 1
                sequences_qm[name] += 1
                total_file_size += path.getsize(filename)
        else:
            for sequence in sequences:
                fname = path_template.format(run, inset, sequence)
                abs_fname = "{}/r{:04d}/{}".format(in_folder, run, fname)
                if not isfile(abs_fname):
                    continue

                module_files[name].put(abs_fname)
                total_sequences += 1
                sequences_qm[name] += 1
                total_file_size += path.getsize(abs_fname)

    return (module_files, mod_ids, total_sequences,
            sequences_qm, total_file_size)


def map_gain_stages(in_folder, runs, path_template, karabo_da, sequences=None):
    """
    Prepare queues of files to process.
    Queues are stored in dictionary with module name Q{}M{}
    and gain name as a keys
    :param in_folder: Input folder with raw data
    :param runs: Dictionary of runs with key naming the gain stages
    :param path_template: Template for file name
                          e.g. `RAW-R{:04d}-{}-S{:05d}.h5`
    :param karabo_da: List of data aggregators e.g. [AGIPD00, AGIPD01]
    :param sequences: List of sequences to be considered
    :return: Dictionary of queues of files,
    total number of sequences
    """
    total_sequences = 0
    total_file_size = 0
    gain_mapped_files = OrderedDict()
    for gain, run in runs.items():
        mapped_files, _, seq, _, fs = map_modules_from_folder(in_folder, run,
                                                              path_template,
                                                              karabo_da,
                                                              sequences)

        total_sequences += seq
        total_file_size += fs
        gain_mapped_files[gain] = mapped_files
    return gain_mapped_files, total_sequences, total_file_size / 1e9


def map_modules_from_files(filelist, file_inset, quadrants, modules_per_quad):
    total_sequences = 0
    total_file_size = 0
    module_files = {}
    mod_ids = {}
    for quadrant in range(quadrants):
        for module in range(modules_per_quad):
            name = "Q{}M{}".format(quadrant + 1, module + 1)
            module_files[name] = Queue()
            num = quadrant * 4 + module
            mod_ids[name] = num
            file_infix = "{}{:02d}".format(file_inset, num)
            for file in filelist:
                if file_infix in file:
                    module_files[name].put(file)
                    total_sequences += 1
                    total_file_size += path.getsize(file)

    return module_files, mod_ids, total_sequences, total_file_size


def gain_map_files(in_folder, runs, sequences, file_inset, quadrants,
                   mods_per_quad):
    total_sequences = 0
    total_file_size = 0
    gain_mapped_files = OrderedDict()
    for gain, run in runs.items():
        ginfolder = "{}/{}".format(in_folder, run)
        dirlist = listdir(ginfolder)
        file_list = []
        for entry in dirlist:
            # only h5 file
            abs_entry = "{}/{}".format(ginfolder, entry)
            if path.isfile(abs_entry) and path.splitext(abs_entry)[1] == ".h5":

                if sequences is None:
                    file_list.append(abs_entry)
                else:
                    for seq in sequences:
                        if "{:05d}.h5".format(seq) in abs_entry:
                            file_list.append(path.abspath(abs_entry))

        mapped_files, mod_ids, seq, fs = map_modules_from_files(file_list,
                                                                file_inset,
                                                                quadrants,
                                                                mods_per_quad)
        total_sequences += seq
        total_file_size += fs
        gain_mapped_files[gain] = mapped_files
    return gain_mapped_files, total_sequences, total_file_size / 1e9


def get_notebook_name():
    """
    Return the full path of the jupyter notebook.
    """
    try:
        kernel_id = re.search('kernel-(.*).json',
                              ipykernel.connect.get_connection_file()).group(1)
        servers = list_running_servers()
        for ss in servers:
            response = requests.get(urljoin(ss['url'], 'api/sessions'),
                                    params={'token': ss.get('token', '')})
            for nn in json.loads(response.text):
                if nn['kernel']['id'] == kernel_id:
                    return nn['notebook']['path']
    except:
        return environ.get("CAL_NOTEBOOK_NAME", "Unknown Notebook")


def creation_date_file_metadata(
    run_folder: Path,
) -> Optional[datetime.datetime]:
    """Get run directory creation date from
    METADATA/CreationDate of the oldest file using EXtra-data.
    # TODO: update after DAQ store the same date as myMDC.

    :param dc: EXtra-data DataCollection for the run directory.
    :return Optional[datetime.datetime]: Run creation date.
    """
    md_dict = RunDirectory(run_folder).run_metadata()

    if md_dict["dataFormatVersion"] != "0.5":
        creation_dates = [
            H5File(f).run_metadata()["creationDate"]
            for f in run_folder.glob("*.h5")
        ]
        return datetime.datetime.strptime(
            min(creation_dates), "%Y%m%dT%H%M%S%z")
    else:
        print("WARNING: input files contains old datasets. "
              "No `METADATA/creationDate` to read.")


def creation_date_train_timestamp(
    dc: RunDirectory
) -> Optional[datetime.datetime]:
    """Get creation date from the timestamp of the first train.

    :param dc: EXtra-data DataCollection for the run directory.
    :return Optional[datetime.datetime]: Run creation date.
    """

    creation_date = np.datetime64(
        dc.select_trains(np.s_[0]).train_timestamps()[0], 'us').item()
    if creation_date is None:
        print("WARNING: input files contains old datasets without"
              " trains timestamps.")
        return None
    return creation_date.replace(tzinfo=datetime.timezone.utc)


def get_dir_creation_date(directory: Union[str, Path], run: int,
                          verbosity: int = 0) -> datetime.datetime:
    """Get the directory creation data based on 3 different methods.

    1) Return run start time from myMDC. (get_runtime_metadata_client)
    2) If myMDC connection is not set,
    get the date from the files metadata. (get_runtime_metadata_file)
    3) If data files are older than 2020 (dataformatversion == "0.5"),
    get the data from the oldest file's modified time.

    If the data is not available from either source,
    this function will raise a FileNotFoundError.

    :param directory: path to a directory which contains runs
    (e.g. /gpfs/exfel/data/exp/callab/202031/p900113/raw/).
    :param run: run number.
    :param verbosity: Level of verbosity (0 - silent)
    :return: creation datetime for the directory.

    """
    directory = Path(directory, f'r{run:04d}')

    # Validate the availability of the input folder.
    # And show a clear error message, if it was not found.
    try:
        dc = RunDirectory(directory)
    except FileNotFoundError as e:
        raise FileNotFoundError(
            "- Failed to read creation time, wrong input folder",
            directory) from e

    cdate = creation_date_train_timestamp(dc)

    if cdate is not None:
        # Exposing the method used for reading the creation_date.
        print("Reading creation_date from input files metadata"
              " `INDEX/timestamp`")
    else:  # It's an older dataset.
        print("Reading creation_date from last modification data "
              "for the oldest input file.")
        cdate = datetime.datetime.fromtimestamp(
            sorted(
                directory.glob("*.h5"), key=path.getmtime,
            )[0].stat().st_mtime,
            tz=datetime.timezone.utc,
        )
    return cdate


def calcat_creation_time(
    in_folder: Path,
    run: str,
    creation_time: Optional[str] = "",
    ) -> datetime.datetime:
    """Return the creation time to use with CALCAT."""
    # Run's creation time:
    if creation_time:
        creation_time = datetime.datetime.strptime(
            creation_time,
            '%Y-%m-%d %H:%M:%S').astimezone(tz=datetime.timezone.utc)
    else:
        creation_time = get_dir_creation_date(in_folder, run)
    return creation_time


def _init_metadata(constant: 'iCalibrationDB.calibration_constant',
                   condition: 'iCalibrationDB.detector_conditions',
                   creation_time: Optional[str] = None
                   ) -> 'ConstantMetaData':

    """Initializing a ConstantMetaData class instance and
    add the correct creation time of the constant metadata.
    """
    metadata = ConstantMetaData()
    metadata.calibration_constant = constant
    metadata.detector_condition = condition
    if creation_time is None:
        metadata.calibration_constant_version = Versions.Now()
    else:
        metadata.calibration_constant_version = Versions.Timespan(
            start=creation_time)
    return metadata


def save_const_to_h5(db_module: str, karabo_id: str,
                     constant: 'iCalibrationDB.calibration_constant',
                     condition: 'iCalibrationDB.detector_conditions',
                     data: np.array, file_loc: str,
                     report: str,
                     creation_time: datetime.datetime,
                     out_folder: str) -> 'ConstantMetaData':

    """ Save constant in h5 file with its metadata
    (e.g. db_module, condition, creation_time)

    :param db_module: database module (PDU/Physical Detector Unit).
    :param karabo_id: karabo identifier.
    :param constant: Calibration constant known for given detector.
    :param condition: Calibration condition.
    :param data: Constant data to save.
    :param file_loc: Location of raw data "proposal:{} runs:{} {} {}".
    :param creation_time: creation_time for the saved constant.
    :param out_folder: path to output folder.
    :return: metadata of the saved constant.
    """

    metadata = _init_metadata(constant, condition, creation_time)

    metadata.calibration_constant_version.raw_data_location = file_loc

    dpar = {
        parm.name: {
            'lower_deviation_value': parm.lower_deviation,
            'upper_deviation_value': parm.upper_deviation,
            'value': parm.value,
            'flg_logarithmic': parm.logarithmic,
        }
        for parm in metadata.detector_condition.parameters
    }

    creation_time = metadata.calibration_constant_version.begin_at
    raw_data = metadata.calibration_constant_version.raw_data_location
    constant_name = metadata.calibration_constant.__class__.__name__

    data_to_store = {
        'condition': dpar,
        'db_module': db_module,
        'karabo_id': karabo_id,
        'constant': constant_name,
        'data': data,
        'creation_time': creation_time,
        'file_loc': raw_data,
        'report': report,
    }

    ofile = f"{out_folder}/const_{constant_name}_{db_module}.h5"
    if isfile(ofile):
        print(f'File {ofile} already exists and will be overwritten')
    save_dict_to_hdf5(data_to_store, ofile)

    return metadata


def get_random_db_interface(cal_db_interface):
    """Return interface to calibration DB with
    random (with given range) port.
    """
    # Initialize the random generator with a random seed value,
    # in case the function was executed within a multiprocessing pool.
    np.random.seed()
    if "#" in cal_db_interface:
        prot, serv, ran = cal_db_interface.split(":")
        r1, r2 = ran.split("#")
        return ":".join(
            [prot, serv, str(np.random.randint(int(r1), int(r2)))])
    return cal_db_interface


def get_report(out_folder: str, default_path: str = ""):
    """Get the report path from calibration_metadata.yml
    stored in the out_folder.
    """

    metadata = CalibrationMetadata(out_folder)
    report_path = metadata.get("report-path", default_path)
    if not report_path:
        print("WARNING: No report path will be injected "
              "with the constants.\n")
    return report_path


def get_pdu_from_db(karabo_id: str, karabo_da: Union[str, list],
                    constant: 'iCalibrationDB.calibration_constant',
                    condition: 'iCalibrationDB.detector_conditions',
                    cal_db_interface: str,
                    snapshot_at: Optional[datetime.datetime] = None,
                    timeout: int = 30000) -> List[str]:

    """Return all physical detector units for a
    karabo_id and list of karabo_da

    :param karabo_id: Karabo identifier.
    :param karabo_da: Karabo data aggregator.
    :param constant: Calibration constant object to
                     intialize CalibrationConstantMetadata class.
    :param condition: Detector condition object to
                      intialize CalibrationConstantMetadata class.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015".
    :param snapshot_at: Database snapshot.
    :param timeout: Calibration Database timeout.
    :return: List of physical detector units (db_modules)
    """
    if not isinstance(karabo_da, (str, list)):
        raise TypeError("karabo_da should either be a list of multiple "
                        "karabo_da or a string of one karabo_da or 'all'")

    metadata = _init_metadata(constant, condition, None)

    # CalibrationDBRemote expects a string.
    if snapshot_at is not None and hasattr(snapshot_at, 'isoformat'):
        snapshot_at = snapshot_at.isoformat()

    # A random interface is chosen if there is # for address range.
    db_interface = get_random_db_interface(cal_db_interface)

    pdu_dicts = metadata.retrieve_pdus_for_detector(receiver=db_interface,
                                                    karabo_id=karabo_id,
                                                    snapshot_at=snapshot_at,
                                                    timeout=timeout)
    # Get a list of pdus based on requested karabo_das
    if karabo_da == 'all':
        db_modules = [d["pdu_physical_name"] for d in pdu_dicts]
    else:
        k_indices = []
        if isinstance(karabo_da, str):
            karabo_da = [karabo_da]
        # Get indices of dict with the right karabo_da,
        # else use None.
        for k in karabo_da:
            pdu_found = False
            for i, d in enumerate(pdu_dicts):
                if d["karabo_da"] == k:
                    k_indices.append(i)
                    pdu_found = True
                    break
            if not pdu_found:
                k_indices.append(None)

        db_modules = []
        for i in k_indices:
            if i is None:
                db_modules.append(None)
            else:
                db_modules.append(pdu_dicts[i]["pdu_physical_name"])

    return db_modules


already_printed = {}


def get_from_db(karabo_id: str, karabo_da: str,
                constant: 'iCalibrationDB.calibration_constant',
                condition: 'iCalibrationDB.detector_conditions',
                empty_constant: np.array,
                cal_db_interface: str,
                creation_time: Optional[datetime.datetime] = None,
                verbosity: int = 1,
                timeout: int = 30000,
                ntries: int = 7,
                meta_only: bool = True,
                load_data: bool = True,
                version_info: bool = False,
                doraise: bool = False,
                strategy: str = "pdu_closest_by_time"
                ) -> Tuple[np.array, 'ConstantMetaData']:

    """Return calibration constants and metadata requested from CalDB

    This feature uses the karabo-id and karabo-da to retrieve the
    desired CCV

    :param karabo_id: karabo identifier (detector identifier).
    :param karabo_da: karabo data aggregator.
    :param constant: Calibration constant known for given detector.
    :param condition: Calibration condition.
    :param empty_constant: Constant to be returned in case of failure.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015"
    :param creation_time: Latest time for constant to be created.
    :param verbosity: Level of verbosity (0 - silent)
    :param timeout: Timeout for zmq request
    ntries is set to 7 so that if the timeout started at 30s last timeout
    will be ~ 1h.
    :param ntries: number of tries to contact the database.
    :param meta_only: Retrieve only metadata via ZMQ. Constants are
        taken directly from the h5 file on maxwell.
    :param version_info: Flag to show the info for the retrieved Constant.
    :param doraise: if True raise errors during communication with DB.
    :param strategy: Retrieving strategy for calibrationDBRemote.
    :return: Calibration constant, metadata.
    """

    if version_info:
        meta_only = False

    metadata = _init_metadata(constant, condition, creation_time)

    if karabo_id and karabo_da:
        when = None

        if creation_time is not None and hasattr(creation_time, 'isoformat'):
            when = creation_time.isoformat()

        metadata.calibration_constant_version.karabo_id = karabo_id
        metadata.calibration_constant_version.karabo_da = karabo_da

        # make sure to remove device name from metadata dict before
        # retrieving to keep using karabo_id and karabo_da only
        # during retrieval. As device_name could have been set after
        # retrieval from iCalibrationDB
        metadata.calibration_constant_version.device_name = None

        while ntries > 0:
            this_interface = get_random_db_interface(cal_db_interface)
            try:
                r = metadata.retrieve(this_interface, timeout=timeout,
                                      when=when, meta_only=meta_only,
                                      version_info=version_info,
                                      strategy=strategy)
                if version_info:
                    return r
                break
            except zmq.error.Again:
                ntries -= 1
                timeout *= 2
                sleep(np.random.randint(30))
                # TODO: reevaluate the need for doraise
                # and remove if not needed.
                if ntries == 0 and doraise:
                    raise
            except Exception as e:
                if verbosity > 0:
                    print(e)
                if 'missing_token' in str(e):
                    ntries -= 1
                else:
                    ntries = 0
                if ntries == 0 and doraise:
                    raise RuntimeError(f'{e}')

        if ntries > 0:
            mdata_const = metadata.calibration_constant_version
            if load_data and meta_only:
                hdf5path = getattr(mdata_const, 'hdf5path', None)
                filename = getattr(mdata_const, 'filename', None)
                h5path = getattr(mdata_const, 'h5path', None)
                if not (hdf5path and filename and h5path):
                    raise ValueError(
                        "Wrong metadata received to access the constant data."
                        f" Retrieved constant filepath is {hdf5path}/{filename}"  # noqa
                        f" and data_set_name is {h5path}."
                    )
                with h5py.File(Path(hdf5path, filename), "r") as f:
                    metadata.calibration_constant.data = f[f"{h5path}/data"][()]  # noqa
                    # The variant attribute is missing for old constants.
                    if "variant" in f[h5path].attrs.keys():
                        metadata.calibration_constant_version.variant = f[h5path].attrs["variant"]  # noqa

            if verbosity > 0:
                if constant.name not in already_printed or verbosity > 1:
                    already_printed[constant.name] = True
                    # TODO: Reset mdata_const.begin_at
                    # if comm_db_success is False.
                    begin_at = mdata_const.begin_at
                    print(f"Retrieved {constant.name} "
                          f"with creation time: {begin_at}")
            return constant.data, metadata
        else:
            return empty_constant, metadata
    else:
        return empty_constant, None


def send_to_db(db_module: str, karabo_id: str, constant, condition,
               file_loc: str, report_path: str, cal_db_interface: str,
               creation_time: Optional[datetime.datetime] = None,
               timeout: int = 30000,
               ntries: int = 7,
               doraise: bool = False,
               variant: int = 0):
    """Send new calibration constants and metadata requested to CalDB

    :param db_module: database module (PDU/Physical Detector Unit)
    :param karabo_id: karabo identifier
    :param constant: Calibration constant known for given detector
    :param condition: Calibration condition
    :param file_loc: Location of raw data.
    :param report_path: xfel-calbrate report path to inject along with
        the calibration constant versions to the database.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015"
    :param creation_time: Latest time for constant to be created
    :param timeout: Timeout for zmq request
    :param ntries: number of tries to contact the database,
        ntries is set to 7 so that if the timeout started
        at 30s last timeout will be ~ 1h.
    :param doraise: if True raise errors during communication with DB
    :param variant: A calibration constant version variant attribute
        for the constant file.
    """

    success = False
    snapshot_at = None
    metadata = _init_metadata(constant, condition, creation_time)

    if db_module:

        # Add injected constant's file source info as a file location
        metadata.calibration_constant_version.raw_data_location = file_loc

        if report_path:
            # calibration_client expects a dict of injected report path
            # of at least 2 characters for each key.
            if not isinstance(report_path, str) or len(report_path) < 2:
                raise TypeError(
                    "\"report_path\" needs to be a string "
                    "of at least 2 characters."
                )

            report = {"name": path.basename(report_path),
                      "file_path": report_path}
            metadata.calibration_constant_version.report_path = report

        metadata.calibration_constant_version.karabo_id = karabo_id
        metadata.calibration_constant_version.device_name = db_module
        metadata.calibration_constant_version.karabo_da = None
        metadata.calibration_constant_version.raw_data_location = file_loc
        metadata.calibration_constant_version.variant = variant
        if constant.data is None:
            raise ValueError(
                "There is no data available to "
                "inject to the database."
            )
        while ntries > 0:

            this_interface = get_random_db_interface(cal_db_interface)

            if (
                creation_time is not None and
                hasattr(creation_time, 'isoformat')
            ):
                # This snapshot will be used only while retrieving
                # the correct PDU and appending its UUID.
                snapshot_at = creation_time.isoformat()

            try:
                metadata.send(
                    this_interface,
                    snapshot_at=snapshot_at,
                    timeout=timeout,
                )
                success = True  # TODO: use comm_db_success
                break
            except zmq.error.Again:
                ntries -= 1
                timeout *= 2
                sleep(np.random.randint(30))
                if ntries == 0 and doraise:
                    raise
            except Exception as e:
                # TODO: refactor to use custom exception class
                # Refactor error message for re-injecting an
                # identical CCV to the database.
                if all(s in str(e) for s in [
                    "Error creating calibration_constant_version",
                    "has already been taken",
                ]):
                    print(
                        f"WARNING: {constant.name} for {db_module}"
                        " has already been injected with the same "
                        "parameter conditions."
                    )
                else:
                    print(f"{e}\n")

                if 'missing_token' in str(e):
                    ntries -= 1
                else:
                    ntries = 0
                if ntries == 0 and doraise:
                    raise RuntimeError(f'{e}')

        if success:
            print(
                f"{constant.name} for {db_module} "
                "is injected with creation-time: "
                f"{metadata.calibration_constant_version.begin_at}."
            )
    return metadata


def get_constant_from_db(karabo_id: str, karabo_da: str,
                         constant, condition, empty_constant,
                         cal_db_interface: str, creation_time=None,
                         print_once=True, timeout=30000, ntries=120,
                         meta_only=True):
    """Return calibration constants requested from CalDB
    """
    data, _ = get_from_db(karabo_id, karabo_da, constant,
                          condition, empty_constant,
                          cal_db_interface, creation_time,
                          int(print_once), timeout, ntries, meta_only)
    return data


def get_constant_from_db_and_time(karabo_id: str, karabo_da: str,
                                  constant, condition, empty_constant,
                                  cal_db_interface: str, creation_time=None,
                                  print_once=True, timeout=30000, ntries=120):
    """Return calibration constants requested from CalDB,
    alongside injection time
    """
    data, m = get_from_db(karabo_id, karabo_da, constant,
                          condition, empty_constant,
                          cal_db_interface, creation_time,
                          int(print_once), timeout, ntries)
    if m and m.comm_db_success:
        return data, m.calibration_constant_version.begin_at
    else:
        # return None for injection time if communication with db failed.
        # reasons (no constant or condition found,
        # or network problem)
        return data, None


def module_index_to_qm(index: int, total_modules: int = 16):
    """Maps module index (0-indexed) to quadrant + module string (1-indexed)"""
    assert index < total_modules, f'{index} is greater than {total_modules}'
    modules_per_quad = total_modules // 4
    quad, mod = divmod(index, modules_per_quad)
    return f"Q{quad+1}M{mod+1}"


def recursive_update(target: dict, source: dict):
    """Recursively merge source into target, checking for conflicts

    Conflicting entries will not be copied to target. Returns True if any
    conflicts were found.
    """
    conflict = False
    for k, v2 in source.items():
        v1 = target.get(k, None)
        if isinstance(v1, dict) and isinstance(v2, dict):
            conflict = recursive_update(v1, v2) or conflict
        elif (v1 is not None) and (v1 != v2):
            conflict = True
        else:
            target[k] = v2

    return conflict

class CalibrationMetadata(dict):
    """Convenience class: dictionary stored in metadata YAML file

    If metadata file already exists, it will be loaded (this may override
    additional constructor parameters given to this class). Use new=True to
    skip loading it.
    """

    def __init__(self, output_dir: Union[Path, str], *args, new=False):
        dict.__init__(self, args)
        self._yaml_fn = Path(output_dir) / "calibration_metadata.yml"
        if self._yaml_fn.exists():
            if new:
                # TODO: update after resolving this discussion
                # https://git.xfel.eu/detectors/pycalibration/-/merge_requests/624  # noqa
                self.save()
            else:
                with self._yaml_fn.open("r") as fd:
                    data = yaml.safe_load(fd)
                if isinstance(data, dict):
                    self.update(data)
                else:
                    print(f"Warning: existing {self._yaml_fn} is malformed, "
                           "will be overwritten")
    @property
    def filename(self):
        return self._yaml_fn

    def save(self):
        with self._yaml_fn.open("w") as fd:
            yaml.safe_dump(dict(self), fd)

    def save_copy(self, copy_dir: Path):
        with (copy_dir / self._yaml_fn.name).open("w") as fd:
            yaml.safe_dump(dict(self), fd)

    def add_fragment(self, data: dict):
        """Save metadata to a separate 'fragment' file to be merged later

        Avoids a risk of corrupting the main file by writing in parallel.
        """
        prefix = f"metadata_frag_j{os.environ.get('SLURM_JOB_ID', '')}_"
        with NamedTemporaryFile("w", dir=self._yaml_fn.parent,
                    prefix=prefix, suffix='.yml', delete=False) as fd:
            yaml.safe_dump(data, fd)

    def gather_fragments(self):
        """Merge in fragments saved by add_fragment(), then delete them"""
        frag_files = list(self._yaml_fn.parent.glob('metadata_frag_*.yml'))
        to_delete = []
        for fn in frag_files:
            with fn.open("r") as fd:
                data = yaml.safe_load(fd)
                if recursive_update(self, data):
                    print(f"{fn} contained conflicting metadata. "
                          f"This file will be left for debugging")
                else:
                    to_delete.append(fn)

        self.save()

        for fn in to_delete:
            fn.unlink()


def save_constant_metadata(
    retrieved_constants: dict,
    mdata: ConstantMetaData,
    constant_name: str,
    ):
    """Save constant metadata to the input meta data dictionary.
    The constant's metadata stored are file path, dataset name,
    creation time, and physical detector unit name.

    :param retrieved_constants: A dictionary to store the metadata for
    the retrieved constant.
    :param mdata: A ConstantMetaData object after retrieving trying
    to retrieve a constant with get_from_db().
    :param constant_name: String for constant name to be used as a key.
    :param constants_key: The key name when all constants metadata
    will be stored.
    """

    mdata_const = mdata.calibration_constant_version
    const_mdata = retrieved_constants[constant_name] = dict()
    # check if constant was successfully retrieved.
    if mdata.comm_db_success:
        const_mdata["file-path"] = (
            f"{mdata_const.hdf5path}" f"{mdata_const.filename}"
        )
        const_mdata["dataset-name"] = mdata_const.h5path
        const_mdata["creation-time"] = mdata_const.begin_at
    else:
        const_mdata["file-path"] = None
        const_mdata["creation-time"] = None


def load_specified_constants(
    retrieved_constants: dict,
    empty_constants: Optional[dict] = None,
    ) -> Tuple[dict, dict]:
    """Load constant data from metadata in the
    retrieved_constants dictionary.

    :param retrieved_constants: A dict. with the constant filepaths and
      dataset-name to read the constant data arrays.
      {
        'Constant Name': {
            'file-path': '/gpfs/.../*.h5',
            'dataset-name': '/module_name/...',
            'creation-time': str(datetime),},
        }
    :param empty_constants: A dict of constant names keys and
      the empty constant array to use in case of not non-retrieved constants.
    :return constant_data: A dict of constant names keys and their data.
    """
    const_data = dict()
    when = dict()

    for cname, mdata in retrieved_constants.items():
        const_data[cname] = dict()
        when[cname] = mdata["creation-time"]
        if when[cname]:
            with h5py.File(mdata["file-path"], "r") as cf:
                const_data[cname] = np.copy(
                    cf[f"{mdata['dataset-name']}/data"])
        else:
            const_data[cname] = (
                empty_constants[cname] if empty_constants else None)
    return const_data, when


def write_constants_fragment(
        out_folder: Path,
        det_metadata: dict,
        caldb_root: Path,
):
    """Record calibration constants metadata to a fragment file.

    Args:
        out_folder (Path): The output folder to store the fragment file.
        det_metadata (dict): A dictionary with the desired detector metadata.
            {karabo_da: {constant_name: metadata}}
        caldb_root (Path): The calibration database root path for constant files.
    """
    metadata = {"retrieved-constants": {}}
    for karabo_da, const_metadata in det_metadata.items():
        mod_metadata = {}
        mod_metadata["constants"] = {
            cname: {
                "path": str(caldb_root / ccv_metadata["path"]),
                "dataset": ccv_metadata["dataset"],
                "creation-time": ccv_metadata["begin_validity_at"],
                "ccv_id": ccv_metadata["ccv_id"],
            } for cname, ccv_metadata in const_metadata.items()
        }
        mod_metadata["physical-name"] = list(
                const_metadata.values())[0]["physical_name"]
        metadata["retrieved-constants"][karabo_da] = mod_metadata
    CalibrationMetadata(out_folder).add_fragment(metadata)


def write_compressed_frames(
        arr: np.ndarray,
        ofile: h5py.File,
        dataset_path: str,
        comp_threads: int = 1):
    """Compress gain/mask frames in multiple threads, and save their data

    This is significantly faster than letting HDF5 do the compression
    in a single thread.
    """

    def _compress_frame(idx):
        # Equivalent to the HDF5 'shuffle' filter: transpose bytes for better
        # compression.
        shuffled = np.ascontiguousarray(
            arr[idx].view(np.uint8).reshape((-1, arr.itemsize)).transpose()
        )
        return idx, zlib.compress(shuffled, level=1)

    # gain/mask compressed with gzip level 1, but not
    # checksummed as we would have to implement this.
    dataset = ofile.create_dataset(
        dataset_path,
        shape=arr.shape,
        chunks=((1,) + arr.shape[1:]),
        compression="gzip",
        compression_opts=1,
        shuffle=True,
        dtype=arr.dtype,
    )

    with ThreadPool(comp_threads) as pool:
        for i, compressed in pool.imap(_compress_frame, range(len(arr))):
            # Each frame is 1 complete chunk
            chunk_start = (i,) + (0,) * (dataset.ndim - 1)
            dataset.id.write_direct_chunk(chunk_start, compressed)

    return dataset


def reorder_axes(a, from_order, to_order):
    """Rearrange axes of array a from from_order to to_order

    This does the same as np.transpose(), but making the before & after axes
    more explicit. from_order is a sequence of strings labelling the axes of a,
    and to_order is a similar sequence for the axes of the result.
    """
    assert len(from_order) == a.ndim
    assert sorted(from_order) == sorted(to_order)
    from_order = list(from_order)
    order = tuple([from_order.index(lbl) for lbl in to_order])
    return a.transpose(order)