From 15d2116f299cefb556b544eba4c8102567d36cb0 Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Wed, 17 Nov 2021 23:35:10 +0100 Subject: [PATCH] Restructure input handler, DRY slightly --- src/calng/AgipdCorrection.py | 176 ++++++++++++----------------------- src/calng/DsscCorrection.py | 138 ++++++++++----------------- src/calng/base_correction.py | 139 +++++++++++++++++++++++---- 3 files changed, 227 insertions(+), 226 deletions(-) diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index e4d4be15..19169e26 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -1,5 +1,3 @@ -import timeit - import numpy as np from karabo.bound import ( BOOL_ELEMENT, @@ -36,14 +34,12 @@ class AgipdCorrection(BaseCorrection): _calcat_friend_class = AgipdCalcatFriend _constant_enum_class = AgipdConstants _managed_keys = BaseCorrection._managed_keys | { - "overrideInputAxisOrder", "sendGainMap", } # this is just extending (not mandatory) _schema_cache_fields = BaseCorrection._schema_cache_fields | { "sendGainMap", - "overrideInputAxisOrder", } @staticmethod @@ -52,13 +48,6 @@ class AgipdCorrection(BaseCorrection): expected.setDefaultValue("dataFormat.memoryCells", 352) expected.setDefaultValue("preview.selectionMode", "cell") ( - BOOL_ELEMENT(expected) - .key("overrideInputAxisOrder") - .displayedName("Override input axis order") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), STRING_ELEMENT(expected) .key("gainMode") .displayedName("Gain mode") @@ -222,7 +211,6 @@ class AgipdCorrection(BaseCorrection): } self._shmem_buffer_gain_map = None - self._update_shapes() # configurability: overriding md_additional_offset if config.get("corrections.relGainPc.overrideMdAdditionalOffset"): @@ -234,125 +222,89 @@ class AgipdCorrection(BaseCorrection): # configurability: disabling subset of bad pixel masking bits self._has_updated_bad_pixel_selection = False - self._update_bad_pixel_selection() + self.registerInitialFunction(self._update_bad_pixel_selection) self.updateState(State.ON) - def process_input(self, data, metadata): - """Registered for dataInput, handles all processing and sending""" - - source = metadata.get("source") - - if source not in self.sources: - self.log_status_info(f"Ignoring hash with unknown source {source}") - return - - if not data.has("image"): - self.log_status_info("Ignoring hash without image node") - return + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + """Called by input_handler for each data hash. Should correct data, optionally + compute preview, write data output, and optionally write preview outputs.""" + # original shape: memory_cell, data/raw_gain, x, y - time_start = timeit.default_timer() - self._last_processing_started = time_start + # TODO: add pulse filter back in + pulse_table = np.squeeze(data_hash.get("image.pulseId")) - train_id = metadata.getAttribute("timestamp", "tid") - cell_table = np.squeeze(data.get("image.cellId")) - if len(cell_table.shape) == 0: - self.log_status_warn( - "cellId had 0 dimensions. DAQ may not be sending data." - ) + try: + self.gpu_runner.load_data(image_data) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") return - # original shape: memory_cell, data/raw_gain, x, y - image_data = data.get("image.data") - if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: - self.log_status_info( - f"Updating input shapes based on received {image_data.shape}" - ) - self.set("dataFormat.memoryCells", image_data.shape[0]) - with self._buffer_lock: - # TODO: pulse filter update after reimplementation - self._update_shapes() - - if not self._schema_cache["state"] is State.PROCESSING: - self.updateState(State.PROCESSING) - self.log_status_info("Processing data") - - correction_cell_num = self._schema_cache["constantParameters.memoryCells"] - do_generate_preview = ( - train_id % self._schema_cache["preview.trainIdModulo"] == 0 - and self._schema_cache["preview.enable"] + except Exception as e: + self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") + + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + self.gpu_runner.load_cell_table(cell_table) + self.gpu_runner.correct(self._correction_flag_enabled) + self.gpu_runner.reshape( + output_order=self._schema_cache["dataFormat.outputAxisOrder"], + out=buffer_array, ) - - if self._schema_cache["overrideInputAxisOrder"]: - expected_shape = self.input_data_shape - if expected_shape != image_data.shape: - image_data.shape = expected_shape - - with self._buffer_lock: - # cell_table = cell_table[self.pulse_filter] - pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] - cell_table_max = np.max(cell_table) - if cell_table_max >= correction_cell_num: - self.log_status_info( - f"Max cell ID ({cell_table_max}) exceeds range for loaded " - f"constants ({correction_cell_num} cells). Some frames will not be " - "corrected." - ) - - try: - self.gpu_runner.load_data(image_data) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - return - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct(self._correction_flag_enabled) - self.gpu_runner.reshape( - output_order=self._schema_cache["dataFormat.outputAxisOrder"], - out=buffer_array, + # after reshape, data for dataOutput is now safe in its own buffer + if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.gpu_runner.correct(self._correction_flag_preview) + ( + preview_slice_index, + preview_cell, + preview_pulse, + ) = utils.pick_frame_index( + self._schema_cache["preview.selectionMode"], + self._schema_cache["preview.index"], + cell_table, + pulse_table, + warn_func=self.log_status_warn, ) - # after reshape, data for dataOutput is now safe in its own buffer - if do_generate_preview: - if self._correction_flag_enabled != self._correction_flag_preview: - self.gpu_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ) = utils.pick_frame_index( - self._schema_cache["preview.selectionMode"], - self._schema_cache["preview.index"], - cell_table, - pulse_table, - warn_func=self.log_status_warn, - ) - preview_raw, preview_corrected = self.gpu_runner.compute_preview( + ( + preview_raw, + preview_corrected, + ) = self.gpu_runner.compute_preview(preview_slice_index) + if self._schema_cache["sendGainMap"]: + preview_gain = self.gpu_runner.compute_preview_gain( preview_slice_index ) - if self._schema_cache["sendGainMap"]: - preview_gain = self.gpu_runner.compute_preview_gain( - preview_slice_index - ) - data.set("image.data", buffer_handle) + # reusing input data hash for sending + data_hash.set("image.data", buffer_handle) if self._schema_cache["sendGainMap"]: - buffer_handle, buffer_array = self._shmem_buffer_gain_map.next_slot() + ( + buffer_handle, + buffer_array, + ) = self._shmem_buffer_gain_map.next_slot() self.gpu_runner.get_gain_map( output_order=self._schema_cache["dataFormat.outputAxisOrder"], out=buffer_array, ) - data.set( + data_hash.set( "image.gainMap", buffer_handle, ) - data.set("calngShmemPaths", ["image.data", "image.gainMap"]) + data_hash.set("calngShmemPaths", ["image.data", "image.gainMap"]) else: - data.set("calngShmemPaths", ["image.data"]) + data_hash.set("calngShmemPaths", ["image.data"]) - data.set("image.cellId", cell_table[:, np.newaxis]) - data.set("image.pulseId", pulse_table[:, np.newaxis]) + data_hash.set("image.cellId", cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) - self._write_output(data, metadata) + self._write_output(data_hash, metadata) if do_generate_preview: if self._schema_cache["sendGainMap"]: self._write_combiner_previews( @@ -375,12 +327,6 @@ class AgipdCorrection(BaseCorrection): source, ) - # update rate etc. - self._buffered_status_update.set("trainId", train_id) - self._rate_tracker.update() - time_spent = timeit.default_timer() - time_start - self._processing_time_ema.update(time_spent) - def _load_constant_to_gpu(self, constant, constant_data): # TODO: encode correction / constant dependencies in a clever way if constant is AgipdConstants.ThresholdsDark: diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index 2b74bed7..2db4d915 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -1,5 +1,3 @@ -import timeit - import numpy as np from karabo.bound import KARABO_CLASSINFO, VECTOR_STRING_ELEMENT from karabo.common.states import State @@ -53,93 +51,57 @@ class DsscCorrection(BaseCorrection): super().__init__(config) self.updateState(State.ON) - def process_input(self, data, metadata): - """Registered for dataInput, handles all processing and sending""" - - source = metadata.get("source") - - if source not in self.sources: - self.log_status_info(f"Ignoring hash with unknown source {source}") - return - - if not data.has("image"): - self.log_status_info("Ignoring hash without image node") - return - - time_start = timeit.default_timer() - self._last_processing_started = time_start - - train_id = metadata.getAttribute("timestamp", "tid") - cell_table = np.squeeze(data.get("image.cellId")) - assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray" - if len(cell_table.shape) == 0: - self.log_status_warn( - "cellId had 0 dimensions. DAQ may not be sending data." - ) + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + # cell_table = cell_table[self.pulse_filter] + pulse_table = np.squeeze(data_hash.get("image.pulseId")) # [self.pulse_filter] + + try: + self.gpu_runner.load_data(image_data) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") return - # original shape: 400, 1, 128, 512 (memory cells, something, y, x) - image_data = data.get("image.data") - if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: - self.log_status_info( - f"Updating input shapes based on received {image_data.shape}" - ) - self.set("dataFormat.memoryCells", image_data.shape[0]) - with self._buffer_lock: - # TODO: pulse filter update after reimplementation - self._update_shapes() - - if not self._schema_cache["state"] is State.PROCESSING: - self.updateState(State.PROCESSING) - self.log_status_info("Processing data") - - correction_cell_num = self._schema_cache["constantParameters.memoryCells"] - do_generate_preview = ( - train_id % self._schema_cache["preview.trainIdModulo"] == 0 - and self._schema_cache["preview.enable"] + except Exception as e: + self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") + + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + self.gpu_runner.load_cell_table(cell_table) + self.gpu_runner.correct(self._correction_flag_enabled) + self.gpu_runner.reshape( + output_order=self._schema_cache["dataFormat.outputAxisOrder"], + out=buffer_array, ) - - with self._buffer_lock: - # cell_table = cell_table[self.pulse_filter] - pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] - cell_table_max = np.max(cell_table) - if cell_table_max >= correction_cell_num: - self.log_status_info( - f"Max cell ID ({cell_table_max}) exceeds range for loaded " - f"constant ({correction_cell_num} cells). Some frames will not be " - "corrected." - ) - - self.gpu_runner.load_data(image_data) - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct(self._correction_flag_enabled) - self.gpu_runner.reshape( - output_order=self._schema_cache["dataFormat.outputAxisOrder"], - out=buffer_array, + if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.gpu_runner.correct(self._correction_flag_preview) + ( + preview_slice_index, + preview_cell, + preview_pulse, + ) = utils.pick_frame_index( + self._schema_cache["preview.selectionMode"], + self._schema_cache["preview.index"], + cell_table, + pulse_table, + warn_func=self.log_status_warn, + ) + preview_raw, preview_corrected = self.gpu_runner.compute_preview( + preview_slice_index, ) - if do_generate_preview: - if self._correction_flag_enabled != self._correction_flag_preview: - self.gpu_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ) = utils.pick_frame_index( - self._schema_cache["preview.selectionMode"], - self._schema_cache["preview.index"], - cell_table, - pulse_table, - warn_func=self.log_status_warn, - ) - preview_raw, preview_corrected = self.gpu_runner.compute_preview( - preview_slice_index, - ) - data.set("image.data", buffer_handle) - data.set("image.cellId", cell_table[:, np.newaxis]) - data.set("image.pulseId", pulse_table[:, np.newaxis]) - data.set("calngShmemPaths", ["image.data"]) - self._write_output(data, metadata) + data_hash.set("image.data", buffer_handle) + data_hash.set("image.cellId", cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) + data_hash.set("calngShmemPaths", ["image.data"]) + self._write_output(data_hash, metadata) if do_generate_preview: self._write_combiner_previews( ( @@ -150,12 +112,6 @@ class DsscCorrection(BaseCorrection): source, ) - # update rate etc. - self._buffered_status_update.set("trainId", train_id) - self._rate_tracker.update() - time_spent = timeit.default_timer() - time_start - self._processing_time_ema.update(time_spent) - def _update_pulse_filter(self, filter_string): """Called whenever the pulse filter changes, typically followed by _update_shapes""" diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 2dccd1e2..47f695f0 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -1,12 +1,12 @@ import pathlib import threading -import timeit +from timeit import default_timer import dateutil.parser import numpy as np from karabo.bound import ( BOOL_ELEMENT, - FLOAT_ELEMENT, + DOUBLE_ELEMENT, INPUT_CHANNEL, INT32_ELEMENT, INT64_ELEMENT, @@ -50,6 +50,7 @@ class BaseCorrection(PythonDevice): "outputShmemBufferSize", "dataFormat.outputAxisOrder", "dataFormat.outputImageDtype", + "dataFormat.overrideInputAxisOrder", "preview.enable", "preview.index", "preview.selectionMode", @@ -63,6 +64,7 @@ class BaseCorrection(PythonDevice): "dataFormat.pixelsX", "dataFormat.pixelsY", "dataFormat.outputAxisOrder", + "dataFormat.overrideInputAxisOrder", "preview.enable", "preview.index", "preview.selectionMode", @@ -235,6 +237,20 @@ class BaseCorrection(PythonDevice): .key("dataFormat") .displayedName("Data format (in/out)") .commit(), + BOOL_ELEMENT(expected) + .key("dataFormat.overrideInputAxisOrder") + .displayedName("Override input axis order") + .description( + "The shape of the image data ndarray as received from the " + "DataAggregator is sometimes wrong - the axes are actually in a " + "different order than the ndarray shape suggests. If this flag is on, " + "the shape of the ndarray will be overridden with the axis order we " + "expect." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), STRING_ELEMENT(expected) .key("dataFormat.inputImageDtype") .displayedName("Input image data dtype") @@ -368,8 +384,8 @@ class BaseCorrection(PythonDevice): .description( "The value of preview.index can be used in multiple ways, controlled " "by this value. If this is set to 'frame', preview.index is sliced " - "directly from data. If 'cell' (or 'pulse') is selected, I will look at " - "cell (or pulse) table for the requested cell (or pulse ID)." + "directly from data. If 'cell' (or 'pulse') is selected, I will look " + "at cell (or pulse) table for the requested cell (or pulse ID)." ) .options("frame,cell,pulse") .assignmentOptional() @@ -404,13 +420,9 @@ class BaseCorrection(PythonDevice): .key("performance") .displayedName("Performance measures") .commit(), - FLOAT_ELEMENT(expected) - .key("performance.processingDuration") + DOUBLE_ELEMENT(expected) + .key("performance.processingTime") .displayedName("Processing time") - .description( - "Exponential moving average over time spent processing individual " - "trains Time includes generating preview and sending data." - ) .unit(Unit.SECOND) .metricPrefix(MetricPrefix.MILLI) .readOnly() @@ -419,7 +431,7 @@ class BaseCorrection(PythonDevice): .info("Processing not fast enough for full speed") .needsAcknowledging(False) .commit(), - FLOAT_ELEMENT(expected) + DOUBLE_ELEMENT(expected) .key("performance.rate") .displayedName("Rate") .description( @@ -445,20 +457,22 @@ class BaseCorrection(PythonDevice): k: config.get(k) for k in self._schema_cache_fields if config.has(k) } super().__init__(config) + self.updateState(State.INIT) if not sorted(config.get("dataFormat.outputAxisOrder")) == ["c", "x", "y"]: # TODO: figure out how to get this information to operator self.log_status_error("Invalid output axis order string") return - self.KARABO_ON_DATA("dataInput", self.process_input) + self.KARABO_ON_INPUT("dataInput", self.input_handler) self.KARABO_ON_EOS("dataInput", self.handle_eos) self.sources = set(config.get("fastSources")) self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype")) self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) - self.gpu_runner = None # must call _update_shapes() in subclass init + self.gpu_runner = None # must call _update_shapes() + self.registerInitialFunction(self._update_shapes) self.calcat_friend = self._calcat_friend_class( self, pathlib.Path.cwd() / "calibration-client-secrets.json" @@ -477,10 +491,10 @@ class BaseCorrection(PythonDevice): 0, "performance.rate", 0, - "performance.processingDuration", + "performance.processingTime", 0, ) - self._last_processing_started = 0 # not input handler should put timestamp + self._last_processing_started = 0 # input handler should put timestamp self._rate_update_timer = utils.RepeatingTimer( interval=1, callback=self._update_rate_and_state, @@ -697,18 +711,103 @@ class BaseCorrection(PythonDevice): self._has_updated_shapes = True + def input_handler(self, input_channel): + """Main handler for data input: Do a few simple checks to determine whether to + even try processing. If yes, will pass data and information to subclass' + process_data function. + """ + + # Is device even ready for this? + state = self._schema_cache["state"] + if state is State.ERROR: + # in this case, we should have already issued warning + return + elif self.gpu_runner is None: + self.log_status_warn("Received data, but have not initialized kernels yet") + return + + all_metadata = input_channel.getMetaData() + for input_index in range(input_channel.size()): + self._last_processing_started = default_timer() + data_hash = input_channel.read(input_index) + metadata = all_metadata[input_index] + source = metadata.get("source") + + if source not in self.sources: + self.log_status_info(f"Ignoring hash with unknown source {source}") + return + elif not data_hash.has("image"): + self.log_status_info("Ignoring hash without image node") + return + + train_id = metadata.getAttribute("timestamp", "tid") + cell_table = np.squeeze(data_hash.get("image.cellId")) + if len(cell_table.shape) == 0: + self.log_status_warn( + "cellId had 0 dimensions. DAQ may not be sending data." + ) + return + + # no more common reasons to skip input, so go to processing + if state is State.ON: + self.updateState(State.PROCESSING) + self.log_status_info("Processing data") + + correction_cell_num = self._schema_cache["constantParameters.memoryCells"] + cell_table_max = np.max(cell_table) + if cell_table_max >= correction_cell_num: + self.log_status_info( + f"Max cell ID ({cell_table_max}) exceeds range for loaded " + f"constants ({correction_cell_num} cells). Some frames will not be " + "corrected." + ) + + image_data = data_hash.get("image.data") + if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: + self.log_status_info( + f"Updating new input shape {image_data.shape}, updating buffers" + ) + self.set("dataFormat.memoryCells", image_data.shape[0]) + with self._buffer_lock: + # TODO: pulse filter update after reimplementation + self._update_shapes() + + # DataAggregator typically tells us the wrong axis order + if self._schema_cache["dataFormat.overrideInputAxisOrder"]: + expected_shape = self.input_data_shape + if expected_shape != image_data.shape: + image_data.shape = expected_shape + + do_generate_preview = ( + train_id % self._schema_cache["preview.trainIdModulo"] == 0 + and self._schema_cache["preview.enable"] + ) + + with self._buffer_lock: + self.process_data( + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ) + self._buffered_status_update.set("trainId", train_id) + self._processing_time_ema.update( + default_timer() - self._last_processing_started + ) + self._rate_tracker.update() + def _update_rate_and_state(self): self._buffered_status_update.set("performance.rate", self._rate_tracker.get()) self._buffered_status_update.set( - "performance.processingDuration", self._processing_time_ema.get() * 1000 + "performance.processingTime", self._processing_time_ema.get() * 1000 ) # trainId should be set on _buffered_status_update in input handler self.set(self._buffered_status_update) - if ( - timeit.default_timer() - self._last_processing_started - > PROCESSING_STATE_TIMEOUT - ): + if default_timer() - self._last_processing_started > PROCESSING_STATE_TIMEOUT: if self.get("state") is State.PROCESSING: self.updateState(State.ON) self.log_status_info( -- GitLab