Skip to content
Snippets Groups Projects
Commit 14e7737d authored by David Hammer's avatar David Hammer
Browse files

Cleanup, update TODOs, docstrings

parent 43b679ab
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
...@@ -171,13 +171,11 @@ class AgipdCorrection(BaseCorrection): ...@@ -171,13 +171,11 @@ class AgipdCorrection(BaseCorrection):
self._gpu_runner_init_args = {"gain_mode": self.gain_mode} self._gpu_runner_init_args = {"gain_mode": self.gain_mode}
super().__init__(config) super().__init__(config)
output_axis_order = config.get("dataFormat.outputAxisOrder") self._output_transpose = {
if output_axis_order == "pixels-fast": "pixels-fast": None,
self._output_transpose = None "memorycells-fast": (2, 1, 0),
elif output_axis_order == "memorycells-fast": "no-reshape": None,
self._output_transpose = (2, 1, 0) }[config.get("dataFormat.outputAxisOrder")]
else:
self._output_transpose = None
self._update_shapes() self._update_shapes()
if config.get("corrections.overrideMdAdditionalOffset"): if config.get("corrections.overrideMdAdditionalOffset"):
self._override_md_additional_offset = config.get( self._override_md_additional_offset = config.get(
......
...@@ -133,9 +133,6 @@ class ModuleStacker(TrainMatcher.TrainMatcher): ...@@ -133,9 +133,6 @@ class ModuleStacker(TrainMatcher.TrainMatcher):
out_hash[self.path_to_stack] = stacked_data out_hash[self.path_to_stack] = stacked_data
out_hash["sources"] = stacked_sources out_hash["sources"] = stacked_sources
out_hash["modulesPresent"] = stacked_present out_hash["modulesPresent"] = stacked_present
if not out_hash.has("image.passport"):
out_hash.set("image.passport", [])
out_hash["image.passport"].append(self.getInstanceId())
channel = self.signalSlotable.getOutputChannel("output") channel = self.signalSlotable.getOutputChannel("output")
channel.write(out_hash, ChannelMetaData(self.getInstanceId(), timestamp)) channel.write(out_hash, ChannelMetaData(self.getInstanceId(), timestamp))
channel.update() channel.update()
......
...@@ -160,7 +160,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -160,7 +160,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
.description( .description(
"Axes of main data output can be reordered after correction. Choose " "Axes of main data output can be reordered after correction. Choose "
"between 'pixels-fast' (memory_cell, x, y), 'memorycells-fast' " "between 'pixels-fast' (memory_cell, x, y), 'memorycells-fast' "
"(x, y, memory_cell), and 'no-reshape' (memory_cell, y, x)" "(x, y, memory_cell), and 'no-reshape'"
) )
.options("pixels-fast,memorycells-fast,no-reshape") .options("pixels-fast,memorycells-fast,no-reshape")
.assignmentOptional() .assignmentOptional()
...@@ -363,12 +363,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -363,12 +363,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
.key("corrections.disableAll") .key("corrections.disableAll")
.displayedName("Disable corrections for dataOutput") .displayedName("Disable corrections for dataOutput")
.description( .description(
"Toggle whether not correction(s) are applied to image data. If " "Toggle for disabling all corrections for dataOutput at once. This "
"false, this device still reshapes data to output shape, applies the " "overrides the individual flags under corrections.enabled (does not "
"pulse filter, and casts to output dtype. Useful if constants are " "affect corrected preview). When corrections are disabled, the device "
"missing / bad, or if data is sent to application doing its own " "still reshapes data to output shape, applies the pulse filter, and "
"correction. Preview is still corrected based on selection of " "casts to output dtype. Disabling all corrections can be useful if "
"corrections independently of this." "constants are missing / bad, or if user software does not want "
"corrected data."
) )
.assignmentOptional() .assignmentOptional()
.defaultValue(False) .defaultValue(False)
...@@ -450,13 +451,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -450,13 +451,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self._has_updated_shapes = False self._has_updated_shapes = False
def postReconfigure(self): def postReconfigure(self):
self.log.INFO("postReconfigure")
if not self._has_updated_shapes: if not self._has_updated_shapes:
self._update_shapes() self._update_shapes()
# TODO: only call this if they are changed (is cheap, though) # TODO: only call this if they are changed (is cheap, though)
self._update_correction_flags() self._update_correction_flags()
def set(self, *args): def set(self, *args):
"""Wrapper around PythonDevice.set to enable caching "hot" schema elements"""
if len(args) == 2: if len(args) == 2:
key, value = args key, value = args
if key in self._schema_cache_slots: if key in self._schema_cache_slots:
...@@ -464,8 +465,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -464,8 +465,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
super().set(*args) super().set(*args)
def requestConstant(self, name, mostRecent=False, tryRemote=True): def requestConstant(self, name, mostRecent=False, tryRemote=True):
"""constantLoaded hook would have gotten called without naming constant, so here """Wrapper around method from CalibrationReceiverBaseDevice
we go. Ugly hooking it."""
The superclass provides the constantLoaded hook, but it gets called without
arguments, losing the name of the freshly loaded constant. To handle individual
constants correctly, we set up our own hook instead.
"""
# TODO: remove when revamping constant retrieval
if name in self._cached_constants: if name in self._cached_constants:
del self._cached_constants[name] del self._cached_constants[name]
super().requestConstant(name, mostRecent, tryRemote) super().requestConstant(name, mostRecent, tryRemote)
...@@ -480,6 +486,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -480,6 +486,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
def flush_constants(self): def flush_constants(self):
"""Override from CalibrationReceiverBaseDevice to also flush GPU buffers""" """Override from CalibrationReceiverBaseDevice to also flush GPU buffers"""
# TODO: update when revamping constant retrieval
super().flush_constants() super().flush_constants()
for correction_step, _ in self._correction_slot_names: for correction_step, _ in self._correction_slot_names:
self.set(f"corrections.available.{correction_step}", False) self.set(f"corrections.available.{correction_step}", False)
...@@ -492,10 +499,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -492,10 +499,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")),
) )
if "image.passport" not in data:
data["image.passport"] = []
data["image.passport"].append(self.getInstanceId())
if not self._has_set_output_schema: if not self._has_set_output_schema:
self.updateState(State.CHANGING) self.updateState(State.CHANGING)
self._update_output_schema(data) self._update_output_schema(data)
...@@ -508,7 +511,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -508,7 +511,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
def _write_combiner_preview(self, data_raw, data_corrected, train_id, source): def _write_combiner_preview(self, data_raw, data_corrected, train_id, source):
# TODO: take into account updated pulse table after pulse filter # TODO: take into account updated pulse table after pulse filter
preview_hash = Hash() preview_hash = Hash()
preview_hash.set("image.passport", [self.getInstanceId()])
preview_hash.set("image.trainId", train_id) preview_hash.set("image.trainId", train_id)
preview_hash.set("image.pulseId", self._schema_cache["preview.pulse"]) preview_hash.set("image.pulseId", self._schema_cache["preview.pulse"])
...@@ -525,6 +527,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -525,6 +527,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
channel.update() channel.update()
def _update_correction_flags(self): def _update_correction_flags(self):
"""Based on constants loaded and settings, update bit mask flags for kernel"""
available = self._correction_flag_class.NONE available = self._correction_flag_class.NONE
enabled = self._correction_flag_class.NONE enabled = self._correction_flag_class.NONE
preview = self._correction_flag_class.NONE preview = self._correction_flag_class.NONE
...@@ -545,13 +548,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -545,13 +548,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.log.INFO(f"Corrections for preview: {str(preview)}") self.log.INFO(f"Corrections for preview: {str(preview)}")
def _update_output_schema(self, data): def _update_output_schema(self, data):
"""Updates the schema of dataOutput based on parameter data (a Hash) """Updates the schema of dataOutput based on data we want to send
This should only be called once: when handling output for the first This should only be called once: when handling output for the first
time, we update the schema to match the modified data we'd send. time, we update the schema to match the modified data we'd send.
""" """
# TODO: remove when switching to specifying output schema up-front
self.log.INFO("Updating output schema") self.log.INFO("Updating output schema")
my_schema_update = Schema() my_schema_update = Schema()
data_schema = hashToSchema.HashToSchema(data).schema data_schema = hashToSchema.HashToSchema(data).schema
...@@ -596,12 +599,14 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -596,12 +599,14 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
**self._gpu_runner_init_args, **self._gpu_runner_init_args,
) )
# TODO: put this under lock so dictionary doesn't change shape underneath us
for constant_name, constant_data in self._cached_constants.items(): for constant_name, constant_data in self._cached_constants.items():
self._load_constant_to_gpu(constant_name, constant_data) self._load_constant_to_gpu(constant_name, constant_data)
self._has_updated_shapes = True self._has_updated_shapes = True
def _reset_state_from_processing(self): def _reset_state_from_processing(self):
# TODO: merge with rate updates (buffer status update checking)
if self.get("state") is State.PROCESSING: if self.get("state") is State.PROCESSING:
self.updateState(State.ON) self.updateState(State.ON)
self._state_reset_timer = None self._state_reset_timer = None
...@@ -628,12 +633,12 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -628,12 +633,12 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.signalEndOfStream("dataOutput") self.signalEndOfStream("dataOutput")
def getConstant(self, name): def getConstant(self, name):
"""Hacky override of getConstant to actually return None on failure """Wrapper around getConstant to return None on failure
Full function is from CalibrationReceiverBaseDevice Full function is from CalibrationReceiverBaseDevice
""" """
# TODO: remove when revamping constant retrieval
const = super().getConstant(name) const = super().getConstant(name)
if const is not None and len(const.shape) == 1: if const is not None and len(const.shape) == 1:
self.log.WARN( self.log.WARN(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment