From 455562397f13bb414ef1a8c508e5b1a8b95e8961 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Mon, 24 Jul 2023 17:07:20 +0200
Subject: [PATCH] Allow multiple consecutive trains

---
 src/calng/PickyBoi.py | 167 +++++++++++++++++++++++++++++++-----------
 1 file changed, 123 insertions(+), 44 deletions(-)

diff --git a/src/calng/PickyBoi.py b/src/calng/PickyBoi.py
index 19777a7b..0d484425 100644
--- a/src/calng/PickyBoi.py
+++ b/src/calng/PickyBoi.py
@@ -18,12 +18,14 @@ from karabo.bound import (
     UINT64_ELEMENT,
     VECTOR_BOOL_ELEMENT,
     VECTOR_STRING_ELEMENT,
-    Hash,
+    ChannelMetaData,
+    Epochstamp,
     ImageData,
     PythonDevice,
     Schema,
     State,
     Timestamp,
+    Trainstamp,
     Types,
 )
 
@@ -37,9 +39,22 @@ class PickyBoi(PythonDevice):
         (
             OVERWRITE_ELEMENT(expected)
             .key("state")
+            .setNewOptions(
+                State.INIT,
+                State.MONITORING,  # waiting for picked train(s)
+                State.ACQUIRING,  # currently getting picked trains
+                State.PASSIVE,  # after getting picked trains with no new pick set
+                State.ERROR,  # missed the train
+            )
             .setNewDefaultValue(State.INIT)
             .commit(),
 
+            STRING_ELEMENT(expected)
+            .key("ppuFollowingState")
+            .readOnly()
+            .initialValue("OFF")
+            .commit(),
+
             INPUT_CHANNEL(expected)
             .key("input")
             .commit(),
@@ -52,7 +67,7 @@ class PickyBoi(PythonDevice):
             .commit(),
 
             UINT64_ELEMENT(expected)
-            .key("numTrainToCatch")
+            .key("numberOfTrainsToCatch")
             .assignmentOptional()
             .defaultValue(1)
             .reconfigurable()
@@ -65,7 +80,8 @@ class PickyBoi(PythonDevice):
             .commit(),
 
             SLOT_ELEMENT(expected)
-            .key("watchPpu")
+            .key("toggleFollowPpu")
+            .allowedStates([State.MONITORING, State.PASSIVE, State.ERROR])
             .commit(),
 
             SLOT_ELEMENT(expected)
@@ -74,43 +90,70 @@ class PickyBoi(PythonDevice):
 
             SLOT_ELEMENT(expected)
             .key("captureNextTrain")
+            .allowedStates([State.MONITORING, State.PASSIVE, State.ERROR])
             .commit(),
         )
 
     def __init__(self, config):
         super().__init__(config)
-        self.registerInitialFunction(self._initialization)
-        self._schema_is_set = False
         self._previous_tid = 0
+        self._trains_to_get = set()  # will hold range of trains
+        self._old_target_tid = None  # just used for warnings about missing trains
+
+        # manual override: forward starting from next train, whatever it is
         self._just_capture_next = False
-        self.KARABO_SLOT(self.resetCapturedSchema)
         self.KARABO_SLOT(self.captureNextTrain)
 
+        # output schema set from first data received; can be reset
+        self._schema_is_set = False
+        self.KARABO_SLOT(self.resetCapturedSchema)
+
+        self._following_ppu = None  # will hold name of PPU device when following
+        self.KARABO_SLOT(self.toggleFollowPpu)
+
+        self.registerInitialFunction(self._initialization)
+
+    def _initialization(self):
+        self.KARABO_ON_DATA("input", self.input_handler)
+        # if ppuDevice is set, will try to follow immediately
+        if self.get("ppuDevice"):
+            self.toggleFollowPpu()
+
     def resetCapturedSchema(self):
         self._schema_is_set = False
 
     def captureNextTrain(self):
         self._just_capture_next = True
+        if not self.get("state") is State.MONITORING:
+            self.updateState(State.MONITORING)
 
-    def watchPpu(self):
-        ppu_device_id = self.get("ppuDevice")
+    def toggleFollowPpu(self):
         client = self.remote()
-        conf = client.getConfiguration(ppu_device_id)
-        self.handlePpuDeviceConfiguration(conf)
-        client.registerDeviceMonitor(ppu_device_id, self.handlePpuDeviceConfiguration)
+        if self._following_ppu is None:
+            ppu_device_id = self.get("ppuDevice")
+            conf = client.getConfiguration(ppu_device_id)
+            self.handlePpuDeviceConfiguration(conf)
+            client.registerDeviceMonitor(
+                ppu_device_id, self.handlePpuDeviceConfiguration
+            )
+            self._following_ppu = ppu_device_id
+            self.set("ppuFollowingState", "ON")
+        else:
+            client.unregisterDeviceMonitor(self._following_ppu)
+            self.set("ppuFollowingState", "OFF")
+            self._following_ppu = None
 
     def handlePpuDeviceConfiguration(self, conf):
