Skip to content
Snippets Groups Projects

Refactor stacking for reuse and overlappability

Merged David Hammer requested to merge refactor-stacking into master
1 file
+ 30
41
Compare changes
  • Side-by-side
  • Inline
import concurrent.futures
import enum
import re
from timeit import default_timer
import numpy as np
from karabo.bound import (
BOOL_ELEMENT,
INT32_ELEMENT,
DOUBLE_ELEMENT,
KARABO_CLASSINFO,
NODE_ELEMENT,
OVERWRITE_ELEMENT,
STRING_ELEMENT,
TABLE_ELEMENT,
VECTOR_STRING_ELEMENT,
UINT32_ELEMENT,
ChannelMetaData,
Hash,
Schema,
MetricPrefix,
State,
Unit,
)
from TrainMatcher import TrainMatcher
from . import shmem_utils, utils
from .stacking_utils import StackingFriend
from .frameselection_utils import FrameselectionFriend
from ._version import version as deviceVersion
class GroupType(enum.Enum):
MULTISOURCE = "sources" # same key stacked from multiple sources in new source
MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key
class MergeMethod(enum.Enum):
STACK = "stack"
INTERLEAVE = "interleave"
def merge_schema():
schema = Schema()
(
BOOL_ELEMENT(schema)
.key("select")
.displayedName("Select")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("sourcePattern")
.displayedName("Source pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("keyPattern")
.displayedName("Key pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("replacement")
.displayedName("Replacement")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("groupType")
.displayedName("Group type")
.options(",".join(option.value for option in GroupType))
.assignmentOptional()
.defaultValue(GroupType.MULTISOURCE.value)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("mergeMethod")
.displayedName("Merge method")
.options(",".join(option.value for option in MergeMethod))
.assignmentOptional()
.defaultValue(MergeMethod.STACK.value)
.reconfigurable()
.commit(),
INT32_ELEMENT(schema)
.key("axis")
.displayedName("Axis")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
)
return schema
@KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion)
class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
@staticmethod
def expectedParameters(expected):
(
TABLE_ELEMENT(expected)
.key("merge")
.displayedName("Array stacking")
.allowedStates(State.PASSIVE)
.description(
"Specify which source(s) or key(s) to stack or interleave."
"When stacking sources, the 'Source pattern' is interpreted as a "
"regular expression and the 'Key pattern' is interpreted as an "
"ordinary string. From all sources matching the source pattern, the "
"data under this key (should be array with same dimensions across all "
"stacked sources) is stacked in the same order as the sources are "
"listed in 'Data sources' and the result is under the same key name in "
"a new source named by 'Replacement'. "
"When stacking keys, both the 'Source pattern' and the 'Key pattern' "
"are regular expressions. Within each source matching the source "
"pattern, all keys matching the key pattern are stacked and the result "
"is put under the key named by 'Replacement'. "
"While source stacking is optimized and can use thread pool, key "
"stacking will iterate over all paths in matched sources and naively "
"call np.stack for each key pattern. In either case, data that is used "
"for stacking is removed from its original location (e.g. key is "
"erased from hash). "
"In both cases, the data can alternatively be interleaved. This is "
"essentially equivalent to stacking except followed by a reshape such "
"that the output is shaped like concatenation."
)
.setColumns(merge_schema())
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
# order is important for stacking, disable sorting
OVERWRITE_ELEMENT(expected)
.key("sortSources")
@@ -143,20 +34,28 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.commit(),
BOOL_ELEMENT(expected)
.key("useThreadPool")
.displayedName("Use thread pool")
.key("enableKaraboOutput")
.displayedName("Enable Karabo channel")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(False)
.defaultValue(True)
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("enableKaraboOutput")
.displayedName("Enable Karabo channel")
DOUBLE_ELEMENT(expected)
.key("processingTime")
.displayedName("Processing time")
.unit(Unit.SECOND)
.metricPrefix(MetricPrefix.MILLI)
.readOnly()
.initialValue(0)
.commit(),
UINT32_ELEMENT(expected)
.key("processingThreads")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(True)
.defaultValue(16)
.reconfigurable()
.commit(),
@@ -170,374 +69,70 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.assignmentOptional()
.defaultValue(True)
.commit(),
NODE_ELEMENT(expected)
.key("frameSelector")
.displayedName("Frame selection")
.commit(),
BOOL_ELEMENT(expected)
.key("frameSelector.enable")
.displayedName("Enabled")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(expected)
.key("frameSelector.arbiterSource")
.displayedName("Arbiter source")
.description(
"Source name to pull the frame selection pattern from, must be part of "
"matched sources."
)
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(expected)
.key("frameSelector.dataSourcePattern")
.displayedName("Data source pattern")
.description(
"Source name pattern to apply frame selection to. Should match "
"subset of matched sources."
)
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
VECTOR_STRING_ELEMENT(expected)
.key("frameSelector.dataKeys")
.displayedName("Data keys")
.description("Keys in data sources to apply frame selection to.")
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
)
FrameselectionFriend.add_schema(expected)
StackingFriend.add_schema(expected)
def __init__(self, config):
if config.get("useInfiniband", default=True):
from PipeToZeroMQ.utils import find_infiniband_ip
config["output.hostname"] = find_infiniband_ip()
self._processing_time_tracker = utils.ExponentialMovingAverage(alpha=0.3)
super().__init__(config)
self.info.merge(Hash("processingTime", 0))
def initialization(self):
self._stacking_buffers = {}
self._source_stacking_indices = {}
self._source_stacking_sources = {}
self._source_stacking_group_sizes = {}
self._key_stacking_sources = {}
self._have_prepared_merge_groups = False
self._prepare_merge_groups()
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
if self.get("useThreadPool"):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
else:
self._thread_pool = None
self._stacking_friend = StackingFriend(
self, self.get("sources"), self.get("merge")
)
self._frameselection_friend = FrameselectionFriend(self.get("frameSelector"))
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self.get("processingThreads")
)
if not self.get("enableKaraboOutput"):
# it is set already by super by default, so only need to turn off
self.output = None
self._frame_selection_enabled = False
self._frame_selection_arbiter = ""
self._frame_selection_data_pattern = ""
self._frame_selection_data_keys = []
self._have_prepared_frame_selection = False
self._prepare_frame_selection()
self.start() # Auto-start this type of matcher.
def preReconfigure(self, conf):
super().preReconfigure(conf)
if conf.has("merge") or conf.has("sources"):
self._have_prepared_merge_groups = False
self._stacking_friend.reconfigure(conf.get("sources"), conf.get("merge"))
# re-prepare in postReconfigure after sources *and* merge are in self
if conf.has("useThreadPool"):
if self._thread_pool is not None:
self._thread_pool.shutdown()
self._thread_pool = None
if conf["useThreadPool"]:
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
if conf.has("frameSelector"):
self._frameselection_friend.reconfigure(conf.get("frameSelector"))
if conf.has("processingThreads"):
self._thread_pool.shutdown()
self._thread_pool = PriorityThreadPoolExecutor(
max_workers=conf.get("processingThreads")
)
if conf.has("enableKaraboOutput"):
if conf["enableKaraboOutput"]:
self.output = self._ss.getOutputChannel("output")
else:
self.output = None
if conf.has("frameSelector"):
self._have_prepared_frame_selection = False
def postReconfigure(self):
super().postReconfigure()
if not self._have_prepared_merge_groups:
self._prepare_merge_groups()
if not self._have_prepared_frame_selection:
self._prepare_frame_selection()
def _prepare_frame_selection(self):
self._frame_selection_enabled = self.get("frameSelector.enable")
self._frame_selection_arbiter = self.get("frameSelector.arbiterSource")
self._frame_selection_source_pattern = re.compile(
self.get("frameSelector.dataSourcePattern")
)
self._frame_selection_data_keys = list(self.get("frameSelector.dataKeys"))
def _prepare_merge_groups(self):
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self.get("sources")]
self._stacking_buffers.clear()
self._source_stacking_indices.clear()
self._source_stacking_sources.clear()
self._source_stacking_group_sizes.clear()
self._key_stacking_sources.clear()
# split by type, prepare regexes
for row in self.get("merge"):
if not row["select"]:
continue
group_type = GroupType(row["groupType"])
source_re = re.compile(row["sourcePattern"])
merge_method = MergeMethod(row["mergeMethod"])
axis = row["axis"]
if group_type is GroupType.MULTISOURCE:
key = row["keyPattern"]
new_source = row["replacement"]
merge_sources = [
source for source in source_names if source_re.match(source)
]
if len(merge_sources) == 0:
self.log.WARN(
f"Group pattern {source_re} did not match any known sources"
)
continue
for (i, source) in enumerate(merge_sources):
self._source_stacking_sources.setdefault(source, []).append(
(key, new_source, merge_method, axis)
)
self._source_stacking_indices[(source, new_source, key)] = i
self._source_stacking_group_sizes[(new_source, key)] = i + 1
else:
key_re = re.compile(row["keyPattern"])
new_key = row["replacement"]
self._key_stacking_sources.setdefault(source, []).append(
(key_re, new_key, merge_method, axis)
)
self._have_prepared_merge_groups = True
def _check_stacking_data(self, sources, frame_selection_mask):
if frame_selection_mask is not None:
orig_size = len(frame_selection_mask)
result_size = np.sum(frame_selection_mask)
stacking_data_shapes = {}
ignore_stacking = {}
fill_missed_data = {}
for source, keys in self._source_stacking_sources.items():
if source not in sources:
for key, new_source, _, _ in keys:
missed_sources = fill_missed_data.setdefault((new_source, key), [])
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
missed_sources.append(merge_index)
continue
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None and
self._frame_selection_source_pattern.match(source)
)
for key, new_source, merge_method, axis in keys:
merge_data_shape = None
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
else:
ignore_stacking[(new_source, key)] = "Some data is missed"
continue
if filtering and key in self._frame_selection_data_keys:
# !!! stacking is not expected to be used with filtering
if merge_data_shape[0] == orig_size:
merge_data_shape = (result_size,) + merge_data.shape[1:]
expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault(
(new_source, key),
(merge_data_shape, merge_method, axis, merge_data.dtype, timestamp)
)
if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype:
ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
return stacking_data_shapes, ignore_stacking, fill_missed_data
def _maybe_update_stacking_buffers(
self, stacking_data_shapes, fill_missed_data, new_sources_map
):
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = self._source_stacking_group_sizes[(new_source, key)]
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
merge_data_shape, group_size, axis=axis)
else:
expected_shape = utils.interleaving_buffer_shape(
merge_data_shape, group_size, axis=axis)
merge_buffer = self._stacking_buffers.get((new_source, key))
if merge_buffer is None or merge_buffer.shape != expected_shape:
merge_buffer = np.empty(
shape=expected_shape, dtype=dtype)
self._stacking_buffers[(new_source, key)] = merge_buffer
for merge_index in fill_missed_data.get((new_source, key), []):
utils.set_on_axis(
merge_buffer,
0,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
def on_matched_data(self, train_id, sources):
ts_start = default_timer()
frame_selection_mask = self._frameselection_friend.get_mask(sources)
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source,
source,
data,
timestamp,
frame_selection_mask,
)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer
def _handle_source(
self, source, data_hash, timestamp, new_sources_map, frame_selection_mask,
ignore_stacking
):
# dereference calng shmem handles
self._shmem_handler.dereference_shmem_handles(data_hash)
# apply frame_selection
if frame_selection_mask is not None and self._frame_selection_source_pattern.match(
source
):
for key in self._frame_selection_data_keys:
if not data_hash.has(key):
continue
if data_hash[key].shape[0] != frame_selection_mask.size:
self.log.WARN("Frame selection mask does not match the data size")
continue
data_hash[key] = data_hash[key][frame_selection_mask]
# stack across sources (many sources, same key)
for (
key,
new_source,
merge_method,
axis,
) in self._source_stacking_sources.get(source, ()):
if (new_source, key) in ignore_stacking:
continue
merge_data = data_hash.get(key)
merge_index = self._source_stacking_indices[
(source, new_source, key)
for source, (data, timestamp) in sources.items()
]
merge_buffer = self._stacking_buffers[(new_source, key)]
utils.set_on_axis(
merge_buffer,
merge_data,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
data_hash.erase(key)
# stack keys (multiple keys within this source)
for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get(
source, ()
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
keys = [key for key in data_hash.paths() if key_re.match(key)]
try:
# note: maybe we could reuse buffers here, too?
if merge_method is MergeMethod.STACK:
stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
else:
first = data_hash.get(keys[0])
stacked = np.empty(
shape=utils.stacking_buffer_shape(
first.shape, len(keys), axis=axis
),
dtype=first.dtype,
)
for i, key in enumerate(keys):
utils.set_on_axis(
stacked,
data_hash.get(key),
np.index_exp[slice(i, None, len(keys))],
axis,
)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in keys:
data_hash.erase(key)
data_hash[new_key] = stacked
def on_matched_data(self, train_id, sources):
new_sources_map = {}
frame_selection_mask = None
if self._frame_selection_enabled and self._frame_selection_arbiter in sources:
frame_selection_mask = np.array(
sources[self._frame_selection_arbiter][0][
"data.dataFramePattern"
], copy=False
).astype(np.bool, copy=False),
# prepare stacking
stacking_data_shapes, ignore_stacking, fill_missed_data = (
self._check_stacking_data(sources, frame_selection_mask)
)
self._maybe_update_stacking_buffers(
stacking_data_shapes, fill_missed_data, new_sources_map)
for (new_source, key), msg in ignore_stacking.items():
self.log.WARN(f"Failed to stack {new_source}.{key}: {msg}")
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(
source, data, timestamp, new_sources_map, frame_selection_mask,
ignore_stacking
)
else:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source,
source,
data,
timestamp,
new_sources_map,
frame_selection_mask,
ignore_stacking,
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(new_sources_map)
self._stacking_friend.process(sources, self._thread_pool)
# karabo output
if self.output is not None:
@@ -552,6 +147,18 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self.zmq_output.write(source, data, timestamp)
self.zmq_output.update()
self._processing_time_tracker.update(default_timer() - ts_start)
self.info["processingTime"] = self._processing_time_tracker.get() * 1000
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _handle_source(
self,
source,
data_hash,
timestamp,
frame_selection_mask,
):
self._shmem_handler.dereference_shmem_handles(data_hash)
self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask)
Loading