Skip to content
Snippets Groups Projects
lpdlib.py 33.96 KiB
import copy
from typing import List, Optional, Tuple
from warnings import warn

import h5py
import numpy as np
from iCalibrationDB import Conditions, Constants, Detectors

from cal_tools.enums import BadPixels
from cal_tools.tools import get_constant_from_db, get_constant_from_db_and_time


class LpdCorrections:
    """
    The LpdCorrections class perfroms LPD offline correction

    The following example shows a typical use case::


        infile = h5py.File(filename, "r", driver="core")
        outfile = h5py.File(filename_out, "w")

        lpd_corr = LpdCorrections(infile, outfile, max_cells, channel,
        max_pulses,
                                  bins_gain_vs_signal, bins_signal_low_range,
                                  bins_signal_high_range)

        try:
            lpd_corr.initialize(offset, rel_gain, rel_gain_offset, mask,
            noise, flatfield)
        except IOError:
            return

        for irange in lpd_corr.get_iteration_range():
            lpd_corr.correct_lpd(irange)

        hists, edges = lpd_corr.get_histograms()

    """

    def __init__(self, infile, outfile, max_cells, channel, max_pulses,
                 bins_gain_vs_signal, bins_signal_low_range,
                 bins_signal_high_range,
                 raw_fmt_version=2, chunk_size=512,
                 h5_data_path="INSTRUMENT/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/",
                 h5_index_path="INDEX/FXE_DET_LPD1M-1/DET/{}CH0:xtdf/",
                 do_ff=True, correct_non_linear=True, karabo_data_mode=False,
                 linear_between=None, mark_non_lin_region=True, nlc_version=2):
        """
        Initialize an LpdCorrections Class

        :param infile: to be corrected h5py input file
        :param outfile: writeable h5py output file
        :param channel: module/channel to correct
        :param max_pulses: maximum pulse id to consider for preview histograms
        :param bins_gain_vs_signal: number of bins for gain vs signal histogram
        :param bins_signal_low_range: number of bins for the low signal
            range histogram
        :param bins_signal_high_range: number of bins for the high signal
            range histogram
        :param raw_fmt_version: raw file format version to use
        :param chunk_size: images per chunk to compute upon
        :param h5_data_path: path in HDF5 file which is prefixed to the
            image/data section
        :param h5_index_path: path in HDF5 file which is prefixed to the
            index section
        :param do_ff: perform flat field corrections
        :param correct_non_linear: perform non-linear transition region corr.
        :param karabo_data_mode: set to true to use data iterated with karabo
            data
        """
        self.lpd_base = h5_data_path.format(channel)
        self.idx_base = h5_index_path.format(channel)
        self.infile = infile
        self.outfile = outfile
        self.channel = channel
        self.index_v = raw_fmt_version
        self.chunksize = chunk_size
        self.initialized = False
        self.max_pulses = max_pulses
        self.max_cells = max_cells
        self.hists_signal_low = 0
        self.hists_signal_high = 0
        self.hists_gain_vs_signal = 0
        self.bins_gain_vs_signal = bins_gain_vs_signal
        self.bins_signal_low_range = bins_signal_low_range
        self.bins_signal_high_range = bins_signal_high_range
        self.cidx = 0
        self.do_ff = do_ff
        filter_modules = []
        self.filter_cells = [0, 1] if channel in filter_modules else []
        self.cnl = correct_non_linear
        self.karabo_data_mode = karabo_data_mode
        self.linear_between = linear_between
        self.mark_nonlin = mark_non_lin_region
        self.nlc_version = nlc_version
        # emprically determined from APD datasets p900038, r155,r156
        # emprically determined from APD datasets p900038, r155,r156
        self.cnl_const = {
            'high': {'A': -0.000815934, 'lam': 0.00811281, 'c': 1908.89,
                     'm': 0, 'b': 0},
            'med': {'A': 0.0999894, 'lam': -0.00137652, 'c': 3107.83,
                    'm': 3.89982e-06, 'b': -0.116811},
            'low': {'A': 0.0119132, 'lam': -0.0002, 'c': 36738.6,
                    'm': 2.00273e-08, 'b': 0.245537}}

    def get_iteration_range(self):
        """Returns a range expression over which to iterate in chunks
        """
        return np.array_split(self.firange,
                              self.firange.size // self.chunksize)

    def initialize(self, offset, rel_gain, rel_gain_offset, mask, noise,
                   flatfield, swap_axis=True):
        """ Initialize the calibration constants and the output data file.

        Any data that is not touched by the corrections is copied into the
        output file at this point. Also data validity tests are performed.
        This functions must be called before `correct_lpd` is executed.

        :param offset: offset map to use for corrections
        :param rel_gain: relative gain map to use for corrections
        :param rel_gain_offset: relative gain offset to use for corrections
        :param mask: bad pixel mask to use for corrections
        :param noise: noise map to use for corrections
        :param flatfield: flatfield map to use for corrections
        :param swap_axis: set to True if using data from the calibration
        database.
        """

        if swap_axis:
            if offset is not None:
                mvd = np.moveaxis(np.moveaxis(offset, 2, 0), 2, 1)
                offset = np.ascontiguousarray(mvd)
            if rel_gain is not None:
                mvd = np.moveaxis(np.moveaxis(rel_gain, 2, 0), 2, 1)
                rel_gain = np.ascontiguousarray(mvd)
            if mask is not None:
                mvd = np.moveaxis(np.moveaxis(mask, 2, 0), 2, 1)
                mask = np.ascontiguousarray(mvd)
            if rel_gain_offset is not None:
                mvd = np.moveaxis(np.moveaxis(rel_gain_offset, 2, 0), 2, 1)
                rel_gain_offset = np.ascontiguousarray(mvd)
            if noise is not None:
                mvd = np.moveaxis(np.moveaxis(noise, 2, 0), 2, 1)
                noise = np.ascontiguousarray(mvd)
            if flatfield is not None:
                flatfield = np.ascontiguousarray(flatfield)

        if offset is not None:
            self.offset = offset
        if rel_gain is not None:
            self.rel_gain = rel_gain
        if rel_gain_offset is not None:
            self.rel_gain_b = rel_gain_offset
        if mask is not None:
            if not hasattr(self, "mask"):
                self.mask = mask
            else:
                self.mask |= mask[:self.mask.shape[0], ...]
        if noise is not None:
            self.noise = noise
        if flatfield is not None:
            self.flatfield = flatfield

        if not self.initialized:
            self.median_noise = np.nanmedian(self.noise[0, ...])
            allcells = [self.offset.shape[0],
                        self.mask.shape[0],
                        self.max_cells]
            self.max_cells = np.min(allcells)
            if not self.karabo_data_mode:
                self.gen_valid_range()
                if self.firange.size < self.chunksize:
                    self.chunksize = self.firange.size
                self.copy_and_sanitize_non_cal_data()
                self.create_output_datasets()
            self.initialized = True

    @staticmethod
    def split_gain(d):
        """ Split gain information off 16-bit LPD data

        Gain information can be found in bits 12 and 13 (0-based)
        """
        msk = np.zeros(d.shape, np.uint16)
        msk[...] = 0b0000111111111111
        data = np.bitwise_and(d, msk)

        gain = np.right_shift(d, 12)
        msk[...] = 0b0000000000000011
        gain = np.bitwise_and(gain, msk)
        return data, gain

    def correct_lpd(self, irange):
        """ Correct Raw LPD data for offset and relative gain effects

        :param irange: list of image indices to work on, should be contiguous,
                       or karabo_data structure to work on

        Will raise an RuntimeError if `initialize()` has not been called first.
        """
        if not self.initialized:
            raise RuntimeError("Must call initialize() first!")
        if not self.karabo_data_mode:
            lpd_base = self.lpd_base
            cidx = self.cidx
            im = np.array(self.infile[lpd_base + "image/data"][irange, ...])
            trainId = self.infile[lpd_base + "/image/trainId"][irange, ...]
            trainId = np.squeeze(trainId)
            pulseId = self.infile[lpd_base + "image/pulseId"][irange, ...]
            pulseId = np.squeeze(pulseId)
            status = self.infile[lpd_base + "/image/status"][irange, ...]
            status = np.squeeze(status)
            cells = self.infile[lpd_base + "/image/cellId"][irange, ...]
            cells = np.squeeze(cells)
            length = self.infile[lpd_base + "/image/length"][irange, ...]
            length = np.squeeze(length)
        else:
            cidx = 1  # do not produce any histograms
            im = irange['image.data']
            trainId = np.squeeze(irange['image.trainId'])
            status = np.squeeze(irange['image.status'])
            pulseId = np.squeeze(irange['image.pulseId'])
            cells = np.squeeze(irange['image.cellId'])
            length = np.squeeze(irange['image.length'])

        # split gain and image info into separate arrays
        im, gain = self.split_gain(im[:, 0, ...])

        # we need data as float from here on
        im = im.astype(np.float32)

        # invalid gain values
        im[gain > 2] = np.nan

        # on first iteration create histograms for reports
        if cidx == 0:
            H, xe, ye = np.histogram2d(im.flatten(), gain.flatten(),
                                       bins=self.bins_gain_vs_signal,
                                       range=[[0, 4096], [0, 4]])
            self.hists_gain_vs_signal += H
            self.signal_edges = (xe, ye)

        # zero invalid gains
        gain[gain > 2] = 0

        #  select constants by memory cells
        om = self.offset[cells, ...]
        rc = self.rel_gain[cells, ...]
        rbc = self.rel_gain_b[cells, ...]

        # and then by gain setting
        og = np.choose(gain, (om[..., 0], om[..., 1], om[..., 2]))
        rg = np.choose(gain, (rc[..., 0], rc[..., 1], rc[..., 2]))
        rgb = np.choose(gain, (rbc[..., 0], rbc[..., 1], rbc[..., 2]))

        mskg = self.mask[cells, ...]
        msk = np.choose(gain, (mskg[..., 0], mskg[..., 1], mskg[..., 2]))

        # correct offset
        im -= og

        nlf = 0
        if self.mark_nonlin and self.linear_between is not None:
            for gl, lr in enumerate(self.linear_between):

                midx = (gain == gl) & ((im < lr[0]) | (im > lr[1]))
                msk[midx] = BadPixels.NON_LIN_RESPONSE_REGION
                numnonlin = np.count_nonzero(midx, axis=(1,2))
                nlf += numnonlin
            nlf = nlf/float(im.shape[0] * im.shape[1])

        # hacky way of smoothening transition region between med and low

        cfac = 1
        if self.nlc_version == 1 and self.cnl:
            cfac = 0.314 * np.exp(-im * 0.001)

        # perform relative gain correction with additional gain-deduced
        # offset
        im = (im - rgb) / rg
        if self.do_ff:
            im /= self.flatfield[None, :, :]

        # hacky way of smoothening transition region between med and low
        if self.nlc_version == 1 and self.cnl:
            im[gain == 2] -= im[gain == 2] * cfac[gain == 2]

        # perform non-linear corrections if requested
        if self.cnl and self.nlc_version == 2:
            def lin_exp_fun(x, m, b, A, lam, c):
                return m * x + b + A * np.exp(lam * (x - c))

            x = im[(gain == 0)]
            cnl = self.cnl_const['high']
            cf = lin_exp_fun(x, cnl['m'], cnl['b'], cnl['A'], cnl['lam'],
                             cnl['c'])
            im[(gain == 0)] -= np.maximum(cf, -0.2) * x

            x = im[(gain == 1)]
            cnl = self.cnl_const['med']
            cf = lin_exp_fun(x, cnl['m'], cnl['b'], cnl['A'], cnl['lam'],
                             cnl['c'])
            im[(gain == 1)] -= np.minimum(cf, 0.2) * x

            x = im[(gain == 2)]
            cnl = self.cnl_const['low']
            cf = lin_exp_fun(x, cnl['m'], cnl['b'], cnl['A'], cnl['lam'],
                             cnl['c'])
            im[(gain == 2)] -= np.minimum(cf, 0.45) * x

        # create bad pixels masks, here non-finite values
        bidx = ~np.isfinite(im)
        im[bidx] = 0
        msk[bidx] = BadPixels.VALUE_IS_NAN

        # values which are unphysically large or small
        bidx = (im < -1e7) | (im > 1e7)
        im[bidx] = 0
        msk[bidx] = BadPixels.VALUE_OUT_OF_RANGE

        # on first iteration we create histograms for the report
        if cidx == 0:
            copim = copy.copy(im)
            copim[copim < self.median_noise] = np.nan
            bins = (self.bins_signal_low_range, self.max_pulses)
            rnge = [[-50, 1000], [0, self.max_pulses + 1]]
            H, xe, ye = np.histogram2d(np.nanmean(copim, axis=(1, 2)),
                                       pulseId,
                                       bins=bins,
                                       range=rnge)
            self.hists_signal_low += H
            self.low_edges = (xe, ye)
            bins = (self.bins_signal_high_range, self.max_pulses)
            rnge = [[0, 200000], [0, self.max_pulses + 1]]
            H, xe, ye = np.histogram2d(np.nanmean(copim, axis=(1, 2)),
                                       pulseId,
                                       bins=bins,
                                       range=rnge)
            self.hists_signal_high += H
            self.high_edges = (xe, ye)

        if not self.karabo_data_mode:
            # write data out
            # upper end of indices we are processing
            nidx = int(cidx + irange.size)
            self.ddset[cidx:nidx, ...] = im
            self.gdset[cidx:nidx, ...] = gain
            self.mdset[cidx:nidx, ...] = msk

            self.outfile[lpd_base + "image/cellId"][cidx:nidx] = cells
            self.outfile[lpd_base + "image/trainId"][cidx:nidx] = trainId
            self.outfile[lpd_base + "image/pulseId"][cidx:nidx] = pulseId
            self.outfile[lpd_base + "image/status"][cidx:nidx] = status
            self.outfile[lpd_base + "image/length"][cidx:nidx] = length
            if self.mark_nonlin:
                self.outfile[lpd_base + "image/nonLinear"][cidx:nidx] = nlf
            self.cidx = nidx
        else:
            irange['image.data'] = im
            irange['image.gain'] = gain
            irange['image.mask'] = msk
            irange['image.cellId'] = cells
            irange['image.trainId'] = trainId
            irange['image.pulseId'] = pulseId
            irange['image.status'] = status
            irange['image.length'] = length
            return irange

    def get_valid_image_idx(self):
        """ Return the indices of valid data
        """
        lpd_base = self.idx_base
        if self.index_v == 2:
            count = np.squeeze(self.infile[lpd_base + "image/count"])
            first = np.squeeze(self.infile[lpd_base + "image/first"])
            if np.count_nonzero(count != 0) == 0:
                raise IOError("File has no valid counts")
            valid = count != 0
            idxtrains = np.squeeze(self.infile["/INDEX/trainId"])
            medianTrain = np.nanmedian(idxtrains)
            lowok = (idxtrains > medianTrain - 1e4)
            highok = (idxtrains < medianTrain + 1e4)
            valid &= lowok & highok

            last_index = int(first[valid][-1] + count[valid][-1])
            first_index = int(first[valid][0])

        elif self.index_v == 1:
            status = np.squeeze(self.infile[lpd_base + "image/status"])
            if np.count_nonzero(status != 0) == 0:
                raise IOError("File {} has no valid counts".format(
                    self.infile))
            last = np.squeeze(self.infile[lpd_base + "image/last"])
            valid = status != 0

            idxtrains = np.squeeze(self.infile["/INDEX/trainId"])
            medianTrain = np.nanmedian(idxtrains)
            lowok = (idxtrains > medianTrain - 1e4)
            highok = (idxtrains < medianTrain + 1e4)
            valid &= lowok & highok

            last_index = int(last[valid][-1])
            first_index = int(last[valid][0])
        else:
            raise AttributeError(
                "Not a known raw format version: {}".format(self.index_v))

        self.valid = valid
        self.first_index = first_index
        self.last_index = last_index
        self.idxtrains = idxtrains

    def gen_valid_range(self):
        """ Generate an index range to pass to `correctLPD`.
        """
        first_index = self.first_index
        last_index = self.last_index
        max_cells = self.max_cells
        lpd_base = self.lpd_base
        allcells = self.infile[lpd_base + "image/cellId"]
        allcells = np.squeeze(allcells[first_index:last_index, ...])
        single_image = self.infile[lpd_base + "image/data"]
        single_image = np.array(single_image[first_index, ...])
        can_calibrate = allcells < max_cells
        for c in self.filter_cells:
            can_calibrate &= allcells != c
        if np.count_nonzero(can_calibrate) == 0:
            return
        allcells = allcells[can_calibrate]
        firange = np.arange(first_index, last_index)
        firange = firange[can_calibrate]
        self.oshape = (firange.size,
                       single_image.shape[1],
                       single_image.shape[2])
        self.firange = firange
        self.single_image = single_image

    def copy_and_sanitize_non_cal_data(self):
        """ Copy and sanitize data in `infile` that is not touched by
        `correctLPD`
        """
        lpd_base = self.lpd_base
        idx_base = self.idx_base
        first_index = self.first_index
        last_index = self.last_index
        firange = self.firange
        alltrains = self.infile[lpd_base + "image/trainId"]
        alltrains = np.squeeze(alltrains[first_index:last_index, ...])
        dont_copy = ["data", "cellId", "trainId", "pulseId", "status",
                     "length"]
        dont_copy = [lpd_base + "image/{}".format(do)
                     for do in dont_copy]

        dont_copy += [idx_base + "{}/first".format(do)
                      for do in ["image"]]
        dont_copy += [idx_base + "{}/count".format(do)
                      for do in ["image"]]

        # a visitor to copy outstanding data
        def visitor(k, item):
            if k not in dont_copy:

                if isinstance(item, h5py.Group):
                    self.outfile.create_group(k)
                elif isinstance(item, h5py.Dataset):
                    group = str(k).split("/")
                    group = "/".join(group[:-1])
                    self.infile.copy(k, self.outfile[group])

        self.infile.visititems(visitor)
        # sanitize indices
        for do in ["image", ]:
            uq, fidxv, cntsv = np.unique(alltrains[firange - firange[0]],
                                         return_index=True,
                                         return_counts=True)

            duq = (uq[1:] - uq[:-1]).astype(np.int64)

            cfidxv = [fidxv[0], ]
            ccntsv = [cntsv[0], ]
            for i, du in enumerate(duq.tolist()):
                if du > 1000:
                    du = 1
                    cntsv[i] = 0
                cfidxv += [0] * (du - 1) + [fidxv[i + 1], ]
                ccntsv += [0] * (du - 1) + [cntsv[i + 1], ]

            mv = len(cfidxv)
            fidx = np.zeros(len(cfidxv), fidxv.dtype)
            fidx[self.valid[:mv]] = np.array(cfidxv)[self.valid[:mv]]

            for i in range(len(fidx) - 1, 2, -1):
                if fidx[i - 1] == 0 and fidx[i] != 0:
                    fidx[i - 1] = fidx[i]

            cnts = np.zeros(len(cfidxv), cntsv.dtype)
            cnts[self.valid[:mv]] = np.array(ccntsv)[self.valid[:mv]]

            self.outfile.create_dataset(idx_base + "{}/first".format(do),
                                        fidx.shape,
                                        dtype=fidx.dtype,
                                        data=fidx,
                                        fletcher32=True)
            self.outfile.create_dataset(idx_base + "{}/count".format(do),
                                        cnts.shape,
                                        dtype=cnts.dtype,
                                        data=cnts,
                                        fletcher32=True)

    def create_output_datasets(self):
        """ Initialize output data sets
        """
        lpdbase = self.lpd_base
        chunksize = self.chunksize
        firange = self.firange
        oshape = self.oshape
        chunks = (chunksize, oshape[1], oshape[2])
        self.ddset = self.outfile.create_dataset(lpdbase + "image/data",
                                                 oshape, chunks=chunks,
                                                 dtype=np.float32,
                                                 fletcher32=True)
        self.gdset = self.outfile.create_dataset(lpdbase + "image/gain",
                                                 oshape, chunks=chunks,
                                                 dtype=np.uint8,
                                                 compression="gzip",
                                                 compression_opts=1,
                                                 shuffle=True,
                                                 fletcher32=True)
        self.mdset = self.outfile.create_dataset(lpdbase + "image/mask",
                                                 oshape, chunks=chunks,
                                                 dtype=np.uint32,
                                                 compression="gzip",
                                                 compression_opts=1,
                                                 shuffle=True,
                                                 fletcher32=True)
        fsz = firange.shape
        self.outfile.create_dataset(lpdbase + "image/cellId", fsz,
                                    dtype=np.uint16, fletcher32=True)
        self.outfile.create_dataset(lpdbase + "image/trainId", fsz,
                                    dtype=np.uint64, fletcher32=True)
        self.outfile.create_dataset(lpdbase + "image/pulseId", fsz,
                                    dtype=np.uint64, fletcher32=True)
        self.outfile.create_dataset(lpdbase + "image/status", fsz,
                                    dtype=np.uint16, fletcher32=True)
        self.outfile.create_dataset(lpdbase + "image/length", fsz,
                                    dtype=np.uint32, fletcher32=True)

        if self.mark_nonlin:
            self.outfile.create_dataset(lpdbase + "image/nonLinear", fsz,
                                        dtype=np.float32, fletcher32=True)

    def get_histograms(self):
        """ Return preview histograms computed from the first chunk
        """
        return ((self.hists_signal_low, self.hists_signal_high,
                 self.hists_gain_vs_signal),
                (self.low_edges, self.high_edges, self.signal_edges))

    def initialize_from_db(self, dbparms: List[Tuple['DBParms', 'DBParms_timeout']],
                           karabo_id: str, karabo_da: str,
                           only_dark: Optional[bool] = False):
        """ Initialize calibration constants from the calibration database

        :param dbparms: a tuple containing relevant database parameters,
        can be either:
            * cal_db_interface, creation_time, max_cells_db, capacitor,
              bias_voltage, photon_energy
              in which case the db timeout is set to 300 seconds,
              the cells to query dark image derived constants from the
              database is set to the global value

            * cal_db_interface, creation_time, max_cells_db, capacitor,
              bias_voltage, photon_energy, timeout
              additionally a timeout is given

        :param karabo_id: karabo identifier
        :param karabo_da: karabo data aggregator
        :param only_dark: load only dark image derived constants. This
            implies that a `calfile` is used to load the remaining
            constants. Useful to reduce DB traffic and interactions
            for non-frequently changing constants, i.e. such which are
            not usually updated during a beamtime.

        The `cal_db_interface` parameter in the `dbparms` tuple may be in
        one of the following notations:
            * tcp://host:port to directly identify the host and port to
              connect to
            * tcp://host:port_low#port_high to specify a port range from
              which a random port will be picked. E.g. specifying

              tcp://max-exfl016:8015#8025

              will randomly pick an address in the range max-exfl016:8015 and
              max-exfl016:8025.


        The latter notation allows for load-balancing.

        This routine loads the following constants as given in
        `iCalibrationDB`:

            Dark Image Derived
            ------------------

            * Constants.LPD.Offset
            * Constants.LPD.Noise
            * Constants.LPD.BadPixelsDark


            CI  Derived
            -----------

            * Constants.LPD.SlopesCI
            * Constants.LPD.BadPixelsCI

            Flat-Field Derived
            ------------------

            * Constants.LPD.SlopesFF
            * Constants.LPD.BadPixelsFF

        """
        if len(dbparms) == 6:
            (cal_db_interface, creation_time, max_cells_db, capacitor,
             bias_voltage, photon_energy) = dbparms
            timeout = 300000
        else:
            (cal_db_interface, creation_time, max_cells_db, capacitor,
             bias_voltage, photon_energy, timeout) = dbparms
        #TODO: remove sw duplication during LPD correction modifications
        offset, when = get_constant_from_db_and_time(
            karabo_id, karabo_da,
            Constants.LPD.Offset(),
            Conditions.Dark.LPD(
                memory_cells=max_cells_db,
                bias_voltage=bias_voltage,
                capacitor=capacitor),
            np.zeros((256, 256, max_cells_db, 3)),
            cal_db_interface,
            creation_time=creation_time,
            timeout=timeout)

        noise = get_constant_from_db(karabo_id, karabo_da,
                                     Constants.LPD.Noise(),
                                     Conditions.Dark.LPD(
                                         memory_cells=max_cells_db,
                                         bias_voltage=bias_voltage,
                                         capacitor=capacitor),
                                     np.zeros((256, 256, max_cells_db, 3)),
                                     cal_db_interface,
                                     creation_time=creation_time,
                                     timeout=timeout)

        bpixels = get_constant_from_db(karabo_id, karabo_da,
                                       Constants.LPD.BadPixelsDark(),
                                       Conditions.Dark.LPD(
                                           memory_cells=max_cells_db,
                                           bias_voltage=bias_voltage,
                                           capacitor=capacitor),
                                       np.zeros((256, 256, max_cells_db, 3)),
                                       cal_db_interface,
                                       creation_time=creation_time,
                                       timeout=timeout).astype(np.uint32)

        # done loading constants and returning
        if only_dark:
            self.initialize(offset, None, None, bpixels, noise, None)
            return when

        slopesCI = get_constant_from_db(karabo_id, karabo_da,
                                        Constants.LPD.SlopesCI(),
                                        Conditions.Dark.LPD(
                                            memory_cells=max_cells_db,
                                            bias_voltage=bias_voltage,
                                            capacitor=capacitor),
                                        np.ones((256, 256, max_cells_db, 2)),
                                        cal_db_interface,
                                        creation_time=creation_time,
                                        timeout=timeout)

        rel_gains = slopesCI[..., 0]
        rel_gain_offset = slopesCI[..., 1]

        flat_fields = np.squeeze(
            get_constant_from_db(karabo_id, karabo_da,
                                 Constants.LPD.SlopesFF(),
                                 Conditions.Illuminated.LPD(max_cells_db,
                                                            bias_voltage,
                                                            photon_energy,
                                                            pixels_x=256,
                                                            pixels_y=256,
                                                            beam_energy=None,
                                                            capacitor=capacitor),  # noqa
                                 np.ones((256, 256)),
                                 cal_db_interface,
                                 creation_time=creation_time,
                                 timeout=timeout))

        bpixels |= get_constant_from_db(karabo_id, karabo_da,
                                        Constants.LPD.BadPixelsCI(),
                                        Conditions.Dark.LPD(
                                            memory_cells=max_cells_db,
                                            bias_voltage=bias_voltage,
                                            capacitor=capacitor),
                                        np.zeros((256, 256, max_cells_db, 3)),
                                        cal_db_interface,
                                        creation_time=creation_time,
                                        timeout=timeout).astype(np.uint32)

        bpix = get_constant_from_db(karabo_id, karabo_da,
                                    Constants.LPD.BadPixelsFF(),
                                    Conditions.Illuminated.LPD(
                                        max_cells_db, bias_voltage,
                                        photon_energy,
                                        pixels_x=256, pixels_y=256,
                                        beam_energy=None,
                                        capacitor=capacitor),
                                    np.zeros((256, 256, max_cells_db)),
                                    cal_db_interface,
                                    creation_time=creation_time,
                                    timeout=timeout).astype(np.uint32)
        bpixels |= bpix[..., None]
        self.initialize(offset, rel_gains, rel_gain_offset, bpixels, noise,
                        flat_fields)
        return when

    def initialize_from_file(self, filename, qm, with_dark=True):
        """ Initialize calibration constants from a calibration file

        :param filename: path to a file containing the calibration
        constants. It is expected to have the following structure:

            /{qm}/BadPixelsFF/data
            /{qm}/BadPixelsCI/data
            /{qm}/Offset/data
            /{qm}/Noise/data
            /{qm}/BadPixelsDark/data
            /{qm}/SlopesFF/data
            /{qm}/SlopesCI/data

            where qm is the `qm` parameter.

        :param qm: quadrant and module of the constants to load in Q1M1
            notation
        :param with_dark: also load dark image derived constants from the
            file. This will overwrite any constants previously loaded from
            the calibration database.

        """
        offsets = None
        noises = None
        with h5py.File(filename, "r") as calfile:
            bpixels = calfile["{}/{}/data".format(qm, "BadPixelsCI")][()]
            bpix = calfile["{}/{}/data".format(qm, "BadPixelsFF")][()]
            bpixels |= bpix[..., None, None]
            if with_dark:
                offsets = calfile["{}/{}/data".format(qm, "Offset")][()]
                noises = calfile["{}/{}/data".format(qm, "Noise")][()]
                bpix = calfile["{}/{}/data".format(qm, "BadPixelsDark")]
                bpixels |= bpix[()]

            slopesCI = calfile["{}/{}/data".format(qm, "SlopesCI")][()]
            rel_gains = slopesCI[..., 0]
            rel_gains_b = slopesCI[..., 1]

            flat_fields = calfile["{}/{}/data".format(qm, "SlopesFF")][()][
                          ::-1, ::-1]

        self.initialize(offsets, rel_gains, rel_gains_b, bpixels, noises,
                        flat_fields)


def get_mem_cell_pattern(run, sources) -> np.ndarray:
    """Load the memory cell order to use as a condition to find constants

    This looks at the first train for each source, issuing a warning if the
    pattern differs between sources.
    """
    patterns = []
    for source in sources:
        cell_id_data = run[source, 'image.cellId'].drop_empty_trains()
        if len(cell_id_data.train_ids) == 0:
            continue  # No data for this module
        cell_ids = cell_id_data[0].ndarray().flatten()
        if not any(np.array_equal(cell_ids, p) for p in patterns):
            patterns.append(cell_ids)

    if len(patterns) > 1:
        warn("Memory cell order varies between detector modules: "
             "; ".join([f"{s[:10]}...{s[-10:]}" for s in patterns]))
    elif not patterns:
        raise ValueError("Couldn't find memory cell order for any modules")

    return patterns[0]


def make_cell_order_condition(use_param, cellid_pattern) -> Optional[str]:
    """Convert the cell ID array to a condition string, or None if not used"""
    if use_param == 'auto':
        # auto -> use cell order if it wraps around (cells not filled monotonically)
        use = len(cellid_pattern) > 2 and (
                np.diff(cellid_pattern.astype(np.int32)) < 0
        ).any()
    else:
        use = (use_param == 'always')

    return (",".join([str(c) for c in cellid_pattern]) + ",") if use else None