import datetime
import json
import re
import zlib
from collections import OrderedDict
from glob import glob
from multiprocessing.pool import ThreadPool
from os import environ, listdir, path
from os.path import isfile
from pathlib import Path
from queue import Queue
from time import sleep
from typing import List, Optional, Tuple, Union
from urllib.parse import urljoin

import dateutil.parser
import h5py
import ipykernel
import numpy as np
import requests
import yaml
import zmq
from iCalibrationDB import ConstantMetaData, Versions
from metadata_client.metadata_client import MetadataClient
from notebook.notebookapp import list_running_servers

from .ana_tools import save_dict_to_hdf5
from .mdc_config import mdc_config


def parse_runs(runs, return_type=str):
    pruns = []
    if isinstance(runs, str):
        for rcomp in runs.split(","):
            if "-" in rcomp:
                start, end = rcomp.split("-")
                pruns += list(range(int(start), int(end)))
            else:
                pruns += [int(rcomp)]
    elif isinstance(runs, (list, tuple)):
        pruns = runs
    else:
        pruns = [runs, ]

    if return_type is str:
        return ["r{:04d}".format(r) for r in pruns]
    else:
        return pruns


def run_prop_seq_from_path(filename):
    run = re.findall(r'.*r([0-9]{4}).*', filename)
    run = run[0] if len(run) else None
    proposal = re.findall(r'.*p([0-9]{6}).*', filename)
    proposal = proposal[0] if len(proposal) else None
    sequence = re.findall(r'.*S([0-9]{5}).*', filename)
    sequence = sequence[0] if len(sequence) else None
    return run, proposal, sequence


def map_seq_files(
    run_dc: "extra_data.DataCollection",
    karabo_id: str,
    karabo_das: List[str],
    sequences: Optional[List[int]] = None,
) -> Tuple[dict, int]:

    """Using a DataCollection from extra-data to read
    available sequence files.

    Returns:
        Dict: with karabo_das keys and the corresponding sequence files.
        Int: for number of all sequence files for all karabo_das to process.
    """

    if sequences == [-1]:
        sequences = None
    if sequences is not None:
        sequences = set(int(seq) for seq in sequences)

    seq_fn_pat = re.compile(r".*-(?P<da>.*?)-S(?P<seq>.*?)\.h5")

    mapped_files = {kda: [] for kda in karabo_das}
    total_files = 0
    for fn in run_dc.select(f"*{karabo_id}*").files:
        fn = Path(fn.filename)
        if (match := seq_fn_pat.match(fn.name)) is not None:
            da = match.group("da")
            if da in mapped_files and (
                sequences is None or int(match.group("seq")) in sequences
            ):
                mapped_files[da].append(fn)
                total_files += 1

    return mapped_files, total_files


def map_modules_from_folder(in_folder, run, path_template, karabo_da,
                            sequences=None):
    """
    Prepare queues of files to process.
    Queues are stored in dictionary with module name Q{}M{} as a key

    :param in_folder: Input folder with raw data
    :param run: Run number
    :param path_template: Template for file name
                          e.g. `RAW-R{:04d}-{}-S{:05d}.h5`
    :param karabo_da: List of data aggregators e.g. [AGIPD00, AGIPD01]
    :param sequences: List of sequences to be considered
    :return: Dictionary of queues of files, dictionary of module indexes,
    total number of sequences, dictionary of number of sequences per module
    """
    module_files = OrderedDict()
    mod_ids = OrderedDict()
    total_sequences = 0
    total_file_size = 0
    sequences_qm = {}
    for inset in karabo_da:
        module_idx = int(inset[-2:])
        name = module_index_to_qm(module_idx)
        module_files[name] = Queue()
        sequences_qm[name] = 0
        mod_ids[name] = module_idx
        if sequences is None:
            fname = path_template.format(run, inset, 0).replace("S00000", "S*")
            abs_fname = "{}/r{:04d}/{}".format(in_folder, run, fname)

            for filename in glob(abs_fname):
                module_files[name].put(filename)
                total_sequences += 1
                sequences_qm[name] += 1
                total_file_size += path.getsize(filename)
        else:
            for sequence in sequences:
                fname = path_template.format(run, inset, sequence)
                abs_fname = "{}/r{:04d}/{}".format(in_folder, run, fname)
                if not isfile(abs_fname):
                    continue

                module_files[name].put(abs_fname)
                total_sequences += 1
                sequences_qm[name] += 1
                total_file_size += path.getsize(abs_fname)

    return (module_files, mod_ids, total_sequences,
            sequences_qm, total_file_size)


