From beb32df59a994f132618320d0d30be6d6ed8b8d4 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Fri, 13 May 2022 13:25:10 +0200
Subject: [PATCH] Fix reconfigure bug and prevent sources from being sorted

---
 src/calng/ShmemTrainMatcher.py | 54 ++++++++++++++++++++++++++++++----
 1 file changed, 48 insertions(+), 6 deletions(-)

diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index c77240fe..a3f919ee 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -13,6 +13,7 @@ from karabo.bound import (
     Hash,
     Schema,
     State,
+    VectorHash,
 )
 from TrainMatcher import TrainMatcher
 
@@ -106,8 +107,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                 "While source stacking is optimized and can use thread pool, key "
                 "stacking will iterate over all paths in matched sources and naively "
                 "call np.stack for each key pattern. In either case, data that is used "
-                "for stacking is removed from its original location (e.g. key is erased "
-                "from hash)."
+                "for stacking is removed from its original location (e.g. key is "
+                "erased from hash)."
             )
             .setColumns(merge_schema())
             .assignmentOptional()
@@ -139,7 +140,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         self._source_stacking_indices = {}
         self._source_stacking_sources = {}
         self._key_stacking_sources = {}
-        self._prepare_merge_groups(self.get("merge"))
+        self._have_prepared_merge_groups = False
+        self._prepare_merge_groups()
         super().initialization()
         self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
         if self.get("useThreadPool"):
@@ -154,7 +156,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
     def preReconfigure(self, conf):
         super().preReconfigure(conf)
         if conf.has("merge") or conf.has("sources"):
-            self._prepare_merge_groups(conf["merge"])
+            self._have_prepared_merge_groups = False
+            # re-prepare in postReconfigure after sources *and* merge are in self
         if conf.has("useThreadPool"):
             if self._thread_pool is not None:
                 self._thread_pool.shutdown()
@@ -167,11 +170,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
             else:
                 self.output = None
 
-    def _prepare_merge_groups(self, merge):
+    def postReconfigure(self):
+        super().postReconfigure()
+        if not self._have_prepared_merge_groups:
+            self._prepare_merge_groups()
+
+    def _prepare_merge_groups(self):
         source_group_patterns = []
         key_group_patterns = []
         # split by type, prepare regexes
-        for row in merge:
+        for row in self.get("merge"):
             if not row["select"]:
                 continue
             group_type = MergeGroupType(row["type"])
@@ -225,6 +233,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                 self._key_stacking_sources.setdefault(source, []).append(
                     (key_re, new_key)
                 )
+        self._have_prepared_merge_groups = True
 
     def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
         # TODO: handle ValueError for max of empty sequence
@@ -334,3 +343,36 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         self.info["sent"] += 1
         self.info["trainId"] = train_id
         self.rate_out.update()
+
+    def _maybe_connect_data(self, conf, update=False, state=None):
+        """Temporary override on _maybe_connect_data to avoid sorting sources list (we
+        need it for stacking order)"""
+        if self["state"] not in (State.CHANGING, State.ACTIVE):
+            return
+
+        last_state = self["state"]
+        self.updateState(State.CHANGING)
+
+        # unwatch removed sources
+        def src_names(c):
+            # do not assign a lambda expression, use a def
+            return {s["source"] for s in c["sources"]}
+
+        for source in src_names(self).difference(src_names(conf)):
+            self.monitor.unwatch_source(source)
+
+        new_conf = VectorHash()
+        for src in conf["sources"]:
+            source = src["source"]
+            if src["select"]:
+                src["status"] = self.monitor.watch_source(source, src["offset"])
+            else:
+                self.monitor.unwatch_source(source)
+                src["status"] = ""
+            new_conf.append(src)
+
+        if update:
+            self.set("sources", new_conf)
+        else:
+            conf["sources"] = new_conf
+        self.updateState(state or last_state)
-- 
GitLab