From c7e948a9611dbbe9fa3452d57f0d8af1d2b58055 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 13 Sep 2023 14:59:07 +0200
Subject: [PATCH] Overhaul context warning lamps, add to arbiter kernel

---
 docs/extensions.md                            |  13 ++
 src/calng/FrameSelectionArbiter.py            |  44 +++++-
 src/calng/arbiter_kernels/base_kernel.py      |  12 ++
 src/calng/arbiter_kernels/boolean_ops.py      |  18 ++-
 src/calng/arbiter_kernels/ppu_arbiter.py      |  21 ++-
 src/calng/arbiter_kernels/reduce_threshold.py |  24 +--
 src/calng/base_correction.py                  |  42 ++----
 src/calng/utils.py                            | 141 ++++++++++--------
 8 files changed, 186 insertions(+), 129 deletions(-)

diff --git a/docs/extensions.md b/docs/extensions.md
index 36d635ef..4227348a 100644
--- a/docs/extensions.md
+++ b/docs/extensions.md
@@ -72,6 +72,19 @@ This class for now containts stubs for the methods a kernel will need to provide
 The configuration keys added by `extend_device_schema` will go in the arbiter schema under `frameSelection.kernels.[node name]` where the node name is automatically based on the class name.
 Note that this node name is automatically part of the `prefix` passed to `extend_device_schema`.
 
+Most parameters to most arbiter kernels should be reconfigurable - if not, changing them means restarting the kernel arbiter entirely.
+For now, changing a parameter for the kernel will cause the kernel to be reinstantiated; in the future, a `reconfigure` option similar to that of correction addons may be added for kernels which are changed often / expensive to reinstantiate.
+
+From the base class, an arbiter kernel gets the following properties / methods:
+
+- `_config`: points to the hash with the configuration passed to init
+- `_device`: will point to the host device - but only *after* `__init__`; should generally not be used directly
+- `geometry`: property containing the current detector geometry (gotten from device, requires correctly configured `FrameSelectionArbiter`)
+- `warning_context`: helper function returning a context manager to allow emitting custom warnings via the host device, setting the warning state of `frameSelection.kernelState`
+    - Note that uncaught exceptions in `consider` will also trigger this with `KernelWarning.PROCESSING`
+    - Some common warning types are defined in the `base_kernel.KernelWarning` enum type; feel free to define additional warnings and pass them to the context manager
+    - See example kernels for examples of how to use this
+
 ## Existing arbiter kernels
 The following arbiter kernels are distributed with `calng` as examples:
 
diff --git a/src/calng/FrameSelectionArbiter.py b/src/calng/FrameSelectionArbiter.py
index 3bfcbf4d..34177312 100644
--- a/src/calng/FrameSelectionArbiter.py
+++ b/src/calng/FrameSelectionArbiter.py
@@ -18,6 +18,7 @@ from TrainMatcher import TrainMatcher
 
 from ._version import version as deviceVersion
 from . import utils, geom_utils
+from .arbiter_kernels.base_kernel import KernelWarning
 
 my_schema = Schema()
 (
@@ -71,15 +72,23 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
             .reconfigurable()
             .commit(),
 
-            OUTPUT_CHANNEL(expected)
-            .key("output")
-            .dataSchema(my_schema)
+            STRING_ELEMENT(expected)
+            .key("frameSelection.kernelState")
+            .setSpecialDisplayType("State")
+            .displayedName("Kernel state")
+            .readOnly()
+            .initialValue("OFF")
             .commit(),
 
             NODE_ELEMENT(expected)
             .key("frameSelection.kernels")
             .displayedName("Kernels")
             .commit(),
+
+            OUTPUT_CHANNEL(expected)
+            .key("output")
+            .dataSchema(my_schema)
+            .commit(),
         )
         for kernel_class in kernel_choice.values():
             kernel_class.extend_device_schema(
@@ -92,11 +101,18 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
 
     def initialization(self):
         super().initialization()
-        self._kernel_class = kernel_choice[self.get("frameSelection.kernelChoice")]
-        self._kernel = self._kernel_class(
-            self._parameters[f"frameSelection.kernels.{self._kernel_class.__name__}"]
+        self.warning_context = utils.WarningContextSystem(
+            self,
+            on_success={"frameSelection.kernelState": "ON"},
         )
-        self._kernel._device = self
+        with self.warning_context("frameSelection.kernelState", KernelWarning.INIT):
+            self._kernel_class = kernel_choice[self.get("frameSelection.kernelChoice")]
+            self._kernel = self._kernel_class(
+                self._parameters[
+                    f"frameSelection.kernels.{self._kernel_class.__name__}"
+                ]
+            )
+            self._kernel._device = self
 
         self._geometry = None
         if self.get("geometryDevice"):
@@ -110,13 +126,20 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
 
     def on_matched_data(self, train_id, sources):
         # TODO: robust frame deduction
+        num_frames = None
         for (data, _) in sources.values():
             if not data.has("image.cellId"):
                 continue
             num_frames = data.get("image.cellId").size
             break
+        else:
+            # TODO: also use warning context
+            self.log.WARN("Unable to figure out number of frames")
 
-        decision = self._kernel.consider(train_id, sources, num_frames)
+        with self.warning_context(
+            "frameSelection.kernelState", KernelWarning.PROCESSING
+        ):
+            decision = self._kernel.consider(train_id, sources, num_frames)
         if isinstance(decision, Hash):
             assert decision.has("data.dataFramePattern")
             decision["data.dataFramePattern"] = list(
@@ -137,6 +160,11 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
         self.output.update()
         self.rate_out.update()
 
+    def start(self):
+        self.set("frameSelection.kernelState", "OFF")
+        self.warning_context.reset()
+        super().start()
+
     def preReconfigure(self, conf):
         super().preReconfigure(conf)
         self._reinstantiate_arbiter_kernel = False
diff --git a/src/calng/arbiter_kernels/base_kernel.py b/src/calng/arbiter_kernels/base_kernel.py
index f9a43614..6d36a956 100644
--- a/src/calng/arbiter_kernels/base_kernel.py
+++ b/src/calng/arbiter_kernels/base_kernel.py
@@ -1,6 +1,15 @@
+import enum
+
 import numpy as np
 
 
+class KernelWarning(enum.Enum):
+    INIT = "init"
+    MISSINGSOURCE = "missingsource"
+    MISSINGKEY = "missingkey"
+    PROCESSING = "processing"
+
+
 class BaseArbiterKernel:
     def __init__(self, config):
         """No need for prefix - hosting device should pass us the relevant subnode"""
@@ -11,6 +20,9 @@ class BaseArbiterKernel:
     def geometry(self):
         return self._device._geometry
 
+    def warning_context(self, warn_type):
+        return self._device.warning_context("frameSelection.kernelState", warn_type)
+
     @staticmethod
     def extend_device_schema(schema, prefix):
         """Should add configurability to the arbiter (matcher) the kernel will be
diff --git a/src/calng/arbiter_kernels/boolean_ops.py b/src/calng/arbiter_kernels/boolean_ops.py
index 91175dcb..135516f6 100644
--- a/src/calng/arbiter_kernels/boolean_ops.py
+++ b/src/calng/arbiter_kernels/boolean_ops.py
@@ -3,7 +3,7 @@ from karabo.bound import (
     STRING_ELEMENT,
 )
 import numpy as np
-from .base_kernel import BaseArbiterKernel
+from .base_kernel import BaseArbiterKernel, KernelWarning
 
 
 class BooleanCombination(BaseArbiterKernel):
@@ -44,7 +44,15 @@ class BooleanCombination(BaseArbiterKernel):
 
     def consider(self, train_id, sources, num_frames):
         # pretty sure this is special case of reduce and threshold in some algebra
-        return self._operator(
-            [data[self._key] for (data, _) in sources.values() if data.has(self._key)],
-            axis=0,
-        ).astype(np.uint8, copy=False)
+        with self.warning_context(KernelWarning.MISSINGKEY) as warn:
+            sources_with_key = [
+                data[self._key] for (data, _) in sources.values() if data.has(self._key)
+            ]
+            if not sources_with_key:
+                warn(f"No sources had '{self._key}'")
+            else:
+                return self._operator(
+                    sources_with_key,
+                    axis=0,
+                ).astype(np.uint8, copy=False)
+        return np.ones(num_frames, dtype=np.uint8)
diff --git a/src/calng/arbiter_kernels/ppu_arbiter.py b/src/calng/arbiter_kernels/ppu_arbiter.py
index ba753eed..04fa919d 100644
--- a/src/calng/arbiter_kernels/ppu_arbiter.py
+++ b/src/calng/arbiter_kernels/ppu_arbiter.py
@@ -5,7 +5,7 @@ from karabo.bound import (
     STRING_ELEMENT,
 )
 
-from .base_kernel import BaseArbiterKernel
+from .base_kernel import BaseArbiterKernel, KernelWarning
 
 
 class PpuKernel(BaseArbiterKernel):
@@ -14,6 +14,8 @@ class PpuKernel(BaseArbiterKernel):
     def __init__(self, config):
         super().__init__(config)
         self._ppu_device_id = config.get("ppuDevice")
+        self._target = 0
+        self._num_trains = np.iinfo(np.uint64).max
 
     @staticmethod
     def extend_device_schema(schema, prefix):
@@ -32,27 +34,22 @@ class PpuKernel(BaseArbiterKernel):
             )
             .assignmentOptional()
             .defaultValue("")
+            .reconfigurable()
             .commit(),
         )
 
     def consider(self, train_id, sources, num_frames):
-        # TODO: same kind of warnings that PickyBoi would emit
-        try:
+        with self.warning_context(KernelWarning.MISSINGSOURCE):
             ppu_data = next(
                 data
                 for source, (data, _) in sources.items()
                 if source == self._ppu_device_id
             )
-            target = ppu_data["trainTrigger.sequenceStart.value"]
-            num_trains = ppu_data["trainTrigger.numberOfTrains.value"]
-        except (StopIteration, RuntimeError) as ex:
-            if isinstance(ex, StopIteration):
-                print("No find PPU device in source :(")
-            else:
-                print("Didn't get expected keys from PPU device, check sources config")
-            return np.ones(num_frames, dtype=np.uint8)
+            with self.warning_context(KernelWarning.MISSINGKEY):
+                self._target = ppu_data["trainTrigger.sequenceStart.value"]
+                self._num_trains = ppu_data["trainTrigger.numberOfTrains.value"]
 
-        if train_id in range(target, target + num_trains):
+        if train_id in range(self._target, self._target + self._num_trains):
             return np.ones(num_frames, dtype=np.uint8)
         else:
             return np.zeros(num_frames, dtype=np.uint8)
diff --git a/src/calng/arbiter_kernels/reduce_threshold.py b/src/calng/arbiter_kernels/reduce_threshold.py
index a19add9e..8c18e18b 100644
--- a/src/calng/arbiter_kernels/reduce_threshold.py
+++ b/src/calng/arbiter_kernels/reduce_threshold.py
@@ -6,7 +6,7 @@ from karabo.bound import (
     STRING_ELEMENT,
 )
 import numpy as np
-from .base_kernel import BaseArbiterKernel
+from .base_kernel import BaseArbiterKernel, KernelWarning
 
 
 class ReduceAndThreshold(BaseArbiterKernel):
@@ -81,14 +81,14 @@ class ReduceAndThreshold(BaseArbiterKernel):
         )
 
     def consider(self, train_id, sources, num_frames):
-        return self._comparator(
-            self._reduction(
-                [
-                    data[self._key]
-                    for (data, _) in sources.values()
-                    if data.has(self._key)
-                ],
-                axis=0,
-            ),
-            self._threshold,
-        ).astype(np.uint8, copy=False)
+        with self.warning_context(KernelWarning.MISSINGKEY) as warn:
+            sources_with_key = [
+                data[self._key] for (data, _) in sources.values() if data.has(self._key)
+            ]
+            if not sources_with_key:
+                warn(f"No sources had '{self._key}'")
+            else:
+                return self._comparator(
+                    self._reduction(sources_with_key, axis=0), self._threshold
+                ).astype(np.uint8, copy=False)
+        return np.ones(num_frames, dtype=np.uint8)
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index d37a7412..dfedb57d 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -1,5 +1,4 @@
 import concurrent.futures
-import contextlib
 import enum
 import functools
 import gc
@@ -622,13 +621,14 @@ class BaseCorrection(PythonDevice):
 
         # register slots
         self.registerSlot(self.slotReceiveGeometry)
+
         def constant_override_fun(friend_fun, constant, preserve_fields):
             def aux():
                 self.flush_constants(
                     constants={constant}, preserve_fields=preserve_fields
                 )
                 with self.warning_context(
-                    f"foundConstants.{constant.name}.state", on_success="ON"
+                    f"foundConstants.{constant.name}.state"
                 ) as warn:
                     try:
                         constant_data = getattr(self.calcat_friend, friend_fun)(
@@ -679,33 +679,15 @@ class BaseCorrection(PythonDevice):
 
         self.registerInitialFunction(self._initialization)
 
-    @contextlib.contextmanager
-    def warning_context(
-        self,
-        schema_key,
-        warn_type=None,
-        on_success="NORMAL",
-        on_error="ERROR",
-        reraise=True,
-        only_print_once=False,
-    ):
-        tracker = self._warning_trackers[schema_key]
-        warn_fun = tracker.new_context(warn_type, only_print_once)
-        try:
-            yield warn_fun
-        except Exception as e:
-            warn_fun(f"Exception happened for {schema_key}, {warn_type}: {e}")
-            if reraise:
-                raise e
-        finally:
-            tracker.update_state(on_success=on_success, on_error=on_error)
-
     def _initialization(self):
         self.updateState(State.INIT)
-        self._warning_trackers = {
-            key: utils.ContextWarningLamp(self, key)
-            for key in self.getFullSchema().getDefaultValue("warningLamps")
-        }
+        self.warning_context = utils.WarningContextSystem(
+            self,
+            on_success={
+                f"foundConstants.{constant.name}.state": "ON"
+                for constant in self._constant_enum_class
+            },
+        )
         self._preview_friend = preview_utils.PreviewFriend(
             self,
             output_channels=self._preview_outputs,
@@ -713,9 +695,6 @@ class BaseCorrection(PythonDevice):
         self["availableScenes"] = self["availableScenes"] + [
             f"preview:{channel}" for channel in self._preview_outputs
         ]
-        for constant in self._constant_enum_class:
-            key = f"foundConstants.{constant.name}.state"
-            self._warning_trackers[key] = utils.ContextWarningLamp(self, key)
 
         self._geometry = None
         if self.get("geometryDevice"):
@@ -1146,7 +1125,6 @@ class BaseCorrection(PythonDevice):
             with self.warning_context(
                 "inputDataState",
                 WarningLampType.EMPTY_HASH,
-                only_print_once=True,
             ) as warn:
                 self._shmem_receiver.dereference_shmem_handles(data_hash)
                 try:
@@ -1199,7 +1177,7 @@ class BaseCorrection(PythonDevice):
                         "connection to timeserver."
                     )
             with self.warning_context(
-                "inputDataState", WarningLampType.TRAIN_ID, only_print_once=True
+                "inputDataState", WarningLampType.TRAIN_ID
             ) as warn:
                 if train_id > (
                     my_train_id
diff --git a/src/calng/utils.py b/src/calng/utils.py
index cd534b75..3eba08da 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -1,5 +1,6 @@
-import enum
+import contextlib
 import collections
+import enum
 import functools
 import inspect
 import itertools
@@ -10,67 +11,86 @@ from timeit import default_timer
 import numpy as np
 
 
-class ContextWarningLamp:
-    """Warning model: all warnings are generated within contexts. Each context handles
-    one type of warning (one lamp aggregates multiple types). If a warning is issued
-    within a context, the corresponding warning type is set. If not, then the
-    corresponding warning type is unset."""
-
-    def __init__(self, device, schema_key):
-        self._device = device
-        self._schema_key = schema_key
-        self._active_warnings = {}
-        # note: the following two attributes are only sets for future generalization
-        # warnings issued during "current" context
-        self._new_warnings = set()
-        # warning types used in current context (to clear if no warnings issued)
-        self._new_tested = set()
-
-    def new_context(self, warn_type=None, only_print_once=False):
-        """Use with "with", will give warning function with appropriate parameters
-
-        warn_type: The warning type which to set or unset based on outcome within
-        context (must be hashable, should probably be some enum member)
-        only_print_once: By default, the exact same string will not be printed twice for
-        any given warn_type. Some errors may, however, generate slightly differing
-        strings each time. With only_print_once, only the first warning for a given
-        warn_type is printed as long as the warning remains active.
-        """
-        # discard instead of clear in case of nesting (not yet fully supported though)
-        self._new_warnings.discard(warn_type)
-        self._new_tested.add(warn_type)
-        return functools.partial(
-            self.warn, warn_type=warn_type, only_print_once=only_print_once
-        )
+class WarningContextSystem:
+    """Helper object to trigger warning lamps based on different warning types in
+    contexts. Intended use: something that is checked multiple times and is good or bad
+    depending on last check (ex. per train).
+
+    Warning model: a warning is generated within context (see __call_). Each warning
+    context handles one schema key and one warning type for that key. One lamp
+    aggregates multiple warning types; for each lamp, a set of currently active warnings
+    decides the ultimate fate of the lamp.
+
+    If a warning is issued within a context, the corresponding warning type is added to
+    the set. If not, then the corresponding warning type is removed from the set."""
 
-    def warn(self, message, warn_type, only_print_once):
-        # avoid duplicating current warning message for this type
-        if (warn_type not in self._active_warnings) or (
-            not only_print_once and message != self._active_warnings[warn_type]
-        ):
-            if warn_type is None:
-                self._device.log_status_warn(message)
+    def __init__(self, device, on_success=None, on_error=None):
+        self.device = device
+        # key -> set of warning types
+        self.active_warnings = collections.defaultdict(lambda: set())
+        # (key, warn type) -> True or False depending on last context
+        self.triggered_warnings = collections.defaultdict(lambda: False)
+
+        # key -> state string
+        self.on_success = collections.defaultdict(lambda: "ON")
+        if on_success is not None:
+            self.on_success.update(on_success)
+        self.on_error = collections.defaultdict(lambda: "ERROR")
+        if on_error is not None:
+            self.on_error.update(on_error)
+        # TODO: support mapping different warning types to different Karabo states
+        # (for colors per lamp depending on state signifier on active set)
+
+    @contextlib.contextmanager
+    def __call__(self, key, warn_type):
+        """Will return a context handler object for tkey specified schema key and
+        warning type. It is your responsibility to ensure that the key is valid. Using
+        __call__ on the returned context handler will issue a warning and add the
+        warning type to the set of active warnings for the key."""
+        warn_fun = functools.partial(self.issue_warning, key, warn_type)
+        try:
+            yield warn_fun
+        except Exception as ex:
+            warn_fun(f"Unexpected exception: {ex}")
+        finally:
+            # warn_fun may have set trigger; we update active warnings out here
+            active_set = self.active_warnings[key]
+            if self.triggered_warnings[(key, warn_type)]:
+                if warn_type not in active_set:
+                    self.device.set(key, self.on_error[key])
+                    active_set.add(warn_type)
+                self.triggered_warnings[(key, warn_type)] = False
             else:
-                self._device.log_status_warn(f"{warn_type.name}: {message}")
-        self._new_warnings.add(warn_type)
-        self._active_warnings[warn_type] = message
-
-    def update_state(self, on_success="NORMAL", on_error="ERROR"):
-        for warn_type_lifted in self._active_warnings.keys() & (
-            self._new_tested - self._new_warnings
-        ):
-            del self._active_warnings[warn_type_lifted]
-            # None is used for "simpler" lamps (like for constants)
-            if warn_type_lifted is not None:
-                self._device.log_status_info(f"Lifted warning: {warn_type_lifted.name}")
-        current_state = self._device.unsafe_get(self._schema_key)
-        if self._active_warnings and current_state != on_error:
-            self._device.set(self._schema_key, on_error)
-        elif not self._active_warnings and current_state != on_success:
-            self._device.set(self._schema_key, on_success)
-        # TODO: maybe handle nesting / multi-type alarm context
-        self._new_warnings.clear()
-        self._new_tested.clear()
+                # warning not triggered, so clear out active set
+                if warn_type in active_set:
+                    message = (
+                        f"{key}: warning lifted!"
+                        if warn_type is None
+                        else f"{key} > {warn_type.name}: lifted!"
+                    )
+                    self.device.log.INFO(message)
+                    self.device.set("status", message)
+                    active_set.discard(warn_type)
+                if not active_set:
+                    self.device.set(key, self.on_success[key])
+
+    def reset(self):
+        self.active_warnings.clear()
+        self.triggered_warnings.clear()
+
+    def issue_warning(self, key, warn_type, text):
+        self.triggered_warnings[(key, warn_type)] = True
+        # already active, so don't print
+        if warn_type in self.active_warnings[key]:
+            return
+        message = (
+            f"{key}: {text}"
+            if warn_type is None
+            else f"{key} > {warn_type.name}: {text}"
+        )
+        # TODO: make log + "status" behavior customizable
+        self.device.log.WARN(message)
+        self.device.set("status", message)
 
 
 class PreviewIndexSelectionMode(enum.Enum):
@@ -575,6 +595,7 @@ def grid_to_cover_shape_with_blocks(full_shape, block_shape):
 def add_unsafe_get(device_class):
     # forward-compatible unsafe_get proposed by @haufs
     if not hasattr(device_class, "unsafe_get"):
+
         def unsafe_get(self, key):
             """Look up key in device schema quickly, but without consistency locks
 
-- 
GitLab