def map_gain_stages(in_folder, runs, path_template, karabo_da, sequences=None):
    """
    Prepare queues of files to process.
    Queues are stored in dictionary with module name Q{}M{}
    and gain name as a keys
    :param in_folder: Input folder with raw data
    :param runs: Dictionary of runs with key naming the gain stages
    :param path_template: Template for file name
                          e.g. `RAW-R{:04d}-{}-S{:05d}.h5`
    :param karabo_da: List of data aggregators e.g. [AGIPD00, AGIPD01]
    :param sequences: List of sequences to be considered
    :return: Dictionary of queues of files,
    total number of sequences
    """
    total_sequences = 0
    total_file_size = 0
    gain_mapped_files = OrderedDict()
    for gain, run in runs.items():
        mapped_files, _, seq, _, fs = map_modules_from_folder(in_folder, run,
                                                              path_template,
                                                              karabo_da,
                                                              sequences)

        total_sequences += seq
        total_file_size += fs
        gain_mapped_files[gain] = mapped_files
    return gain_mapped_files, total_sequences, total_file_size / 1e9


def map_modules_from_files(filelist, file_inset, quadrants, modules_per_quad):
    total_sequences = 0
    total_file_size = 0
    module_files = {}
    mod_ids = {}
    for quadrant in range(quadrants):
        for module in range(modules_per_quad):
            name = "Q{}M{}".format(quadrant + 1, module + 1)
            module_files[name] = Queue()
            num = quadrant * 4 + module
            mod_ids[name] = num
            file_infix = "{}{:02d}".format(file_inset, num)
            for file in filelist:
                if file_infix in file:
                    module_files[name].put(file)
                    total_sequences += 1
                    total_file_size += path.getsize(file)

    return module_files, mod_ids, total_sequences, total_file_size


def gain_map_files(in_folder, runs, sequences, file_inset, quadrants,
                   mods_per_quad):
    total_sequences = 0
    total_file_size = 0
    gain_mapped_files = OrderedDict()
    for gain, run in runs.items():
        ginfolder = "{}/{}".format(in_folder, run)
        dirlist = listdir(ginfolder)
        file_list = []
        for entry in dirlist:
            # only h5 file
            abs_entry = "{}/{}".format(ginfolder, entry)
            if path.isfile(abs_entry) and path.splitext(abs_entry)[1] == ".h5":

                if sequences is None:
                    file_list.append(abs_entry)
                else:
                    for seq in sequences:
                        if "{:05d}.h5".format(seq) in abs_entry:
                            file_list.append(path.abspath(abs_entry))

        mapped_files, mod_ids, seq, fs = map_modules_from_files(file_list,
                                                                file_inset,
                                                                quadrants,
                                                                mods_per_quad)
        total_sequences += seq
        total_file_size += fs
        gain_mapped_files[gain] = mapped_files
    return gain_mapped_files, total_sequences, total_file_size / 1e9


def get_notebook_name():
    """
    Return the full path of the jupyter notebook.
    """
    try:
        kernel_id = re.search('kernel-(.*).json',
                              ipykernel.connect.get_connection_file()).group(1)
        servers = list_running_servers()
        for ss in servers:
            response = requests.get(urljoin(ss['url'], 'api/sessions'),
                                    params={'token': ss.get('token', '')})
            for nn in json.loads(response.text):
                if nn['kernel']['id'] == kernel_id:
                    return nn['notebook']['path']
    except:
        return environ.get("CAL_NOTEBOOK_NAME", "Unknown Notebook")


def get_run_info(proposal, run):
    """
    Return information about run from the MDC

    :param proposal: proposal number
    :param run: run number
    :return: dictionary with run information
    """

    mdc = MetadataClient(
        client_id=mdc_config['user-id'],
        client_secret=mdc_config['user-secret'],
        user_email=mdc_config['user-email'],
        token_url=mdc_config['token-url'],
        refresh_url=mdc_config['refresh-url'],
        auth_url=mdc_config['auth-url'],
        scope=mdc_config['scope'],
        base_api_url=mdc_config['base-api-url'],
    )

    runs = mdc.get_proposal_runs(proposal_number=proposal,
                                 run_number=run)
    run_id = runs['data']['runs'][0]['id']

    resp = mdc.get_run_by_id_api(run_id)
    return resp.json()


def get_dir_creation_date(directory: Union[str, Path], run: int,
                          verbosity: int = 0) -> datetime.datetime:
    """
    Return run start time from MyDC.
    If not available from MyMDC, retrieve the data from the dataset's metadata
    in [directory]/[run] or, if the dataset is older than 2020, from the oldest
    file's modified time.

    If the data is not available from either source,
    this function will raise a ValueError.

    :param directory: path to directory which contains runs
    :param run: run number
    :param verbosity: Level of verbosity (0 - silent)
    :return: (datetime) modification time

    """
    directory = Path(directory)
    proposal = int(directory.parent.name[1:])

    try:
        run_info = get_run_info(proposal, run)
        return dateutil.parser.parse(run_info['begin_at'])
    except Exception as e:
        if verbosity > 0:
            print(e)

    directory = directory / 'r{:04d}'.format(run)

    # Loop a number of times to catch stale file handle errors, due to
    # migration or gpfs sync.
    ntries = 100
    while ntries > 0:
        try:
            rfiles = list(directory.glob('*.h5'))
            # get creation time for oldest file,
            # as creation time between run files
            # should differ by a few seconds only.
            rfile = sorted(rfiles, key=path.getmtime)[0]
            with h5py.File(rfile, 'r') as fin:
                cdate = fin['METADATA/creationDate'][0].decode()
                cdate = datetime.datetime.strptime(
                    cdate,
                    "%Y%m%dT%H%M%SZ").replace(tzinfo=datetime.timezone.utc)
            return cdate
        except (IndexError, IOError, ValueError):
            ntries -= 1
        except KeyError:  # The files are here, but it's an older dataset
            return datetime.datetime.fromtimestamp(rfile.stat().st_mtime)

    msg = 'Could not get the creation time from the directory'
    raise ValueError(msg, directory)


def _init_metadata(constant: 'iCalibrationDB.calibration_constant',
                   condition: 'iCalibrationDB.detector_conditions',
                   creation_time: Optional[str] = None
                   ) -> 'ConstantMetaData':

    """Initializing a ConstantMetaData class instance and
    add the correct creation time of the constant metadata.
    """
    metadata = ConstantMetaData()
    metadata.calibration_constant = constant
    metadata.detector_condition = condition
    if creation_time is None:
        metadata.calibration_constant_version = Versions.Now()
    else:
        metadata.calibration_constant_version = Versions.Timespan(
            start=creation_time)
    return metadata