-        ...
-        # TODO
-        self._set_new_target_tid(new_target_tid)
-        self.set("nextTrainToCatch", new_target_tid)
-
-    def _initialization(self):
-        self.KARABO_ON_DATA("input", self.input_handler)
+        if conf.has("numberOfTrains"):
+            self.set("numberOfTrainsToCatch", conf["numberOfTrains"])
+        if conf.has("sequenceStart"):
+            self._old_target_tid = self.get("nextTrainToCatch")
+            self.set("nextTrainToCatch", conf["sequenceStart"])
+        if conf.has("numberOfTrains") or conf.has("sequenceStart"):
+            self._update_target()
 
     def input_handler(self, data, meta):
         if not self._schema_is_set:
-            self.updateState(State.ROTATING)
             schema_update = Schema()
             (
                 OUTPUT_CHANNEL(schema_update)
@@ -126,51 +169,87 @@ class PickyBoi(PythonDevice):
         current_tid = Timestamp.fromHashAttributes(
             meta.getAttributes("timestamp")
         ).getTrainId()
+        # TODO: check against timeserver to handle wild future trains
         if self._just_capture_next:
             self.set("nextTrainToCatch", current_tid)
+            self._traint_to_get = set(
+                range(current_tid, current_tid + self.get("numberOfTrainsToCatch"))
+            )
             self._just_capture_next = False
         target_tid = self.get("nextTrainToCatch")
 
-        if target_tid > current_tid:
+        if current_tid in self._trains_to_get:
+            # capture
+            if state is not State.ACQUIRING:
+                self.updateState(State.ACQUIRING)
+            channel = self.signalSlotable.getOutputChannel("output")
+            channel.write(
+                data,
+                # TODO: forward source name or use own?
+                ChannelMetaData(
+                    meta.get("source"), Timestamp(Epochstamp(), Trainstamp(current_tid))
+                ),
+                copyAllData=False,
+            )
+            channel.update()
+            # TODO: consider buffering instead and doing something clever
+            self._trains_to_get.discard(current_tid)
+            if not self._trains_to_get:
+                self.log.INFO("Caught all target trains :D")
+                self.updateState(State.PASSIVE)
+        elif target_tid > current_tid:
+            # wait
             if state is not State.MONITORING:
                 self.updateState(State.MONITORING)
-        elif target_tid < current_tid:
-            if self._previous_tid < target_tid and state is not State.ERROR:
-                self.log.ERROR(f"Missed target train of {target_tid} :(")
-                self.updateState(State.ERROR)
-            else:
-                if state is not State.PASSIVE:
-                    self.updateState(State.PASSIVE)
         else:
-            self.updateState(State.FILLING)
-            self.log.INFO(f"Got target train {target} now :D")
-            # TODO: copy train ID on metadata
-            self.writeChannel("output", data)
-            self.updateState(State.DISENGAGED)
-        self._previous_tid = current
-
-    def _set_new_target_tid(self, new_target_tid):
-        # assumes nextTrainToCatch gets set *after* this function
-        current_target_tid = self.get("nextTrainToCatch")
+            # past capture range
+            if self._trains_to_get:
+                # note: wouuld also get triggered by receiving the same train twice
+                self.log.ERROR(f"Missed some train(s): {self._trains_to_get}")
+                self.updateState(State.ERROR)
+                self._trains_to_get.clear()
+            elif state not in (State.PASSIVE, State.ERROR):
+                self.log.INFO(f"Weird state: {state}; admonish the developer!")
+                self.updateState(State.PASSIVE)
+        self._previous_tid = current_tid
+
+    def _update_target(self, new_target_tid):
+        # assumes nextTrainToCatch and numberOfTrainsToCatch have been set
+        new_target_tid = self.get("nextTrainToCatch")
+        self._trains_to_get = set(
+            range(new_target_tid, new_target_tid + self.get("numberOfTrainsToCatch"))
+        )
         if self._previous_tid >= new_target_tid:
             self.log.INFO(
                 f"Moved target train to {new_target_tid} even though last seen was "
-                f"{self._previous_tid} - will not be able to retroactively catch this."
+                f"{self._previous_tid} - will miss some trains!"
             )
-            self.updateState(State.ERROR)
         else:
-            if current_target_tid < new_target_tid:
+            if self._old_target_tid < new_target_tid:
                 self.log.INFO(
-                    f"Moved target train from {current_target_tid} to {new_target_tid}"
-                    f"even though last seen was {self._previous_tid} "
-                    f"effectively skipping {current_target_tid}"
+                    f"Moved target train from {self._old_target_tid} to "
+                    f"{new_target_tid} (last received was {self._previous_tid}), "
+                    f"effectively skipping {self._old_target_tid}"
                 )
             self.updateState(State.MONITORING)
 
     def preReconfigure(self, config):
         super().preReconfigure(config)
         if config.has("nextTrainToCatch"):
-            self._set_new_target_tid(config["nextTrainToCatch"])
+            self._old_target_tid = self.get("nextTrainToCatch")
+        self._cached_update = config
+
+    def postReconfigure(self):
+        super().postReconfigure()
+        if not hasattr(self, "_cached_update"):
+            self.log.WARN("postReconfigure update caching trick failed")
+            return
+
+        if (
+                self._cached_update.has("nextTrainToCatch")
+                or self._cached_update.has("numberOfTrainsToCatch")
+        ):
+            self._update_target()
 
 
 def hash_to_schema(h, root=None, prefix=""):
-- 
GitLab