diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index adf8866e7d8e28198ad99711d400d23b61ce70b5..74d8c05880b81c155c9e13430460a49caf23cd81 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -35,16 +35,26 @@ class AgipdCorrection(BaseCorrection): _gpu_runner_class = AgipdGpuRunner _calcat_friend_class = AgipdCalcatFriend _constant_enum_class = AgipdConstants - _managed_keys = BaseCorrection._managed_keys[:] + _managed_keys = BaseCorrection._managed_keys[:] + ["overrideInputAxisOrder"] # this is just extending (not mandatory) - _schema_cache_fields = BaseCorrection._schema_cache_fields | {"sendGainMap"} + _schema_cache_fields = BaseCorrection._schema_cache_fields | { + "sendGainMap", + "overrideInputAxisOrder", + } @staticmethod def expectedParameters(expected): super(AgipdCorrection, AgipdCorrection).expectedParameters(expected) expected.setDefaultValue("dataFormat.memoryCells", 352) ( + BOOL_ELEMENT(expected) + .key("overrideInputAxisOrder") + .displayedName("Override input axis order") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), STRING_ELEMENT(expected) .key("gainMode") .displayedName("Gain mode") @@ -188,10 +198,10 @@ class AgipdCorrection(BaseCorrection): @property def input_data_shape(self): return ( - self.get("dataFormat.memoryCells"), + self._schema_cache["dataFormat.memoryCells"], 2, - self.get("dataFormat.pixelsX"), - self.get("dataFormat.pixelsY"), + self._schema_cache["dataFormat.pixelsX"], + self._schema_cache["dataFormat.pixelsY"], ) def __init__(self, config): @@ -269,6 +279,14 @@ class AgipdCorrection(BaseCorrection): and self._schema_cache["preview.enable"] ) + if self._schema_cache["overrideInputAxisOrder"]: + expected_shape = self.input_data_shape + if expected_shape != image_data.shape: + self.log.INFO( + f"Overriding input data order from {image_data.shape} to {expected_shape}" + ) + image_data.shape = expected_shape + with self._buffer_lock: # cell_table = cell_table[self.pulse_filter] pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] @@ -280,7 +298,12 @@ class AgipdCorrection(BaseCorrection): "corrected." ) - self.gpu_runner.load_data(image_data) + try: + self.gpu_runner.load_data(image_data) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") + return + buffer_handle, buffer_array = self._shmem_buffer.next_slot() self.gpu_runner.load_cell_table(cell_table) self.gpu_runner.correct(self._correction_flag_enabled) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 3f3d46860d4212c96d607ae37c3733b59a1d2503..11f68b429a7904dcc1e4797e6db1ed03cc853a54 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -57,8 +57,8 @@ class BaseCorrection(PythonDevice): ] # subclass must extend this and put it in schema _schema_cache_fields = { "doAnything", - "dataFormat.memoryCells", "constantParameters.memoryCells", + "dataFormat.memoryCells", "dataFormat.pixelsX", "dataFormat.pixelsY", "dataFormat.outputAxisOrder",