From 85ee451d6050f69c1b516d74ff51acb8c31bdbb5 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 30 Aug 2023 16:17:13 +0200
Subject: [PATCH] Allow PickyBoi to handle multiple data per input

---
 src/calng/PickyBoi.py | 162 ++++++++++++++++++++++++++----------------
 1 file changed, 99 insertions(+), 63 deletions(-)

diff --git a/src/calng/PickyBoi.py b/src/calng/PickyBoi.py
index e0b87dc4..72392c67 100644
--- a/src/calng/PickyBoi.py
+++ b/src/calng/PickyBoi.py
@@ -96,6 +96,21 @@ class PickyBoi(PythonDevice):
             .defaultValue("")
             .commit(),
 
+            STRING_ELEMENT(expected)
+            .key("sourceToSetSchemaFrom")
+            .assignmentOptional()
+            .defaultValue("")
+            .description(
+                "This device will set its output schema based on the data received and"
+                "forwarded. In the case of multiple input sources, this parameter "
+                "allows you to name the input source which should dictate the output "
+                "schema (which is then probably wrong for the other forwarded inputs; "
+                "this is messy either way). If left empty, schema is set based on "
+                "first seen data."
+            )
+            .reconfigurable()
+            .commit(),
+
             INT64_ELEMENT(expected)
             .key("ppuTrainOffset")
             .description(
@@ -148,6 +163,7 @@ class PickyBoi(PythonDevice):
         super().__init__(config)
         self._previous_tid = 0
         self._trains_to_get = set()  # will hold range of trains
+        self._remaining_trains = set()  # will hold subset of trains to get not seen
         self._old_target_tid = None  # just used for warnings about missing trains
 
         # manual override: forward starting from next train, whatever it is
@@ -164,7 +180,7 @@ class PickyBoi(PythonDevice):
         self.registerInitialFunction(self._initialization)
 
     def _initialization(self):
-        self.KARABO_ON_DATA("input", self.input_handler)
+        self.KARABO_ON_INPUT("input", self.input_handler)
         self._train_ratio_tracker = utils.TrainRatioTracker()
         self._rate_update_timer = utils.RepeatingTimer(
             interval=1,
@@ -215,67 +231,86 @@ class PickyBoi(PythonDevice):
         ):
             self._update_target(offset=self.get("ppuTrainOffset"))
 
-    def input_handler(self, data, meta):
-        if not self._schema_is_set:
-            schema_update = Schema()
-            (
-                OUTPUT_CHANNEL(schema_update)
-                .key("output")
-                .dataSchema(hash_to_schema(data))
-                .commit(),
-            )
-            self.updateSchema(schema_update)
-            self._schema_is_set = True
-
-        # TODO: handle multiple (consecutive) trains picked
-        state = self.get("state")
-        current_tid = Timestamp.fromHashAttributes(
-            meta.getAttributes("timestamp")
-        ).getTrainId()
-        self._train_ratio_tracker.update(current_tid)
-        # TODO: check against timeserver to handle wild future trains
-        if self._just_capture_next:
-            self.set("nextTrainToCatch", current_tid)
-            self._trains_to_get = set(
-                range(current_tid, current_tid + self.get("numberOfTrainsToCatch"))
-            )
-            self._just_capture_next = False
-        target_tid = self.get("nextTrainToCatch")
-
-        if current_tid in self._trains_to_get:
-            # capture
-            if state is not State.ACQUIRING:
-                self.updateState(State.ACQUIRING)
-            channel = self.signalSlotable.getOutputChannel("output")
-            channel.write(
-                data,
-                # TODO: forward source name or use own?
-                ChannelMetaData(
-                    meta.get("source"), Timestamp(Epochstamp(), Trainstamp(current_tid))
-                ),
-                copyAllData=False,
-            )
-            channel.update()
-            # TODO: consider buffering instead and doing something clever
-            self._trains_to_get.discard(current_tid)
-            if not self._trains_to_get:
-                self.log.INFO("Caught all target trains :D")
-                self.updateState(State.PASSIVE)
-        elif target_tid > current_tid:
-            # wait
-            if state is not State.MONITORING:
-                self.updateState(State.MONITORING)
-        else:
-            # past capture range
-            if self._trains_to_get:
-                # note: wouuld also get triggered by receiving the same train twice
-                self.log.ERROR(f"Missed some train(s): {self._trains_to_get}")
-                self.updateState(State.ERROR)
-                self._trains_to_get.clear()
-            elif state not in (State.PASSIVE, State.ERROR):
-                self.log.INFO(f"Weird state: {state}; admonish the developer!")
-                self.updateState(State.PASSIVE)
-        self._previous_tid = current_tid
+    def input_handler(self, input_channel):
+        all_metadata = input_channel.getMetaData()
+        have_written_something = False
+
+        for input_index in range(input_channel.size()):
+            state = self.get("state")
+            data = input_channel.read(input_index)
+            meta = all_metadata[input_index]
+            source = meta.get("source")
+            current_tid = Timestamp.fromHashAttributes(
+                meta.getAttributes("timestamp")
+            ).getTrainId()
+
+            try:
+                self._train_ratio_tracker.update(current_tid)
+            except utils.NonMonotonicTrainIdWarning as ex:
+                self.log.WARN(
+                    f"Train ID issue: {ex}; last I saw was {self._previous_tid}"
+                )
+
+            if not self._schema_is_set and self.get("sourceToSetSchemaFrom") in (
+                "",
+                source,
+            ):
+                schema_update = Schema()
+                (
+                    OUTPUT_CHANNEL(schema_update)
+                    .key("output")
+                    .dataSchema(hash_to_schema(data))
+                    .commit(),
+                )
+                self.updateSchema(schema_update)
+                self._schema_is_set = True
+
+            if self._just_capture_next:
+                self._old_target_tid = self.get("nextTrainToCatch")
+                self.set("nextTrainToCatch", current_tid)
+                self._update_target()
+                self._just_capture_next = False
+
+            target_tid = self.get("nextTrainToCatch")
+
+            if current_tid in self._trains_to_get:
+                # capture
+                if state is not State.ACQUIRING:
+                    self.updateState(State.ACQUIRING)
+                channel = self.signalSlotable.getOutputChannel("output")
+                channel.write(
+                    data,
+                    # TODO: forward source name or use own?
+                    ChannelMetaData(
+                        source,
+                        Timestamp(Epochstamp(), Trainstamp(current_tid)),
+                    ),
+                    copyAllData=False,
+                )
+                have_written_something = True
+                # TODO: consider buffering instead and doing something clever
+                self._remaining_trains.discard(current_tid)
+                if not self._remaining_trains:
+                    self.log.INFO("Caught all target trains :D")
+                    self.updateState(State.PASSIVE)
+            elif target_tid > current_tid:
+                # wait
+                if state is not State.MONITORING:
+                    self.updateState(State.MONITORING)
+            else:
+                # past capture range
+                if self._remaining_trains:
+                    # note: wouuld also get triggered by receiving the same train twice
+                    self.log.ERROR(f"Missed some train(s): {self._remaining_trains}")
+                    self.updateState(State.ERROR)
+                    self._remaining_trains.clear()
+                elif state not in (State.PASSIVE, State.ERROR):
+                    self.log.INFO(f"Weird state: {state}; admonish the developer!")
+                    self.updateState(State.PASSIVE)
+
+            self._previous_tid = current_tid
+            if have_written_something:
+                channel.update()
 
     def _update_target(self, offset=0):
         # assumes nextTrainToCatch and numberOfTrainsToCatch etc. have been set
@@ -290,10 +325,11 @@ class PickyBoi(PythonDevice):
             self._trains_to_get = set(
                 (
                     new_target_tid + min(
-                        self.get("numberOfTrainsToCatch")-1, index_to_forward
+                        self.get("numberOfTrainsToCatch") - 1, index_to_forward
                     ),
                 )
             )
+        self._remaining_trains = self._trains_to_get.copy()
         if self._previous_tid >= new_target_tid:
             self.log.INFO(
                 f"Moved target train to {new_target_tid} even though last seen was "
-- 
GitLab