From abfe5e34f2c609b7bf9d7a66d59ff4fb10730e82 Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Thu, 17 Oct 2024 18:53:07 +0200
Subject: [PATCH] Add next generation SPI hitfinder

---
 setup.py                        |   1 +
 src/sfx_addons/spi_hitfinder.py | 253 ++++++++++++++++++++++++++++++++
 2 files changed, 254 insertions(+)
 create mode 100644 src/sfx_addons/spi_hitfinder.py

diff --git a/setup.py b/setup.py
index de2f1cd..2e9c177 100644
--- a/setup.py
+++ b/setup.py
@@ -23,6 +23,7 @@ setup(
     entry_points={
         "calng.arbiter_kernel": [
             "HitFinderSPI = sfx_addons.hitfinder_spi:HitFinderSPI",
+            "SPIhitfinder = sfx_addons.spi_hitfinder:SPIhitfinder",
         ],
     },
     requires=[],
diff --git a/src/sfx_addons/spi_hitfinder.py b/src/sfx_addons/spi_hitfinder.py
new file mode 100644
index 0000000..518ef55
--- /dev/null
+++ b/src/sfx_addons/spi_hitfinder.py
@@ -0,0 +1,253 @@
+from collections import deque
+
+import numpy as np
+from calng.arbiter_kernels.base_kernel import BaseArbiterKernel
+from karabo.bound import (
+    BOOL_ELEMENT, DOUBLE_ELEMENT, INT32_ELEMENT,
+    OVERWRITE_ELEMENT, STRING_ELEMENT,
+    VECTOR_UINT32_ELEMENT)
+
+
+def module_number(source):
+    token = source.split("/")[-1]
+    return int(token[:-8]) if token.endswith("CH0:xtdf") else None
+
+
+class SPIhitfinder(BaseArbiterKernel):
+    _node_name = "spiHitfinder"
+
+    def __init__(self, device, name, config):
+        self._history = deque([], 100)
+        super().__init__(device, name, config)
+
+        self._nlitpx_key = "litpixels.count"
+        self._nwrkpx_key = "litpixels.unmasked"
+        self._last_threshold = 0.0
+
+    @staticmethod
+    def extend_device_schema(schema, prefix):
+        (
+            OVERWRITE_ELEMENT(schema)
+            .key(prefix)
+            .setNewDescription(
+                "This kernel selects the frames by comparing the number"
+                "of lit-pixels to the threshold."
+            )
+            .commit(),
+
+            VECTOR_UINT32_ELEMENT(schema)
+            .key(f"{prefix}.modules")
+            .assignmentOptional()
+            .defaultValue([3, 4, 8, 15])
+            .reconfigurable()
+            .commit(),
+
+            STRING_ELEMENT(schema)
+            .key(f"{prefix}.thresholdMode")
+            .options("fixed,adaptive")
+            .assignmentOptional()
+            .defaultValue("adaptive")
+            .reconfigurable()
+            .commit(),
+
+            INT32_ELEMENT(schema)
+            .key(f"{prefix}.fixedThreshold")
+            .assignmentOptional()
+            .defaultValue(0)
+            .reconfigurable()
+            .commit(),
+
+            DOUBLE_ELEMENT(schema)
+            .key(f"{prefix}.snr")
+            .assignmentOptional()
+            .defaultValue(3.5)
+            .reconfigurable()
+            .commit(),
+
+            INT32_ELEMENT(schema)
+            .key(f"{prefix}.minFrames")
+            .assignmentOptional()
+            .defaultValue(100)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key(f"{prefix}.xgmNomalization")
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key(f"{prefix}.suppressMaskedScores")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
+
+    @staticmethod
+    def extend_output_schema(schema, name):
+        (
+            DOUBLE_ELEMENT(schema)
+            .key(f"{name}.threshold")
+            .assignmentOptional()
+            .defaultValue(np.nan)
+            .commit(),
+
+            VECTOR_UINT32_ELEMENT(schema)
+            .key(f"{name}.litpixelCount")
+            .assignmentOptional()
+            .defaultValue([])
+            .commit(),
+
+            INT32_ELEMENT(schema)
+            .key(f"{name}.numberOfHits")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+
+            INT32_ELEMENT(schema)
+            .key(f"{name}.numberOfMiss")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+
+            DOUBLE_ELEMENT(schema)
+            .key(f"{name}.hitrate")
+            .assignmentOptional()
+            .defaultValue(0.0)
+            .commit(),
+        )
+
+    def reconfigure(self, config):
+        if config.has("modules"):
+            self._modules = set(config.get("modules"))
+        if config.has("snr"):
+            self._snr = config.get("snr")
+        if config.has("minFrames"):
+            self._min_frames = config.get("minFrames")
+            self._history = deque(self._history, self._min_frames)
+
+        if config.has("thresholdMode"):
+            mode = config.get("thresholdMode")
+            self.get_threshold = (
+                self._get_fixed_threshold if mode == "fixed"
+                else self._get_adaptive_threshold
+            )
+        if config.has("fixedThreshold"):
+            self._fixed_threshold = int(config.get("fixedThreshold"))
+
+        if config.has("xgmNomalization"):
+            self._xgm_normalization = config["xgmNomalization"]
+
+        if config.has("suppressMaskedScores"):
+            self._zero_to_masked = config["suppressMaskedScores"]
+
+    def _output_defaults(self, num_frames, out_hash):
+        decision = np.zeros(num_frames, dtype=bool)
+        out_hash.set(f"{self._name}.threshold", self._last_threshold)
+        out_hash.set(f"{self._name}.litpixelCount", np.zeros(num_frames, np.uint32))
+        out_hash.set(f"{self._name}.numberOfHits", 0)
+        out_hash.set(f"{self._name}.numberOfMiss", 0)
+        out_hash.set(f"{self._name}.hits", decision)
+        out_hash.set(f"{self._name}.hitrate", 0.0)
+
+    def _get_fixed_threshold(self, num_frames, nlitpx):
+        return self._fixed_threshold
+
+    def _get_adaptive_threshold(self, num_frames, nlitpx):
+        if num_frames >= self._min_frames:
+            q1, mu, q3 = np.percentile(nlitpx, [25, 50, 75])
+        else:
+            q1, mu, q3 = np.percentile(self._history, [25, 50, 75])
+
+        sig = (q3 - q1) / 1.34896
+        return int(mu + self._snr * sig + 0.5)
+
+    def get_normalization_factor(self, num_frames, mask, xgm, npulse, xgm_ix):
+        norm_factor = np.ones(num_frames, float)
+        if xgm is None or num_frames <= 1:
+            return norm_factor
+
+        ncell = len(npulse)
+        cell_ix = np.repeat(np.arange(ncell), npulse).astype(int)
+        inten = np.bincount(cell_ix, weights=xgm[xgm_ix], minlength=ncell)
+        has_xray = npulse > 0
+        if num_frames == np.sum(has_xray):
+            # filtered
+            inten = inten[has_xray]
+        if np.sum(has_xray) > 1:
+            mean_inten = np.mean(inten[has_xray])
+            inten[(inten < 50) | ~has_xray] = mean_inten
+            norm_factor = mean_inten / inten
+
+        return norm_factor
+
+    def consider(self, train_id, sources, num_frames, mask, out_hash):
+        num_work_frames = np.sum(mask)
+
+        xgm = None
+        nlitpx = []
+        nwrkpx = []
+        for source, (data, _) in sources.items():
+            if data.has("data.intensityTD"):
+                xgm = np.array(data["data.intensityTD"])
+            if data.has("data.nPulsePerFrame") and data.has("data.xgmPulseId"):
+                npulse = np.array(data["data.nPulsePerFrame"])
+                xgm_ix = np.array(data["data.xgmPulseId"], dtype=int)
+            modno = module_number(source)
+            if modno not in self._modules:
+                continue
+            if data.has(self._nlitpx_key) & data.has(self._nwrkpx_key):
+                nlitpx.append(data[self._nlitpx_key])
+                nwrkpx.append(data[self._nwrkpx_key])
+
+        if self._xgm_normalization:
+            norm_factor = self.get_normalization_factor(
+                num_frames, mask, xgm, npulse, xgm_ix)
+        else:
+            norm_factor = np.ones(num_frames, float)
+
+        nmodules = len(nlitpx)
+        if nmodules == 0:
+            self._output_defaults(num_frames, out_hash)
+            return mask
+
+        nlitpx = np.sum(np.array(nlitpx), axis=0)
+        nwrkpx = np.sum(np.array(nwrkpx), axis=0)
+
+        if nlitpx.ndim > 1:
+            axes = tuple(range(1, nlitpx.ndim))
+            nlitpx = np.sum(nlitpx, axis=axes)
+            nwrkpx = np.sum(nwrkpx, axis=axes)
+
+        ntotpx = 65536 * nmodules
+
+        flag = nwrkpx > 256
+        nlitpx = (np.divide(nlitpx, nwrkpx, where=flag) *
+                  ntotpx * norm_factor).astype(np.uint32)
+        nlitpx[~flag] = 0
+
+        self._history.extend(nlitpx[mask][-self._min_frames:])
+
+        if num_work_frames > 0:
+            threshold = self.get_threshold(num_frames, nlitpx[mask])
+        else:
+            threshold = self._last_threshold
+        self._last_threshold = threshold
+        decision = (nlitpx > threshold) & mask
+        num_hits = int(np.sum(decision))
+        num_miss = int(np.sum(~decision & mask))
+        hitrate = num_hits / num_work_frames if num_work_frames > 0 else .0
+
+        if self._zero_to_masked:
+            nlitpx[~mask] = 0
+
+        out_hash.set(f"{self._name}.threshold", threshold)
+        out_hash.set(f"{self._name}.litpixelCount", list(map(int, nlitpx)))
+        out_hash.set(f"{self._name}.numberOfHits", num_hits)
+        out_hash.set(f"{self._name}.numberOfMiss", num_miss)
+        out_hash.set(f"{self._name}.hitrate", hitrate)
+
+        return decision
-- 
GitLab