From 311de5aece5559629b79838212f7bf35a39cec59 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 2 Feb 2022 16:11:00 +0100
Subject: [PATCH] Replace _schema_cache with unsafe_get using _parameters

---
 src/calng/AgipdCorrection.py | 12 +++----
 src/calng/DsscCorrection.py  |  6 ++--
 src/calng/base_correction.py | 68 ++++++++++++++----------------------
 3 files changed, 36 insertions(+), 50 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index 3289fa65..6ad07eea 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -580,10 +580,10 @@ class AgipdCorrection(BaseCorrection):
     @property
     def input_data_shape(self):
         return (
-            self._schema_cache["dataFormat.memoryCells"],
+            self.unsafe_get("dataFormat.memoryCells"),
             2,
-            self._schema_cache["dataFormat.pixelsX"],
-            self._schema_cache["dataFormat.pixelsY"],
+            self.unsafe_get("dataFormat.pixelsX"),
+            self.unsafe_get("dataFormat.pixelsY"),
         )
 
     def __init__(self, config):
@@ -657,7 +657,7 @@ class AgipdCorrection(BaseCorrection):
         self.kernel_runner.load_cell_table(cell_table)
         self.kernel_runner.correct(self._correction_flag_enabled)
         self.kernel_runner.reshape(
-            output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
             out=buffer_array,
         )
         # after reshape, data for dataOutput is now safe in its own buffer
@@ -669,8 +669,8 @@ class AgipdCorrection(BaseCorrection):
                 preview_cell,
                 preview_pulse,
             ) = utils.pick_frame_index(
-                self._schema_cache["preview.selectionMode"],
-                self._schema_cache["preview.index"],
+                self.unsafe_get("preview.selectionMode"),
+                self.unsafe_get("preview.index"),
                 cell_table,
                 pulse_table,
                 warn_func=self.log_status_warn,
diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py
index 3fd0504a..c228648a 100644
--- a/src/calng/DsscCorrection.py
+++ b/src/calng/DsscCorrection.py
@@ -244,7 +244,7 @@ class DsscCorrection(BaseCorrection):
         self.kernel_runner.load_cell_table(cell_table)
         self.kernel_runner.correct(self._correction_flag_enabled)
         self.kernel_runner.reshape(
-            output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
             out=buffer_array,
         )
         if do_generate_preview:
@@ -255,8 +255,8 @@ class DsscCorrection(BaseCorrection):
                 preview_cell,
                 preview_pulse,
             ) = utils.pick_frame_index(
-                self._schema_cache["preview.selectionMode"],
-                self._schema_cache["preview.index"],
+                self.unsafe_get("preview.selectionMode"),
+                self.unsafe_get("preview.index"),
                 cell_table,
                 pulse_table,
                 warn_func=self.log_status_warn,
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 7d93990a..358ac605 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -197,22 +197,6 @@ class BaseCorrection(PythonDevice):
         "preview.trainIdModulo",
         "loadMostRecentConstants",
     }  # subclass can extend this, /must/ put it in schema as managedKeys
-    _schema_cache_fields = {
-        "doAnything",
-        "constantParameters.memoryCells",
-        "dataFormat.filteredFrames",
-        "dataFormat.memoryCells",
-        "dataFormat.pixelsX",
-        "dataFormat.pixelsY",
-        "dataFormat.outputAxisOrder",
-        "dataFormat.overrideInputAxisOrder",
-        "preview.enable",
-        "preview.index",
-        "preview.selectionMode",
-        "preview.trainIdModulo",
-        "processingStateTimeout",
-        "state",
-    }  # subclass should be aware of cache, but does not need to extend
     _image_data_path = "image.data"  # customize for *some* subclasses
     _cell_table_path = "image.cellId"
 
@@ -232,13 +216,13 @@ class BaseCorrection(PythonDevice):
         """Shape of corrected image data sent on dataOutput. Depends on data format
         parameters pixels x / y, and number of cells (optionally after frame filter)."""
         axis_lengths = {
-            "x": self._schema_cache["dataFormat.pixelsX"],
-            "y": self._schema_cache["dataFormat.pixelsY"],
-            "c": self._schema_cache["dataFormat.filteredFrames"],
+            "x": self.unsafe_get("dataFormat.pixelsX"),
+            "y": self.unsafe_get("dataFormat.pixelsY"),
+            "c": self.unsafe_get("dataFormat.filteredFrames"),
         }
         return tuple(
             axis_lengths[axis]
-            for axis in self._schema_cache["dataFormat.outputAxisOrder"]
+            for axis in self.unsafe_get("dataFormat.outputAxisOrder")
         )
 
     def process_data(
@@ -614,9 +598,6 @@ class BaseCorrection(PythonDevice):
         )
 
     def __init__(self, config):
-        self._schema_cache = {
-            k: config[k] for k in self._schema_cache_fields if config.has(k)
-        }
         super().__init__(config)
 
         self.input_data_dtype = np.dtype(config["dataFormat.inputImageDtype"])
@@ -719,9 +700,6 @@ class BaseCorrection(PythonDevice):
             return
 
         update = self._prereconfigure_update_hash
-        for path in update.getPaths():
-            if path in self._schema_cache_fields:
-                self._schema_cache[path] = update.get(path)
 
         if update.has("frameFilter"):
             with self._buffer_lock:
@@ -769,15 +747,6 @@ class BaseCorrection(PythonDevice):
         self.log.ERROR(msg)
         self.updateState(State.ERROR)
 
-    def set(self, *args):
-        """Wrapper around PythonDevice.set to enable caching "hot" schema elements"""
-        # TODO: handle other cases of PythonDevice.set arguments
-        if len(args) == 2 and not isinstance(args[0], Hash):
-            key, value = args
-            if key in self._schema_cache_fields:
-                self._schema_cache[key] = value
-        super().set(*args)
-
     def requestScene(self, params):
         payload = Hash()
         name = params.get("name", default="")
@@ -951,7 +920,7 @@ class BaseCorrection(PythonDevice):
         method provided by subclass."""
 
         # Is device even ready for this?
-        state = self._schema_cache["state"]
+        state = State[self.unsafe_get("state")]
         if state is State.ERROR:
             # in this case, we should have already issued warning
             return
@@ -991,11 +960,11 @@ class BaseCorrection(PythonDevice):
                 self.updateState(State.PROCESSING)
                 self.log_status_info("Processing data")
 
-            correction_cell_num = self._schema_cache["constantParameters.memoryCells"]
+            correction_cell_num = self.unsafe_get("constantParameters.memoryCells")
             cell_table_max = np.max(cell_table)
 
             image_data = data_hash.get(self._image_data_path)
-            if cell_table.size != self._schema_cache["dataFormat.memoryCells"]:
+            if cell_table.size != self.unsafe_get("dataFormat.memoryCells"):
                 self.log_status_info(
                     f"Updating new input shape {image_data.shape}, updating buffers"
                 )
@@ -1004,14 +973,14 @@ class BaseCorrection(PythonDevice):
                     self._update_frame_filter()
 
             # DataAggregator typically tells us the wrong axis order
-            if self._schema_cache["dataFormat.overrideInputAxisOrder"]:
+            if self.unsafe_get("dataFormat.overrideInputAxisOrder"):
                 expected_shape = self.input_data_shape
                 if expected_shape != image_data.shape:
                     image_data.shape = expected_shape
 
             do_generate_preview = (
-                train_id % self._schema_cache["preview.trainIdModulo"] == 0
-                and self._schema_cache["preview.enable"]
+                train_id % self.unsafe_get("preview.trainIdModulo") == 0
+                and self.unsafe_get("preview.enable")
             )
 
             with self._buffer_lock:
@@ -1059,6 +1028,23 @@ class BaseCorrection(PythonDevice):
         self.signalEndOfStream("dataOutput")
 
 
+# forward-compatible unsafe_get proposed by @haufs
+if not hasattr(BaseCorrection, "unsafe_get"):
+    def unsafe_get(self, key):
+        """Look up key in device schema quickly, but without consistency locks
+
+        This is only relevant for use in hot path (input handler).  Circumvents the
+        locking done by PythonDevice.get. Note that PythonDevice.get does handle some
+        special types (by looking at full schema for type information).  In particular,
+        device state enum: `self.get("state")` will return a State whereas
+        `self.unsafe_get("state")` will return a string. Handle with care!"""
+
+        # at least until Karabo 2.14, self._parameters is maintained by PythonDevice
+        return self._parameters.get(key)
+
+    setattr(BaseCorrection, "unsafe_get", unsafe_get)
+
+
 def add_correction_step_schema(schema, managed_keys, field_flag_mapping):
     """Using the fields in the provided mapping, will add nodes to schema
 
-- 
GitLab