diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 19722de729a7cb0c55371c179c66be771f421503..8b7e4dde97a42664ebe3edccef243e22c90eed23 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -1,140 +1,29 @@ import concurrent.futures -import enum import re import numpy as np from karabo.bound import ( BOOL_ELEMENT, - INT32_ELEMENT, KARABO_CLASSINFO, NODE_ELEMENT, OVERWRITE_ELEMENT, STRING_ELEMENT, - TABLE_ELEMENT, VECTOR_STRING_ELEMENT, ChannelMetaData, - Hash, - Schema, State, ) from TrainMatcher import TrainMatcher -from . import shmem_utils, utils +from . import shmem_utils +from .stacking_utils import StackingFriend from ._version import version as deviceVersion -class GroupType(enum.Enum): - MULTISOURCE = "sources" # same key stacked from multiple sources in new source - MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key - - -class MergeMethod(enum.Enum): - STACK = "stack" - INTERLEAVE = "interleave" - - -def merge_schema(): - schema = Schema() - ( - BOOL_ELEMENT(schema) - .key("select") - .displayedName("Select") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("sourcePattern") - .displayedName("Source pattern") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("keyPattern") - .displayedName("Key pattern") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("replacement") - .displayedName("Replacement") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("groupType") - .displayedName("Group type") - .options(",".join(option.value for option in GroupType)) - .assignmentOptional() - .defaultValue(GroupType.MULTISOURCE.value) - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("mergeMethod") - .displayedName("Merge method") - .options(",".join(option.value for option in MergeMethod)) - .assignmentOptional() - .defaultValue(MergeMethod.STACK.value) - .reconfigurable() - .commit(), - - INT32_ELEMENT(schema) - .key("axis") - .displayedName("Axis") - .assignmentOptional() - .defaultValue(0) - .reconfigurable() - .commit(), - ) - - return schema - - @KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion) class ShmemTrainMatcher(TrainMatcher.TrainMatcher): @staticmethod def expectedParameters(expected): ( - TABLE_ELEMENT(expected) - .key("merge") - .displayedName("Array stacking") - .allowedStates(State.PASSIVE) - .description( - "Specify which source(s) or key(s) to stack or interleave." - "When stacking sources, the 'Source pattern' is interpreted as a " - "regular expression and the 'Key pattern' is interpreted as an " - "ordinary string. From all sources matching the source pattern, the " - "data under this key (should be array with same dimensions across all " - "stacked sources) is stacked in the same order as the sources are " - "listed in 'Data sources' and the result is under the same key name in " - "a new source named by 'Replacement'. " - "When stacking keys, both the 'Source pattern' and the 'Key pattern' " - "are regular expressions. Within each source matching the source " - "pattern, all keys matching the key pattern are stacked and the result " - "is put under the key named by 'Replacement'. " - "While source stacking is optimized and can use thread pool, key " - "stacking will iterate over all paths in matched sources and naively " - "call np.stack for each key pattern. In either case, data that is used " - "for stacking is removed from its original location (e.g. key is " - "erased from hash). " - "In both cases, the data can alternatively be interleaved. This is " - "essentially equivalent to stacking except followed by a reshape such " - "that the output is shaped like concatenation." - ) - .setColumns(merge_schema()) - .assignmentOptional() - .defaultValue([]) - .reconfigurable() - .commit(), - # order is important for stacking, disable sorting OVERWRITE_ELEMENT(expected) .key("sortSources") @@ -217,6 +106,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .reconfigurable() .commit(), ) + StackingFriend.add_schema(expected) def __init__(self, config): if config.get("useInfiniband", default=True): @@ -226,14 +116,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): super().__init__(config) def initialization(self): - self._stacking_buffers = {} - self._source_stacking_indices = {} - self._source_stacking_sources = {} - self._source_stacking_group_sizes = {} - self._key_stacking_sources = {} - self._have_prepared_merge_groups = False - self._prepare_merge_groups() super().initialization() + self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources")) self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() if self.get("useThreadPool"): self._thread_pool = concurrent.futures.ThreadPoolExecutor() @@ -256,7 +140,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def preReconfigure(self, conf): super().preReconfigure(conf) if conf.has("merge") or conf.has("sources"): - self._have_prepared_merge_groups = False + self._stacking_friend.reconfigure(conf.get("merge"), conf.get("sources")) # re-prepare in postReconfigure after sources *and* merge are in self if conf.has("useThreadPool"): if self._thread_pool is not None: @@ -274,8 +158,6 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def postReconfigure(self): super().postReconfigure() - if not self._have_prepared_merge_groups: - self._prepare_merge_groups() if not self._have_prepared_frame_selection: self._prepare_frame_selection() @@ -287,136 +169,14 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ) self._frame_selection_data_keys = list(self.get("frameSelector.dataKeys")) - def _prepare_merge_groups(self): - # not filtering by row["select"] to allow unselected sources to create gaps - source_names = [row["source"].partition("@")[0] for row in self.get("sources")] - self._stacking_buffers.clear() - self._source_stacking_indices.clear() - self._source_stacking_sources.clear() - self._source_stacking_group_sizes.clear() - self._key_stacking_sources.clear() - # split by type, prepare regexes - for row in self.get("merge"): - if not row["select"]: - continue - group_type = GroupType(row["groupType"]) - source_re = re.compile(row["sourcePattern"]) - merge_method = MergeMethod(row["mergeMethod"]) - axis = row["axis"] - if group_type is GroupType.MULTISOURCE: - key = row["keyPattern"] - new_source = row["replacement"] - merge_sources = [ - source for source in source_names if source_re.match(source) - ] - if len(merge_sources) == 0: - self.log.WARN( - f"Group pattern {source_re} did not match any known sources" - ) - continue - for (i, source) in enumerate(merge_sources): - self._source_stacking_sources.setdefault(source, []).append( - (key, new_source, merge_method, axis) - ) - self._source_stacking_indices[(source, new_source, key)] = i - self._source_stacking_group_sizes[(new_source, key)] = i + 1 - else: - key_re = re.compile(row["keyPattern"]) - new_key = row["replacement"] - self._key_stacking_sources.setdefault(source, []).append( - (key_re, new_key, merge_method, axis) - ) - - self._have_prepared_merge_groups = True - - def _check_stacking_data(self, sources, frame_selection_mask): - if frame_selection_mask is not None: - orig_size = len(frame_selection_mask) - result_size = np.sum(frame_selection_mask) - stacking_data_shapes = {} - ignore_stacking = {} - fill_missed_data = {} - for source, keys in self._source_stacking_sources.items(): - if source not in sources: - for key, new_source, _, _ in keys: - missed_sources = fill_missed_data.setdefault((new_source, key), []) - merge_index = self._source_stacking_indices[ - (source, new_source, key) - ] - missed_sources.append(merge_index) - continue - data_hash, timestamp = sources[source] - filtering = ( - frame_selection_mask is not None and - self._frame_selection_source_pattern.match(source) - ) - for key, new_source, merge_method, axis in keys: - merge_data_shape = None - if key in data_hash: - merge_data = data_hash[key] - merge_data_shape = merge_data.shape - else: - ignore_stacking[(new_source, key)] = "Some data is missed" - continue - - if filtering and key in self._frame_selection_data_keys: - # !!! stacking is not expected to be used with filtering - if merge_data_shape[0] == orig_size: - merge_data_shape = (result_size,) + merge_data.shape[1:] - - expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault( - (new_source, key), - (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp) - ) - if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype: - ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent" - del stacking_data_shapes[(new_source, key)] - - return stacking_data_shapes, ignore_stacking, fill_missed_data - - def _maybe_update_stacking_buffers( - self, stacking_data_shapes, fill_missed_data, new_sources_map - ): - for (new_source, key), attr in stacking_data_shapes.items(): - merge_data_shape, merge_method, axis, dtype, timestamp = attr - group_size = self._source_stacking_group_sizes[(new_source, key)] - if merge_method is MergeMethod.STACK: - expected_shape = utils.stacking_buffer_shape( - merge_data_shape, group_size, axis=axis) - else: - expected_shape = utils.interleaving_buffer_shape( - merge_data_shape, group_size, axis=axis) - - merge_buffer = self._stacking_buffers.get((new_source, key)) - if merge_buffer is None or merge_buffer.shape != expected_shape: - merge_buffer = np.empty( - shape=expected_shape, dtype=dtype) - self._stacking_buffers[(new_source, key)] = merge_buffer - - for merge_index in fill_missed_data.get((new_source, key), []): - utils.set_on_axis( - merge_buffer, - 0, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - if new_source not in new_sources_map: - new_sources_map[new_source] = (Hash(), timestamp) - new_source_hash = new_sources_map[new_source][0] - if not new_source_hash.has(key): - new_source_hash[key] = merge_buffer - def _handle_source( - self, source, data_hash, timestamp, new_sources_map, frame_selection_mask, - ignore_stacking + self, + source, + data_hash, + timestamp, + new_sources_map, + frame_selection_mask, + ignore_stacking, ): # dereference calng shmem handles self._shmem_handler.dereference_shmem_handles(data_hash) @@ -434,68 +194,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): data_hash[key] = data_hash[key][frame_selection_mask] - # stack across sources (many sources, same key) - for ( - key, - new_source, - merge_method, - axis, - ) in self._source_stacking_sources.get(source, ()): - if (new_source, key) in ignore_stacking: - continue - merge_data = data_hash.get(key) - merge_index = self._source_stacking_indices[ - (source, new_source, key) - ] - merge_buffer = self._stacking_buffers[(new_source, key)] - utils.set_on_axis( - merge_buffer, - merge_data, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - data_hash.erase(key) - - # stack keys (multiple keys within this source) - for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get( - source, () - ): - # note: please no overlap between different key_re - # note: if later key_re match earlier new_key, this gets spicy - keys = [key for key in data_hash.paths() if key_re.match(key)] - try: - # note: maybe we could reuse buffers here, too? - if merge_method is MergeMethod.STACK: - stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) - else: - first = data_hash.get(keys[0]) - stacked = np.empty( - shape=utils.stacking_buffer_shape( - first.shape, len(keys), axis=axis - ), - dtype=first.dtype, - ) - for i, key in enumerate(keys): - utils.set_on_axis( - stacked, - data_hash.get(key), - np.index_exp[slice(i, None, len(keys))], - axis, - ) - except Exception as e: - self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") - else: - for key in keys: - data_hash.erase(key) - data_hash[new_key] = stacked + self._stacking_friend.handle_source(...) def on_matched_data(self, train_id, sources): new_sources_map = {} @@ -508,11 +207,15 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ).astype(np.bool, copy=False), # prepare stacking - stacking_data_shapes, ignore_stacking, fill_missed_data = ( - self._check_stacking_data(sources, frame_selection_mask) + ( + stacking_data_shapes, + ignore_stacking, + fill_missed_data, + ) = self._stacking_friend.check_stacking_data(sources, frame_selection_mask) + self._stacking_friend.maybe_update_stacking_buffers( + stacking_data_shapes, fill_missed_data, new_sources_map ) - self._maybe_update_stacking_buffers( - stacking_data_shapes, fill_missed_data, new_sources_map) + for (new_source, key), msg in ignore_stacking.items(): self.log.WARN(f"Failed to stack {new_source}.{key}: {msg}") diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f453bf1c5febaaef10bff5e62ce9bf474f1631 --- /dev/null +++ b/src/calng/stacking_utils.py @@ -0,0 +1,328 @@ +import enum +import re + +from karabo.bound import ( + BOOL_ELEMENT, + INT32_ELEMENT, + STRING_ELEMENT, + TABLE_ELEMENT, + Hash, + Schema, +) + + +class GroupType(enum.Enum): + MULTISOURCE = "sources" # same key stacked from multiple sources in new source + MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key + + +class MergeMethod(enum.Enum): + STACK = "stack" + INTERLEAVE = "interleave" + + +def merge_schema(): + schema = Schema() + ( + BOOL_ELEMENT(schema) + .key("select") + .displayedName("Select") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("sourcePattern") + .displayedName("Source pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("keyPattern") + .displayedName("Key pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("replacement") + .displayedName("Replacement") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("groupType") + .displayedName("Group type") + .options(",".join(option.value for option in GroupType)) + .assignmentOptional() + .defaultValue(GroupType.MULTISOURCE.value) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("mergeMethod") + .displayedName("Merge method") + .options(",".join(option.value for option in MergeMethod)) + .assignmentOptional() + .defaultValue(MergeMethod.STACK.value) + .reconfigurable() + .commit(), + + INT32_ELEMENT(schema) + .key("axis") + .displayedName("Axis") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + ) + + return schema + + +class StackingFriend: + @staticmethod + def add_schema(self, schema): + ( + TABLE_ELEMENT(expected) + .key("merge") + .displayedName("Array stacking") + .allowedStates(State.PASSIVE) + .description( + "Specify which source(s) or key(s) to stack or interleave." + "When stacking sources, the 'Source pattern' is interpreted as a " + "regular expression and the 'Key pattern' is interpreted as an " + "ordinary string. From all sources matching the source pattern, the " + "data under this key (should be array with same dimensions across all " + "stacked sources) is stacked in the same order as the sources are " + "listed in 'Data sources' and the result is under the same key name in " + "a new source named by 'Replacement'. " + "When stacking keys, both the 'Source pattern' and the 'Key pattern' " + "are regular expressions. Within each source matching the source " + "pattern, all keys matching the key pattern are stacked and the result " + "is put under the key named by 'Replacement'. " + "While source stacking is optimized and can use thread pool, key " + "stacking will iterate over all paths in matched sources and naively " + "call np.stack for each key pattern. In either case, data that is used " + "for stacking is removed from its original location (e.g. key is " + "erased from hash). " + "In both cases, the data can alternatively be interleaved. This is " + "essentially equivalent to stacking except followed by a reshape such " + "that the output is shaped like concatenation." + ) + .setColumns(merge_schema()) + .assignmentOptional() + .defaultValue([]) + .reconfigurable() + .commit(), + ) + + def __init__(self, merge_config, source_config): + self._stacking_buffers = {} + self._source_stacking_indices = {} + self._source_stacking_sources = {} + self._source_stacking_group_sizes = {} + self._key_stacking_sources = {} + self._merge_config = Hash() + self._source_config = Hash() + self.reconfigure(merge_config, source_config) + + def reconfigure(self, merge_config, source_config): + if merge_config is not None: + self._merge_config.merge(merge_config) + if source_config is not None: + self._source_config.merge(source_config) + + # not filtering by row["select"] to allow unselected sources to create gaps + source_names = [row["source"].partition("@")[0] for row in self._source_config] + self._stacking_buffers.clear() + self._source_stacking_indices.clear() + self._source_stacking_sources.clear() + self._source_stacking_group_sizes.clear() + self._key_stacking_sources.clear() + # split by type, prepare regexes + for row in self._merge_config: + if not row["select"]: + continue + group_type = GroupType(row["groupType"]) + source_re = re.compile(row["sourcePattern"]) + merge_method = MergeMethod(row["mergeMethod"]) + axis = row["axis"] + if group_type is GroupType.MULTISOURCE: + key = row["keyPattern"] + new_source = row["replacement"] + merge_sources = [ + source for source in source_names if source_re.match(source) + ] + if len(merge_sources) == 0: + self.log.WARN( + f"Group pattern {source_re} did not match any known sources" + ) + continue + for (i, source) in enumerate(merge_sources): + self._source_stacking_sources.setdefault(source, []).append( + (key, new_source, merge_method, axis) + ) + self._source_stacking_indices[(source, new_source, key)] = i + self._source_stacking_group_sizes[(new_source, key)] = i + 1 + else: + key_re = re.compile(row["keyPattern"]) + new_key = row["replacement"] + self._key_stacking_sources.setdefault(source, []).append( + (key_re, new_key, merge_method, axis) + ) + + def check_stacking_data(self, sources, frame_selection_mask): + if frame_selection_mask is not None: + orig_size = len(frame_selection_mask) + result_size = np.sum(frame_selection_mask) + stacking_data_shapes = {} + ignore_stacking = {} + fill_missed_data = {} + for source, keys in self._source_stacking_sources.items(): + if source not in sources: + for key, new_source, _, _ in keys: + missed_sources = fill_missed_data.setdefault((new_source, key), []) + merge_index = self._source_stacking_indices[ + (source, new_source, key) + ] + missed_sources.append(merge_index) + continue + data_hash, timestamp = sources[source] + filtering = ( + frame_selection_mask is not None and + self._frame_selection_source_pattern.match(source) + ) + for key, new_source, merge_method, axis in keys: + merge_data_shape = None + if key in data_hash: + merge_data = data_hash[key] + merge_data_shape = merge_data.shape + else: + ignore_stacking[(new_source, key)] = "Some data is missed" + continue + + if filtering and key in self._frame_selection_data_keys: + # !!! stacking is not expected to be used with filtering + if merge_data_shape[0] == orig_size: + merge_data_shape = (result_size,) + merge_data.shape[1:] + + expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault( + (new_source, key), + (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp) + ) + if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype: + ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent" + del stacking_data_shapes[(new_source, key)] + + return stacking_data_shapes, ignore_stacking, fill_missed_data + + def maybe_update_stacking_buffers( + self, stacking_data_shapes, fill_missed_data, new_sources_map + ): + for (new_source, key), attr in stacking_data_shapes.items(): + merge_data_shape, merge_method, axis, dtype, timestamp = attr + group_size = self._source_stacking_group_sizes[(new_source, key)] + if merge_method is MergeMethod.STACK: + expected_shape = utils.stacking_buffer_shape( + merge_data_shape, group_size, axis=axis) + else: + expected_shape = utils.interleaving_buffer_shape( + merge_data_shape, group_size, axis=axis) + + merge_buffer = self._stacking_buffers.get((new_source, key)) + if merge_buffer is None or merge_buffer.shape != expected_shape: + merge_buffer = np.empty( + shape=expected_shape, dtype=dtype) + self._stacking_buffers[(new_source, key)] = merge_buffer + + for merge_index in fill_missed_data.get((new_source, key), []): + utils.set_on_axis( + merge_buffer, + 0, + merge_index + if merge_method is MergeMethod.STACK + else np.index_exp[ + slice( + merge_index, + None, + self._source_stacking_group_sizes[(new_source, key)], + ) + ], + axis, + ) + if new_source not in new_sources_map: + new_sources_map[new_source] = (Hash(), timestamp) + new_source_hash = new_sources_map[new_source][0] + if not new_source_hash.has(key): + new_source_hash[key] = merge_buffer + + def handle_source(...): + # stack across sources (many sources, same key) + for ( + key, + new_source, + merge_method, + axis, + ) in self._source_stacking_sources.get(source, ()): + if (new_source, key) in ignore_stacking: + continue + merge_data = data_hash.get(key) + merge_index = self._source_stacking_indices[ + (source, new_source, key) + ] + merge_buffer = self._stacking_buffers[(new_source, key)] + utils.set_on_axis( + merge_buffer, + merge_data, + merge_index + if merge_method is MergeMethod.STACK + else np.index_exp[ + slice( + merge_index, + None, + self._source_stacking_group_sizes[(new_source, key)], + ) + ], + axis, + ) + data_hash.erase(key) + + # stack keys (multiple keys within this source) + for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get( + source, () + ): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + keys = [key for key in data_hash.paths() if key_re.match(key)] + try: + # note: maybe we could reuse buffers here, too? + if merge_method is MergeMethod.STACK: + stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) + else: + first = data_hash.get(keys[0]) + stacked = np.empty( + shape=utils.stacking_buffer_shape( + first.shape, len(keys), axis=axis + ), + dtype=first.dtype, + ) + for i, key in enumerate(keys): + utils.set_on_axis( + stacked, + data_hash.get(key), + np.index_exp[slice(i, None, len(keys))], + axis, + ) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in keys: + data_hash.erase(key) + data_hash[new_key] = stacked