From abfe5e34f2c609b7bf9d7a66d59ff4fb10730e82 Mon Sep 17 00:00:00 2001 From: Egor Sobolev <egor.sobolev@xfel.eu> Date: Thu, 17 Oct 2024 18:53:07 +0200 Subject: [PATCH] Add next generation SPI hitfinder --- setup.py | 1 + src/sfx_addons/spi_hitfinder.py | 253 ++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 src/sfx_addons/spi_hitfinder.py diff --git a/setup.py b/setup.py index de2f1cd..2e9c177 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ setup( entry_points={ "calng.arbiter_kernel": [ "HitFinderSPI = sfx_addons.hitfinder_spi:HitFinderSPI", + "SPIhitfinder = sfx_addons.spi_hitfinder:SPIhitfinder", ], }, requires=[], diff --git a/src/sfx_addons/spi_hitfinder.py b/src/sfx_addons/spi_hitfinder.py new file mode 100644 index 0000000..518ef55 --- /dev/null +++ b/src/sfx_addons/spi_hitfinder.py @@ -0,0 +1,253 @@ +from collections import deque + +import numpy as np +from calng.arbiter_kernels.base_kernel import BaseArbiterKernel +from karabo.bound import ( + BOOL_ELEMENT, DOUBLE_ELEMENT, INT32_ELEMENT, + OVERWRITE_ELEMENT, STRING_ELEMENT, + VECTOR_UINT32_ELEMENT) + + +def module_number(source): + token = source.split("/")[-1] + return int(token[:-8]) if token.endswith("CH0:xtdf") else None + + +class SPIhitfinder(BaseArbiterKernel): + _node_name = "spiHitfinder" + + def __init__(self, device, name, config): + self._history = deque([], 100) + super().__init__(device, name, config) + + self._nlitpx_key = "litpixels.count" + self._nwrkpx_key = "litpixels.unmasked" + self._last_threshold = 0.0 + + @staticmethod + def extend_device_schema(schema, prefix): + ( + OVERWRITE_ELEMENT(schema) + .key(prefix) + .setNewDescription( + "This kernel selects the frames by comparing the number" + "of lit-pixels to the threshold." + ) + .commit(), + + VECTOR_UINT32_ELEMENT(schema) + .key(f"{prefix}.modules") + .assignmentOptional() + .defaultValue([3, 4, 8, 15]) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key(f"{prefix}.thresholdMode") + .options("fixed,adaptive") + .assignmentOptional() + .defaultValue("adaptive") + .reconfigurable() + .commit(), + + INT32_ELEMENT(schema) + .key(f"{prefix}.fixedThreshold") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{prefix}.snr") + .assignmentOptional() + .defaultValue(3.5) + .reconfigurable() + .commit(), + + INT32_ELEMENT(schema) + .key(f"{prefix}.minFrames") + .assignmentOptional() + .defaultValue(100) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{prefix}.xgmNomalization") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{prefix}.suppressMaskedScores") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) + + @staticmethod + def extend_output_schema(schema, name): + ( + DOUBLE_ELEMENT(schema) + .key(f"{name}.threshold") + .assignmentOptional() + .defaultValue(np.nan) + .commit(), + + VECTOR_UINT32_ELEMENT(schema) + .key(f"{name}.litpixelCount") + .assignmentOptional() + .defaultValue([]) + .commit(), + + INT32_ELEMENT(schema) + .key(f"{name}.numberOfHits") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(schema) + .key(f"{name}.numberOfMiss") + .assignmentOptional() + .defaultValue(0) + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{name}.hitrate") + .assignmentOptional() + .defaultValue(0.0) + .commit(), + ) + + def reconfigure(self, config): + if config.has("modules"): + self._modules = set(config.get("modules")) + if config.has("snr"): + self._snr = config.get("snr") + if config.has("minFrames"): + self._min_frames = config.get("minFrames") + self._history = deque(self._history, self._min_frames) + + if config.has("thresholdMode"): + mode = config.get("thresholdMode") + self.get_threshold = ( + self._get_fixed_threshold if mode == "fixed" + else self._get_adaptive_threshold + ) + if config.has("fixedThreshold"): + self._fixed_threshold = int(config.get("fixedThreshold")) + + if config.has("xgmNomalization"): + self._xgm_normalization = config["xgmNomalization"] + + if config.has("suppressMaskedScores"): + self._zero_to_masked = config["suppressMaskedScores"] + + def _output_defaults(self, num_frames, out_hash): + decision = np.zeros(num_frames, dtype=bool) + out_hash.set(f"{self._name}.threshold", self._last_threshold) + out_hash.set(f"{self._name}.litpixelCount", np.zeros(num_frames, np.uint32)) + out_hash.set(f"{self._name}.numberOfHits", 0) + out_hash.set(f"{self._name}.numberOfMiss", 0) + out_hash.set(f"{self._name}.hits", decision) + out_hash.set(f"{self._name}.hitrate", 0.0) + + def _get_fixed_threshold(self, num_frames, nlitpx): + return self._fixed_threshold + + def _get_adaptive_threshold(self, num_frames, nlitpx): + if num_frames >= self._min_frames: + q1, mu, q3 = np.percentile(nlitpx, [25, 50, 75]) + else: + q1, mu, q3 = np.percentile(self._history, [25, 50, 75]) + + sig = (q3 - q1) / 1.34896 + return int(mu + self._snr * sig + 0.5) + + def get_normalization_factor(self, num_frames, mask, xgm, npulse, xgm_ix): + norm_factor = np.ones(num_frames, float) + if xgm is None or num_frames <= 1: + return norm_factor + + ncell = len(npulse) + cell_ix = np.repeat(np.arange(ncell), npulse).astype(int) + inten = np.bincount(cell_ix, weights=xgm[xgm_ix], minlength=ncell) + has_xray = npulse > 0 + if num_frames == np.sum(has_xray): + # filtered + inten = inten[has_xray] + if np.sum(has_xray) > 1: + mean_inten = np.mean(inten[has_xray]) + inten[(inten < 50) | ~has_xray] = mean_inten + norm_factor = mean_inten / inten + + return norm_factor + + def consider(self, train_id, sources, num_frames, mask, out_hash): + num_work_frames = np.sum(mask) + + xgm = None + nlitpx = [] + nwrkpx = [] + for source, (data, _) in sources.items(): + if data.has("data.intensityTD"): + xgm = np.array(data["data.intensityTD"]) + if data.has("data.nPulsePerFrame") and data.has("data.xgmPulseId"): + npulse = np.array(data["data.nPulsePerFrame"]) + xgm_ix = np.array(data["data.xgmPulseId"], dtype=int) + modno = module_number(source) + if modno not in self._modules: + continue + if data.has(self._nlitpx_key) & data.has(self._nwrkpx_key): + nlitpx.append(data[self._nlitpx_key]) + nwrkpx.append(data[self._nwrkpx_key]) + + if self._xgm_normalization: + norm_factor = self.get_normalization_factor( + num_frames, mask, xgm, npulse, xgm_ix) + else: + norm_factor = np.ones(num_frames, float) + + nmodules = len(nlitpx) + if nmodules == 0: + self._output_defaults(num_frames, out_hash) + return mask + + nlitpx = np.sum(np.array(nlitpx), axis=0) + nwrkpx = np.sum(np.array(nwrkpx), axis=0) + + if nlitpx.ndim > 1: + axes = tuple(range(1, nlitpx.ndim)) + nlitpx = np.sum(nlitpx, axis=axes) + nwrkpx = np.sum(nwrkpx, axis=axes) + + ntotpx = 65536 * nmodules + + flag = nwrkpx > 256 + nlitpx = (np.divide(nlitpx, nwrkpx, where=flag) * + ntotpx * norm_factor).astype(np.uint32) + nlitpx[~flag] = 0 + + self._history.extend(nlitpx[mask][-self._min_frames:]) + + if num_work_frames > 0: + threshold = self.get_threshold(num_frames, nlitpx[mask]) + else: + threshold = self._last_threshold + self._last_threshold = threshold + decision = (nlitpx > threshold) & mask + num_hits = int(np.sum(decision)) + num_miss = int(np.sum(~decision & mask)) + hitrate = num_hits / num_work_frames if num_work_frames > 0 else .0 + + if self._zero_to_masked: + nlitpx[~mask] = 0 + + out_hash.set(f"{self._name}.threshold", threshold) + out_hash.set(f"{self._name}.litpixelCount", list(map(int, nlitpx))) + out_hash.set(f"{self._name}.numberOfHits", num_hits) + out_hash.set(f"{self._name}.numberOfMiss", num_miss) + out_hash.set(f"{self._name}.hitrate", hitrate) + + return decision -- GitLab