def save_const_to_h5(db_module: str, karabo_id: str,
                     constant: 'iCalibrationDB.calibration_constant',
                     condition: 'iCalibrationDB.detector_conditions',
                     data: np.array, file_loc: str,
                     report: str,
                     creation_time: datetime.datetime,
                     out_folder: str) -> 'ConstantMetaData':

    """ Save constant in h5 file with its metadata
    (e.g. db_module, condition, creation_time)

    :param db_module: database module (PDU/Physical Detector Unit).
    :param karabo_id: karabo identifier.
    :param constant: Calibration constant known for given detector.
    :param condition: Calibration condition.
    :param data: Constant data to save.
    :param file_loc: Location of raw data "proposal:{} runs:{} {} {}".
    :param creation_time: creation_time for the saved constant.
    :param out_folder: path to output folder.
    :return: metadata of the saved constant.
    """

    metadata = _init_metadata(constant, condition, creation_time)

    metadata.calibration_constant_version.raw_data_location = file_loc

    dpar = {
        parm.name: {
            'lower_deviation_value': parm.lower_deviation,
            'upper_deviation_value': parm.upper_deviation,
            'value': parm.value,
            'flg_logarithmic': parm.logarithmic,
        }
        for parm in metadata.detector_condition.parameters
    }

    creation_time = metadata.calibration_constant_version.begin_at
    raw_data = metadata.calibration_constant_version.raw_data_location
    constant_name = metadata.calibration_constant.__class__.__name__

    data_to_store = {
        'condition': dpar,
        'db_module': db_module,
        'karabo_id': karabo_id,
        'constant': constant_name,
        'data': data,
        'creation_time': creation_time,
        'file_loc': raw_data,
        'report': report,
    }

    ofile = f"{out_folder}/const_{constant_name}_{db_module}.h5"
    if isfile(ofile):
        print(f'File {ofile} already exists and will be overwritten')
    save_dict_to_hdf5(data_to_store, ofile)

    return metadata


def get_random_db_interface(cal_db_interface):
    """
    Return interface to calibration DB with random (with given range) port.
    """
    if "#" in cal_db_interface:
        prot, serv, ran = cal_db_interface.split(":")
        r1, r2 = ran.split("#")
        return ":".join([prot, serv, str(np.random.randint(int(r1), int(r2)))])
    return cal_db_interface


def get_report(out_folder: str, default_path: str = ""):
    """Get the report path from calibration_metadata.yml
    stored in the out_folder.
    """

    metadata = CalibrationMetadata(out_folder)
    report_path = metadata.get("report-path", default_path)
    if not report_path:
        print("WARNING: No report path will be injected "
              "with the constants.\n")
    return report_path


def get_pdu_from_db(karabo_id: str, karabo_da: Union[str, list],
                    constant: 'iCalibrationDB.calibration_constant',
                    condition: 'iCalibrationDB.detector_conditions',
                    cal_db_interface: str,
                    snapshot_at: Optional[datetime.datetime] = None,
                    timeout: int = 30000) -> List[str]:

    """Return all physical detector units for a
    karabo_id and list of karabo_da

    :param karabo_id: Karabo identifier.
    :param karabo_da: Karabo data aggregator.
    :param constant: Calibration constant object to
                     intialize CalibrationConstantMetadata class.
    :param condition: Detector condition object to
                      intialize CalibrationConstantMetadata class.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015".
    :param snapshot_at: Database snapshot.
    :param timeout: Calibration Database timeout.
    :return: List of physical detector units (db_modules)
    """
    if not isinstance(karabo_da, (str, list)):
        raise TypeError("karabo_da should either be a list of multiple "
                        "karabo_da or a string of one karabo_da or 'all'")

    metadata = _init_metadata(constant, condition, None)

    # A random interface is chosen if there is # for address range.
    this_interface = get_random_db_interface(cal_db_interface)
    # CalibrationDBRemote expects a string.
    if snapshot_at is not None and hasattr(snapshot_at, 'isoformat'):
        snapshot_at = snapshot_at.isoformat()

    db_interface = get_random_db_interface(cal_db_interface)

    pdu_dicts = metadata.retrieve_pdus_for_detector(receiver=db_interface,
                                                    karabo_id=karabo_id,
                                                    snapshot_at=snapshot_at,
                                                    timeout=timeout)
    # Get a list of pdus based on requested karabo_das
    if karabo_da == 'all':
        db_modules = [d["pdu_physical_name"] for d in pdu_dicts]
    else:
        k_indices = []
        if isinstance(karabo_da, str):
            karabo_da = [karabo_da]
        # Get indices of dict with the right karabo_da,
        # else use None.
        for k in karabo_da:
            pdu_found = False
            for i, d in enumerate(pdu_dicts):
                if d["karabo_da"] == k:
                    k_indices.append(i)
                    pdu_found = True
                    break
            if not pdu_found:
                k_indices.append(None)

        db_modules = []
        for i in k_indices:
            if i is None:
                db_modules.append(None)
            else:
               db_modules.append(pdu_dicts[i]["pdu_physical_name"])

    return db_modules


