From 01998b0f26ec856380b77687d3d577b888e527f3 Mon Sep 17 00:00:00 2001 From: David Hammer <david.hammer@xfel.eu> Date: Fri, 2 Aug 2024 09:50:39 +0200 Subject: [PATCH] Saturation monitor addon with own output channel --- src/calng/SaturationWarningAggregator.py | 66 +------- src/calng/base_correction.py | 24 +-- src/calng/correction_addons/base_addon.py | 8 +- .../correction_addons/integrated_intensity.py | 2 +- .../correction_addons/litpixel_counter.py | 6 +- src/calng/correction_addons/peakfinder9.py | 8 +- src/calng/correction_addons/random_frames.py | 2 +- .../correction_addons/saturation_monitor.py | 150 +++++++++++------- src/calng/corrections/AgipdCorrection.py | 4 +- 9 files changed, 128 insertions(+), 142 deletions(-) diff --git a/src/calng/SaturationWarningAggregator.py b/src/calng/SaturationWarningAggregator.py index 8c5ba89b..d7330888 100644 --- a/src/calng/SaturationWarningAggregator.py +++ b/src/calng/SaturationWarningAggregator.py @@ -15,6 +15,7 @@ from karabo.bound import ( Trainstamp, ) from .DetectorAssembler import DetectorAssembler +from .correction_addons.saturation_monitor import saturation_monitoring_schema from ._version import version as deviceVersion @@ -22,76 +23,17 @@ from ._version import version as deviceVersion class SaturationWarningAggregator(DetectorAssembler): @staticmethod def expectedParameters(expected): + saturation_monitoring_schema(expected) ( OVERWRITE_ELEMENT(expected) .key("imageDataPath") - .setNewDefaultValue("saturationMonitor.maxImage") + .setNewDefaultValue("saturationMonitor.alarmImage") .commit(), OVERWRITE_ELEMENT(expected) .key("imageMaskPath") .setNewDefaultValue("") .commit(), - - # The reason for the node is compatibility with the saturation - # monitor add-on in calng. That way a MDL device can use this - # aggregator, the add-on or the SaturationMonitor from the - # ImageProcessor package with the same code - NODE_ELEMENT(expected) - .key("saturationMonitor") - .commit(), - - BOOL_ELEMENT(expected) - .key("saturationMonitor.warning") - .readOnly() - .initialValue(False) - .commit(), - - BOOL_ELEMENT(expected) - .key("saturationMonitor.alarm") - .readOnly() - .initialValue(False) - .commit(), - - UINT32_ELEMENT(expected) - .key("saturationMonitor.warnCount") - .description( - "Total number of pixels above warning threshold. Each pixel " - "is only counted once even if it exceeds the threshold in " - "multiple frames (/ memory cells)." - ) - .readOnly() - .initialValue(0) - .commit(), - - UINT32_ELEMENT(expected) - .key("saturationMonitor.alarmCount") - .description( - "Total number of pixels above alarm threshold. Each pixel is " - "only counted once even if it exceeds the threshold in " - "multiple frames (/ memory cells)." - ) - .readOnly() - .initialValue(0) - .commit(), - - UINT64_ELEMENT(expected) - .key("saturationMonitor.trainId") - .description( - "Total number of pixels above alarm threshold. Each pixel is " - "only counted once even if it exceeds the threshold in " - "multiple frames (/ memory cells)." - ) - .readOnly() - .initialValue(0) - .commit(), - - FLOAT_ELEMENT(expected) - .key("saturationMonitor.maxValue") - .description("Max pixel value in latest train with warning or alarm.") - .readOnly() - .initialValue(0) - .commit(), ) def on_matched_data(self, tid, sources): @@ -162,7 +104,7 @@ class SaturationWarningAggregator(DetectorAssembler): coords=coords, ) ) - self._preview_friend.write_outputs(my_timestamp, assembled_data) + self._preview_friend.write_outputs(assembled_data, timestamp=my_timestamp) self.info["sent"] += 1 self.info["trainId"] = tid diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index a11f9b96..527fe06c 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -557,19 +557,21 @@ class BaseCorrection(PythonDevice): self.KARABO_ON_DATA("dataInput", self.input_handler) self.KARABO_ON_EOS("dataInput", self.handle_eos) - self._enabled_addons = [ - addon_class(self._parameters[f"addons.{addon_class.__name__}"]) - for addon_class in self._available_addons - if self.get(f"addons.{addon_class.__name__}.enable") - ] - for addon in self._enabled_addons: + self._enabled_addons = [] + for addon_class in self._available_addons: + addon_prefix = f"addons.{addon_class.__name__}" + if not self.get(f"{addon_prefix}.enable"): + continue + addon = addon_class(self._parameters[addon_prefix]) addon._device = self + addon._prefix = addon_prefix + self._enabled_addons.append(addon) if ( (self.get("useShmemHandles") != self._use_shmem_handles) or self._enabled_addons ): schema_override = Schema() - output_schema_override = self._base_output_schema( + output_schema_override = self.__class__._base_output_schema( use_shmem_handles=self.get("useShmemHandles") ) for addon in self._enabled_addons: @@ -898,7 +900,11 @@ class BaseCorrection(PythonDevice): data_hash["corrections"] = corrections for addon in self._enabled_addons: addon.post_correction( - processed_buffer, cell_table, pulse_table, data_hash + timestamp.getTrainId(), + processed_buffer, + cell_table, + pulse_table, + data_hash, ) self.kernel_runner.reshape(processed_buffer, out=buffer_array) @@ -915,7 +921,7 @@ class BaseCorrection(PythonDevice): for addon in self._enabled_addons: addon.post_reshape( - buffer_array, cell_table, pulse_table, data_hash + timestamp.getTrainId(), buffer_array, cell_table, pulse_table, data_hash ) if self.unsafe_get("useShmemHandles"): diff --git a/src/calng/correction_addons/base_addon.py b/src/calng/correction_addons/base_addon.py index 95b5192f..8e6d8690 100644 --- a/src/calng/correction_addons/base_addon.py +++ b/src/calng/correction_addons/base_addon.py @@ -25,12 +25,16 @@ class BaseCorrectionAddon: device)""" return self._device._geometry - def post_correction(self, processed_data, cell_table, pulse_table, output_hash): + def post_correction( + self, train_id, processed_data, cell_table, pulse_table, output_hash + ): """Called directly after correction has happened. Processed data will still be on GPU if the correction device is generally running in GPU mode.""" pass - def post_reshape(self, reshaped_data, cell_table, pulse_table, output_hash): + def post_reshape( + self, train_id, reshaped_data, cell_table, pulse_table, output_hash + ): pass def reconfigure(self, changed_config): diff --git a/src/calng/correction_addons/integrated_intensity.py b/src/calng/correction_addons/integrated_intensity.py index c09d67cb..5cfaad9e 100644 --- a/src/calng/correction_addons/integrated_intensity.py +++ b/src/calng/correction_addons/integrated_intensity.py @@ -68,7 +68,7 @@ class IntegratedIntensity(BaseCorrectionAddon): .commit() ) - def post_correction(self, data, cell_table, pulse_table, output_hash): + def post_correction(self, tid, data, cell_table, pulse_table, output_hash): # Numpy should handle CuPy dispatch for us mask = np.isfinite(data) & (self._vmin < data) & (data < self._vmax) diff --git a/src/calng/correction_addons/litpixel_counter.py b/src/calng/correction_addons/litpixel_counter.py index 48468c46..7baa14fd 100644 --- a/src/calng/correction_addons/litpixel_counter.py +++ b/src/calng/correction_addons/litpixel_counter.py @@ -46,9 +46,9 @@ class LitPixelCounter(BaseCorrectionAddon): .commit() ) - def post_correction(self, processed_data, cell_table, pulse_table, output_hash): - n_cells, n_x, n_y = processed_data.shape - per_asic_data = processed_data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64) + def post_correction(self, tid, data, cell_table, pulse_table, output_hash): + n_cells, n_x, n_y = data.shape + per_asic_data = data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64) lit_pixels = np.sum(per_asic_data > self._threshold, axis=(1, 3)) unmasked_pixels = np.isfinite(per_asic_data).sum(axis=(1, 3)) diff --git a/src/calng/correction_addons/peakfinder9.py b/src/calng/correction_addons/peakfinder9.py index ddbac713..b4694ce2 100644 --- a/src/calng/correction_addons/peakfinder9.py +++ b/src/calng/correction_addons/peakfinder9.py @@ -124,9 +124,9 @@ class Peakfinder9(BaseCorrectionAddon): .commit(), ) - def post_correction(self, processed_data, cell_table, pulse_table, output_hash): + def post_correction(self, data, train_id, cell_table, pulse_table, output_hash): # assumes processed data shape is frames, pixels, pixels - if self._input_shape != processed_data.shape: + if self._input_shape != data.shape: try: del self._peakfinding_parameters except AttributeError: @@ -135,7 +135,7 @@ class Peakfinder9(BaseCorrectionAddon): del self._grid_and_block except AttributeError: pass - self._input_shape = processed_data.shape + self._input_shape = data.shape self._rebuild_buffers() kernel_params = self._peakfinding_parameters # this will create buffers self._peak_counts.fill(0) @@ -143,7 +143,7 @@ class Peakfinder9(BaseCorrectionAddon): *self._grid_and_block, ( *kernel_params, - processed_data.astype(cupy.float32, copy=False), + data.astype(cupy.float32, copy=False), self._peak_counts, self._peak_x, self._peak_y, diff --git a/src/calng/correction_addons/random_frames.py b/src/calng/correction_addons/random_frames.py index 8d6fb71e..d3447710 100644 --- a/src/calng/correction_addons/random_frames.py +++ b/src/calng/correction_addons/random_frames.py @@ -34,7 +34,7 @@ class RandomFrames(BaseCorrectionAddon): # TODO: figure out why no / 100 here... self._probability = config["probability"] - def post_correction(self, processed_data, cell_table, pulse_table, output_hash): + def post_correction(self, tid, data, cell_table, pulse_table, output_hash): output_hash["data.dataFramePattern"] = ( np.random.random(cell_table.size) < self._probability ).astype(np.uint8) diff --git a/src/calng/correction_addons/saturation_monitor.py b/src/calng/correction_addons/saturation_monitor.py index 6bac179d..1c48a219 100644 --- a/src/calng/correction_addons/saturation_monitor.py +++ b/src/calng/correction_addons/saturation_monitor.py @@ -5,11 +5,17 @@ from karabo.bound import ( FLOAT_ELEMENT, IMAGEDATA_ELEMENT, NODE_ELEMENT, + OUTPUT_CHANNEL, UINT32_ELEMENT, UINT64_ELEMENT, Dims, Encoding, + Epochstamp, + Hash, ImageData, + Schema, + Timestamp, + Trainstamp, ) from .base_addon import BaseCorrectionAddon @@ -22,6 +28,64 @@ def maybe_get(a): return a +def saturation_monitoring_schema(schema=None): + if schema is None: + schema = Schema() + ( + NODE_ELEMENT(schema) + .key("saturationMonitor") + .commit(), + + BOOL_ELEMENT(schema) + .key("saturationMonitor.warning") + .readOnly() + .initialValue(False) + .commit(), + + BOOL_ELEMENT(schema) + .key("saturationMonitor.alarm") + .readOnly() + .initialValue(False) + .commit(), + + UINT32_ELEMENT(schema) + .key("saturationMonitor.warnCount") + .description( + "Total number of pixels above warning threshold. Each pixel " + "is only counted once even if it exceeds the threshold in " + "multiple frames (/ memory cells = given axis)." + ) + .readOnly() + .initialValue(0) + .commit(), + + UINT32_ELEMENT(schema) + .key("saturationMonitor.alarmCount") + .description( + "Total number of pixels above alarm threshold. Each pixel " + "is only counted once even if it exceeds the threshold in " + "multiple frames (/ memory cells = given axis)." + ) + .readOnly() + .initialValue(0) + .commit(), + + FLOAT_ELEMENT(schema) + .key("saturationMonitor.maxValue") + .description("Max pixel value in latest train with warning or alarm.") + .readOnly() + .initialValue(0) + .commit(), + + UINT64_ELEMENT(schema) + .key("saturationMonitor.trainId") + .readOnly() + .initialValue(0) + .commit(), + ) + return schema + + class SaturationMonitor(BaseCorrectionAddon): def __init__(self, config): global cupy @@ -44,56 +108,15 @@ class SaturationMonitor(BaseCorrectionAddon): if changed_config.has("frameAxis"): self._frameAxis = changed_config["frameAxis"] - @staticmethod - def extend_output_schema(schema): + @classmethod + def extend_device_schema(cls, schema, prefix): + # will add own output channel with this schema + output_schema = saturation_monitoring_schema() ( - NODE_ELEMENT(schema) - .key("saturationMonitor") - .commit(), - - BOOL_ELEMENT(schema) - .key("saturationMonitor.warning") - .readOnly() - .commit(), - - BOOL_ELEMENT(schema) - .key("saturationMonitor.alarm") - .readOnly() - .commit(), - - UINT32_ELEMENT(schema) - .key("saturationMonitor.warnCount") - .description( - "Total number of pixels above warning threshold. Each pixel " - "is only counted once even if it exceeds the threshold in " - "multiple frames (/ memory cells = given axis)." - ) - .readOnly() - .commit(), - - UINT32_ELEMENT(schema) - .key("saturationMonitor.alarmCount") - .description( - "Total number of pixels above alarm threshold. Each pixel " - "is only counted once even if it exceeds the threshold in " - "multiple frames (/ memory cells = given axis)." - ) - .readOnly() - .commit(), - - FLOAT_ELEMENT(schema) - .key("saturationMonitor.maxValue") - .readOnly() - .commit(), - - # TODO: switch to image data - IMAGEDATA_ELEMENT(schema) + IMAGEDATA_ELEMENT(output_schema) .key("saturationMonitor.alarmImage") .commit(), ) - - @staticmethod - def extend_device_schema(schema, prefix): ( FLOAT_ELEMENT(schema) .key(f"{prefix}.alarmThreshold") @@ -133,30 +156,43 @@ class SaturationMonitor(BaseCorrectionAddon): UINT64_ELEMENT(schema) .key(f"{prefix}.frameAxis") - .displayedName('Multi-frame axis') + .displayedName("Multi-frame axis") .description("Axis for frames. Used to take the max over this axis.") .tags("managed") .assignmentOptional() .defaultValue(0) .reconfigurable() .commit(), + + OUTPUT_CHANNEL(schema) + .key(f"{prefix}.output") + .dataSchema(output_schema) + .commit(), ) - def post_correction(self, processed_data, cell_table, pulse_table, output_hash): + def post_correction(self, train_id, data, cell_table, pulse_table, data_hash): + output = Hash() + # only take the max if data has frames -> more than 2 dimensions - if processed_data.ndim > 2: - max_image = np.nanmax(processed_data, axis=self._frameAxis) + if data.ndim > 2: + max_image = np.nanmax(data, axis=self._frameAxis) else: - max_image = processed_data + max_image = data nb_pix_warning = int(np.nansum(max_image > self._warnThreshold)) nb_pix_alarm = int(np.nansum(max_image > self._alarmThreshold)) - output_hash["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount - output_hash["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount - output_hash["saturationMonitor.warnCount"] = nb_pix_warning - output_hash["saturationMonitor.alarmCount"] = nb_pix_alarm - output_hash["saturationMonitor.maxValue"] = float(np.nanmax(max_image)) + output["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount + output["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount + output["saturationMonitor.warnCount"] = nb_pix_warning + output["saturationMonitor.alarmCount"] = nb_pix_alarm + output["saturationMonitor.maxValue"] = float(np.nanmax(max_image)) + output["saturationMonitor.trainId"] = train_id max_image[max_image <= self._warnThreshold] = 0 - output_hash["saturationMonitor.alarmImage"] = ImageData( + output["saturationMonitor.alarmImage"] = ImageData( maybe_get(max_image), Dims(*max_image.shape), Encoding.GRAY, bitsPerPixel=32 ) + self._device.writeChannel( + f"{self._prefix}.output", + output, + timestamp=Timestamp(Epochstamp(), Trainstamp(train_id)), + ) diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py index 6632b6ef..b3900c25 100644 --- a/src/calng/corrections/AgipdCorrection.py +++ b/src/calng/corrections/AgipdCorrection.py @@ -756,9 +756,7 @@ class AgipdCorrection(base_correction.BaseCorrection): OUTPUT_CHANNEL(expected) .key("dataOutput") .dataSchema( - AgipdCorrection._base_output_schema( - use_shmem_handles=cls._use_shmem_handles - ) + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) ) .commit(), ) -- GitLab