Skip to content
Snippets Groups Projects
Commit 193264ed authored by David Hammer's avatar David Hammer
Browse files

Adding Philipp's proposed changes to TrainMatcher for merging

parent f07657ce
No related branches found
2 merge requests!10DetectorAssembler: assemble with extra shape (multiple frames),!9Stacking shmem matcher
from karabo.bound import KARABO_CLASSINFO
import re
import numpy as np
from karabo.bound import (
BOOL_ELEMENT,
KARABO_CLASSINFO,
STRING_ELEMENT,
TABLE_ELEMENT,
Hash,
Schema,
State,
)
from TrainMatcher import TrainMatcher
from . import shmem_utils
from ._version import version as deviceVersion
def merge_schema():
schema = Schema()
(
BOOL_ELEMENT(schema)
.key("select")
.displayedName("Select")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("source_pattern")
.displayedName("Source pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("key_pattern")
.displayedName("Key pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("replacement")
.displayedName("Replacement")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
)
return schema
@KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion)
class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
@staticmethod
def expectedParameters(expected):
(
TABLE_ELEMENT(expected)
.key("merge")
.displayedName("Array merging")
.allowedStates(State.PASSIVE)
.description(
"List source or key patterns to merge their data arrays, e.g. to "
"combine multiple detector sources or digitizer channels into a single "
"source or key. Both source or key patterns may be regular expressions,"
" but only one may have multiple matches at the same time. The merged "
"source or key is substituted by the replacement value."
)
.setColumns(merge_schema())
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
)
def initialization(self):
self._compile_merge_patterns(self.get("merge"))
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
def preReconfigure(self, conf):
super().preReconfigure(conf)
if conf.has("merge"):
self._compile_merge_patterns(conf["merge"])
def _compile_merge_patterns(self, merge):
self._merge_patterns = [
(
re.compile(row["source_pattern"]),
re.compile(row["key_pattern"]),
row["replacement"],
)
for row in merge
if row["select"]
]
def on_matched_data(self, train_id, sources):
# dereference calng shmem handles
for source, (data, timestamp) in sources.items():
if data.has("calngShmemPaths"):
shmem_paths = list(data["calngShmemPaths"])
......@@ -23,4 +112,70 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
dereferenced = self._shmem_handler.get(data[shmem_path])
data[shmem_path] = dereferenced
# merge arrays
for source_re, key_re, replacement in self._merge_patterns:
# Find all sources matching the source pattern.
merge_sources = [
source for source in sources.keys() if source_re.match(source)
]
if len(merge_sources) > 1:
# More than one source match, merge by source.
if key_re.pattern in sources[merge_sources[0]][0]:
# Short-circuit the pattern for performance if itself is a key.
new_key = key_re.pattern
else:
# Find the first key matching the pattern.
for new_key in sources[merge_sources[0]][0].paths():
if key_re.match(new_key):
break
merge_keys = [new_key]
new_source = replacement
to_merge = [
sources[source][0][new_key] for source in sorted(merge_sources)
]
if len(to_merge) != len(merge_sources):
# Make sure all matched sources contain the key.
break
elif len(merge_sources) == 1:
# Exactly one source match, merge by key.
new_source = merge_sources[0]
new_key = replacement
merge_keys = [
key for key in sources[new_source][0].paths() if key_re.match(key)
]
if not merge_keys:
# No key match, ignore.
continue
to_merge = [sources[new_source][0][key] for key in sorted(merge_keys)]
else:
# No source match, ignore.
continue
# Stack data and insert into source data.
try:
new_data = np.stack(to_merge, axis=0)
except ValueError as e:
self.log.ERROR(
f"Failed to merge data for " f"{new_source}.{new_key}: {e}"
)
continue
sources.setdefault(new_source, (Hash(), sources[merge_sources[0]][1]))[0][
new_key
] = new_data
# Unset keys merged together across all source matches.
for source in merge_sources:
for key in merge_keys:
sources[source][0].erase(key)
super().on_matched_data(train_id, sources)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment