From beb32df59a994f132618320d0d30be6d6ed8b8d4 Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Fri, 13 May 2022 13:25:10 +0200 Subject: [PATCH] Fix reconfigure bug and prevent sources from being sorted --- src/calng/ShmemTrainMatcher.py | 54 ++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index c77240fe..a3f919ee 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -13,6 +13,7 @@ from karabo.bound import ( Hash, Schema, State, + VectorHash, ) from TrainMatcher import TrainMatcher @@ -106,8 +107,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): "While source stacking is optimized and can use thread pool, key " "stacking will iterate over all paths in matched sources and naively " "call np.stack for each key pattern. In either case, data that is used " - "for stacking is removed from its original location (e.g. key is erased " - "from hash)." + "for stacking is removed from its original location (e.g. key is " + "erased from hash)." ) .setColumns(merge_schema()) .assignmentOptional() @@ -139,7 +140,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self._source_stacking_indices = {} self._source_stacking_sources = {} self._key_stacking_sources = {} - self._prepare_merge_groups(self.get("merge")) + self._have_prepared_merge_groups = False + self._prepare_merge_groups() super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() if self.get("useThreadPool"): @@ -154,7 +156,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def preReconfigure(self, conf): super().preReconfigure(conf) if conf.has("merge") or conf.has("sources"): - self._prepare_merge_groups(conf["merge"]) + self._have_prepared_merge_groups = False + # re-prepare in postReconfigure after sources *and* merge are in self if conf.has("useThreadPool"): if self._thread_pool is not None: self._thread_pool.shutdown() @@ -167,11 +170,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): else: self.output = None - def _prepare_merge_groups(self, merge): + def postReconfigure(self): + super().postReconfigure() + if not self._have_prepared_merge_groups: + self._prepare_merge_groups() + + def _prepare_merge_groups(self): source_group_patterns = [] key_group_patterns = [] # split by type, prepare regexes - for row in merge: + for row in self.get("merge"): if not row["select"]: continue group_type = MergeGroupType(row["type"]) @@ -225,6 +233,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self._key_stacking_sources.setdefault(source, []).append( (key_re, new_key) ) + self._have_prepared_merge_groups = True def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype): # TODO: handle ValueError for max of empty sequence @@ -334,3 +343,36 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self.info["sent"] += 1 self.info["trainId"] = train_id self.rate_out.update() + + def _maybe_connect_data(self, conf, update=False, state=None): + """Temporary override on _maybe_connect_data to avoid sorting sources list (we + need it for stacking order)""" + if self["state"] not in (State.CHANGING, State.ACTIVE): + return + + last_state = self["state"] + self.updateState(State.CHANGING) + + # unwatch removed sources + def src_names(c): + # do not assign a lambda expression, use a def + return {s["source"] for s in c["sources"]} + + for source in src_names(self).difference(src_names(conf)): + self.monitor.unwatch_source(source) + + new_conf = VectorHash() + for src in conf["sources"]: + source = src["source"] + if src["select"]: + src["status"] = self.monitor.watch_source(source, src["offset"]) + else: + self.monitor.unwatch_source(source) + src["status"] = "" + new_conf.append(src) + + if update: + self.set("sources", new_conf) + else: + conf["sources"] = new_conf + self.updateState(state or last_state) -- GitLab