Skip to content
Snippets Groups Projects
Commit c3e1c48f authored by David Hammer's avatar David Hammer
Browse files

WIP: restructure / simplify stacking execution

parent ca46ba88
No related branches found
No related tags found
2 merge requests!74Refactor DetectorAssembler,!73Refactor stacking for reuse and overlappability
...@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def initialization(self): def initialization(self):
super().initialization() super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources")) self._stacking_friend = StackingFriend(
self, self.get("merge"), self.get("sources")
)
self._frameselection_friend = FrameselectionFriend(self.get("frameSelector")) self._frameselection_friend = FrameselectionFriend(self.get("frameSelector"))
self._thread_pool = concurrent.futures.ThreadPoolExecutor( self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self.get("processingThreads") max_workers=self.get("processingThreads")
...@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def on_matched_data(self, train_id, sources): def on_matched_data(self, train_id, sources):
frame_selection_mask = self._frameselection_friend.get_mask(sources) frame_selection_mask = self._frameselection_friend.get_mask(sources)
# note: should not do stacking and frame selection for now! # note: should not do stacking and frame selection for now!
self._stacking_friend.prepare_stacking_for_train(sources) with self._stacking_friend.stacking_context as stacker:
concurrent.futures.wait(
concurrent.futures.wait( [
[ self._thread_pool.submit(
self._thread_pool.submit( self._handle_source,
self._handle_source, source,
source, data,
data, timestamp,
timestamp, stacker,
new_sources_map, frame_selection_mask,
frame_selection_mask, )
) for source, (data, timestamp) in sources.items()
for source, (data, timestamp) in sources.items() ]
] )
) sources.update(stacker.new_source_map)
sources.update(new_sources_map)
# karabo output # karabo output
if self.output is not None: if self.output is not None:
...@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
source, source,
data_hash, data_hash,
timestamp, timestamp,
new_sources_map, stacker,
frame_selection_mask, frame_selection_mask,
ignore_stacking,
): ):
self._shmem_handler.dereference_shmem_handles(data_hash) self._shmem_handler.dereference_shmem_handles(data_hash)
self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask) self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask)
self._stacking_friend.handle_source(...) stacker.process(source, data_hash)
import collections
import enum import enum
import re import re
...@@ -79,6 +80,22 @@ def merge_schema(): ...@@ -79,6 +80,22 @@ def merge_schema():
.defaultValue(0) .defaultValue(0)
.reconfigurable() .reconfigurable()
.commit(), .commit(),
STRING_ELEMENT(schema)
.key("missingValue")
.displayedName("Missing value default")
.description(
"If some sources are missing within one group in multi-source stacking*, "
"the corresponding parts of the resulting stacked array will be set to "
"this value. Note that if no sources are present, the array is not created "
"at all. This field is a string to allow special values like float nan / "
"inf; it is your responsibility to make sure that data types match (i.e. "
"if this is 'nan', the stacked data better be floats or doubles). *Missing "
"value handling is not yet implementedo for multi-key stacking."
)
.assignmentOptional()
.defaultValue("0")
.reconfigurable()
.commit(),
) )
return schema return schema
...@@ -121,147 +138,132 @@ class StackingFriend: ...@@ -121,147 +138,132 @@ class StackingFriend:
.commit(), .commit(),
) )
def __init__(self, merge_config, source_config): def __init__(self, device, source_config, merge_config):
self._stacking_buffers = {}
self._source_stacking_indices = {} self._source_stacking_indices = {}
self._source_stacking_sources = {} self._source_stacking_sources = collections.defaultdict(list)
self._source_stacking_group_sizes = {} self._source_stacking_group_sizes = {}
self._key_stacking_sources = {} # (new source name, key) -> {original sources used}
self._merge_config = Hash() self._new_sources_inputs = collections.defaultdict(set)
self._source_config = Hash() self._key_stacking_sources = collections.defaultdict(list)
self._merge_config = None
self._source_config = None
self._device = device
self.reconfigure(merge_config, source_config) self.reconfigure(merge_config, source_config)
def reconfigure(self, merge_config, source_config): def reconfigure(self, merge_config, source_config):
print("merge_config", type(merge_config))
print("source_config", type(source_config))
if merge_config is not None: if merge_config is not None:
self._merge_config.merge(merge_config) self._merge_config = merge_config
if source_config is not None: if source_config is not None:
self._source_config.merge(source_config) self._source_config = source_config
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self._source_config]
self._stacking_buffers.clear()
self._source_stacking_indices.clear() self._source_stacking_indices.clear()
self._source_stacking_sources.clear() self._source_stacking_sources.clear()
self._source_stacking_group_sizes.clear() self._source_stacking_group_sizes.clear()
self._key_stacking_sources.clear() self._key_stacking_sources.clear()
# split by type, prepare regexes self._new_sources_inputs.clear()
for row in self._merge_config:
if not row["select"]: # not filtering by row["select"] to allow unselected sources to create gaps
continue source_names = [row["source"].partition("@")[0] for row in self._source_config]
group_type = GroupType(row["groupType"]) source_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name
]
key_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTIKEY.name
]
for row in source_stacking_groups:
source_re = re.compile(row["sourcePattern"]) source_re = re.compile(row["sourcePattern"])
merge_method = MergeMethod(row["mergeMethod"]) merge_method = MergeMethod(row["mergeMethod"])
axis = row["axis"] key = row["keyPattern"]
if group_type is GroupType.MULTISOURCE: new_source = row["replacement"]
key = row["keyPattern"] merge_sources = [
new_source = row["replacement"] source for source in source_names if source_re.match(source)
merge_sources = [ ]
source for source in source_names if source_re.match(source) if len(merge_sources) == 0:
] self._device.log.WARN(
if len(merge_sources) == 0: f"Group pattern {source_re} did not match any known sources"
self.log.WARN(
f"Group pattern {source_re} did not match any known sources"
)
continue
self._source_stacking_group_sizes[(new_source, key)] = len(
merge_sources
) )
for i, source in enumerate(merge_sources): continue
self._source_stacking_sources.setdefault(source, []).append( self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources)
(key, new_source, merge_method, axis) for i, source in enumerate(merge_sources):
) self._source_stacking_sources[source].append(
self._source_stacking_indices[(source, new_source, key)] = ( (key, new_source, merge_method, row["axis"])
i )
if merge_method is MergeMethod.STACK self._source_stacking_indices[(source, new_source, key)] = (
else np.index_exp[ i
slice( if merge_method is MergeMethod.STACK
i, else np.index_exp[ # interleaving
None, slice(
self._source_stacking_group_sizes[(new_source, key)], i,
) None,
] len(merge_sources),
) )
else: ]
key_re = re.compile(row["keyPattern"])
new_key = row["replacement"]
self._key_stacking_sources.setdefault(source, []).append(
(key_re, new_key, merge_method, axis)
) )
def prepare_stacking_for_train( for row in key_stacking_groups:
self, sources, frame_selection_mask, new_sources_map key_re = re.compile(row["keyPattern"])
): new_key = row["replacement"]
if frame_selection_mask is not None: self._key_stacking_sources[source].append(
orig_size = len(frame_selection_mask) (key_re, new_key, MergeMethod(row["mergeMethod"]), row["axis"])
result_size = np.sum(frame_selection_mask) )
def process(self, sources, thread_pool=None):
stacking_data_shapes = {} stacking_data_shapes = {}
self._ignore_stacking = {} stacking_buffers = {}
self._fill_missed_data = {} new_source_map = collections.defaultdict(Hash)
for source, keys in self._source_stacking_sources.items(): missing_value_defaults = {}
if source not in sources:
for key, new_source, _, _ in keys: # prepare for source stacking where sources are present
missed_sources = self._fill_missed_data.setdefault( source_set = set(sources.keys())
(new_source, key), [] for (
new_source,
data_key,
merge_method,
group_size,
axis,
missing_value,
), original_sources in self._new_sources_inputs.items():
for present_source in source_set & original_sources:
data = sources[present_source].get(data_key)[0]
if data is None:
continue
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
data.shape, group_size, axis=axis
) )
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
missed_sources.append(merge_index)
continue
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None
and self._frame_selection_source_pattern.match(source)
)
for key, new_source, merge_method, axis in keys:
merge_data_shape = None
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
else: else:
self._ignore_stacking[(new_source, key)] = "Some data is missed" expected_shape = utils.interleaving_buffer_shape(
continue data.shape, group_size, axis=axis
)
if filtering and key in self._frame_selection_data_keys: stacking_buffer = np.empty(
# !!! stacking is not expected to be used with filtering expected_shape, dtype=data.dtype
if merge_data_shape[0] == orig_size:
merge_data_shape = (result_size,) + merge_data.shape[1:]
(
expected_shape,
_,
_,
expected_dtype,
_,
) = stacking_data_shapes.setdefault(
(new_source, key),
(merge_data_shape, merge_method, axis, merge_data.dtype, timestamp),
)
if (
expected_shape != merge_data_shape
or expected_dtype != merge_data.dtype
):
self._ignore_stacking[
(new_source, key)
] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
for (new_source, key), msg in self._ignore_stacking.items():
self._device.log.WARN(f"Failed to stack {new_source}.{key}: {msg}")
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = self._source_stacking_group_sizes[(new_source, key)]
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
merge_data_shape, group_size, axis=axis
) )
stacking_buffers[(new_source, data_key)] = stacking_buffer
new_source_map[new_source][data_key] = stacking_buffer
try:
missing_value_defaults[(new_source, data_key)] = data.dtype.type(
missing_value
)
except ValueError:
self._device.log.WARN(
f"Invalid missing data value for {new_source}.{data_key}, using 0"
)
break
else: else:
expected_shape = utils.interleaving_buffer_shape( # in this case: no present_source (if any) had data_key
merge_data_shape, group_size, axis=axis self._device.log.WARN(
f"No sources needed for {new_source}.{data_key} were present"
) )
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = parent._source_stacking_group_sizes[(new_source, key)]
merge_buffer = self._stacking_buffers.get((new_source, key)) merge_buffer = self._stacking_buffers.get((new_source, key))
if merge_buffer is None or merge_buffer.shape != expected_shape: if merge_buffer is None or merge_buffer.shape != expected_shape:
merge_buffer = np.empty(shape=expected_shape, dtype=dtype) merge_buffer = np.empty(shape=expected_shape, dtype=dtype)
...@@ -277,19 +279,33 @@ class StackingFriend: ...@@ -277,19 +279,33 @@ class StackingFriend:
slice( slice(
merge_index, merge_index,
None, None,
self._source_stacking_group_sizes[(new_source, key)], parent._source_stacking_group_sizes[(new_source, key)],
) )
], ],
axis, axis,
) )
if new_source not in new_sources_map: if new_source not in self.new_source_map:
new_sources_map[new_source] = (Hash(), timestamp) self.new_source_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0] new_source_hash = self.new_source_map[new_source][0]
if not new_source_hash.has(key): if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer new_source_hash[key] = merge_buffer
def handle_source(self, source, data_hash): # now actually do some work
# stack across sources (many sources, same key) fun = functools.partial(self._handle_source, ...)
if thread_pool is None:
for _ in map(fun, ...):
pass
else:
concurrent.futures.wait(thread_pool.map(fun, ...))
def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source):
"""Helper function used in processing. Note that it should be called for each
source that was supposed to be there - so it can decide whether to move data in
for stacking, fill missing data in case a source is missing, or skip in case no
buffer was created (none of the necessary sources were present)."""
if expected_source not in actual_sources:
if ex
for ( for (
key, key,
new_source, new_source,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment