Skip to content
Snippets Groups Projects
Commit 01998b0f authored by David Hammer's avatar David Hammer
Browse files

Saturation monitor addon with own output channel

parent 8b4e056b
No related branches found
No related tags found
1 merge request!104Saturation monitor addon with own output channel
...@@ -15,6 +15,7 @@ from karabo.bound import ( ...@@ -15,6 +15,7 @@ from karabo.bound import (
Trainstamp, Trainstamp,
) )
from .DetectorAssembler import DetectorAssembler from .DetectorAssembler import DetectorAssembler
from .correction_addons.saturation_monitor import saturation_monitoring_schema
from ._version import version as deviceVersion from ._version import version as deviceVersion
...@@ -22,76 +23,17 @@ from ._version import version as deviceVersion ...@@ -22,76 +23,17 @@ from ._version import version as deviceVersion
class SaturationWarningAggregator(DetectorAssembler): class SaturationWarningAggregator(DetectorAssembler):
@staticmethod @staticmethod
def expectedParameters(expected): def expectedParameters(expected):
saturation_monitoring_schema(expected)
( (
OVERWRITE_ELEMENT(expected) OVERWRITE_ELEMENT(expected)
.key("imageDataPath") .key("imageDataPath")
.setNewDefaultValue("saturationMonitor.maxImage") .setNewDefaultValue("saturationMonitor.alarmImage")
.commit(), .commit(),
OVERWRITE_ELEMENT(expected) OVERWRITE_ELEMENT(expected)
.key("imageMaskPath") .key("imageMaskPath")
.setNewDefaultValue("") .setNewDefaultValue("")
.commit(), .commit(),
# The reason for the node is compatibility with the saturation
# monitor add-on in calng. That way a MDL device can use this
# aggregator, the add-on or the SaturationMonitor from the
# ImageProcessor package with the same code
NODE_ELEMENT(expected)
.key("saturationMonitor")
.commit(),
BOOL_ELEMENT(expected)
.key("saturationMonitor.warning")
.readOnly()
.initialValue(False)
.commit(),
BOOL_ELEMENT(expected)
.key("saturationMonitor.alarm")
.readOnly()
.initialValue(False)
.commit(),
UINT32_ELEMENT(expected)
.key("saturationMonitor.warnCount")
.description(
"Total number of pixels above warning threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
UINT32_ELEMENT(expected)
.key("saturationMonitor.alarmCount")
.description(
"Total number of pixels above alarm threshold. Each pixel is "
"only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
UINT64_ELEMENT(expected)
.key("saturationMonitor.trainId")
.description(
"Total number of pixels above alarm threshold. Each pixel is "
"only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
FLOAT_ELEMENT(expected)
.key("saturationMonitor.maxValue")
.description("Max pixel value in latest train with warning or alarm.")
.readOnly()
.initialValue(0)
.commit(),
) )
def on_matched_data(self, tid, sources): def on_matched_data(self, tid, sources):
...@@ -162,7 +104,7 @@ class SaturationWarningAggregator(DetectorAssembler): ...@@ -162,7 +104,7 @@ class SaturationWarningAggregator(DetectorAssembler):
coords=coords, coords=coords,
) )
) )
self._preview_friend.write_outputs(my_timestamp, assembled_data) self._preview_friend.write_outputs(assembled_data, timestamp=my_timestamp)
self.info["sent"] += 1 self.info["sent"] += 1
self.info["trainId"] = tid self.info["trainId"] = tid
......
...@@ -557,19 +557,21 @@ class BaseCorrection(PythonDevice): ...@@ -557,19 +557,21 @@ class BaseCorrection(PythonDevice):
self.KARABO_ON_DATA("dataInput", self.input_handler) self.KARABO_ON_DATA("dataInput", self.input_handler)
self.KARABO_ON_EOS("dataInput", self.handle_eos) self.KARABO_ON_EOS("dataInput", self.handle_eos)
self._enabled_addons = [ self._enabled_addons = []
addon_class(self._parameters[f"addons.{addon_class.__name__}"]) for addon_class in self._available_addons:
for addon_class in self._available_addons addon_prefix = f"addons.{addon_class.__name__}"
if self.get(f"addons.{addon_class.__name__}.enable") if not self.get(f"{addon_prefix}.enable"):
] continue
for addon in self._enabled_addons: addon = addon_class(self._parameters[addon_prefix])
addon._device = self addon._device = self
addon._prefix = addon_prefix
self._enabled_addons.append(addon)
if ( if (
(self.get("useShmemHandles") != self._use_shmem_handles) (self.get("useShmemHandles") != self._use_shmem_handles)
or self._enabled_addons or self._enabled_addons
): ):
schema_override = Schema() schema_override = Schema()
output_schema_override = self._base_output_schema( output_schema_override = self.__class__._base_output_schema(
use_shmem_handles=self.get("useShmemHandles") use_shmem_handles=self.get("useShmemHandles")
) )
for addon in self._enabled_addons: for addon in self._enabled_addons:
...@@ -898,7 +900,11 @@ class BaseCorrection(PythonDevice): ...@@ -898,7 +900,11 @@ class BaseCorrection(PythonDevice):
data_hash["corrections"] = corrections data_hash["corrections"] = corrections
for addon in self._enabled_addons: for addon in self._enabled_addons:
addon.post_correction( addon.post_correction(
processed_buffer, cell_table, pulse_table, data_hash timestamp.getTrainId(),
processed_buffer,
cell_table,
pulse_table,
data_hash,
) )
self.kernel_runner.reshape(processed_buffer, out=buffer_array) self.kernel_runner.reshape(processed_buffer, out=buffer_array)
...@@ -915,7 +921,7 @@ class BaseCorrection(PythonDevice): ...@@ -915,7 +921,7 @@ class BaseCorrection(PythonDevice):
for addon in self._enabled_addons: for addon in self._enabled_addons:
addon.post_reshape( addon.post_reshape(
buffer_array, cell_table, pulse_table, data_hash timestamp.getTrainId(), buffer_array, cell_table, pulse_table, data_hash
) )
if self.unsafe_get("useShmemHandles"): if self.unsafe_get("useShmemHandles"):
......
...@@ -25,12 +25,16 @@ class BaseCorrectionAddon: ...@@ -25,12 +25,16 @@ class BaseCorrectionAddon:
device)""" device)"""
return self._device._geometry return self._device._geometry
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(
self, train_id, processed_data, cell_table, pulse_table, output_hash
):
"""Called directly after correction has happened. Processed data will still be """Called directly after correction has happened. Processed data will still be
on GPU if the correction device is generally running in GPU mode.""" on GPU if the correction device is generally running in GPU mode."""
pass pass
def post_reshape(self, reshaped_data, cell_table, pulse_table, output_hash): def post_reshape(
self, train_id, reshaped_data, cell_table, pulse_table, output_hash
):
pass pass
def reconfigure(self, changed_config): def reconfigure(self, changed_config):
......
...@@ -68,7 +68,7 @@ class IntegratedIntensity(BaseCorrectionAddon): ...@@ -68,7 +68,7 @@ class IntegratedIntensity(BaseCorrectionAddon):
.commit() .commit()
) )
def post_correction(self, data, cell_table, pulse_table, output_hash): def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
# Numpy should handle CuPy dispatch for us # Numpy should handle CuPy dispatch for us
mask = np.isfinite(data) & (self._vmin < data) & (data < self._vmax) mask = np.isfinite(data) & (self._vmin < data) & (data < self._vmax)
......
...@@ -46,9 +46,9 @@ class LitPixelCounter(BaseCorrectionAddon): ...@@ -46,9 +46,9 @@ class LitPixelCounter(BaseCorrectionAddon):
.commit() .commit()
) )
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
n_cells, n_x, n_y = processed_data.shape n_cells, n_x, n_y = data.shape
per_asic_data = processed_data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64) per_asic_data = data.reshape(n_cells, 64, n_x // 64, 64, n_y // 64)
lit_pixels = np.sum(per_asic_data > self._threshold, axis=(1, 3)) lit_pixels = np.sum(per_asic_data > self._threshold, axis=(1, 3))
unmasked_pixels = np.isfinite(per_asic_data).sum(axis=(1, 3)) unmasked_pixels = np.isfinite(per_asic_data).sum(axis=(1, 3))
......
...@@ -124,9 +124,9 @@ class Peakfinder9(BaseCorrectionAddon): ...@@ -124,9 +124,9 @@ class Peakfinder9(BaseCorrectionAddon):
.commit(), .commit(),
) )
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(self, data, train_id, cell_table, pulse_table, output_hash):
# assumes processed data shape is frames, pixels, pixels # assumes processed data shape is frames, pixels, pixels
if self._input_shape != processed_data.shape: if self._input_shape != data.shape:
try: try:
del self._peakfinding_parameters del self._peakfinding_parameters
except AttributeError: except AttributeError:
...@@ -135,7 +135,7 @@ class Peakfinder9(BaseCorrectionAddon): ...@@ -135,7 +135,7 @@ class Peakfinder9(BaseCorrectionAddon):
del self._grid_and_block del self._grid_and_block
except AttributeError: except AttributeError:
pass pass
self._input_shape = processed_data.shape self._input_shape = data.shape
self._rebuild_buffers() self._rebuild_buffers()
kernel_params = self._peakfinding_parameters # this will create buffers kernel_params = self._peakfinding_parameters # this will create buffers
self._peak_counts.fill(0) self._peak_counts.fill(0)
...@@ -143,7 +143,7 @@ class Peakfinder9(BaseCorrectionAddon): ...@@ -143,7 +143,7 @@ class Peakfinder9(BaseCorrectionAddon):
*self._grid_and_block, *self._grid_and_block,
( (
*kernel_params, *kernel_params,
processed_data.astype(cupy.float32, copy=False), data.astype(cupy.float32, copy=False),
self._peak_counts, self._peak_counts,
self._peak_x, self._peak_x,
self._peak_y, self._peak_y,
......
...@@ -34,7 +34,7 @@ class RandomFrames(BaseCorrectionAddon): ...@@ -34,7 +34,7 @@ class RandomFrames(BaseCorrectionAddon):
# TODO: figure out why no / 100 here... # TODO: figure out why no / 100 here...
self._probability = config["probability"] self._probability = config["probability"]
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(self, tid, data, cell_table, pulse_table, output_hash):
output_hash["data.dataFramePattern"] = ( output_hash["data.dataFramePattern"] = (
np.random.random(cell_table.size) < self._probability np.random.random(cell_table.size) < self._probability
).astype(np.uint8) ).astype(np.uint8)
......
...@@ -5,11 +5,17 @@ from karabo.bound import ( ...@@ -5,11 +5,17 @@ from karabo.bound import (
FLOAT_ELEMENT, FLOAT_ELEMENT,
IMAGEDATA_ELEMENT, IMAGEDATA_ELEMENT,
NODE_ELEMENT, NODE_ELEMENT,
OUTPUT_CHANNEL,
UINT32_ELEMENT, UINT32_ELEMENT,
UINT64_ELEMENT, UINT64_ELEMENT,
Dims, Dims,
Encoding, Encoding,
Epochstamp,
Hash,
ImageData, ImageData,
Schema,
Timestamp,
Trainstamp,
) )
from .base_addon import BaseCorrectionAddon from .base_addon import BaseCorrectionAddon
...@@ -22,6 +28,64 @@ def maybe_get(a): ...@@ -22,6 +28,64 @@ def maybe_get(a):
return a return a
def saturation_monitoring_schema(schema=None):
if schema is None:
schema = Schema()
(
NODE_ELEMENT(schema)
.key("saturationMonitor")
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.warning")
.readOnly()
.initialValue(False)
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.alarm")
.readOnly()
.initialValue(False)
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.warnCount")
.description(
"Total number of pixels above warning threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.initialValue(0)
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.alarmCount")
.description(
"Total number of pixels above alarm threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.initialValue(0)
.commit(),
FLOAT_ELEMENT(schema)
.key("saturationMonitor.maxValue")
.description("Max pixel value in latest train with warning or alarm.")
.readOnly()
.initialValue(0)
.commit(),
UINT64_ELEMENT(schema)
.key("saturationMonitor.trainId")
.readOnly()
.initialValue(0)
.commit(),
)
return schema
class SaturationMonitor(BaseCorrectionAddon): class SaturationMonitor(BaseCorrectionAddon):
def __init__(self, config): def __init__(self, config):
global cupy global cupy
...@@ -44,56 +108,15 @@ class SaturationMonitor(BaseCorrectionAddon): ...@@ -44,56 +108,15 @@ class SaturationMonitor(BaseCorrectionAddon):
if changed_config.has("frameAxis"): if changed_config.has("frameAxis"):
self._frameAxis = changed_config["frameAxis"] self._frameAxis = changed_config["frameAxis"]
@staticmethod @classmethod
def extend_output_schema(schema): def extend_device_schema(cls, schema, prefix):
# will add own output channel with this schema
output_schema = saturation_monitoring_schema()
( (
NODE_ELEMENT(schema) IMAGEDATA_ELEMENT(output_schema)
.key("saturationMonitor")
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.warning")
.readOnly()
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.alarm")
.readOnly()
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.warnCount")
.description(
"Total number of pixels above warning threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.alarmCount")
.description(
"Total number of pixels above alarm threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.commit(),
FLOAT_ELEMENT(schema)
.key("saturationMonitor.maxValue")
.readOnly()
.commit(),
# TODO: switch to image data
IMAGEDATA_ELEMENT(schema)
.key("saturationMonitor.alarmImage") .key("saturationMonitor.alarmImage")
.commit(), .commit(),
) )
@staticmethod
def extend_device_schema(schema, prefix):
( (
FLOAT_ELEMENT(schema) FLOAT_ELEMENT(schema)
.key(f"{prefix}.alarmThreshold") .key(f"{prefix}.alarmThreshold")
...@@ -133,30 +156,43 @@ class SaturationMonitor(BaseCorrectionAddon): ...@@ -133,30 +156,43 @@ class SaturationMonitor(BaseCorrectionAddon):
UINT64_ELEMENT(schema) UINT64_ELEMENT(schema)
.key(f"{prefix}.frameAxis") .key(f"{prefix}.frameAxis")
.displayedName('Multi-frame axis') .displayedName("Multi-frame axis")
.description("Axis for frames. Used to take the max over this axis.") .description("Axis for frames. Used to take the max over this axis.")
.tags("managed") .tags("managed")
.assignmentOptional() .assignmentOptional()
.defaultValue(0) .defaultValue(0)
.reconfigurable() .reconfigurable()
.commit(), .commit(),
OUTPUT_CHANNEL(schema)
.key(f"{prefix}.output")
.dataSchema(output_schema)
.commit(),
) )
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(self, train_id, data, cell_table, pulse_table, data_hash):
output = Hash()
# only take the max if data has frames -> more than 2 dimensions # only take the max if data has frames -> more than 2 dimensions
if processed_data.ndim > 2: if data.ndim > 2:
max_image = np.nanmax(processed_data, axis=self._frameAxis) max_image = np.nanmax(data, axis=self._frameAxis)
else: else:
max_image = processed_data max_image = data
nb_pix_warning = int(np.nansum(max_image > self._warnThreshold)) nb_pix_warning = int(np.nansum(max_image > self._warnThreshold))
nb_pix_alarm = int(np.nansum(max_image > self._alarmThreshold)) nb_pix_alarm = int(np.nansum(max_image > self._alarmThreshold))
output_hash["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount output["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount
output_hash["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount output["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount
output_hash["saturationMonitor.warnCount"] = nb_pix_warning output["saturationMonitor.warnCount"] = nb_pix_warning
output_hash["saturationMonitor.alarmCount"] = nb_pix_alarm output["saturationMonitor.alarmCount"] = nb_pix_alarm
output_hash["saturationMonitor.maxValue"] = float(np.nanmax(max_image)) output["saturationMonitor.maxValue"] = float(np.nanmax(max_image))
output["saturationMonitor.trainId"] = train_id
max_image[max_image <= self._warnThreshold] = 0 max_image[max_image <= self._warnThreshold] = 0
output_hash["saturationMonitor.alarmImage"] = ImageData( output["saturationMonitor.alarmImage"] = ImageData(
maybe_get(max_image), Dims(*max_image.shape), Encoding.GRAY, bitsPerPixel=32 maybe_get(max_image), Dims(*max_image.shape), Encoding.GRAY, bitsPerPixel=32
) )
self._device.writeChannel(
f"{self._prefix}.output",
output,
timestamp=Timestamp(Epochstamp(), Trainstamp(train_id)),
)
...@@ -756,9 +756,7 @@ class AgipdCorrection(base_correction.BaseCorrection): ...@@ -756,9 +756,7 @@ class AgipdCorrection(base_correction.BaseCorrection):
OUTPUT_CHANNEL(expected) OUTPUT_CHANNEL(expected)
.key("dataOutput") .key("dataOutput")
.dataSchema( .dataSchema(
AgipdCorrection._base_output_schema( cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
use_shmem_handles=cls._use_shmem_handles
)
) )
.commit(), .commit(),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment