From 317a71c93fc0a774a3776e670f6d4dac9ed57be9 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 27 Oct 2021 14:52:04 +0200
Subject: [PATCH] Give schema up front

---
 src/calng/AgipdCorrection.py |  10 +++
 src/calng/base_correction.py | 125 +++++++++++++++++++++++++++--------
 src/calng/calcat_utils.py    |  37 +++++------
 3 files changed, 121 insertions(+), 51 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index 7ff10526..d3f42390 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -58,6 +58,16 @@ class AgipdCorrection(BaseCorrection):
             .commit(),
         )
         AgipdCorrection._managed_keys.append("sendGainMap")
+
+        (
+            STRING_ELEMENT(expected)
+            .key("dataOutput.schema.image.gainMap")
+            .displayedName("Gain map (optional)")
+            .assignmentOptional()
+            .defaultValue("")
+            .commit()
+        )
+
         AgipdCalcatFriend.add_schema(expected, AgipdCorrection._managed_keys)
         # this is not automatically done by superclass for complicated class reasons
         add_correction_step_schema(
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 8b575485..c99d82eb 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -10,6 +10,7 @@ from karabo.bound import (
     FLOAT_ELEMENT,
     INPUT_CHANNEL,
     INT32_ELEMENT,
+    INT64_ELEMENT,
     KARABO_CLASSINFO,
     NDARRAY_ELEMENT,
     NODE_ELEMENT,
@@ -90,10 +91,101 @@ class BaseCorrection(PythonDevice):
 
     @staticmethod
     def expectedParameters(expected):
+        output_schema = Schema()
+        (
+            NODE_ELEMENT(output_schema).key("image").commit(),
+            STRING_ELEMENT(output_schema)
+            .key("image.data")
+            .assignmentOptional()
+            .defaultValue("")
+            .commit(),
+            NDARRAY_ELEMENT(output_schema).key("image.length").dtype("UINT32").commit(),
+            NDARRAY_ELEMENT(output_schema).key("image.cellId").dtype("UINT16").commit(),
+            NDARRAY_ELEMENT(output_schema)
+            .key("image.pulseId")
+            .dtype("UINT64")
+            .commit(),
+            NDARRAY_ELEMENT(output_schema).key("image.status").commit(),
+            NDARRAY_ELEMENT(output_schema)
+            .key("image.trainId")
+            .dtype("UINT64")
+            .commit(),
+            VECTOR_STRING_ELEMENT(output_schema)
+            .key("calngShmemPaths")
+            .assignmentOptional()
+            .defaultValue(["image.data"])
+            .commit(),
+            NODE_ELEMENT(output_schema).key("metadata").commit(),
+            STRING_ELEMENT(output_schema)
+            .key("metadata.source")
+            .assignmentOptional()
+            .defaultValue("")
+            .commit(),
+            NODE_ELEMENT(output_schema).key("metadata.timestamp").commit(),
+            INT32_ELEMENT(output_schema)
+            .key("metadata.timestamp.tid")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            NODE_ELEMENT(output_schema).key("header").commit(),
+            INT32_ELEMENT(output_schema)
+            .key("header.minorTrainFormatVersion")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT32_ELEMENT(output_schema)
+            .key("header.majorTrainFormatVersion")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT32_ELEMENT(output_schema)
+            .key("header.trainId")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT64_ELEMENT(output_schema)
+            .key("header.linkId")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT64_ELEMENT(output_schema)
+            .key("header.dataId")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT64_ELEMENT(output_schema)
+            .key("header.pulseCount")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            NDARRAY_ELEMENT(output_schema).key("header.reserved").commit(),
+            NDARRAY_ELEMENT(output_schema).key("header.magicNumberBegin").commit(),
+            NODE_ELEMENT(output_schema).key("detector").commit(),
+            INT32_ELEMENT(output_schema)
+            .key("detector.trainId")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            NDARRAY_ELEMENT(output_schema).key("detector.data").commit(),
+            NODE_ELEMENT(output_schema).key("trailer").commit(),
+            NDARRAY_ELEMENT(output_schema).key("trailer.checksum").commit(),
+            NDARRAY_ELEMENT(output_schema).key("trailer.magicNumberEnd").commit(),
+            INT32_ELEMENT(output_schema)
+            .key("trailer.status")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+            INT32_ELEMENT(output_schema)
+            .key("trailer.trainId")
+            .assignmentOptional()
+            .defaultValue(0)
+            .commit(),
+        )
+        (OUTPUT_CHANNEL(expected).key("dataOutput").dataSchema(output_schema).commit(),)
+
         (
             INPUT_CHANNEL(expected).key("dataInput").commit(),
             # note: output schema not set, will be updated to match data later
-            OUTPUT_CHANNEL(expected).key("dataOutput").commit(),
             VECTOR_STRING_ELEMENT(expected)
             .key("fastSources")
             .displayedName("Fast data sources")
@@ -227,9 +319,11 @@ class BaseCorrection(PythonDevice):
 
         preview_schema = Schema()
         (
-            NODE_ELEMENT(expected).key("preview").displayedName("Preview").commit(),
             NODE_ELEMENT(preview_schema).key("data").commit(),
             NDARRAY_ELEMENT(preview_schema).key("data.adc").dtype("FLOAT").commit(),
+        )
+        (
+            NODE_ELEMENT(expected).key("preview").displayedName("Preview").commit(),
             OUTPUT_CHANNEL(expected)
             .key("preview.outputRaw")
             .dataSchema(preview_schema)
@@ -370,7 +464,6 @@ class BaseCorrection(PythonDevice):
         self._correction_flag_preview = self._correction_flag_class.NONE
 
         self._shmem_buffer = None
-        self._has_set_output_schema = False
         self._has_updated_shapes = False
         self._processing_time_ema = utils.ExponentialMovingAverage(alpha=0.3)
         self._rate_tracker = utils.WindowRateTracker()
@@ -521,11 +614,6 @@ class BaseCorrection(PythonDevice):
             Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")),
         )
 
-        if not self._has_set_output_schema:
-            self.updateState(State.CHANGING)
-            self._update_output_schema(data)
-            self.updateState(State.PROCESSING)
-
         channel = self.signalSlotable.getOutputChannel("dataOutput")
         channel.write(data, metadata, False)
         channel.update()
@@ -571,26 +659,6 @@ class BaseCorrection(PythonDevice):
         self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
         self.log.DEBUG(f"Corrections for preview: {str(preview)}")
 
-    def _update_output_schema(self, data):
-        """Updates the schema of dataOutput based on data we want to send
-
-        This should only be called once: when handling output for the first
-        time, we update the schema to match the modified data we'd send.
-        """
-
-        # TODO: remove when switching to specifying output schema up-front
-        self.log.INFO("Updating output schema")
-        my_schema_update = Schema()
-        data_schema = hashToSchema.HashToSchema(data).schema
-        (
-            OUTPUT_CHANNEL(my_schema_update)
-            .key("dataOutput")
-            .dataSchema(data_schema)
-            .commit()
-        )
-        self.appendSchema(my_schema_update)
-        self._has_set_output_schema = True
-
     def _update_shapes(self):
         """(Re)initialize buffers according to expected data shapes"""
         self.log.INFO("Updating shapes")
@@ -647,7 +715,6 @@ class BaseCorrection(PythonDevice):
                 self.updateState(State.ON)
 
     def handle_eos(self, channel):
-        self._has_set_output_schema = False
         self.updateState(State.ON)
         self.signalEndOfStream("dataOutput")
 
diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py
index 9088f058..5b5818ca 100644
--- a/src/calng/calcat_utils.py
+++ b/src/calng/calcat_utils.py
@@ -199,10 +199,8 @@ class BaseCalcatFriend:
             .commit(),
             STRING_ELEMENT(schema)
             .key(f"{param_prefix}.constantVersionEventAt")
-            .displayedName("TODO")
-            .description(
-                "TODO"
-            )
+            .displayedName("Event at timestamp (for constant version)")
+            .description("TODO")
             .assignmentOptional()
             .defaultValue("")
             .reconfigurable()
@@ -284,7 +282,9 @@ class BaseCalcatFriend:
         self.cached_constants = {}
 
         if not secrets_fn.is_file():
-            self.device.log_status_warn(f"Missing CalCat secrets file (expected {secrets_fn})")
+            self.device.log_status_warn(
+                f"Missing CalCat secrets file (expected {secrets_fn})"
+            )
         with secrets_fn.open("r") as fd:
             calcat_secrets = json.load(fd)
 
@@ -321,7 +321,6 @@ class BaseCalcatFriend:
         """Helper to update information about found constants on device"""
         self.device.set(f"{self.status_prefix}.{constant.name}.{key}", value)
 
-
     @functools.cached_property
     def detector_id(self):
         resp = Detector.get_by_identifier(self.client, self._get_param("detectorName"))
@@ -338,8 +337,6 @@ class BaseCalcatFriend:
         self._set_param("detectorTypeId", str(res))
         return res
 
-    # TODO: support updating mapping (means del self.pdus and properties using it)
-    # TODO: support snapshot
     @functools.cached_property
     def pdus(self):
         resp = PhysicalDetectorUnit.get_all_by_detector(
@@ -368,7 +365,6 @@ class BaseCalcatFriend:
     @utils.threadsafe_cache
     def calibration_id(self, calibration_name: str):
         resp = Calibration.get_by_name(self.client, calibration_name)
-        # TODO: include calibration name in exception
         _check_resp(resp, CalibrationNotFound)
         return resp["data"]["id"]
 
@@ -395,9 +391,7 @@ class BaseCalcatFriend:
         _check_resp(resp)
         return resp["data"]["id"]
 
-    def get_constant_version(self, constant, snapshot_at=None):
-        # TODO: support snapshot
-        # TODO: support creation time
+    def get_constant_version(self, constant):
         # TODO: catch exceptions, give warnings appropriately
         karabo_da = self._get_param("karaboDa")
         self.device.log_status_info(f"Attempting to find {constant} for {karabo_da}")
@@ -423,11 +417,11 @@ class BaseCalcatFriend:
         )
         self._set_status(constant, "constantId", constant_id)
 
-        resp = CalibrationConstantVersion.get_by_uk(
+        resp = CalibrationConstantVersion.get_closest_by_time(
             self.client,
-            calibration_constant_id=constant_id,
+            calibration_constant_ids=[constant_id],
             physical_detector_unit_id=self._karabo_da_to_id[karabo_da],
-            event_at=None,
+            event_at=self._get_param("constantVersionEventAt"),
             snapshot_at=None,
         )
         _check_resp(resp)
@@ -459,7 +453,6 @@ class BaseCalcatFriend:
         file_path = (
             self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"]
         )
-        # TODO: handle FileNotFoundError if we are led astray
         with h5py.File(file_path, "r") as fd:
             constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"])
         self.cached_constants[constant] = constant_data
@@ -471,14 +464,12 @@ class BaseCalcatFriend:
         self._set_status(constant, "found", True)
         return constant_data
 
-    def get_constant_version_and_call_me_back(
-        self, constant, callback, snapshot_at=None
-    ):
+    def get_constant_version_and_call_me_back(self, constant, callback):
         """Runs get_constant_version in thread, will call callback on completion"""
         # TODO: do we want to use asyncio / "modern" async?
         # TODO: consider moving out of this class, closer to correction device
         def aux():
-            data = self.get_constant_version(constant, snapshot_at)
+            data = self.get_constant_version(constant)
             callback(constant, data)
 
         thread = threading.Thread(target=aux)
@@ -581,11 +572,13 @@ class AgipdCalcatFriend(BaseCalcatFriend):
 
     def dark_condition(self):
         res = OperatingConditions()
+        res["Memory cells"] = self._get_param("memoryCells")
+        res["Sensor Bias Voltage"] = self._get_param("biasVoltage")
         res["Pixels X"] = self._get_param("pixelsX")
         res["Pixels Y"] = self._get_param("pixelsY")
-        res["Memory cells"] = self._get_param("memoryCells")
         res["Acquisition rate"] = self._get_param("acquisitionRate")
-        res["Sensor Bias Voltage"] = self._get_param("biasVoltage")
+        # TODO: make configurable whether or not to include gain setting?
+        res["Gain Setting"] = self._get_param("gainSetting")
         return res
 
     def illuminated_condition(self):
-- 
GitLab