already_printed = {}


def get_from_db(karabo_id: str, karabo_da: str,
                constant: 'iCalibrationDB.calibration_constant',
                condition: 'iCalibrationDB.detector_conditions',
                empty_constant: np.array,
                cal_db_interface: str,
                creation_time: Optional[datetime.datetime] = None,
                verbosity: int = 1,
                timeout: int = 30000,
                ntries: int = 7,
                meta_only: bool = True,
                load_data: bool = True,
                version_info: bool = False,
                doraise: bool = False,
                strategy: str = "pdu_closest_by_time"
                ) -> Tuple[np.array, 'ConstantMetaData']:

    """Return calibration constants and metadata requested from CalDB

    This feature uses the karabo-id and karabo-da to retrieve the
    desired CCV

    :param karabo_id: karabo identifier (detector identifier).
    :param karabo_da: karabo data aggregator.
    :param constant: Calibration constant known for given detector.
    :param condition: Calibration condition.
    :param empty_constant: Constant to be returned in case of failure.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015"
    :param creation_time: Latest time for constant to be created.
    :param verbosity: Level of verbosity (0 - silent)
    :param timeout: Timeout for zmq request
    ntries is set to 7 so that if the timeout started at 30s last timeout
    will be ~ 1h.
    :param ntries: number of tries to contact the database.
    :param meta_only: Retrieve only metadata via ZMQ. Constants are
        taken directly from the h5 file on maxwell.
    :param version_info: Flag to show the info for the retrieved Constant.
    :param doraise: if True raise errors during communication with DB.
    :param strategy: Retrieving strategy for calibrationDBRemote.
    :return: Calibration constant, metadata.
    """

    if version_info:
        meta_only = False

    metadata = _init_metadata(constant, condition, creation_time)

    if karabo_id and karabo_da:
        when = None

        if creation_time is not None and hasattr(creation_time, 'isoformat'):
            when = creation_time.isoformat()

        metadata.calibration_constant_version.karabo_id = karabo_id
        metadata.calibration_constant_version.karabo_da = karabo_da

        # make sure to remove device name from metadata dict before
        # retrieving to keep using karabo_id and karabo_da only
        # during retrieval. As device_name could have been set after
        # retrieval from iCalibrationDB
        metadata.calibration_constant_version.device_name = None

        while ntries > 0:
            this_interface = get_random_db_interface(cal_db_interface)
            try:
                r = metadata.retrieve(this_interface, timeout=timeout,
                                      when=when, meta_only=meta_only,
                                      version_info=version_info,
                                      strategy=strategy)
                if version_info:
                    return r
                break
            except zmq.error.Again:
                ntries -= 1
                timeout *= 2
                sleep(np.random.randint(30))
                # TODO: reevaluate the need for doraise
                # and remove if not needed.
                if ntries == 0 and doraise:
                    raise
            except Exception as e:
                if verbosity > 0:
                    print(e)
                if 'missing_token' in str(e):
                    ntries -= 1
                else:
                    ntries = 0
                if ntries == 0 and doraise:
                    raise RuntimeError(f'{e}')

        if ntries > 0:
            mdata_const = metadata.calibration_constant_version
            if load_data and meta_only:
                hdf5path = getattr(mdata_const, 'hdf5path', None)
                filename = getattr(mdata_const, 'filename', None)
                h5path = getattr(mdata_const, 'h5path', None)
                if not (hdf5path and filename and h5path):
                    raise ValueError(
                        "Wrong metadata received to access the constant data."
                        f" Retrieved constant filepath is {hdf5path}/{filename}"  # noqa
                        f" and data_set_name is {h5path}."
                    )
                with h5py.File(Path(hdf5path, filename), "r") as f:
                    arr = f[f"{h5path}/data"][()]
                metadata.calibration_constant.data = arr

            if verbosity > 0:
                if constant.name not in already_printed or verbosity > 1:
                    already_printed[constant.name] = True
                    # TODO: Reset mdata_const.begin_at
                    # if comm_db_success is False.
                    begin_at = mdata_const.begin_at
                    print(f"Retrieved {constant.name} "
                          f"with creation time: {begin_at}")
            return constant.data, metadata
        else:
            return empty_constant, metadata
    else:
        return empty_constant, None


