diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index 50ed129d5e80fbdd86160fce2598250721b34635..136ea595e587e184ef8498aa3027e3ef9c466ff6 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -115,8 +115,8 @@ class AgipdCorrection(BaseCorrection): def process_input(self, data, metadata): """Registered for dataInput, handles all processing and sending""" - if not self.get("doAnything"): - if self.get("state") is State.PROCESSING: + if not self._schema_cache["doAnything"]: + if self._schema_cache["state"] is State.PROCESSING: self.updateState(State.ACTIVE) return @@ -143,7 +143,7 @@ class AgipdCorrection(BaseCorrection): return # original shape: memory_cell, data/raw_gain, x, y image_data = data.get("image.data") - if image_data.shape[0] != self.get("dataFormat.memoryCells"): + if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: self.set( "status", f"Updating input shapes based on received {image_data.shape}" ) @@ -153,21 +153,24 @@ class AgipdCorrection(BaseCorrection): # TODO: pulse filter update after reimplementation self._update_shapes() - if not self.get("state") is State.PROCESSING: + if not self._schema_cache["state"] is State.PROCESSING: self.updateState(State.PROCESSING) self.set("status", "Processing data") if self._state_reset_timer is None: self._state_reset_timer = utils.DelayableTimer( - timeout=self.get("processingStateTimeout"), + timeout=self._schema_cache["processingStateTimeout"], callback=self._reset_state_from_processing, ) else: - self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) + self._state_reset_timer.set_timeout( + self._schema_cache["processingStateTimeout"] + ) - correction_cell_num = self.get("dataFormat.constantMemoryCells") - do_generate_preview = train_id % self.get( - "preview.trainIdModulo" - ) == 0 and self.get("preview.enable") + correction_cell_num = self._schema_cache["dataFormat.constantMemoryCells"] + do_generate_preview = ( + train_id % self._schema_cache["preview.trainIdModulo"] == 0 + and self._schema_cache["preview.enable"] + ) with self._buffer_lock: # cell_table = cell_table[self.pulse_filter] @@ -192,7 +195,7 @@ class AgipdCorrection(BaseCorrection): if do_generate_preview: if self._correction_flag_enabled != self._correction_flag_preview: self.gpu_runner.correct(self._correction_flag_preview) - preview_slice_index = self.get("preview.pulse") + preview_slice_index = self._schema_cache["preview.pulse"] if preview_slice_index >= 0: # look at pulse_table to find which index this pulse ID is in # TODO: move this to GPU @@ -230,8 +233,6 @@ class AgipdCorrection(BaseCorrection): self._buffered_status_update.set( "performance.lastProcessingDuration", time_spent * 1000 ) - if self.get("performance.rateUpdateOnEachInput"): - self._update_actual_rate() def _load_constant_to_gpu(self, constant_name, constant_data): # TODO: also hook flushConstants or whatever it is called diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index 7a2cc5ec8a71e0cb74659030b64edec8ae1708b5..b1dc274a6bc8f297af868cfefdafd8a956b8d34a 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -55,7 +55,9 @@ class DsscCorrection(BaseCorrection): @property def output_data_shape(self): - return utils.shape_after_transpose(self.input_data_shape, self._output_transpose) + return utils.shape_after_transpose( + self.input_data_shape, self._output_transpose + ) def __init__(self, config): super().__init__(config) @@ -72,8 +74,8 @@ class DsscCorrection(BaseCorrection): def process_input(self, data, metadata): """Registered for dataInput, handles all processing and sending""" - if not self.get("doAnything"): - if self.get("state") is State.PROCESSING: + if not self._schema_cache["doAnything"]: + if self._schema_cache["state"] is State.PROCESSING: self.updateState(State.ACTIVE) return @@ -100,7 +102,7 @@ class DsscCorrection(BaseCorrection): return # original shape: 400, 1, 128, 512 (memory cells, something, y, x) image_data = data.get("image.data") - if image_data.shape[0] != self.get("dataFormat.memoryCells"): + if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: self.set( "status", f"Updating input shapes based on received {image_data.shape}" ) @@ -110,21 +112,24 @@ class DsscCorrection(BaseCorrection): # TODO: pulse filter update after reimplementation self._update_shapes() - if not self.get("state") is State.PROCESSING: + if not self._schema_cache["state"] is State.PROCESSING: self.updateState(State.PROCESSING) self.set("status", "Processing data") if self._state_reset_timer is None: self._state_reset_timer = utils.DelayableTimer( - timeout=self.get("processingStateTimeout"), + timeout=self._schema_cache["processingStateTimeout"], callback=self._reset_state_from_processing, ) else: - self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) + self._state_reset_timer.set_timeout( + self._schema_cache["processingStateTimeout"] + ) - correction_cell_num = self.get("dataFormat.constantMemoryCells") - do_generate_preview = train_id % self.get( - "preview.trainIdModulo" - ) == 0 and self.get("preview.enable") + correction_cell_num = self._schema_cache["dataFormat.constantMemoryCells"] + do_generate_preview = ( + train_id % self._schema_cache["preview.trainIdModulo"] == 0 + and self._schema_cache["preview.enable"] + ) with self._buffer_lock: # cell_table = cell_table[self.pulse_filter] @@ -148,7 +153,7 @@ class DsscCorrection(BaseCorrection): if do_generate_preview: if self._correction_flag_enabled != self._correction_flag_preview: self.gpu_runner.correct(self._correction_flag_preview) - preview_slice_index = self.get("preview.pulse") + preview_slice_index = self._schema_cache["preview.pulse"] if preview_slice_index >= 0: # look at pulse_table to find which index this pulse ID is in pulse_id_found = np.where(pulse_table == preview_slice_index)[0] @@ -185,8 +190,6 @@ class DsscCorrection(BaseCorrection): self._buffered_status_update.set( "performance.lastProcessingDuration", time_spent * 1000 ) - if self.get("performance.rateUpdateOnEachInput"): - self._update_actual_rate() def _update_pulse_filter(self, filter_string): """Called whenever the pulse filter changes, typically followed by diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index a7076c485c3096dc360411c0f4171418a9742236..cfc3c3a6dde280853ba3b2116c2e544239c7541b 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -32,7 +32,7 @@ from . import shmem_utils, utils class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): _correction_flag_class = None # subclass must override this with some enum class - _dict_cache_slots = { + _schema_cache_slots = { "doAnything", "dataFormat.memoryCells", "dataFormat.constantMemoryCells", @@ -42,7 +42,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): "preview.pulse", "preview.trainIdModulo", "processingStateTimeout", - "performance.rateUpdateOnEachInput", "state", } @@ -264,10 +263,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): FLOAT_ELEMENT(expected) .key("performance.rateUpdateInterval") .displayedName("Rate update interval") - .description( - "Maximum interval (seconds) between updates of the rate. Mostly " - "relevant if not rateUpdateOnEachInput or if input is slow." - ) + .description("Interval (seconds) between updates of processing rate.") .assignmentOptional() .defaultValue(1) .reconfigurable() @@ -280,19 +276,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): .defaultValue(20) .reconfigurable() .commit(), - BOOL_ELEMENT(expected) - .key("performance.rateUpdateOnEachInput") - .displayedName("Update rate on each input") - .description( - "Whether or not to update the device rate for each input (otherwise " - "only based on rateUpdateInterval). Note that processed trains are " - "always registered - this just impacts when the rate is computed " - "based on this." - ) - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), FLOAT_ELEMENT(expected) .key("processingStateTimeout") .description( @@ -392,7 +375,9 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ) def __init__(self, config): - self._dict_cache = {k: config.get(k) for k in self._dict_cache_slots} + self._schema_cache = { + k: config.get(k) for k in self._schema_cache_slots if config.has(k) + } super().__init__(config) self.KARABO_ON_DATA("dataInput", self.process_input) @@ -445,8 +430,8 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ) for path in config.getPaths(): - if path in self._dict_cache_slots: - self._dict_cache[path] = config.get(path) + if path in self._schema_cache_slots: + self._schema_cache[path] = config.get(path) # TODO: pulse filter (after reimplementing) if any( @@ -468,17 +453,11 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): # TODO: only call this if they are changed (is cheap, though) self._update_correction_flags() - def get(self, key): - if key in self._dict_cache_slots: - return self._dict_cache.get(key) - else: - return super().get(key) - def set(self, *args): if len(args) == 2: key, value = args - if key in self._dict_cache_slots: - self._dict_cache[key] = value + if key in self._schema_cache_slots: + self._schema_cache[key] = value super().set(*args) def requestConstant(self, name, mostRecent=False, tryRemote=True): @@ -522,7 +501,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): preview_hash = Hash() preview_hash.set("image.passport", [self.getInstanceId()]) preview_hash.set("image.trainId", train_id) - preview_hash.set("image.pulseId", self.get("preview.pulse")) + preview_hash.set("image.pulseId", self._schema_cache["preview.pulse"]) # note: have to construct because setting .tid after init is broken timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))