diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index faac119562c1a1299b2635cda158e5e7cba302e6..607ab5f2cb400e9caddfee7a841fb8703c0f3017 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -1,3 +1,4 @@ +import enum import re import numpy as np @@ -16,6 +17,11 @@ from . import shmem_utils from ._version import version as deviceVersion +class MergeGroupType(enum.Enum): + MULTISOURCE = "sources" # same key stacked from multiple sources in new source + MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key + + def merge_schema(): schema = Schema() ( @@ -50,6 +56,15 @@ def merge_schema(): .defaultValue("") .reconfigurable() .commit(), + + STRING_ELEMENT(schema) + .key("type") + .displayedName("Group type") + .options(",".join(option.value for option in MergeGroupType)) + .assignmentOptional() + .defaultValue(MergeGroupType.MULTISOURCE.value) + .reconfigurable() + .commit(), ) return schema @@ -79,103 +94,142 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ) def initialization(self): - self._compile_merge_patterns(self.get("merge")) + self._stacking_buffers = {} + self._source_stacking_indices = {} + self._source_stacking_sources = {} + self._key_stacking_sources = {} + self._prepare_merge_groups(self.get("merge")) super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() def preReconfigure(self, conf): super().preReconfigure(conf) - if conf.has("merge"): - self._compile_merge_patterns(conf["merge"]) - - def _compile_merge_patterns(self, merge): - self._merge_patterns = [ - ( - re.compile(row["source_pattern"]), - re.compile(row["key_pattern"]), - row["replacement"], - ) - for row in merge - if row["select"] - ] + if conf.has("merge") or conf.has("sources"): + self._prepare_merge_groups(conf["merge"]) + + def _prepare_merge_groups(self, merge): + source_group_patterns = [] + key_group_patterns = [] + # split by type, prepare regexes + for row in merge: + if not row["select"]: + continue + group_type = MergeGroupType(row["type"]) + if group_type is MergeGroupType.MULTISOURCE: + source_group_patterns.append( + ( + re.compile(row["source_pattern"]), + row["key_pattern"], + row["replacement"], + ) + ) + else: + key_group_patterns.append( + ( + re.compile(row["source_pattern"]), + re.compile(row["key_pattern"]), + row["replacement"], + ) + ) - def on_matched_data(self, train_id, sources): - # dereference calng shmem handles - for source, (data, timestamp) in sources.items(): - if data.has("calngShmemPaths"): - shmem_paths = list(data["calngShmemPaths"]) - data.erase("calngShmemPaths") - for shmem_path in shmem_paths: - if not data.has(shmem_path): - self.log.INFO(f"Hash from {source} did not have {shmem_path}") - continue - dereferenced = self._shmem_handler.get(data[shmem_path]) - data[shmem_path] = dereferenced - - # merge arrays - for source_re, key_re, replacement in self._merge_patterns: - # Find all sources matching the source pattern. + # not filtering by row["select"] to allow unselected sources to create gaps + source_names = [row["source"].partition("@")[0] for row in self.get("sources")] + + self._stacking_buffers.clear() + # handle source stacking groups + self._source_stacking_indices.clear() + self._source_stacking_sources.clear() + for source_re, key, new_source in source_group_patterns: merge_sources = [ - source for source in sources.keys() if source_re.match(source) + source for source in source_names if source_re.match(source) ] - - if len(merge_sources) > 1: - # More than one source match, merge by source. - - if key_re.pattern in sources[merge_sources[0]][0]: - # Short-circuit the pattern for performance if itself is a key. - new_key = key_re.pattern - else: - # Find the first key matching the pattern. - for new_key in sources[merge_sources[0]][0].paths(): - if key_re.match(new_key): - break - - merge_keys = [new_key] - new_source = replacement - to_merge = [ - sources[source][0][new_key] for source in sorted(merge_sources) - ] - - if len(to_merge) != len(merge_sources): - # Make sure all matched sources contain the key. - break - - elif len(merge_sources) == 1: - # Exactly one source match, merge by key. - - new_source = merge_sources[0] - new_key = replacement - merge_keys = [ - key for key in sources[new_source][0].paths() if key_re.match(key) - ] - - if not merge_keys: - # No key match, ignore. - continue - - to_merge = [sources[new_source][0][key] for key in sorted(merge_keys)] - - else: - # No source match, ignore. - continue - - # Stack data and insert into source data. - try: - new_data = np.stack(to_merge, axis=0) - except ValueError as e: - self.log.ERROR( - f"Failed to merge data for " f"{new_source}.{new_key}: {e}" + if len(merge_sources) == 0: + self.log.WARN( + f"Group merge pattern {source_re} did not match any known sources" ) continue + for (i, source) in enumerate(merge_sources): + self._source_stacking_sources.setdefault(source, []).append( + (key, new_source) + ) + self._source_stacking_indices[(source, key)] = i + + # handle key stacking groups + self._key_stacking_sources.clear() + for source_re, key_re, new_key in key_group_patterns: + for source in source_names: + # TODO: maybe also warn if no matches here? + if not source_re.match(source): + continue + self._key_stacking_sources.setdefault(source, []).append( + (key_re, new_key) + ) + + def on_matched_data(self, train_id, sources): + # dereference calng shmem handles + for (data, _) in sources.values(): + self._shmem_handler.dereference_shmem_handles(data) - sources.setdefault(new_source, (Hash(), sources[merge_sources[0]][1]))[0][ - new_key - ] = new_data + new_sources_map = {} + for source, (data, timestamp) in sources.items(): + # stack across sources (many sources, same key) + # could probably save ~100 ns by "if ... in" instead of get + for (stack_key, new_source) in self._source_stacking_sources.get( + source, () + ): + this_data = data.get(stack_key) + try: + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[(source, stack_key)] + this_buffer[stack_index] = this_data + except (ValueError, KeyError): + # ValueError: wrong shape + # KeyError: buffer doesn't exist yet + # either way, create appropriate buffer now + # TODO: complain if shape varies between sources + self._stacking_buffers[(new_source, stack_key)] = np.empty( + shape=( + max( + index_ + for ( + source_, + key_, + ), index_ in self._source_stacking_indices.items() + if source_ == source and key_ == stack_key + ) + + 1, + ) + + this_data.shape, + dtype=this_data.dtype, + ) + # and then try again + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[(source, stack_key)] + this_buffer[stack_index] = this_data + # TODO: zero out unfilled buffer entries + data.erase(stack_key) + + if new_source not in new_sources_map: + new_sources_map[new_source] = (Hash(), timestamp) + new_source_hash = new_sources_map[new_source][0] + if not new_source_hash.has(stack_key): + new_source_hash[stack_key] = this_buffer + + # stack keys (multiple keys within this source) + for (key_re, new_key) in self._key_stacking_sources.get(source, ()): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + stack_keys = [key for key in data.paths() if key_re.match(key)] + try: + # TODO: consider reusing buffers here, too + stacked = np.stack([data.get(key) for key in stack_keys], axis=0) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in stack_keys: + data.erase(key) + data[new_key] = stacked - # Unset keys merged together across all source matches. - for source in merge_sources: - for key in merge_keys: - sources[source][0].erase(key) + sources.update(new_sources_map) super().on_matched_data(train_id, sources) diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py index 4c4838e21fbb6df786d83508fd3be7630f110d08..e02e3fd0b3bbf9a46172dc14735fd33e8c630eac 100644 --- a/src/calng/shmem_utils.py +++ b/src/calng/shmem_utils.py @@ -46,6 +46,18 @@ class ShmemCircularBufferReceiver: return ary[index] + def dereference_shmem_handles(self, data_hash): + if data_hash.has("calngShmemPaths"): + shmem_paths = list(data_hash["calngShmemPaths"]) + data_hash.erase("calngShmemPaths") + for shmem_path in shmem_paths: + if not data_hash.has(shmem_path): + # TODO: proper warnings + print(f"Warning: hash did not contain {shmem_path}") + continue + dereferenced = self.get(data_hash[shmem_path]) + data_hash[shmem_path] = dereferenced + class ShmemCircularBuffer: """Convenience wrapper around posixshmem-backed ndarray buffers