From 14e7737db48db3bcbd914ce11005fba55a4ece6a Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 22 Sep 2021 15:54:30 +0200
Subject: [PATCH] Cleanup, update TODOs, docstrings

---
 src/calng/AgipdCorrection.py | 12 +++++-----
 src/calng/ModuleStacker.py   |  3 ---
 src/calng/base_correction.py | 43 ++++++++++++++++++++----------------
 3 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index 599e0f28..dec34dd3 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -171,13 +171,11 @@ class AgipdCorrection(BaseCorrection):
         self._gpu_runner_init_args = {"gain_mode": self.gain_mode}
 
         super().__init__(config)
-        output_axis_order = config.get("dataFormat.outputAxisOrder")
-        if output_axis_order == "pixels-fast":
-            self._output_transpose = None
-        elif output_axis_order == "memorycells-fast":
-            self._output_transpose = (2, 1, 0)
-        else:
-            self._output_transpose = None
+        self._output_transpose = {
+            "pixels-fast": None,
+            "memorycells-fast": (2, 1, 0),
+            "no-reshape": None,
+        }[config.get("dataFormat.outputAxisOrder")]
         self._update_shapes()
         if config.get("corrections.overrideMdAdditionalOffset"):
             self._override_md_additional_offset = config.get(
diff --git a/src/calng/ModuleStacker.py b/src/calng/ModuleStacker.py
index 63842cb5..967e080d 100644
--- a/src/calng/ModuleStacker.py
+++ b/src/calng/ModuleStacker.py
@@ -133,9 +133,6 @@ class ModuleStacker(TrainMatcher.TrainMatcher):
         out_hash[self.path_to_stack] = stacked_data
         out_hash["sources"] = stacked_sources
         out_hash["modulesPresent"] = stacked_present
-        if not out_hash.has("image.passport"):
-            out_hash.set("image.passport", [])
-        out_hash["image.passport"].append(self.getInstanceId())
         channel = self.signalSlotable.getOutputChannel("output")
         channel.write(out_hash, ChannelMetaData(self.getInstanceId(), timestamp))
         channel.update()
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 11746d85..15ed974d 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -160,7 +160,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .description(
                 "Axes of main data output can be reordered after correction. Choose "
                 "between 'pixels-fast' (memory_cell, x, y), 'memorycells-fast' "
-                "(x, y, memory_cell), and 'no-reshape' (memory_cell, y, x)"
+                "(x, y, memory_cell), and 'no-reshape'"
             )
             .options("pixels-fast,memorycells-fast,no-reshape")
             .assignmentOptional()
@@ -363,12 +363,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .key("corrections.disableAll")
             .displayedName("Disable corrections for dataOutput")
             .description(
-                "Toggle whether not correction(s) are applied to image data. If "
-                "false, this device still reshapes data to output shape, applies the "
-                "pulse filter, and casts to output dtype. Useful if constants are "
-                "missing / bad, or if data is sent to application doing its own "
-                "correction.  Preview is still corrected based on selection of "
-                "corrections independently of this."
+                "Toggle for disabling all corrections for dataOutput at once. This "
+                "overrides the individual flags under corrections.enabled (does not "
+                "affect corrected preview). When corrections are disabled, the device "
+                "still reshapes data to output shape, applies the pulse filter, and "
+                "casts to output dtype. Disabling all corrections can be useful if "
+                "constants are missing / bad, or if user software does not want "
+                "corrected data."
             )
             .assignmentOptional()
             .defaultValue(False)
@@ -450,13 +451,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             self._has_updated_shapes = False
 
     def postReconfigure(self):
-        self.log.INFO("postReconfigure")
         if not self._has_updated_shapes:
             self._update_shapes()
         # TODO: only call this if they are changed (is cheap, though)
         self._update_correction_flags()
 
     def set(self, *args):
+        """Wrapper around PythonDevice.set to enable caching "hot" schema elements"""
         if len(args) == 2:
             key, value = args
             if key in self._schema_cache_slots:
@@ -464,8 +465,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         super().set(*args)
 
     def requestConstant(self, name, mostRecent=False, tryRemote=True):
-        """constantLoaded hook would have gotten called without naming constant, so here
-        we go. Ugly hooking it."""
+        """Wrapper around method from CalibrationReceiverBaseDevice
+
+        The superclass provides the constantLoaded hook, but it gets called without
+        arguments, losing the name of the freshly loaded constant. To handle individual
+        constants correctly, we set up our own hook instead.
+        """
+        # TODO: remove when revamping constant retrieval
         if name in self._cached_constants:
             del self._cached_constants[name]
         super().requestConstant(name, mostRecent, tryRemote)
@@ -480,6 +486,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
 
     def flush_constants(self):
         """Override from CalibrationReceiverBaseDevice to also flush GPU buffers"""
+        # TODO: update when revamping constant retrieval
         super().flush_constants()
         for correction_step, _ in self._correction_slot_names:
             self.set(f"corrections.available.{correction_step}", False)
@@ -492,10 +499,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")),
         )
 
-        if "image.passport" not in data:
-            data["image.passport"] = []
-        data["image.passport"].append(self.getInstanceId())
-
         if not self._has_set_output_schema:
             self.updateState(State.CHANGING)
             self._update_output_schema(data)
@@ -508,7 +511,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
     def _write_combiner_preview(self, data_raw, data_corrected, train_id, source):
         # TODO: take into account updated pulse table after pulse filter
         preview_hash = Hash()
-        preview_hash.set("image.passport", [self.getInstanceId()])
         preview_hash.set("image.trainId", train_id)
         preview_hash.set("image.pulseId", self._schema_cache["preview.pulse"])
 
@@ -525,6 +527,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             channel.update()
 
     def _update_correction_flags(self):
+        """Based on constants loaded and settings, update bit mask flags for kernel"""
         available = self._correction_flag_class.NONE
         enabled = self._correction_flag_class.NONE
         preview = self._correction_flag_class.NONE
@@ -545,13 +548,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         self.log.INFO(f"Corrections for preview: {str(preview)}")
 
     def _update_output_schema(self, data):
-        """Updates the schema of dataOutput based on parameter data (a Hash)
+        """Updates the schema of dataOutput based on data we want to send
 
         This should only be called once: when handling output for the first
         time, we update the schema to match the modified data we'd send.
-
         """
 
+        # TODO: remove when switching to specifying output schema up-front
         self.log.INFO("Updating output schema")
         my_schema_update = Schema()
         data_schema = hashToSchema.HashToSchema(data).schema
@@ -596,12 +599,14 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             **self._gpu_runner_init_args,
         )
 
+        # TODO: put this under lock so dictionary doesn't change shape underneath us
         for constant_name, constant_data in self._cached_constants.items():
             self._load_constant_to_gpu(constant_name, constant_data)
 
         self._has_updated_shapes = True
 
     def _reset_state_from_processing(self):
+        # TODO: merge with rate updates (buffer status update checking)
         if self.get("state") is State.PROCESSING:
             self.updateState(State.ON)
             self._state_reset_timer = None
@@ -628,12 +633,12 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         self.signalEndOfStream("dataOutput")
 
     def getConstant(self, name):
-        """Hacky override of getConstant to actually return None on failure
+        """Wrapper around getConstant to return None on failure
 
         Full function is from CalibrationReceiverBaseDevice
-
         """
 
+        # TODO: remove when revamping constant retrieval
         const = super().getConstant(name)
         if const is not None and len(const.shape) == 1:
             self.log.WARN(
-- 
GitLab