Skip to content
Snippets Groups Projects
Commit beb32df5 authored by David Hammer's avatar David Hammer
Browse files

Fix reconfigure bug and prevent sources from being sorted

parent a79b501a
No related branches found
No related tags found
2 merge requests!10DetectorAssembler: assemble with extra shape (multiple frames),!9Stacking shmem matcher
......@@ -13,6 +13,7 @@ from karabo.bound import (
Hash,
Schema,
State,
VectorHash,
)
from TrainMatcher import TrainMatcher
......@@ -106,8 +107,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
"While source stacking is optimized and can use thread pool, key "
"stacking will iterate over all paths in matched sources and naively "
"call np.stack for each key pattern. In either case, data that is used "
"for stacking is removed from its original location (e.g. key is erased "
"from hash)."
"for stacking is removed from its original location (e.g. key is "
"erased from hash)."
)
.setColumns(merge_schema())
.assignmentOptional()
......@@ -139,7 +140,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._source_stacking_indices = {}
self._source_stacking_sources = {}
self._key_stacking_sources = {}
self._prepare_merge_groups(self.get("merge"))
self._have_prepared_merge_groups = False
self._prepare_merge_groups()
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
if self.get("useThreadPool"):
......@@ -154,7 +156,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def preReconfigure(self, conf):
super().preReconfigure(conf)
if conf.has("merge") or conf.has("sources"):
self._prepare_merge_groups(conf["merge"])
self._have_prepared_merge_groups = False
# re-prepare in postReconfigure after sources *and* merge are in self
if conf.has("useThreadPool"):
if self._thread_pool is not None:
self._thread_pool.shutdown()
......@@ -167,11 +170,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
else:
self.output = None
def _prepare_merge_groups(self, merge):
def postReconfigure(self):
super().postReconfigure()
if not self._have_prepared_merge_groups:
self._prepare_merge_groups()
def _prepare_merge_groups(self):
source_group_patterns = []
key_group_patterns = []
# split by type, prepare regexes
for row in merge:
for row in self.get("merge"):
if not row["select"]:
continue
group_type = MergeGroupType(row["type"])
......@@ -225,6 +233,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._key_stacking_sources.setdefault(source, []).append(
(key_re, new_key)
)
self._have_prepared_merge_groups = True
def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
# TODO: handle ValueError for max of empty sequence
......@@ -334,3 +343,36 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _maybe_connect_data(self, conf, update=False, state=None):
"""Temporary override on _maybe_connect_data to avoid sorting sources list (we
need it for stacking order)"""
if self["state"] not in (State.CHANGING, State.ACTIVE):
return
last_state = self["state"]
self.updateState(State.CHANGING)
# unwatch removed sources
def src_names(c):
# do not assign a lambda expression, use a def
return {s["source"] for s in c["sources"]}
for source in src_names(self).difference(src_names(conf)):
self.monitor.unwatch_source(source)
new_conf = VectorHash()
for src in conf["sources"]:
source = src["source"]
if src["select"]:
src["status"] = self.monitor.watch_source(source, src["offset"])
else:
self.monitor.unwatch_source(source)
src["status"] = ""
new_conf.append(src)
if update:
self.set("sources", new_conf)
else:
conf["sources"] = new_conf
self.updateState(state or last_state)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment