diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index 19169e2636f191d299146cbf2520410357d27eab..a9c5ecc623a31c6fda3dddaa3bd5f453d7a7e84a 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -240,8 +240,17 @@ class AgipdCorrection(BaseCorrection): compute preview, write data output, and optionally write preview outputs.""" # original shape: memory_cell, data/raw_gain, x, y - # TODO: add pulse filter back in pulse_table = np.squeeze(data_hash.get("image.pulseId")) + if self._frame_filter is not None: + try: + cell_table = cell_table[self._frame_filter] + pulse_table = pulse_table[self._frame_filter] + image_data = image_data[self._frame_filter] + except IndexError: + self.log_status_warn( + "Failed to apply frame filter, please check that it is valid!" + ) + return try: self.gpu_runner.load_data(image_data) @@ -362,19 +371,8 @@ class AgipdCorrection(BaseCorrection): self._update_correction_flags() self.log_status_info(f"Done loading {constant.name} to GPU") - def _update_pulse_filter(self, filter_string): - """Called whenever the pulse filter changes, typically followed by - _update_shapes""" - - if filter_string.strip() == "": - new_filter = np.arange(self.get("dataFormat.memoryCells"), dtype=np.uint16) - else: - new_filter = np.array(eval(filter_string), dtype=np.uint16) - assert np.max(new_filter) < self.get("dataFormat.memoryCells") - self.pulse_filter = new_filter - - def _update_shapes(self): - super()._update_shapes() + def _update_buffers(self): + super()._update_buffers() # TODO: pack four pixels per byte if self._schema_cache["sendGainMap"]: if self._shmem_buffer_gain_map is None: diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index 2db4d915d3c7cb99116b4715d44e3fa53f65e5bb..df287d78725e4ba96c84abfbbc3d25daec6c55b3 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -61,8 +61,17 @@ class DsscCorrection(BaseCorrection): cell_table, do_generate_preview, ): - # cell_table = cell_table[self.pulse_filter] - pulse_table = np.squeeze(data_hash.get("image.pulseId")) # [self.pulse_filter] + pulse_table = np.squeeze(data_hash.get("image.pulseId")) + if self._frame_filter is not None: + try: + cell_table = cell_table[self._frame_filter] + pulse_table = pulse_table[self._frame_filter] + image_data = image_data[self._frame_filter] + except IndexError: + self.log_status_warn( + "Failed to apply frame filter, please check that it is valid!" + ) + return try: self.gpu_runner.load_data(image_data) @@ -112,17 +121,6 @@ class DsscCorrection(BaseCorrection): source, ) - def _update_pulse_filter(self, filter_string): - """Called whenever the pulse filter changes, typically followed by - _update_shapes""" - - if filter_string.strip() == "": - new_filter = np.arange(self.get("dataFormat.memoryCells"), dtype=np.uint16) - else: - new_filter = np.array(eval(filter_string), dtype=np.uint16) - assert np.max(new_filter) < self.get("dataFormat.memoryCells") - self.pulse_filter = new_filter - def _load_constant_to_gpu(self, constant, constant_data): assert constant is DsscConstants.Offset self.gpu_runner.load_offset_map(constant_data) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 0a4f1eae228f87da17aed2892b0c0ef86c616764..c8a64eb44dea00bb86a13d53e657a2fe70f0331e 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -51,6 +51,7 @@ class BaseCorrection(PythonDevice): "dataFormat.outputAxisOrder", "dataFormat.outputImageDtype", "dataFormat.overrideInputAxisOrder", + "frameFilter", "preview.enable", "preview.index", "preview.selectionMode", @@ -60,6 +61,7 @@ class BaseCorrection(PythonDevice): _schema_cache_fields = { "doAnything", "constantParameters.memoryCells", + "dataFormat.filteredFrames", "dataFormat.memoryCells", "dataFormat.pixelsX", "dataFormat.pixelsY", @@ -85,7 +87,7 @@ class BaseCorrection(PythonDevice): axis_lengths = { "x": self._schema_cache["dataFormat.pixelsX"], "y": self._schema_cache["dataFormat.pixelsY"], - "c": self._schema_cache["dataFormat.memoryCells"], + "c": self._schema_cache["dataFormat.filteredFrames"], } return tuple( axis_lengths[axis] @@ -200,16 +202,20 @@ class BaseCorrection(PythonDevice): .defaultValue([]) .commit(), STRING_ELEMENT(expected) - .key("pulseFilter") - .displayedName("[disabled] Pulse filter") + .key("frameFilter") + .displayedName("Frame filter") .description( - "Filter pulses: will be evaluated as array of indices to keep from " - "data. Can be anything which can be turned into numpy uint16 array. " - "Numpy is available as np. Take care not to include duplicates. If " - "empty, will not filter at all." + "The frame filter - if set - slices the input data. Frames not in the " + "filter will be discarded before any processing happens (will not get " + "to dataOutput or preview. Note that this filter goes by frame index " + "rather than cell ID or pulse ID; set accordingly (and only if you " + "know what you are doing as it can break everything). Filter is " + "evaluated into numpy uint16 array; a valid filter could be " + "'np.arange(0, 352, 2)'." ) - .readOnly() - .initialValue("") + .assignmentOptional() + .defaultValue("") + .reconfigurable() .commit(), UINT32_ELEMENT(expected) .key("outputShmemBufferSize") @@ -293,6 +299,13 @@ class BaseCorrection(PythonDevice): .assignmentOptional() .noDefaultValue() .commit(), + UINT32_ELEMENT(expected) + .key("dataFormat.filteredFrames") + .displayedName("Frames after filter") + .description("Number of frames left after applying frame filter") + .readOnly() + .initialValue(0) + .commit(), STRING_ELEMENT(expected) .key("dataFormat.outputAxisOrder") .displayedName("Output axis order") @@ -471,8 +484,9 @@ class BaseCorrection(PythonDevice): self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype")) self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) - self.gpu_runner = None # must call _update_shapes() - self.registerInitialFunction(self._update_shapes) + self.gpu_runner = None # must call _update_buffers() + self.registerInitialFunction(self._update_frame_filter) + self.registerInitialFunction(self._update_buffers) self.calcat_friend = self._calcat_friend_class( self, pathlib.Path.cwd() / "calibration-client-secrets.json" @@ -482,7 +496,6 @@ class BaseCorrection(PythonDevice): self._correction_flag_preview = self._correction_flag_class.NONE self._shmem_buffer = None - self._has_updated_shapes = False self._processing_time_ema = utils.ExponentialMovingAverage(alpha=0.3) self._rate_tracker = utils.WindowRateTracker() @@ -557,6 +570,9 @@ class BaseCorrection(PythonDevice): self._schema_cache[path] = update.get(path) # TODO: pulse filter (after reimplementing) + if update.has("frameFilter"): + with self._buffer_lock: + self._update_frame_filter() if any( update.has(shape_param) for shape_param in ( @@ -564,9 +580,11 @@ class BaseCorrection(PythonDevice): "dataFormat.pixelsY", "dataFormat.memoryCells", "constantParameters.memoryCells", + "frameFilter", ) ): - self._update_shapes() + with self._buffer_lock: + self._update_buffers() # TODO: only call this if they are changed (is cheap, though) self._update_correction_flags() @@ -645,7 +663,6 @@ class BaseCorrection(PythonDevice): channel.update() def _write_combiner_previews(self, channel_data_pairs, train_id, source): - # TODO: take into account updated pulse table after pulse filter # TODO: send as ImageData (requires updated assembler) # TODO: allow sending *all* frames for commissioning (request: Jola) preview_hash = Hash() @@ -679,12 +696,27 @@ class BaseCorrection(PythonDevice): self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}") self.log.DEBUG(f"Corrections for preview: {str(preview)}") - def _update_shapes(self): + def _update_frame_filter(self): + filter_string = self.get("frameFilter") + if filter_string.strip() == "": + self._frame_filter = None + self.set("dataFormat.filteredFrames", self.get("dataFormat.memoryCells")) + else: + self._frame_filter = np.array(eval(filter_string), dtype=np.uint16) + self.set("dataFormat.filteredFrames", self._frame_filter.size) + if self._frame_filter.min() < 0 or self._frame_filter.max() >= self.get( + "dataFormat.memoryCells" + ): + self.log_status_warn("Invalid frame filter set, expect exceptions!") + + def _update_buffers(self): """(Re)initialize buffers according to expected data shapes""" - self.log.INFO("Updating shapes") + self.log.INFO("Updating buffers according to data shapes") # reflect the axis reordering in the expected output shape self.set("dataFormat.inputDataShape", list(self.input_data_shape)) self.set("dataFormat.outputDataShape", list(self.output_data_shape)) + self.log.INFO(f"Input shape: {self.input_data_shape}") + self.log.INFO(f"Output shape: {self.output_data_shape}") if self._shmem_buffer is None: shmem_buffer_name = self.getInstanceId() + ":dataOutput" @@ -704,7 +736,7 @@ class BaseCorrection(PythonDevice): self.gpu_runner = self._gpu_runner_class( self.get("dataFormat.pixelsX"), self.get("dataFormat.pixelsY"), - self.get("dataFormat.memoryCells"), + self.get("dataFormat.filteredFrames"), int(self.get("constantParameters.memoryCells")), input_data_dtype=self.input_data_dtype, output_data_dtype=self.output_data_dtype, @@ -719,8 +751,6 @@ class BaseCorrection(PythonDevice): self.log_status_info(f"Reload constant {constant}") self._load_constant_to_gpu(constant, data) - self._has_updated_shapes = True - def input_handler(self, input_channel): """Main handler for data input: Do a few simple checks to determine whether to even try processing. If yes, will pass data and information to subclass' @@ -779,8 +809,8 @@ class BaseCorrection(PythonDevice): ) self.set("dataFormat.memoryCells", image_data.shape[0]) with self._buffer_lock: - # TODO: pulse filter update after reimplementation - self._update_shapes() + # TODO: re-validate frame filter against new number of cells + self._update_buffers() # DataAggregator typically tells us the wrong axis order if self._schema_cache["dataFormat.overrideInputAxisOrder"]: