Skip to content
Snippets Groups Projects
Commit ab79ee5f authored by David Hammer's avatar David Hammer
Browse files

Use explicit cache instead of overriding get

parent 02b50039
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
...@@ -115,8 +115,8 @@ class AgipdCorrection(BaseCorrection): ...@@ -115,8 +115,8 @@ class AgipdCorrection(BaseCorrection):
def process_input(self, data, metadata): def process_input(self, data, metadata):
"""Registered for dataInput, handles all processing and sending""" """Registered for dataInput, handles all processing and sending"""
if not self.get("doAnything"): if not self._schema_cache["doAnything"]:
if self.get("state") is State.PROCESSING: if self._schema_cache["state"] is State.PROCESSING:
self.updateState(State.ACTIVE) self.updateState(State.ACTIVE)
return return
...@@ -143,7 +143,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -143,7 +143,7 @@ class AgipdCorrection(BaseCorrection):
return return
# original shape: memory_cell, data/raw_gain, x, y # original shape: memory_cell, data/raw_gain, x, y
image_data = data.get("image.data") image_data = data.get("image.data")
if image_data.shape[0] != self.get("dataFormat.memoryCells"): if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]:
self.set( self.set(
"status", f"Updating input shapes based on received {image_data.shape}" "status", f"Updating input shapes based on received {image_data.shape}"
) )
...@@ -153,21 +153,24 @@ class AgipdCorrection(BaseCorrection): ...@@ -153,21 +153,24 @@ class AgipdCorrection(BaseCorrection):
# TODO: pulse filter update after reimplementation # TODO: pulse filter update after reimplementation
self._update_shapes() self._update_shapes()
if not self.get("state") is State.PROCESSING: if not self._schema_cache["state"] is State.PROCESSING:
self.updateState(State.PROCESSING) self.updateState(State.PROCESSING)
self.set("status", "Processing data") self.set("status", "Processing data")
if self._state_reset_timer is None: if self._state_reset_timer is None:
self._state_reset_timer = utils.DelayableTimer( self._state_reset_timer = utils.DelayableTimer(
timeout=self.get("processingStateTimeout"), timeout=self._schema_cache["processingStateTimeout"],
callback=self._reset_state_from_processing, callback=self._reset_state_from_processing,
) )
else: else:
self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) self._state_reset_timer.set_timeout(
self._schema_cache["processingStateTimeout"]
)
correction_cell_num = self.get("dataFormat.constantMemoryCells") correction_cell_num = self._schema_cache["dataFormat.constantMemoryCells"]
do_generate_preview = train_id % self.get( do_generate_preview = (
"preview.trainIdModulo" train_id % self._schema_cache["preview.trainIdModulo"] == 0
) == 0 and self.get("preview.enable") and self._schema_cache["preview.enable"]
)
with self._buffer_lock: with self._buffer_lock:
# cell_table = cell_table[self.pulse_filter] # cell_table = cell_table[self.pulse_filter]
...@@ -192,7 +195,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -192,7 +195,7 @@ class AgipdCorrection(BaseCorrection):
if do_generate_preview: if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview: if self._correction_flag_enabled != self._correction_flag_preview:
self.gpu_runner.correct(self._correction_flag_preview) self.gpu_runner.correct(self._correction_flag_preview)
preview_slice_index = self.get("preview.pulse") preview_slice_index = self._schema_cache["preview.pulse"]
if preview_slice_index >= 0: if preview_slice_index >= 0:
# look at pulse_table to find which index this pulse ID is in # look at pulse_table to find which index this pulse ID is in
# TODO: move this to GPU # TODO: move this to GPU
...@@ -230,8 +233,6 @@ class AgipdCorrection(BaseCorrection): ...@@ -230,8 +233,6 @@ class AgipdCorrection(BaseCorrection):
self._buffered_status_update.set( self._buffered_status_update.set(
"performance.lastProcessingDuration", time_spent * 1000 "performance.lastProcessingDuration", time_spent * 1000
) )
if self.get("performance.rateUpdateOnEachInput"):
self._update_actual_rate()
def _load_constant_to_gpu(self, constant_name, constant_data): def _load_constant_to_gpu(self, constant_name, constant_data):
# TODO: also hook flushConstants or whatever it is called # TODO: also hook flushConstants or whatever it is called
......
...@@ -55,7 +55,9 @@ class DsscCorrection(BaseCorrection): ...@@ -55,7 +55,9 @@ class DsscCorrection(BaseCorrection):
@property @property
def output_data_shape(self): def output_data_shape(self):
return utils.shape_after_transpose(self.input_data_shape, self._output_transpose) return utils.shape_after_transpose(
self.input_data_shape, self._output_transpose
)
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -72,8 +74,8 @@ class DsscCorrection(BaseCorrection): ...@@ -72,8 +74,8 @@ class DsscCorrection(BaseCorrection):
def process_input(self, data, metadata): def process_input(self, data, metadata):
"""Registered for dataInput, handles all processing and sending""" """Registered for dataInput, handles all processing and sending"""
if not self.get("doAnything"): if not self._schema_cache["doAnything"]:
if self.get("state") is State.PROCESSING: if self._schema_cache["state"] is State.PROCESSING:
self.updateState(State.ACTIVE) self.updateState(State.ACTIVE)
return return
...@@ -100,7 +102,7 @@ class DsscCorrection(BaseCorrection): ...@@ -100,7 +102,7 @@ class DsscCorrection(BaseCorrection):
return return
# original shape: 400, 1, 128, 512 (memory cells, something, y, x) # original shape: 400, 1, 128, 512 (memory cells, something, y, x)
image_data = data.get("image.data") image_data = data.get("image.data")
if image_data.shape[0] != self.get("dataFormat.memoryCells"): if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]:
self.set( self.set(
"status", f"Updating input shapes based on received {image_data.shape}" "status", f"Updating input shapes based on received {image_data.shape}"
) )
...@@ -110,21 +112,24 @@ class DsscCorrection(BaseCorrection): ...@@ -110,21 +112,24 @@ class DsscCorrection(BaseCorrection):
# TODO: pulse filter update after reimplementation # TODO: pulse filter update after reimplementation
self._update_shapes() self._update_shapes()
if not self.get("state") is State.PROCESSING: if not self._schema_cache["state"] is State.PROCESSING:
self.updateState(State.PROCESSING) self.updateState(State.PROCESSING)
self.set("status", "Processing data") self.set("status", "Processing data")
if self._state_reset_timer is None: if self._state_reset_timer is None:
self._state_reset_timer = utils.DelayableTimer( self._state_reset_timer = utils.DelayableTimer(
timeout=self.get("processingStateTimeout"), timeout=self._schema_cache["processingStateTimeout"],
callback=self._reset_state_from_processing, callback=self._reset_state_from_processing,
) )
else: else:
self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) self._state_reset_timer.set_timeout(
self._schema_cache["processingStateTimeout"]
)
correction_cell_num = self.get("dataFormat.constantMemoryCells") correction_cell_num = self._schema_cache["dataFormat.constantMemoryCells"]
do_generate_preview = train_id % self.get( do_generate_preview = (
"preview.trainIdModulo" train_id % self._schema_cache["preview.trainIdModulo"] == 0
) == 0 and self.get("preview.enable") and self._schema_cache["preview.enable"]
)
with self._buffer_lock: with self._buffer_lock:
# cell_table = cell_table[self.pulse_filter] # cell_table = cell_table[self.pulse_filter]
...@@ -148,7 +153,7 @@ class DsscCorrection(BaseCorrection): ...@@ -148,7 +153,7 @@ class DsscCorrection(BaseCorrection):
if do_generate_preview: if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview: if self._correction_flag_enabled != self._correction_flag_preview:
self.gpu_runner.correct(self._correction_flag_preview) self.gpu_runner.correct(self._correction_flag_preview)
preview_slice_index = self.get("preview.pulse") preview_slice_index = self._schema_cache["preview.pulse"]
if preview_slice_index >= 0: if preview_slice_index >= 0:
# look at pulse_table to find which index this pulse ID is in # look at pulse_table to find which index this pulse ID is in
pulse_id_found = np.where(pulse_table == preview_slice_index)[0] pulse_id_found = np.where(pulse_table == preview_slice_index)[0]
...@@ -185,8 +190,6 @@ class DsscCorrection(BaseCorrection): ...@@ -185,8 +190,6 @@ class DsscCorrection(BaseCorrection):
self._buffered_status_update.set( self._buffered_status_update.set(
"performance.lastProcessingDuration", time_spent * 1000 "performance.lastProcessingDuration", time_spent * 1000
) )
if self.get("performance.rateUpdateOnEachInput"):
self._update_actual_rate()
def _update_pulse_filter(self, filter_string): def _update_pulse_filter(self, filter_string):
"""Called whenever the pulse filter changes, typically followed by """Called whenever the pulse filter changes, typically followed by
......
...@@ -32,7 +32,7 @@ from . import shmem_utils, utils ...@@ -32,7 +32,7 @@ from . import shmem_utils, utils
class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
_correction_flag_class = None # subclass must override this with some enum class _correction_flag_class = None # subclass must override this with some enum class
_dict_cache_slots = { _schema_cache_slots = {
"doAnything", "doAnything",
"dataFormat.memoryCells", "dataFormat.memoryCells",
"dataFormat.constantMemoryCells", "dataFormat.constantMemoryCells",
...@@ -42,7 +42,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -42,7 +42,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
"preview.pulse", "preview.pulse",
"preview.trainIdModulo", "preview.trainIdModulo",
"processingStateTimeout", "processingStateTimeout",
"performance.rateUpdateOnEachInput",
"state", "state",
} }
...@@ -264,10 +263,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -264,10 +263,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
FLOAT_ELEMENT(expected) FLOAT_ELEMENT(expected)
.key("performance.rateUpdateInterval") .key("performance.rateUpdateInterval")
.displayedName("Rate update interval") .displayedName("Rate update interval")
.description( .description("Interval (seconds) between updates of processing rate.")
"Maximum interval (seconds) between updates of the rate. Mostly "
"relevant if not rateUpdateOnEachInput or if input is slow."
)
.assignmentOptional() .assignmentOptional()
.defaultValue(1) .defaultValue(1)
.reconfigurable() .reconfigurable()
...@@ -280,19 +276,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -280,19 +276,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
.defaultValue(20) .defaultValue(20)
.reconfigurable() .reconfigurable()
.commit(), .commit(),
BOOL_ELEMENT(expected)
.key("performance.rateUpdateOnEachInput")
.displayedName("Update rate on each input")
.description(
"Whether or not to update the device rate for each input (otherwise "
"only based on rateUpdateInterval). Note that processed trains are "
"always registered - this just impacts when the rate is computed "
"based on this."
)
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
FLOAT_ELEMENT(expected) FLOAT_ELEMENT(expected)
.key("processingStateTimeout") .key("processingStateTimeout")
.description( .description(
...@@ -392,7 +375,9 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -392,7 +375,9 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
) )
def __init__(self, config): def __init__(self, config):
self._dict_cache = {k: config.get(k) for k in self._dict_cache_slots} self._schema_cache = {
k: config.get(k) for k in self._schema_cache_slots if config.has(k)
}
super().__init__(config) super().__init__(config)
self.KARABO_ON_DATA("dataInput", self.process_input) self.KARABO_ON_DATA("dataInput", self.process_input)
...@@ -445,8 +430,8 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -445,8 +430,8 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
) )
for path in config.getPaths(): for path in config.getPaths():
if path in self._dict_cache_slots: if path in self._schema_cache_slots:
self._dict_cache[path] = config.get(path) self._schema_cache[path] = config.get(path)
# TODO: pulse filter (after reimplementing) # TODO: pulse filter (after reimplementing)
if any( if any(
...@@ -468,17 +453,11 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -468,17 +453,11 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
# TODO: only call this if they are changed (is cheap, though) # TODO: only call this if they are changed (is cheap, though)
self._update_correction_flags() self._update_correction_flags()
def get(self, key):
if key in self._dict_cache_slots:
return self._dict_cache.get(key)
else:
return super().get(key)
def set(self, *args): def set(self, *args):
if len(args) == 2: if len(args) == 2:
key, value = args key, value = args
if key in self._dict_cache_slots: if key in self._schema_cache_slots:
self._dict_cache[key] = value self._schema_cache[key] = value
super().set(*args) super().set(*args)
def requestConstant(self, name, mostRecent=False, tryRemote=True): def requestConstant(self, name, mostRecent=False, tryRemote=True):
...@@ -522,7 +501,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -522,7 +501,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
preview_hash = Hash() preview_hash = Hash()
preview_hash.set("image.passport", [self.getInstanceId()]) preview_hash.set("image.passport", [self.getInstanceId()])
preview_hash.set("image.trainId", train_id) preview_hash.set("image.trainId", train_id)
preview_hash.set("image.pulseId", self.get("preview.pulse")) preview_hash.set("image.pulseId", self._schema_cache["preview.pulse"])
# note: have to construct because setting .tid after init is broken # note: have to construct because setting .tid after init is broken
timestamp = Timestamp(Epochstamp(), Trainstamp(train_id)) timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment