Skip to content
Snippets Groups Projects
Commit c3e1c48f authored by David Hammer's avatar David Hammer
Browse files

WIP: restructure / simplify stacking execution

parent ca46ba88
No related branches found
No related tags found
2 merge requests!74Refactor DetectorAssembler,!73Refactor stacking for reuse and overlappability
......@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def initialization(self):
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources"))
self._stacking_friend = StackingFriend(
self, self.get("merge"), self.get("sources")
)
self._frameselection_friend = FrameselectionFriend(self.get("frameSelector"))
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self.get("processingThreads")
......@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def on_matched_data(self, train_id, sources):
frame_selection_mask = self._frameselection_friend.get_mask(sources)
# note: should not do stacking and frame selection for now!
self._stacking_friend.prepare_stacking_for_train(sources)
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source,
source,
data,
timestamp,
new_sources_map,
frame_selection_mask,
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(new_sources_map)
with self._stacking_friend.stacking_context as stacker:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source,
source,
data,
timestamp,
stacker,
frame_selection_mask,
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(stacker.new_source_map)
# karabo output
if self.output is not None:
......@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
source,
data_hash,
timestamp,
new_sources_map,
stacker,
frame_selection_mask,
ignore_stacking,
):
self._shmem_handler.dereference_shmem_handles(data_hash)
self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask)
self._stacking_friend.handle_source(...)
stacker.process(source, data_hash)
import collections
import enum
import re
......@@ -79,6 +80,22 @@ def merge_schema():
.defaultValue(0)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("missingValue")
.displayedName("Missing value default")
.description(
"If some sources are missing within one group in multi-source stacking*, "
"the corresponding parts of the resulting stacked array will be set to "
"this value. Note that if no sources are present, the array is not created "
"at all. This field is a string to allow special values like float nan / "
"inf; it is your responsibility to make sure that data types match (i.e. "
"if this is 'nan', the stacked data better be floats or doubles). *Missing "
"value handling is not yet implementedo for multi-key stacking."
)
.assignmentOptional()
.defaultValue("0")
.reconfigurable()
.commit(),
)
return schema
......@@ -121,147 +138,132 @@ class StackingFriend:
.commit(),
)
def __init__(self, merge_config, source_config):
self._stacking_buffers = {}
def __init__(self, device, source_config, merge_config):
self._source_stacking_indices = {}
self._source_stacking_sources = {}
self._source_stacking_sources = collections.defaultdict(list)
self._source_stacking_group_sizes = {}
self._key_stacking_sources = {}
self._merge_config = Hash()
self._source_config = Hash()
# (new source name, key) -> {original sources used}
self._new_sources_inputs = collections.defaultdict(set)
self._key_stacking_sources = collections.defaultdict(list)
self._merge_config = None
self._source_config = None
self._device = device
self.reconfigure(merge_config, source_config)
def reconfigure(self, merge_config, source_config):
print("merge_config", type(merge_config))
print("source_config", type(source_config))
if merge_config is not None:
self._merge_config.merge(merge_config)
self._merge_config = merge_config
if source_config is not None:
self._source_config.merge(source_config)
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self._source_config]
self._stacking_buffers.clear()
self._source_config = source_config
self._source_stacking_indices.clear()
self._source_stacking_sources.clear()
self._source_stacking_group_sizes.clear()
self._key_stacking_sources.clear()
# split by type, prepare regexes
for row in self._merge_config:
if not row["select"]:
continue
group_type = GroupType(row["groupType"])
self._new_sources_inputs.clear()
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self._source_config]
source_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name
]
key_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTIKEY.name
]
for row in source_stacking_groups:
source_re = re.compile(row["sourcePattern"])
merge_method = MergeMethod(row["mergeMethod"])
axis = row["axis"]
if group_type is GroupType.MULTISOURCE:
key = row["keyPattern"]
new_source = row["replacement"]
merge_sources = [
source for source in source_names if source_re.match(source)
]
if len(merge_sources) == 0:
self.log.WARN(
f"Group pattern {source_re} did not match any known sources"
)
continue
self._source_stacking_group_sizes[(new_source, key)] = len(
merge_sources
key = row["keyPattern"]
new_source = row["replacement"]
merge_sources = [
source for source in source_names if source_re.match(source)
]
if len(merge_sources) == 0:
self._device.log.WARN(
f"Group pattern {source_re} did not match any known sources"
)
for i, source in enumerate(merge_sources):
self._source_stacking_sources.setdefault(source, []).append(
(key, new_source, merge_method, axis)
)
self._source_stacking_indices[(source, new_source, key)] = (
i
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
i,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
]
)
else:
key_re = re.compile(row["keyPattern"])
new_key = row["replacement"]
self._key_stacking_sources.setdefault(source, []).append(
(key_re, new_key, merge_method, axis)
continue
self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources)
for i, source in enumerate(merge_sources):
self._source_stacking_sources[source].append(
(key, new_source, merge_method, row["axis"])
)
self._source_stacking_indices[(source, new_source, key)] = (
i
if merge_method is MergeMethod.STACK
else np.index_exp[ # interleaving
slice(
i,
None,
len(merge_sources),
)
]
)
def prepare_stacking_for_train(
self, sources, frame_selection_mask, new_sources_map
):
if frame_selection_mask is not None:
orig_size = len(frame_selection_mask)
result_size = np.sum(frame_selection_mask)
for row in key_stacking_groups:
key_re = re.compile(row["keyPattern"])
new_key = row["replacement"]
self._key_stacking_sources[source].append(
(key_re, new_key, MergeMethod(row["mergeMethod"]), row["axis"])
)
def process(self, sources, thread_pool=None):
stacking_data_shapes = {}
self._ignore_stacking = {}
self._fill_missed_data = {}
for source, keys in self._source_stacking_sources.items():
if source not in sources:
for key, new_source, _, _ in keys:
missed_sources = self._fill_missed_data.setdefault(
(new_source, key), []
stacking_buffers = {}
new_source_map = collections.defaultdict(Hash)
missing_value_defaults = {}
# prepare for source stacking where sources are present
source_set = set(sources.keys())
for (
new_source,
data_key,
merge_method,
group_size,
axis,
missing_value,
), original_sources in self._new_sources_inputs.items():
for present_source in source_set & original_sources:
data = sources[present_source].get(data_key)[0]
if data is None:
continue
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
data.shape, group_size, axis=axis
)
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
missed_sources.append(merge_index)
continue
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None
and self._frame_selection_source_pattern.match(source)
)
for key, new_source, merge_method, axis in keys:
merge_data_shape = None
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
else:
self._ignore_stacking[(new_source, key)] = "Some data is missed"
continue
if filtering and key in self._frame_selection_data_keys:
# !!! stacking is not expected to be used with filtering
if merge_data_shape[0] == orig_size:
merge_data_shape = (result_size,) + merge_data.shape[1:]
(
expected_shape,
_,
_,
expected_dtype,
_,
) = stacking_data_shapes.setdefault(
(new_source, key),
(merge_data_shape, merge_method, axis, merge_data.dtype, timestamp),
)
if (
expected_shape != merge_data_shape
or expected_dtype != merge_data.dtype
):
self._ignore_stacking[
(new_source, key)
] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
for (new_source, key), msg in self._ignore_stacking.items():
self._device.log.WARN(f"Failed to stack {new_source}.{key}: {msg}")
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = self._source_stacking_group_sizes[(new_source, key)]
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
merge_data_shape, group_size, axis=axis
expected_shape = utils.interleaving_buffer_shape(
data.shape, group_size, axis=axis
)
stacking_buffer = np.empty(
expected_shape, dtype=data.dtype
)
stacking_buffers[(new_source, data_key)] = stacking_buffer
new_source_map[new_source][data_key] = stacking_buffer
try:
missing_value_defaults[(new_source, data_key)] = data.dtype.type(
missing_value
)
except ValueError:
self._device.log.WARN(
f"Invalid missing data value for {new_source}.{data_key}, using 0"
)
break
else:
expected_shape = utils.interleaving_buffer_shape(
merge_data_shape, group_size, axis=axis
# in this case: no present_source (if any) had data_key
self._device.log.WARN(
f"No sources needed for {new_source}.{data_key} were present"
)
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = parent._source_stacking_group_sizes[(new_source, key)]
merge_buffer = self._stacking_buffers.get((new_source, key))
if merge_buffer is None or merge_buffer.shape != expected_shape:
merge_buffer = np.empty(shape=expected_shape, dtype=dtype)
......@@ -277,19 +279,33 @@ class StackingFriend:
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
parent._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if new_source not in self.new_source_map:
self.new_source_map[new_source] = (Hash(), timestamp)
new_source_hash = self.new_source_map[new_source][0]
if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer
def handle_source(self, source, data_hash):
# stack across sources (many sources, same key)
# now actually do some work
fun = functools.partial(self._handle_source, ...)
if thread_pool is None:
for _ in map(fun, ...):
pass
else:
concurrent.futures.wait(thread_pool.map(fun, ...))
def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source):
"""Helper function used in processing. Note that it should be called for each
source that was supposed to be there - so it can decide whether to move data in
for stacking, fill missing data in case a source is missing, or skip in case no
buffer was created (none of the necessary sources were present)."""
if expected_source not in actual_sources:
if ex
for (
key,
new_source,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment