From ad4606b850f435e8f29a9bdff350b64c788b3ca4 Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Tue, 17 Oct 2023 08:57:21 +0200 Subject: [PATCH] Finished overhaul of source stacking, added test --- src/calng/stacking_utils.py | 241 +++++++++++++++++++----------------- src/tests/test_stacking.py | 64 ++++++++++ 2 files changed, 188 insertions(+), 117 deletions(-) create mode 100644 src/tests/test_stacking.py diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py index 049480d4..8149ceea 100644 --- a/src/calng/stacking_utils.py +++ b/src/calng/stacking_utils.py @@ -1,5 +1,7 @@ import collections +import concurrent.futures import enum +import functools import re from karabo.bound import ( @@ -139,41 +141,42 @@ class StackingFriend: ) def __init__(self, device, source_config, merge_config): + # (old source, new source, key) -> (index (maybe index_exp), axis) self._source_stacking_indices = {} - self._source_stacking_sources = collections.defaultdict(list) - self._source_stacking_group_sizes = {} - # (new source name, key) -> {original sources used} + # list of (old source, key, new source) + self._source_stacking_parts = [] + # (new source, key, method, group size, axis, missing) -> {original sources} self._new_sources_inputs = collections.defaultdict(set) self._key_stacking_sources = collections.defaultdict(list) + self._merge_config = None self._source_config = None self._device = device - self.reconfigure(merge_config, source_config) + self.reconfigure(source_config, merge_config) - def reconfigure(self, merge_config, source_config): + def reconfigure(self, source_config, merge_config): print("merge_config", type(merge_config)) print("source_config", type(source_config)) - if merge_config is not None: - self._merge_config = merge_config if source_config is not None: self._source_config = source_config + if merge_config is not None: + self._merge_config = merge_config self._source_stacking_indices.clear() - self._source_stacking_sources.clear() - self._source_stacking_group_sizes.clear() - self._key_stacking_sources.clear() + self._source_stacking_parts.clear() self._new_sources_inputs.clear() + self._key_stacking_sources.clear() # not filtering by row["select"] to allow unselected sources to create gaps source_names = [row["source"].partition("@")[0] for row in self._source_config] source_stacking_groups = [ row for row in self._merge_config - if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name + if row["select"] and row["groupType"] == GroupType.MULTISOURCE.value ] key_stacking_groups = [ row for row in self._merge_config - if row["select"] and row["groupType"] == GroupType.MULTIKEY.name + if row["select"] and row["groupType"] == GroupType.MULTIKEY.value ] for row in source_stacking_groups: @@ -181,6 +184,7 @@ class StackingFriend: merge_method = MergeMethod(row["mergeMethod"]) key = row["keyPattern"] new_source = row["replacement"] + axis = row["axis"] merge_sources = [ source for source in source_names if source_re.match(source) ] @@ -189,11 +193,7 @@ class StackingFriend: f"Group pattern {source_re} did not match any known sources" ) continue - self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources) for i, source in enumerate(merge_sources): - self._source_stacking_sources[source].append( - (key, new_source, merge_method, row["axis"]) - ) self._source_stacking_indices[(source, new_source, key)] = ( i if merge_method is MergeMethod.STACK @@ -204,7 +204,18 @@ class StackingFriend: len(merge_sources), ) ] - ) + ), axis + self._new_sources_inputs[ + ( + new_source, + key, + merge_method, + len(merge_sources), + axis, + row["missingValue"], + ) + ].add(source) + self._source_stacking_parts.append((source, key, new_source)) for row in key_stacking_groups: key_re = re.compile(row["keyPattern"]) @@ -214,25 +225,27 @@ class StackingFriend: ) def process(self, sources, thread_pool=None): - stacking_data_shapes = {} - stacking_buffers = {} - new_source_map = collections.defaultdict(Hash) + merge_buffers = {} + new_source_map = {} missing_value_defaults = {} # prepare for source stacking where sources are present - source_set = set(sources.keys()) for ( new_source, data_key, merge_method, group_size, axis, - missing_value, + missing_value_str, ), original_sources in self._new_sources_inputs.items(): - for present_source in source_set & original_sources: - data = sources[present_source].get(data_key)[0] + for source in original_sources: + if source not in sources: + continue + data_hash, timestamp = sources[source] + data = data_hash.get(data_key) if data is None: continue + # didn't hit continue => first source in group present and with data if merge_method is MergeMethod.STACK: expected_shape = utils.stacking_buffer_shape( data.shape, group_size, axis=axis @@ -241,119 +254,113 @@ class StackingFriend: expected_shape = utils.interleaving_buffer_shape( data.shape, group_size, axis=axis ) - stacking_buffer = np.empty( - expected_shape, dtype=data.dtype - ) - stacking_buffers[(new_source, data_key)] = stacking_buffer - new_source_map[new_source][data_key] = stacking_buffer + merge_buffer = np.empty(expected_shape, dtype=data.dtype) + merge_buffers[(new_source, data_key)] = merge_buffer + if new_source not in new_source_map: + new_source_map[new_source] = (Hash(), timestamp) + new_source_map[new_source][0][data_key] = merge_buffer try: - missing_value_defaults[(new_source, data_key)] = data.dtype.type( - missing_value - ) + missing_value = data.dtype.type(missing_value_str) except ValueError: self._device.log.WARN( - f"Invalid missing data value for {new_source}.{data_key}, using 0" + f"Invalid missing data value for {new_source}.{data_key} " + f"(tried making a {data.dtype} out of " + f" '{missing_value_str}', using 0" ) + missing_value = 0 + missing_value_defaults[(new_source, data_key)] = missing_value + # now we have set up the buffer for this new source, so break break else: - # in this case: no present_source (if any) had data_key + # in this case: no source (if any) had data_key self._device.log.WARN( f"No sources needed for {new_source}.{data_key} were present" ) - for (new_source, key), attr in stacking_data_shapes.items(): - merge_data_shape, merge_method, axis, dtype, timestamp = attr - group_size = parent._source_stacking_group_sizes[(new_source, key)] - merge_buffer = self._stacking_buffers.get((new_source, key)) - if merge_buffer is None or merge_buffer.shape != expected_shape: - merge_buffer = np.empty(shape=expected_shape, dtype=dtype) - self._stacking_buffers[(new_source, key)] = merge_buffer - - for merge_index in self._fill_missed_data.get((new_source, key), []): - utils.set_on_axis( - merge_buffer, - 0, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - parent._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - if new_source not in self.new_source_map: - self.new_source_map[new_source] = (Hash(), timestamp) - new_source_hash = self.new_source_map[new_source][0] - if not new_source_hash.has(key): - new_source_hash[key] = merge_buffer - # now actually do some work - fun = functools.partial(self._handle_source, ...) + fun = functools.partial( + self._handle_source_stacking_part, + merge_buffers, + missing_value_defaults, + sources, + ) if thread_pool is None: - for _ in map(fun, ...): + for _ in map(fun, self._source_stacking_parts): pass else: - concurrent.futures.wait(thread_pool.map(fun, ...)) + concurrent.futures.wait(thread_pool.map(fun, self._source_stacking_parts)) + + # note: new source names for group stacking may not match existing source names + sources.update(new_source_map) - def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source): + def _handle_source_stacking_part( + self, + merge_buffers, + missing_values, + actual_sources, + stacking_triple, + ): """Helper function used in processing. Note that it should be called for each - source that was supposed to be there - so it can decide whether to move data in - for stacking, fill missing data in case a source is missing, or skip in case no - buffer was created (none of the necessary sources were present).""" + (original source, data key, new source) triple expected in the stacking scheme + regardless of whether that original source is present - so it can decide whether + to move data in for stacking, fill missing data in case a source is missing, + or skip in case no buffer was created (none of the necessary sources were + present).""" - if expected_source not in actual_sources: - if ex - for ( - key, - new_source, - merge_method, - axis, - ) in self._source_stacking_sources.get(source, ()): - if (new_source, key) in self._ignore_stacking: - continue - merge_data = data_hash.get(key) - merge_index = self._source_stacking_indices[(source, new_source, key)] - merge_buffer = self._stacking_buffers[(new_source, key)] + old_source, data_key, new_source = stacking_triple + + if (new_source, data_key) not in merge_buffers: + # preparatory step did not create buffer + # (i.e. no sources from this group were present) + return + merge_buffer = merge_buffers[(new_source, data_key)] + if old_source in actual_sources: + data_hash = actual_sources[old_source][0] + else: + data_hash = None + if data_hash is None or (merge_data := data_hash.get(data_key)) is None: + merge_data = missing_values[(new_source, data_key)] + try: utils.set_on_axis( merge_buffer, merge_data, - merge_index, - axis, + *self._source_stacking_indices[(old_source, new_source, data_key)], + ) + if data_hash is not None: + data_hash.erase(data_key) + except Exception as ex: + self._device.log.WARN( + f"Failed to stack {data_key} from {old_source} into {new_source}: {ex}" ) - data_hash.erase(key) - # stack keys (multiple keys within this source) - for key_re, new_key, merge_method, axis in self._key_stacking_sources.get( - source, () - ): - # note: please no overlap between different key_re - # note: if later key_re match earlier new_key, this gets spicy - keys = [key for key in data_hash.paths() if key_re.match(key)] - try: - # note: maybe we could reuse buffers here, too? - if merge_method is MergeMethod.STACK: - stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) - else: - first = data_hash.get(keys[0]) - stacked = np.empty( - shape=utils.stacking_buffer_shape( - first.shape, len(keys), axis=axis - ), - dtype=first.dtype, - ) - for i, key in enumerate(keys): - utils.set_on_axis( - stacked, - data_hash.get(key), - np.index_exp[slice(i, None, len(keys))], - axis, - ) - except Exception as e: - self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + def _handle_key_stacking_entry( + self, source, data_hash, key_re, new_key, merge_method, axis + ): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + keys = [key for key in data_hash.paths() if key_re.match(key)] + try: + # note: maybe we could reuse buffers here, too? + if merge_method is MergeMethod.STACK: + stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) else: - for key in keys: - data_hash.erase(key) - data_hash[new_key] = stacked + first = data_hash.get(keys[0]) + stacked = np.empty( + shape=utils.stacking_buffer_shape( + first.shape, len(keys), axis=axis + ), + dtype=first.dtype, + ) + for i, key in enumerate(keys): + utils.set_on_axis( + stacked, + data_hash.get(key), + np.index_exp[slice(i, None, len(keys))], + axis, + ) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in keys: + data_hash.erase(key) + data_hash[new_key] = stacked diff --git a/src/tests/test_stacking.py b/src/tests/test_stacking.py new file mode 100644 index 00000000..0c2ff00e --- /dev/null +++ b/src/tests/test_stacking.py @@ -0,0 +1,64 @@ +from karabo.bound import Hash +import numpy as np + +from calng import stacking_utils + + +class NotADevice: + class log: + @staticmethod + def WARN(s): + print(f"Warning: {s}") + + +source_table = [ + { + "select": True, + "source": f"source{i}@device{i}:channel", + } + for i in range(3) +] + +merge_rules = [ + { + "select": True, + "sourcePattern": "source\\d+", + "keyPattern": "keyToStack", + "replacement": "newSource", + "groupType": "sources", + "mergeMethod": "stack", + "axis": 1, + "missingValue": "0", + } +] + +friend = stacking_utils.StackingFriend(NotADevice(), source_table, merge_rules) + +datas = [np.arange(i * 100, i * 100 + 8).reshape(4, 2) for i in range(3)] + + +def test_simple_source_stacking(): + sources = { + f"source{i}": (Hash("keyToStack", data), None) for i, data in enumerate(datas) + } + + friend.process(sources) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + assert stacked.shape == (4, 3, 2) + for i, data in enumerate(datas): + assert np.array_equal(stacked[:, i], data, equal_nan=True) + + +def test_source_stacking_one_missing(): + sources = { + f"source{i}": (Hash("keyToStack", data), None) for i, data in enumerate(datas) + } + del sources["source0"] + friend.process(sources) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + print(stacked) + assert np.all(stacked[:, 0] == 0) + for i, data in enumerate(datas[1:], start=1): + assert np.array_equal(stacked[:, i], data, equal_nan=True) -- GitLab