Skip to content
Snippets Groups Projects
spi_hitfinder.py 10.6 KiB
Newer Older
from collections import deque

import numpy as np
from calng.arbiter_kernels.base_kernel import BaseArbiterKernel
from karabo.bound import (BOOL_ELEMENT, DOUBLE_ELEMENT, INT32_ELEMENT,
                          OVERWRITE_ELEMENT, STRING_ELEMENT,
                          VECTOR_FLOAT_ELEMENT, VECTOR_UINT32_ELEMENT)
DET_SOURCE_PATTERNS = [
    re.compile(r".+/DET/(?P<modno>\d+)CH0:xtdf"),
    re.compile(r".+/DET/.*?(?P<modno>\d+):daqOutput"),
]

def guess_module_number(source):
    for pattern in DET_SOURCE_PATTERNS:
        match = pattern.match(source)
        if match is not None:
            return int(match["modno"])
    return -1


class SPIhitfinder(BaseArbiterKernel):
    _node_name = "spiHitfinder"

    def __init__(self, device, name, config):
        self._history_litpx = deque([], 100)
        self._history_intens = deque([], 100)
        self._history_hitrate = deque([], 20)
        super().__init__(device, name, config)

        self._last_threshold = 0.0

    @staticmethod
    def extend_device_schema(schema, prefix):
        (
            OVERWRITE_ELEMENT(schema)
            .key(prefix)
            .setNewDescription(
                "This kernel selects the frames by comparing the number"
                "of lit-pixels to the threshold."
            )
            .commit(),

            STRING_ELEMENT(schema)
            .key(f"{prefix}.scores")
            .options("litpixels,intensities")
            .assignmentOptional()
            .defaultValue("litpixels")
            .reconfigurable()
            .commit(),

            VECTOR_UINT32_ELEMENT(schema)
            .key(f"{prefix}.modules")
            .assignmentOptional()
            .defaultValue([3, 4, 8, 15])
            .reconfigurable()
            .commit(),

            STRING_ELEMENT(schema)
            .key(f"{prefix}.thresholdMode")
            .options("fixed,adaptive")
            .assignmentOptional()
            .defaultValue("adaptive")
            .reconfigurable()
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{prefix}.fixedThreshold")
            .assignmentOptional()
            .defaultValue(0)
            .reconfigurable()
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{prefix}.snr")
            .assignmentOptional()
            .defaultValue(3.5)
            .reconfigurable()
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{prefix}.minFrames")
            .assignmentOptional()
            .defaultValue(100)
            .reconfigurable()
            .commit(),

            BOOL_ELEMENT(schema)
            .key(f"{prefix}.xgmNomalization")
            .assignmentOptional()
            .defaultValue(False)
            .reconfigurable()
            .commit(),

            BOOL_ELEMENT(schema)
            .key(f"{prefix}.suppressMaskedScores")
            .assignmentOptional()
            .defaultValue(True)
            .reconfigurable()
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{prefix}.hitrateAverageWindow")
            .assignmentOptional()
            .defaultValue(20)
            .reconfigurable()
            .commit(),
        )

    @staticmethod
    def extend_output_schema(schema, name):
        (
            DOUBLE_ELEMENT(schema)
            .key(f"{name}.threshold")
            .assignmentOptional()
            .defaultValue(np.nan)
            .commit(),

            VECTOR_UINT32_ELEMENT(schema)
            .key(f"{name}.litpixelCount")
            .assignmentOptional()
            .defaultValue([])
            .commit(),

            VECTOR_FLOAT_ELEMENT(schema)
            .key(f"{name}.intensity")
            .assignmentOptional()
            .defaultValue([])
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{name}.numberOfHits")
            .assignmentOptional()
            .defaultValue(0)
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{name}.numberOfMiss")
            .assignmentOptional()
            .defaultValue(0)
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{name}.hitrate")
            .assignmentOptional()
            .defaultValue(0.0)
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{name}.averageHitrate")
            .assignmentOptional()
            .defaultValue(0.0)
            .commit(),
        )

    def reconfigure(self, config):
        if config.has("scores"):
            self._scores = config["scores"]
        if config.has("modules"):
            self._modules = set(config.get("modules"))
        if config.has("snr"):
            self._snr = config.get("snr")
        if config.has("minFrames"):
            self._min_frames = config.get("minFrames")
            self._history_litpx = deque(self._history_litpx, self._min_frames)
            self._history_intens = deque(self._history_intens, self._min_frames)

        if config.has("thresholdMode"):
            mode = config.get("thresholdMode")
            self.get_threshold = (
                self._get_fixed_threshold if mode == "fixed"
                else self._get_adaptive_threshold
            )
        if config.has("fixedThreshold"):
            self._fixed_threshold = int(config.get("fixedThreshold"))
        if config.has("xgmNomalization"):
            self._xgm_normalization = config["xgmNomalization"]
        if config.has("suppressMaskedScores"):
            self._zero_to_masked = config["suppressMaskedScores"]
        if config.has("hitrateAverageWindow"):
            self._hitrate_window = config["hitrateAverageWindow"]
            self._history_hitrate = deque(self._history_hitrate, self._hitrate_window)

    def _output_defaults(self, num_frames, out_hash):
        decision = np.zeros(num_frames, dtype=bool)
        out_hash.set(f"{self._name}.threshold", self._last_threshold)
        out_hash.set(f"{self._name}.litpixelCount", np.zeros(num_frames, np.uint32))
        out_hash.set(f"{self._name}.numberOfHits", 0)
        out_hash.set(f"{self._name}.numberOfMiss", 0)
        out_hash.set(f"{self._name}.hits", decision)
        out_hash.set(f"{self._name}.hitrate", 0.0)

    def _get_fixed_threshold(self, num_frames, scores, history):
        return self._fixed_threshold

    def _get_adaptive_threshold(self, num_frames, scores, history):
        if num_frames >= self._min_frames:
            q1, mu, q3 = np.percentile(scores, [25, 50, 75])
            q1, mu, q3 = np.percentile(history, [25, 50, 75])

        sig = (q3 - q1) / 1.34896
        return int(mu + self._snr * sig + 0.5)

    def get_normalization_factor(self, num_frames, mask, xgm, npulse, xgm_ix):
        norm_factor = np.ones(num_frames, float)
        if xgm is None or num_frames <= 1:
            return norm_factor

        ncell = len(npulse)
        cell_ix = np.repeat(np.arange(ncell), npulse).astype(int)
        inten = np.bincount(cell_ix, weights=xgm[xgm_ix], minlength=ncell)
        has_xray = npulse > 0
        if num_frames == np.sum(has_xray):
            # filtered
            inten = inten[has_xray]
        if np.sum(has_xray) > 1:
            mean_inten = np.mean(inten[has_xray])
            inten[(inten < 50) | ~has_xray] = mean_inten
            norm_factor = mean_inten / inten

        return norm_factor

    def consider(self, train_id, sources, num_frames, mask, out_hash):
        num_work_frames = np.sum(mask)

        xgm = None
        nlitpx = []
        nwrkpx = []
        for source, (data, _) in sources.items():
            if data.has("data.intensityTD"):
                xgm = np.array(data["data.intensityTD"])
            if data.has("data.nPulsePerFrame") and data.has("data.xgmPulseId"):
                npulse = np.array(data["data.nPulsePerFrame"])
                xgm_ix = np.array(data["data.xgmPulseId"], dtype=int)
            modno = guess_module_number(source)
            if (modno < 0) or (modno not in self._modules):
                continue
            if data.has("litpixels"):
                nlitpx.append(data["litpixels.count"])
                nwrkpx.append(data["litpixels.unmasked"])
                ntotpx.append(data["litpixels.total"])
                intens.append(data["litpixels.intensity"])

        if self._xgm_normalization:
            norm_factor = self.get_normalization_factor(
                num_frames, mask, xgm, npulse, xgm_ix)
        else:
            norm_factor = np.ones(num_frames, float)

        nmodules = len(nlitpx)
        if nmodules == 0:
            self._output_defaults(num_frames, out_hash)
            return mask

        nlitpx = np.array(nlitpx)
        axes = (0,) + tuple(range(2, nlitpx.ndim))
        nlitpx = np.sum(nlitpx, axis=axes)
        intens = np.sum(np.array(intens), axis=axes)
        nwrkpx = np.sum(np.array(nwrkpx), axis=axes)
        ntotpx = np.sum(np.array(ntotpx), axis=axes)
        nlitpx = (np.divide(nlitpx, nwrkpx, where=flag) *
                  ntotpx * norm_factor).astype(np.uint32)
        nlitpx[~flag] = 0

        intens = (np.divide(intens, nwrkpx, where=flag) *
                  ntotpx * norm_factor).astype(np.float32)
        intens[~flag] = 0

        self._history_litpx.extend(nlitpx[mask][-self._min_frames:])
        self._history_intens.extend(intens[mask][-self._min_frames:])

        if self._scores == "litpixels":
            scores = nlitpx
            history = self._history_litpx
        elif self._scores == "intensities":
            scores = intens
            history = self._history_intens
        else:
            raise KeyError("Unknown scores:", self._scores)

        if num_work_frames > 0:
            threshold = self.get_threshold(num_frames, scores[mask], history)
            self._last_threshold = threshold
        else:
            threshold = self._last_threshold
        num_hits = int(np.sum(decision))
        num_miss = int(np.sum(~decision & mask))
        hitrate = num_hits / num_work_frames if num_work_frames > 0 else .0
        self._history_hitrate.append(hitrate)

        if self._zero_to_masked:
            nlitpx[~mask] = 0

        out_hash.set(f"{self._name}.threshold", threshold)
        out_hash.set(f"{self._name}.litpixelCount", list(map(int, nlitpx)))
        out_hash.set(f"{self._name}.intensity", list(map(float, intens)))
        out_hash.set(f"{self._name}.numberOfHits", num_hits)
        out_hash.set(f"{self._name}.numberOfMiss", num_miss)
        out_hash.set(f"{self._name}.hitrate", hitrate)
        out_hash.set(f"{self._name}.averageHitrate", np.mean(self._history_hitrate))

        return decision