diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 72bab416cef4b49724935e44815745644d665c7c..301b8d040b0fac2712d89429316b5e66ee5f864f 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def initialization(self): super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() - self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources")) + self._stacking_friend = StackingFriend( + self, self.get("merge"), self.get("sources") + ) self._frameselection_friend = FrameselectionFriend(self.get("frameSelector")) self._thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=self.get("processingThreads") @@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def on_matched_data(self, train_id, sources): frame_selection_mask = self._frameselection_friend.get_mask(sources) # note: should not do stacking and frame selection for now! - self._stacking_friend.prepare_stacking_for_train(sources) - - concurrent.futures.wait( - [ - self._thread_pool.submit( - self._handle_source, - source, - data, - timestamp, - new_sources_map, - frame_selection_mask, - ) - for source, (data, timestamp) in sources.items() - ] - ) - sources.update(new_sources_map) + with self._stacking_friend.stacking_context as stacker: + concurrent.futures.wait( + [ + self._thread_pool.submit( + self._handle_source, + source, + data, + timestamp, + stacker, + frame_selection_mask, + ) + for source, (data, timestamp) in sources.items() + ] + ) + sources.update(stacker.new_source_map) # karabo output if self.output is not None: @@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): source, data_hash, timestamp, - new_sources_map, + stacker, frame_selection_mask, - ignore_stacking, ): self._shmem_handler.dereference_shmem_handles(data_hash) self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask) - self._stacking_friend.handle_source(...) + stacker.process(source, data_hash) diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py index b6d21a3d6d99a990acc82af28c7ea9735605c470..049480d4385311384c1f1159a9e5d2be727de79b 100644 --- a/src/calng/stacking_utils.py +++ b/src/calng/stacking_utils.py @@ -1,3 +1,4 @@ +import collections import enum import re @@ -79,6 +80,22 @@ def merge_schema(): .defaultValue(0) .reconfigurable() .commit(), + STRING_ELEMENT(schema) + .key("missingValue") + .displayedName("Missing value default") + .description( + "If some sources are missing within one group in multi-source stacking*, " + "the corresponding parts of the resulting stacked array will be set to " + "this value. Note that if no sources are present, the array is not created " + "at all. This field is a string to allow special values like float nan / " + "inf; it is your responsibility to make sure that data types match (i.e. " + "if this is 'nan', the stacked data better be floats or doubles). *Missing " + "value handling is not yet implementedo for multi-key stacking." + ) + .assignmentOptional() + .defaultValue("0") + .reconfigurable() + .commit(), ) return schema @@ -121,147 +138,132 @@ class StackingFriend: .commit(), ) - def __init__(self, merge_config, source_config): - self._stacking_buffers = {} + def __init__(self, device, source_config, merge_config): self._source_stacking_indices = {} - self._source_stacking_sources = {} + self._source_stacking_sources = collections.defaultdict(list) self._source_stacking_group_sizes = {} - self._key_stacking_sources = {} - self._merge_config = Hash() - self._source_config = Hash() + # (new source name, key) -> {original sources used} + self._new_sources_inputs = collections.defaultdict(set) + self._key_stacking_sources = collections.defaultdict(list) + self._merge_config = None + self._source_config = None + self._device = device self.reconfigure(merge_config, source_config) def reconfigure(self, merge_config, source_config): + print("merge_config", type(merge_config)) + print("source_config", type(source_config)) if merge_config is not None: - self._merge_config.merge(merge_config) + self._merge_config = merge_config if source_config is not None: - self._source_config.merge(source_config) - - # not filtering by row["select"] to allow unselected sources to create gaps - source_names = [row["source"].partition("@")[0] for row in self._source_config] - self._stacking_buffers.clear() + self._source_config = source_config self._source_stacking_indices.clear() self._source_stacking_sources.clear() self._source_stacking_group_sizes.clear() self._key_stacking_sources.clear() - # split by type, prepare regexes - for row in self._merge_config: - if not row["select"]: - continue - group_type = GroupType(row["groupType"]) + self._new_sources_inputs.clear() + + # not filtering by row["select"] to allow unselected sources to create gaps + source_names = [row["source"].partition("@")[0] for row in self._source_config] + source_stacking_groups = [ + row + for row in self._merge_config + if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name + ] + key_stacking_groups = [ + row + for row in self._merge_config + if row["select"] and row["groupType"] == GroupType.MULTIKEY.name + ] + + for row in source_stacking_groups: source_re = re.compile(row["sourcePattern"]) merge_method = MergeMethod(row["mergeMethod"]) - axis = row["axis"] - if group_type is GroupType.MULTISOURCE: - key = row["keyPattern"] - new_source = row["replacement"] - merge_sources = [ - source for source in source_names if source_re.match(source) - ] - if len(merge_sources) == 0: - self.log.WARN( - f"Group pattern {source_re} did not match any known sources" - ) - continue - self._source_stacking_group_sizes[(new_source, key)] = len( - merge_sources + key = row["keyPattern"] + new_source = row["replacement"] + merge_sources = [ + source for source in source_names if source_re.match(source) + ] + if len(merge_sources) == 0: + self._device.log.WARN( + f"Group pattern {source_re} did not match any known sources" ) - for i, source in enumerate(merge_sources): - self._source_stacking_sources.setdefault(source, []).append( - (key, new_source, merge_method, axis) - ) - self._source_stacking_indices[(source, new_source, key)] = ( - i - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - i, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ] - ) - else: - key_re = re.compile(row["keyPattern"]) - new_key = row["replacement"] - self._key_stacking_sources.setdefault(source, []).append( - (key_re, new_key, merge_method, axis) + continue + self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources) + for i, source in enumerate(merge_sources): + self._source_stacking_sources[source].append( + (key, new_source, merge_method, row["axis"]) + ) + self._source_stacking_indices[(source, new_source, key)] = ( + i + if merge_method is MergeMethod.STACK + else np.index_exp[ # interleaving + slice( + i, + None, + len(merge_sources), + ) + ] ) - def prepare_stacking_for_train( - self, sources, frame_selection_mask, new_sources_map - ): - if frame_selection_mask is not None: - orig_size = len(frame_selection_mask) - result_size = np.sum(frame_selection_mask) + for row in key_stacking_groups: + key_re = re.compile(row["keyPattern"]) + new_key = row["replacement"] + self._key_stacking_sources[source].append( + (key_re, new_key, MergeMethod(row["mergeMethod"]), row["axis"]) + ) + def process(self, sources, thread_pool=None): stacking_data_shapes = {} - self._ignore_stacking = {} - self._fill_missed_data = {} - for source, keys in self._source_stacking_sources.items(): - if source not in sources: - for key, new_source, _, _ in keys: - missed_sources = self._fill_missed_data.setdefault( - (new_source, key), [] + stacking_buffers = {} + new_source_map = collections.defaultdict(Hash) + missing_value_defaults = {} + + # prepare for source stacking where sources are present + source_set = set(sources.keys()) + for ( + new_source, + data_key, + merge_method, + group_size, + axis, + missing_value, + ), original_sources in self._new_sources_inputs.items(): + for present_source in source_set & original_sources: + data = sources[present_source].get(data_key)[0] + if data is None: + continue + if merge_method is MergeMethod.STACK: + expected_shape = utils.stacking_buffer_shape( + data.shape, group_size, axis=axis ) - merge_index = self._source_stacking_indices[ - (source, new_source, key) - ] - missed_sources.append(merge_index) - continue - data_hash, timestamp = sources[source] - filtering = ( - frame_selection_mask is not None - and self._frame_selection_source_pattern.match(source) - ) - for key, new_source, merge_method, axis in keys: - merge_data_shape = None - if key in data_hash: - merge_data = data_hash[key] - merge_data_shape = merge_data.shape else: - self._ignore_stacking[(new_source, key)] = "Some data is missed" - continue - - if filtering and key in self._frame_selection_data_keys: - # !!! stacking is not expected to be used with filtering - if merge_data_shape[0] == orig_size: - merge_data_shape = (result_size,) + merge_data.shape[1:] - - ( - expected_shape, - _, - _, - expected_dtype, - _, - ) = stacking_data_shapes.setdefault( - (new_source, key), - (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp), - ) - if ( - expected_shape != merge_data_shape - or expected_dtype != merge_data.dtype - ): - self._ignore_stacking[ - (new_source, key) - ] = "Shape or dtype is inconsistent" - del stacking_data_shapes[(new_source, key)] - - for (new_source, key), msg in self._ignore_stacking.items(): - self._device.log.WARN(f"Failed to stack {new_source}.{key}: {msg}") - - for (new_source, key), attr in stacking_data_shapes.items(): - merge_data_shape, merge_method, axis, dtype, timestamp = attr - group_size = self._source_stacking_group_sizes[(new_source, key)] - if merge_method is MergeMethod.STACK: - expected_shape = utils.stacking_buffer_shape( - merge_data_shape, group_size, axis=axis + expected_shape = utils.interleaving_buffer_shape( + data.shape, group_size, axis=axis + ) + stacking_buffer = np.empty( + expected_shape, dtype=data.dtype ) + stacking_buffers[(new_source, data_key)] = stacking_buffer + new_source_map[new_source][data_key] = stacking_buffer + try: + missing_value_defaults[(new_source, data_key)] = data.dtype.type( + missing_value + ) + except ValueError: + self._device.log.WARN( + f"Invalid missing data value for {new_source}.{data_key}, using 0" + ) + break else: - expected_shape = utils.interleaving_buffer_shape( - merge_data_shape, group_size, axis=axis + # in this case: no present_source (if any) had data_key + self._device.log.WARN( + f"No sources needed for {new_source}.{data_key} were present" ) + for (new_source, key), attr in stacking_data_shapes.items(): + merge_data_shape, merge_method, axis, dtype, timestamp = attr + group_size = parent._source_stacking_group_sizes[(new_source, key)] merge_buffer = self._stacking_buffers.get((new_source, key)) if merge_buffer is None or merge_buffer.shape != expected_shape: merge_buffer = np.empty(shape=expected_shape, dtype=dtype) @@ -277,19 +279,33 @@ class StackingFriend: slice( merge_index, None, - self._source_stacking_group_sizes[(new_source, key)], + parent._source_stacking_group_sizes[(new_source, key)], ) ], axis, ) - if new_source not in new_sources_map: - new_sources_map[new_source] = (Hash(), timestamp) - new_source_hash = new_sources_map[new_source][0] + if new_source not in self.new_source_map: + self.new_source_map[new_source] = (Hash(), timestamp) + new_source_hash = self.new_source_map[new_source][0] if not new_source_hash.has(key): new_source_hash[key] = merge_buffer - def handle_source(self, source, data_hash): - # stack across sources (many sources, same key) + # now actually do some work + fun = functools.partial(self._handle_source, ...) + if thread_pool is None: + for _ in map(fun, ...): + pass + else: + concurrent.futures.wait(thread_pool.map(fun, ...)) + + def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source): + """Helper function used in processing. Note that it should be called for each + source that was supposed to be there - so it can decide whether to move data in + for stacking, fill missing data in case a source is missing, or skip in case no + buffer was created (none of the necessary sources were present).""" + + if expected_source not in actual_sources: + if ex for ( key, new_source,