Skip to content
Snippets Groups Projects
hitfinder_spi.py 9.26 KiB
Newer Older
import numpy as np

from karabo.bound import (
    Hash,
    DOUBLE_ELEMENT,
    INT32_ELEMENT,
    STRING_ELEMENT,
    BOOL_ELEMENT,
)

from calng.arbiter_kernels.base_kernel import BaseArbiterKernel


class HitFinderSPI(BaseArbiterKernel):
    _node_name = "hitFinderSPI"

    def reconfigure(self, config):
        # note: automatically called in super().__init__
        if config.has("SPI"):
            self._use_spi = config.get("SPI")
        if config.has("modules"):
            self._modules = set(eval("np.r_[{}]".format(config.get("modules"))))
        if config.has("absoluteThreshold"):
            self._absolute_threshold = config.get("absoluteThreshold")
        if config.has("useAdaptiveThreshold"):
            self._use_adaptive_threshold = config.get("useAdaptiveThreshold")
        if config.has("sigmaLevel"):
            self._sigma_level = config.get("sigmaLevel")
        if config.has("maxHistoryLength"):
            self._max_history_length = config.get("maxHistoryLength")
            self._cur_history_length = 0
            self._history = np.zeros(self._max_history_length, dtype=int)
        if config.has("SFX"):
            self._use_sfx = config.get("SFX")
        if config.has("minPeaks"):
            self._min_peaks = config.get("minPeaks")
        if config.has("minRadius"):
            self._min_r = config.get("minRadius")
        if config.has("maxRadius"):
            self._max_r = config.get("maxRadius")

    @staticmethod
    def extend_device_schema(schema, prefix):
        (
            BOOL_ELEMENT(schema)
            .key(f"{prefix}.SPI")
            .assignmentOptional()
            .defaultValue(True)
            .reconfigurable()
            .commit(),

            STRING_ELEMENT(schema)
            .key(f"{prefix}.modules")
            .assignmentOptional()
            .defaultValue(":16")
            .reconfigurable()
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{prefix}.absoluteThreshold")
            .assignmentOptional()
            .defaultValue(240.0)
            .reconfigurable()
            .commit(),

            BOOL_ELEMENT(schema)
            .key(f"{prefix}.useAdaptiveThreshold")
            .assignmentOptional()
            .defaultValue(True)
            .reconfigurable()
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{prefix}.sigmaLevel")
            .assignmentOptional()
            .defaultValue(4.0)
            .reconfigurable()
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{prefix}.maxHistoryLength")
            .assignmentOptional()
            .defaultValue(200)
            .reconfigurable()
            .commit(),

            BOOL_ELEMENT(schema)
            .key(f"{prefix}.SFX")
            .assignmentOptional()
            .defaultValue(True)
            .reconfigurable()
            .commit(),

            INT32_ELEMENT(schema)
            .key(f"{prefix}.minPeaks")
            .assignmentOptional()
            .defaultValue(10)
            .reconfigurable()
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{prefix}.minRadius")
            .assignmentOptional()
            .defaultValue(50.0)
            .reconfigurable()
            .commit(),

            DOUBLE_ELEMENT(schema)
            .key(f"{prefix}.maxRadius")
            .assignmentOptional()
            .defaultValue(700.0)
            .reconfigurable()
            .commit(),
        )

    @property
    def _pixel_pos(self):
        # TODO: cahe
        return (
            self.geometry.get_pixel_positions() / self.geometry.pixel_size
        ).astype(int)

    def consider(self, train_id, sources, num_frames, mask, out_hash):
        has_xray = self.get_litframe_pattern(train_id, sources, num_frames)

        result = Hash()
        hits = np.zeros_like(has_xray)

        if self._use_spi:
            hits_spi, result_spi = self.spi_hitfinder(
                train_id, sources, num_frames, has_xray
            )
            hits = hits | hits_spi
            result.merge(result_spi)

        if self._use_sfx:
            hits_sfx, result_sfx = self.sfx_hitfinder(
                train_id, sources, num_frames, has_xray
            )
            hits = hits | hits_sfx
            result.merge(result_sfx)

        result["data.dataFramePattern"] = hits

        return result

    def get_litframe_pattern(self, train_id, sources, num_frames):
        has_xray = np.ones(num_frames, dtype=bool)
        for source, (data, _) in sources.items():
            if not data.has("data.nPulsePerFrame"):
                continue

            lff_data = np.array(data["data.nPulsePerFrame"])
            if len(lff_data) == num_frames:
                has_xray = lff_data > 0
                break
            else:
                self.log.WARN("Ignoring LFF data of different length")

        return has_xray

    def spi_hitfinder(self, train_id, sources, num_frames, has_xray):
        num_lit = 0
        num_working = 0
        num_total = 0

        for source, (data, _) in sources.items():
            if data.has("litpixels.count"):
                # AGIPD data.
                modno = int(source.split("/")[-1][:-8])
                if data["litpixels.count"].size > 0 and modno not in self._modules:
                    continue

                num_lit += np.sum(data["litpixels.count"], axis=(1, 2))
                num_working += np.sum(data["litpixels.unmasked"], axis=(1, 2))
                num_total += 65536  # Pixels per module.

        flag = has_xray & (num_working > 256)
        num_normalized = np.divide(num_lit, num_working, where=flag) * num_total
        num_normalized[~flag] = 0

        result = Hash(
            "litpixels.normalizedCount",
            num_normalized,
            "hitfinder.litpixelThreshold",
            -1.0,
        )

        num_good_frames = np.sum(flag)
        if num_good_frames == 0:
            hits = np.zeros(num_frames, dtype=bool)
            result["hitfinder.spiHits"] = hits
            return hits, result

        threshold = self._absolute_threshold

        if self._use_adaptive_threshold:
            self._history = np.roll(self._history, num_good_frames)
            good_counts = num_normalized[flag]
            self._history[:num_good_frames] = good_counts[: self._max_history_length]
            self._cur_history_length = min(
                self._cur_history_length + num_good_frames, self._max_history_length
            )

            q1, mu, q3 = np.percentile(
                (
                    good_counts
                    if num_good_frames > self._cur_history_length
                    else self._history[: self._cur_history_length]
                ),
                [25, 50, 75],
            )
            sigma = (q3 - q1) / 1.34896
            threshold = max(threshold, self._sigma_level * sigma + mu)

        hits = flag & (num_normalized > threshold)
        result["hitfinder.litpixelThreshold"] = threshold
        result["hitfinder.spiHits"] = hits

        return hits, result

    def sfx_hitfinder(self, train_id, sources, num_frames, has_xray):
        num_peaks, intensity, x, y = [], [], [], []
        modules = []
        for source, (data, _) in sources.items():
            if data.has("peakfinding.numPeaks"):
                # AGIPD data.
                modno = int(source.split("/")[-1][:-8])
                modules.append(modno)
                num_peaks.append(data["peakfinding.numPeaks"])
                intensity.append(data["peakfinding.peakIntensity"])
                x.append(data["peakfinding.peakX"])
                y.append(data["peakfinding.peakY"])

        num_peaks = np.stack(num_peaks, axis=1)
        intensity = np.stack(intensity, axis=1)
        x = np.stack(x, axis=1)
        y = np.stack(y, axis=1)

        # ncell, nmod, maxpeak
        max_peaks = intensity.shape[-1]
        module = np.tile(np.array(modules)[None, :, None], [num_frames, 1, max_peaks])
        num_peaks[~has_xray, :] = 0
        mask = np.arange(max_peaks, dtype=int)[None, None, :] < num_peaks[..., None]

        x = x[mask]
        y = y[mask]
        module = module[mask]
        intensity = intensity[mask]

        if self._pixel_pos is not None:
            xi = np.clip(np.round(x).astype(int), 0, 511)
            yi = np.clip(np.round(y).astype(int), 0, 127)
            xc = self._pixel_pos[module, xi, yi, 0]
            yc = self._pixel_pos[module, xi, yi, 1]
            r = np.sqrt(xc * xc + yc * yc)
            radius_flag = (self._min_r < r) & (r < self._max_r)

            x = x[radius_flag]
            y = y[radius_flag]
            module = module[radius_flag]
            intensity = intensity[radius_flag]

            np.place(mask, mask, radius_flag)

        num_peaks = np.sum(mask, axis=(1, 2))

        hits = has_xray & (num_peaks >= self._min_peaks)
        result = Hash(
            "peakfinder.numPeaks",
            num_peaks,
            "peakfinder.peakX",
            x,
            "peakfinder.peakY",
            y,
            "peakfinder.peakModule",
            module,
            "peakfinder.peakIntensity",
            intensity,
            "hitfinder.sfxHits",
            hits,
        )

        result["hitfinder.peakNumberThreshold"] = self._min_peaks
        result["hitfinder.minRadius"] = self._min_r
        result["hitfinder.maxRadius"] = self._max_r

        return hits, result