def send_to_db(db_module: str, karabo_id: str, constant, condition,
               file_loc: str, report_path: str, cal_db_interface: str,
               creation_time: Optional[datetime.datetime] = None,
               timeout: int = 30000,
               ntries: int = 7,
               doraise: bool = False):
    """Return calibration constants and metadata requested from CalDB

    :param db_module: database module (PDU/Physical Detector Unit)
    :param karabo_id: karabo identifier
    :param constant: Calibration constant known for given detector
    :param condition: Calibration condition
    :param file_loc: Location of raw data.
    :param report_path: xfel-calbrate report path to inject along with
                        the calibration constant versions to the database.
    :param cal_db_interface: Interface string, e.g. "tcp://max-exfl016:8015"
    :param creation_time: Latest time for constant to be created
    :param timeout: Timeout for zmq request
    :param ntries: number of tries to contact the database,
           ntries is set to 7 so that if the timeout started
           at 30s last timeout will be ~ 1h.
    :param doraise: if True raise errors during communication with DB
    """

    success = False
    snapshot_at = None
    metadata = _init_metadata(constant, condition, creation_time)

    if db_module:

        # Add injected constant's file source info as a file location
        metadata.calibration_constant_version.raw_data_location = file_loc

        if report_path:
            # calibration_client expects a dict of injected report path
            # of at least 2 characters for each key.
            if not isinstance(report_path, str) or len(report_path) < 2:
                raise TypeError(
                    "\"report_path\" needs to be a string "
                    "of at least 2 characters."
                )

            report = {"name": path.basename(report_path),
                      "file_path": report_path}
            metadata.calibration_constant_version.report_path = report

        metadata.calibration_constant_version.karabo_id = karabo_id
        metadata.calibration_constant_version.device_name = db_module
        metadata.calibration_constant_version.karabo_da = None
        metadata.calibration_constant_version.raw_data_location = file_loc
        if constant.data is None:
            raise ValueError(
                "There is no data available to "
                "inject to the database."
            )
        while ntries > 0:

            this_interface = get_random_db_interface(cal_db_interface)

            if (
                creation_time is not None and
                hasattr(creation_time, 'isoformat')
            ):
                # This snapshot will be used only while retrieving
                # the correct PDU and appending its UUID.
                snapshot_at = creation_time.isoformat()

            try:
                metadata.send(
                    this_interface,
                    snapshot_at=snapshot_at,
                    timeout=timeout,
                )
                success = True  # TODO: use comm_db_success
                break
            except zmq.error.Again:
                ntries -= 1
                timeout *= 2
                sleep(np.random.randint(30))
                if ntries == 0 and doraise:
                    raise
            except Exception as e:
                # TODO: refactor to use custom exception class
                if "has already been taken" in str(e):
                    print(
                        f"WARNING: {constant.name} for {db_module}"
                        " has already been injected with the same "
                        "parameter conditions."
                    )
                else:
                    print(f"{e}\n")

                if 'missing_token' in str(e):
                    ntries -= 1
                else:
                    ntries = 0
                if ntries == 0 and doraise:
                    raise RuntimeError(f'{e}')

        if success:
            print(
                f"{constant.name} for {db_module} "
                "is injected with creation-time: "
                f"{metadata.calibration_constant_version.begin_at}."
            )
    return metadata


