From 01998b0f26ec856380b77687d3d577b888e527f3 Mon Sep 17 00:00:00 2001
From: David Hammer <david.hammer@xfel.eu>
Date: Fri, 2 Aug 2024 09:50:39 +0200
Subject: [PATCH] Saturation monitor addon with own output channel

---
 src/calng/SaturationWarningAggregator.py      |  66 +-------
 src/calng/base_correction.py                  |  24 +--
 src/calng/correction_addons/base_addon.py     |   8 +-
 .../correction_addons/integrated_intensity.py |   2 +-
 .../correction_addons/litpixel_counter.py     |   6 +-
 src/calng/correction_addons/peakfinder9.py    |   8 +-
 src/calng/correction_addons/random_frames.py  |   2 +-
 .../correction_addons/saturation_monitor.py   | 150 +++++++++++-------
 src/calng/corrections/AgipdCorrection.py      |   4 +-
 9 files changed, 128 insertions(+), 142 deletions(-)

diff --git a/src/calng/SaturationWarningAggregator.py b/src/calng/SaturationWarningAggregator.py
index 8c5ba89b..d7330888 100644
--- a/src/calng/SaturationWarningAggregator.py
+++ b/src/calng/SaturationWarningAggregator.py
@@ -15,6 +15,7 @@ from karabo.bound import (
     Trainstamp,
 )
 from .DetectorAssembler import DetectorAssembler
+from .correction_addons.saturation_monitor import saturation_monitoring_schema
 from ._version import version as deviceVersion
 
 
@@ -22,76 +23,17 @@ from ._version import version as deviceVersion
 class SaturationWarningAggregator(DetectorAssembler):
     @staticmethod
     def expectedParameters(expected):
+        saturation_monitoring_schema(expected)
         (
             OVERWRITE_ELEMENT(expected)
             .key("imageDataPath")
-            .setNewDefaultValue("saturationMonitor.maxImage")
+            .setNewDefaultValue("saturationMonitor.alarmImage")
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
             .key("imageMaskPath")
             .setNewDefaultValue("")
             .commit(),
-
-            # The reason for the node is compatibility with the saturation
-            # monitor add-on in calng. That way a MDL device can use this
-            # aggregator, the add-on or the SaturationMonitor from the
-            # ImageProcessor package with the same code
-            NODE_ELEMENT(expected)
-            .key("saturationMonitor")
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("saturationMonitor.warning")
-            .readOnly()
-            .initialValue(False)
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("saturationMonitor.alarm")
-            .readOnly()
-            .initialValue(False)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("saturationMonitor.warnCount")
-            .description(
-                "Total number of pixels above warning threshold. Each pixel "
-                "is only counted once even if it exceeds the threshold in "
-                "multiple frames (/ memory cells)."
-            )
-            .readOnly()
-            .initialValue(0)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("saturationMonitor.alarmCount")
-            .description(
-                "Total number of pixels above alarm threshold. Each pixel is "
-                "only counted once even if it exceeds the threshold in "
-                "multiple frames (/ memory cells)."
-            )
-            .readOnly()
-            .initialValue(0)
-            .commit(),
-
-            UINT64_ELEMENT(expected)
-            .key("saturationMonitor.trainId")
-            .description(
-                "Total number of pixels above alarm threshold. Each pixel is "
-                "only counted once even if it exceeds the threshold in "
-                "multiple frames (/ memory cells)."
-            )
-            .readOnly()
-            .initialValue(0)
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("saturationMonitor.maxValue")
-            .description("Max pixel value in latest train with warning or alarm.")
-            .readOnly()
-            .initialValue(0)
-            .commit(),
         )
 
     def on_matched_data(self, tid, sources):
@@ -162,7 +104,7 @@ class SaturationWarningAggregator(DetectorAssembler):
                     coords=coords,
                 )
             )
-            self._preview_friend.write_outputs(my_timestamp, assembled_data)
+            self._preview_friend.write_outputs(assembled_data, timestamp=my_timestamp)
 
         self.info["sent"] += 1
         self.info["trainId"] = tid
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index a11f9b96..527fe06c 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -557,19 +557,21 @@ class BaseCorrection(PythonDevice):
         self.KARABO_ON_DATA("dataInput", self.input_handler)
         self.KARABO_ON_EOS("dataInput", self.handle_eos)
 
-        self._enabled_addons = [
-            addon_class(self._parameters[f"addons.{addon_class.__name__}"])
-            for addon_class in self._available_addons
-            if self.get(f"addons.{addon_class.__name__}.enable")
-        ]
-        for addon in self._enabled_addons:
+        self._enabled_addons = []
+        for addon_class in self._available_addons:
+            addon_prefix = f"addons.{addon_class.__name__}"
+            if not self.get(f"{addon_prefix}.enable"):
+                continue
+            addon = addon_class(self._parameters[addon_prefix])
             addon._device = self
+            addon._prefix = addon_prefix
+            self._enabled_addons.append(addon)
         if (
             (self.get("useShmemHandles") != self._use_shmem_handles)
             or self._enabled_addons
         ):
             schema_override = Schema()
-            output_schema_override = self._base_output_schema(
+            output_schema_override = self.__class__._base_output_schema(
                 use_shmem_handles=self.get("useShmemHandles")
             )
             for addon in self._enabled_addons:
@@ -898,7 +900,11 @@ class BaseCorrection(PythonDevice):
             data_hash["corrections"] = corrections
             for addon in self._enabled_addons:
                 addon.post_correction(
-                    processed_buffer, cell_table, pulse_table, data_hash
+                    timestamp.getTrainId(),
+                    processed_buffer,
+                    cell_table,
+                    pulse_table,
+                    data_hash,
                 )
             self.kernel_runner.reshape(processed_buffer, out=buffer_array)
 
@@ -915,7 +921,7 @@ class BaseCorrection(PythonDevice):
 
         for addon in self._enabled_addons:
             addon.post_reshape(
-                buffer_array, cell_table, pulse_table, data_hash
+                timestamp.getTrainId(), buffer_array, cell_table, pulse_table, data_hash
             )
 
         if self.unsafe_get("useShmemHandles"):
diff --git a/src/calng/correction_addons/base_addon.py b/src/calng/correction_addons/base_addon.py
index 95b5192f..8e6d8690 100644
--- a/src/calng/correction_addons/base_addon.py
+++ b/src/calng/correction_addons/base_addon.py
@@ -25,12 +25,16 @@ class BaseCorrectionAddon:
         device)"""
         return self._device._geometry
 
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
+    def post_correction(
+        self, train_id, processed_data, cell_table, pulse_table, output_hash
+    ):
         """Called directly after correction has happened. Processed data will still be
         on GPU if the correction device is generally running in GPU mode."""
         pass
 
-    def post_reshape(self, reshaped_data, cell_table, pulse_table, output_hash):
+    def post_reshape(
+        self, train_id, reshaped_data, cell_table, pulse_table, output_hash
+    ):
         pass
 
     def reconfigure(self, changed_config):
diff --git a/src/calng/correction_addons/integrated_intensity.py b/src/calng/correction_addons/integrated_intensity.py
index c09d67cb..5cfaad9e 100644
--- a/src/calng/correction_addons/integrated_intensity.py
+++ b/src/calng/correction_addons/integrated_intensity.py
@@ -68,7 +68,7 @@ class IntegratedIntensity(BaseCorrectionAddon):
             .commit()
         )
 
-    def post_correction(self, data, cell_table, pulse_table, output_hash):
+    def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
         # Numpy should handle CuPy dispatch for us
         mask = np.isfinite(data) & (self._vmin < data) & (data < self._vmax)
 
diff --git a/src/calng/correction_addons/litpixel_counter.py b/src/calng/correction_addons/litpixel_counter.py
index 48468c46..7baa14fd 100644
--- a/src/calng/correction_addons/litpixel_counter.py
+++ b/src/calng/correction_addons/litpixel_counter.py
@@ -46,9 +46,9 @@ class LitPixelCounter(BaseCorrectionAddon):
             .commit()
         )
 
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
-        n_cells, n_x, n_y = processed_data.shape
-        per_asic_data = processed_data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64)
+    def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
+        n_cells, n_x, n_y = data.shape
+        per_asic_data = data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64)
 
         lit_pixels = np.sum(per_asic_data > self._threshold, axis=(1, 3))
         unmasked_pixels = np.isfinite(per_asic_data).sum(axis=(1, 3))
diff --git a/src/calng/correction_addons/peakfinder9.py b/src/calng/correction_addons/peakfinder9.py
index ddbac713..b4694ce2 100644
--- a/src/calng/correction_addons/peakfinder9.py
+++ b/src/calng/correction_addons/peakfinder9.py
@@ -124,9 +124,9 @@ class Peakfinder9(BaseCorrectionAddon):
             .commit(),
         )
 
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
+    def post_correction(self, data, train_id, cell_table, pulse_table, output_hash):
         # assumes processed data shape is frames, pixels, pixels
-        if self._input_shape != processed_data.shape:
+        if self._input_shape != data.shape:
             try:
                 del self._peakfinding_parameters
             except AttributeError:
@@ -135,7 +135,7 @@ class Peakfinder9(BaseCorrectionAddon):
                 del self._grid_and_block
             except AttributeError:
                 pass
-            self._input_shape = processed_data.shape
+            self._input_shape = data.shape
             self._rebuild_buffers()
         kernel_params = self._peakfinding_parameters  # this will create buffers
         self._peak_counts.fill(0)
@@ -143,7 +143,7 @@ class Peakfinder9(BaseCorrectionAddon):
             *self._grid_and_block,
             (
                 *kernel_params,
-                processed_data.astype(cupy.float32, copy=False),
+                data.astype(cupy.float32, copy=False),
                 self._peak_counts,
                 self._peak_x,
                 self._peak_y,
diff --git a/src/calng/correction_addons/random_frames.py b/src/calng/correction_addons/random_frames.py
index 8d6fb71e..d3447710 100644
--- a/src/calng/correction_addons/random_frames.py
+++ b/src/calng/correction_addons/random_frames.py
@@ -34,7 +34,7 @@ class RandomFrames(BaseCorrectionAddon):
         # TODO: figure out why no / 100 here...
         self._probability = config["probability"]
 
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
+    def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
         output_hash["data.dataFramePattern"] = (
             np.random.random(cell_table.size) < self._probability
         ).astype(np.uint8)
diff --git a/src/calng/correction_addons/saturation_monitor.py b/src/calng/correction_addons/saturation_monitor.py
index 6bac179d..1c48a219 100644
--- a/src/calng/correction_addons/saturation_monitor.py
+++ b/src/calng/correction_addons/saturation_monitor.py
@@ -5,11 +5,17 @@ from karabo.bound import (
     FLOAT_ELEMENT,
     IMAGEDATA_ELEMENT,
     NODE_ELEMENT,
+    OUTPUT_CHANNEL,
     UINT32_ELEMENT,
     UINT64_ELEMENT,
     Dims,
     Encoding,
+    Epochstamp,
+    Hash,
     ImageData,
+    Schema,
+    Timestamp,
+    Trainstamp,
 )
 
 from .base_addon import BaseCorrectionAddon
@@ -22,6 +28,64 @@ def maybe_get(a):
     return a
 
 
+def saturation_monitoring_schema(schema=None):
+    if schema is None:
+        schema = Schema()
+    (
+        NODE_ELEMENT(schema)
+        .key("saturationMonitor")
+        .commit(),
+
+        BOOL_ELEMENT(schema)
+        .key("saturationMonitor.warning")
+        .readOnly()
+        .initialValue(False)
+        .commit(),
+
+        BOOL_ELEMENT(schema)
+        .key("saturationMonitor.alarm")
+        .readOnly()
+        .initialValue(False)
+        .commit(),
+
+        UINT32_ELEMENT(schema)
+        .key("saturationMonitor.warnCount")
+        .description(
+            "Total number of pixels above warning threshold. Each pixel "
+            "is only counted once even if it exceeds the threshold in "
+            "multiple frames (/ memory cells = given axis)."
+        )
+        .readOnly()
+        .initialValue(0)
+        .commit(),
+
+        UINT32_ELEMENT(schema)
+        .key("saturationMonitor.alarmCount")
+        .description(
+            "Total number of pixels above alarm threshold. Each pixel "
+            "is only counted once even if it exceeds the threshold in "
+            "multiple frames (/ memory cells = given axis)."
+        )
+        .readOnly()
+        .initialValue(0)
+        .commit(),
+
+        FLOAT_ELEMENT(schema)
+        .key("saturationMonitor.maxValue")
+        .description("Max pixel value in latest train with warning or alarm.")
+        .readOnly()
+        .initialValue(0)
+        .commit(),
+
+        UINT64_ELEMENT(schema)
+        .key("saturationMonitor.trainId")
+        .readOnly()
+        .initialValue(0)
+        .commit(),
+    )
+    return schema
+
+
 class SaturationMonitor(BaseCorrectionAddon):
     def __init__(self, config):
         global cupy
@@ -44,56 +108,15 @@ class SaturationMonitor(BaseCorrectionAddon):
         if changed_config.has("frameAxis"):
             self._frameAxis = changed_config["frameAxis"]
 
-    @staticmethod
-    def extend_output_schema(schema):
+    @classmethod
+    def extend_device_schema(cls, schema, prefix):
+        # will add own output channel with this schema
+        output_schema = saturation_monitoring_schema()
         (
-            NODE_ELEMENT(schema)
-            .key("saturationMonitor")
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key("saturationMonitor.warning")
-            .readOnly()
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key("saturationMonitor.alarm")
-            .readOnly()
-            .commit(),
-
-            UINT32_ELEMENT(schema)
-            .key("saturationMonitor.warnCount")
-            .description(
-                "Total number of pixels above warning threshold. Each pixel "
-                "is only counted once even if it exceeds the threshold in "
-                "multiple frames (/ memory cells = given axis)."
-            )
-            .readOnly()
-            .commit(),
-
-            UINT32_ELEMENT(schema)
-            .key("saturationMonitor.alarmCount")
-                .description(
-                "Total number of pixels above alarm threshold. Each pixel "
-                "is only counted once even if it exceeds the threshold in "
-                "multiple frames (/ memory cells = given axis)."
-            )
-            .readOnly()
-            .commit(),
-
-            FLOAT_ELEMENT(schema)
-            .key("saturationMonitor.maxValue")
-            .readOnly()
-            .commit(),
-
-            # TODO: switch to image data
-            IMAGEDATA_ELEMENT(schema)
+            IMAGEDATA_ELEMENT(output_schema)
             .key("saturationMonitor.alarmImage")
             .commit(),
         )
-
-    @staticmethod
-    def extend_device_schema(schema, prefix):
         (
             FLOAT_ELEMENT(schema)
             .key(f"{prefix}.alarmThreshold")
@@ -133,30 +156,43 @@ class SaturationMonitor(BaseCorrectionAddon):
 
             UINT64_ELEMENT(schema)
             .key(f"{prefix}.frameAxis")
-            .displayedName('Multi-frame axis')
+            .displayedName("Multi-frame axis")
             .description("Axis for frames. Used to take the max over this axis.")
             .tags("managed")
             .assignmentOptional()
             .defaultValue(0)
             .reconfigurable()
             .commit(),
+
+            OUTPUT_CHANNEL(schema)
+            .key(f"{prefix}.output")
+            .dataSchema(output_schema)
+            .commit(),
         )
 
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
+    def post_correction(self, train_id, data, cell_table, pulse_table, data_hash):
+        output = Hash()
+
         # only take the max if data has frames -> more than 2 dimensions
-        if processed_data.ndim > 2:
-            max_image = np.nanmax(processed_data, axis=self._frameAxis)
+        if data.ndim > 2:
+            max_image = np.nanmax(data, axis=self._frameAxis)
         else:
-            max_image = processed_data
+            max_image = data
         nb_pix_warning = int(np.nansum(max_image > self._warnThreshold))
         nb_pix_alarm = int(np.nansum(max_image > self._alarmThreshold))
 
-        output_hash["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount
-        output_hash["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount
-        output_hash["saturationMonitor.warnCount"] = nb_pix_warning
-        output_hash["saturationMonitor.alarmCount"] = nb_pix_alarm
-        output_hash["saturationMonitor.maxValue"] = float(np.nanmax(max_image))
+        output["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount
+        output["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount
+        output["saturationMonitor.warnCount"] = nb_pix_warning
+        output["saturationMonitor.alarmCount"] = nb_pix_alarm
+        output["saturationMonitor.maxValue"] = float(np.nanmax(max_image))
+        output["saturationMonitor.trainId"] = train_id
         max_image[max_image <= self._warnThreshold] = 0
-        output_hash["saturationMonitor.alarmImage"] = ImageData(
+        output["saturationMonitor.alarmImage"] = ImageData(
             maybe_get(max_image), Dims(*max_image.shape), Encoding.GRAY, bitsPerPixel=32
         )
+        self._device.writeChannel(
+            f"{self._prefix}.output",
+            output,
+            timestamp=Timestamp(Epochstamp(), Trainstamp(train_id)),
+        )
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index 6632b6ef..b3900c25 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -756,9 +756,7 @@ class AgipdCorrection(base_correction.BaseCorrection):
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
             .dataSchema(
-                AgipdCorrection._base_output_schema(
-                    use_shmem_handles=cls._use_shmem_handles
-                )
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
             )
             .commit(),
         )
-- 
GitLab