def get_constant_from_db(karabo_id: str, karabo_da: str,
                         constant, condition, empty_constant,
                         cal_db_interface: str, creation_time=None,
                         print_once=True, timeout=30000, ntries=120,
                         meta_only=True):
    """Return calibration constants requested from CalDB
    """
    data, _ = get_from_db(karabo_id, karabo_da, constant,
                          condition, empty_constant,
                          cal_db_interface, creation_time,
                          int(print_once), timeout, ntries, meta_only)
    return data


def get_constant_from_db_and_time(karabo_id: str, karabo_da: str,
                                  constant, condition, empty_constant,
                                  cal_db_interface: str, creation_time=None,
                                  print_once=True, timeout=30000, ntries=120):
    """Return calibration constants requested from CalDB,
    alongside injection time
    """
    data, m = get_from_db(karabo_id, karabo_da, constant,
                          condition, empty_constant,
                          cal_db_interface, creation_time,
                          int(print_once), timeout, ntries)
    if m and m.comm_db_success:
        return data, m.calibration_constant_version.begin_at
    else:
        # return None for injection time if communication with db failed.
        # reasons (no constant or condition found,
        # or network problem)
        return data, None


def module_index_to_qm(index: int, total_modules: int = 16):
    """Maps module index (0-indexed) to quadrant + module string (1-indexed)"""
    assert index < total_modules, f'{index} is greater than {total_modules}'
    modules_per_quad = total_modules // 4
    quad, mod = divmod(index, modules_per_quad)
    return f"Q{quad+1}M{mod+1}"


class CalibrationMetadata(dict):
    """Convenience class: dictionary stored in metadata YAML file

    If metadata file already exists, it will be loaded (this may override
    additional constructor parameters given to this class). Use new=True to
    skip loading it.
    """

    def __init__(self, output_dir: Union[Path, str], *args, new=False):
        dict.__init__(self, args)
        self._yaml_fn = Path(output_dir) / "calibration_metadata.yml"
        if (not new) and self._yaml_fn.exists():
            with self._yaml_fn.open("r") as fd:
                data = yaml.safe_load(fd)
            if isinstance(data, dict):
                self.update(data)
            else:
                print(f"Warning: existing {self._yaml_fn} is malformed, "
                      "will be overwritten")
        else:
            # TODO: update after resolving this discussion
            # https://git.xfel.eu/detectors/pycalibration/-/merge_requests/624
            self.save()

    def save(self):
        with self._yaml_fn.open("w") as fd:
            yaml.safe_dump(dict(self), fd)

    def save_copy(self, copy_dir: Path):
        with (copy_dir / self._yaml_fn.name).open("w") as fd:
            yaml.safe_dump(dict(self), fd)


def write_compressed_frames(
        arr: np.ndarray,
        ofile: h5py.File,
        dataset_path: str,
        comp_threads: int = 1):
    """Compress gain/mask frames in multiple threads, and save their data

    This is significantly faster than letting HDF5 do the compression
    in a single thread.
    """

    def _compress_frame(idx):
        # Equivalent to the HDF5 'shuffle' filter: transpose bytes for better
        # compression.
        shuffled = np.ascontiguousarray(
            arr[idx].view(np.uint8).reshape((-1, arr.itemsize)).transpose()
        )
        return idx, zlib.compress(shuffled, level=1)

    # gain/mask compressed with gzip level 1, but not
    # checksummed as we would have to implement this.
    dataset = ofile.create_dataset(
        dataset_path,
        shape=arr.shape,
        chunks=((1,) + arr.shape[1:]),
        compression="gzip",
        compression_opts=1,
        shuffle=True,
        dtype=arr.dtype,
    )

    with ThreadPool(comp_threads) as pool:
        for i, compressed in pool.imap(_compress_frame, range(len(arr))):
            # Each frame is 1 complete chunk
            chunk_start = (i,) + (0,) * (dataset.ndim - 1)
            dataset.id.write_direct_chunk(chunk_start, compressed)