From 02a4ccd97667c07abd1a57eb1d0b50db74c99b83 Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Fri, 6 May 2022 13:27:59 +0200 Subject: [PATCH] Allow stacking on other axes, fix thread pool --- src/calng/ShmemTrainMatcher.py | 103 ++++++++++++++++++++++----------- src/calng/utils.py | 25 ++++++++ 2 files changed, 95 insertions(+), 33 deletions(-) diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 788f328b..780be449 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -5,6 +5,7 @@ import re import numpy as np from karabo.bound import ( BOOL_ELEMENT, + INT32_ELEMENT, KARABO_CLASSINFO, STRING_ELEMENT, TABLE_ELEMENT, @@ -14,7 +15,7 @@ from karabo.bound import ( ) from TrainMatcher import TrainMatcher -from . import shmem_utils +from . import shmem_utils, utils from ._version import version as deviceVersion @@ -35,7 +36,7 @@ def merge_schema(): .commit(), STRING_ELEMENT(schema) - .key("source_pattern") + .key("sourcePattern") .displayedName("Source pattern") .assignmentOptional() .defaultValue("") @@ -43,7 +44,7 @@ def merge_schema(): .commit(), STRING_ELEMENT(schema) - .key("key_pattern") + .key("keyPattern") .displayedName("Key pattern") .assignmentOptional() .defaultValue("") @@ -66,6 +67,14 @@ def merge_schema(): .defaultValue(MergeGroupType.MULTISOURCE.value) .reconfigurable() .commit(), + + INT32_ELEMENT(schema) + .key("stackingAxis") + .displayedName("Axis") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), ) return schema @@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ( TABLE_ELEMENT(expected) .key("merge") - .displayedName("Array merging") + .displayedName("Array stacking") .allowedStates(State.PASSIVE) .description( "List source or key patterns to merge their data arrays, e.g. to " @@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): BOOL_ELEMENT(expected) .key("useThreadPool") + .displayedName("Use thread pool") + .allowedStates(State.PASSIVE) .assignmentOptional() .defaultValue(False) + .reconfigurable() .commit(), ) @@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): super().preReconfigure(conf) if conf.has("merge") or conf.has("sources"): self._prepare_merge_groups(conf["merge"]) + if conf.has("useThreadPool"): + if self._thread_pool is not None: + self._thread_pool.shutdown() + self._thread_pool = None + if conf["useThreadPool"]: + self._thread_pool = concurrent.futures.ThreadPoolExecutor() def _prepare_merge_groups(self, merge): source_group_patterns = [] @@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): if group_type is MergeGroupType.MULTISOURCE: source_group_patterns.append( ( - re.compile(row["source_pattern"]), - row["key_pattern"], + re.compile(row["sourcePattern"]), + row["keyPattern"], row["replacement"], + row["stackingAxis"], ) ) else: key_group_patterns.append( ( - re.compile(row["source_pattern"]), - re.compile(row["key_pattern"]), + re.compile(row["sourcePattern"]), + re.compile(row["keyPattern"]), row["replacement"], ) ) @@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): # handle source stacking groups self._source_stacking_indices.clear() self._source_stacking_sources.clear() - for source_re, key, new_source in source_group_patterns: + for source_re, key, new_source, stack_axis in source_group_patterns: merge_sources = [ source for source in source_names if source_re.match(source) ] @@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): continue for (i, source) in enumerate(merge_sources): self._source_stacking_sources.setdefault(source, []).append( - (key, new_source) + (key, new_source, stack_axis) ) - self._source_stacking_indices[(source, key)] = i + self._source_stacking_indices[(source, new_source, key)] = i # handle key stacking groups self._key_stacking_sources.clear() @@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): (key_re, new_key) ) + def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype): + # TODO: handle ValueError for max of empty sequence + stack_num = ( + max( + index + for ( + _, + new_source_, + key_, + ), index in self._source_stacking_indices.items() + if new_source_ == new_source and key_ == key + ) + + 1 + ) + self._stacking_buffers[(new_source, key)] = np.empty( + shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis), + dtype=dtype, + ) + def _handle_source(self, source, data_hash, timestamp, new_sources_map): # dereference calng shmem handles self._shmem_handler.dereference_shmem_handles(data_hash) # stack across sources (many sources, same key) # could probably save ~100 ns by "if ... in" instead of get - for (stack_key, new_source) in self._source_stacking_sources.get(source, ()): + for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get( + source, () + ): this_data = data_hash.get(stack_key) try: this_buffer = self._stacking_buffers[(new_source, stack_key)] - stack_index = self._source_stacking_indices[(source, stack_key)] - this_buffer[stack_index] = this_data + stack_index = self._source_stacking_indices[ + (source, new_source, stack_key) + ] + utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis) except (ValueError, IndexError, KeyError): - # ValueError: wrong shape + # ValueError: wrong shape (react to this_data.shape) # KeyError: buffer doesn't exist yet + # IndexError: new source? (buffer not long enough) # either way, create appropriate buffer now - # TODO: complain if shape varies between sources - self._stacking_buffers[(new_source, stack_key)] = np.empty( - shape=( - max( - index_ - for ( - source_, - key_, - ), index_ in self._source_stacking_indices.items() - if source_ == source and key_ == stack_key - ) - + 1, - ) - + this_data.shape, + # TODO: complain if shape varies between sources within train + self._update_stacking_buffer( + new_source, + stack_key, + this_data.shape, + axis=stack_axis, dtype=this_data.dtype, ) # and then try again this_buffer = self._stacking_buffers[(new_source, stack_key)] - stack_index = self._source_stacking_indices[(source, stack_key)] - this_buffer[stack_index] = this_data - # TODO: zero out unfilled buffer entries + stack_index = self._source_stacking_indices[ + (source, new_source, stack_key) + ] + utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis) + # TODO: zero out unfilled buffer entries data_hash.erase(stack_key) if new_source not in new_sources_map: @@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): concurrent.futures.wait( [ self._thread_pool.submit( - self._handle_source, data, timestamp, new_sources_map + self._handle_source, source, data, timestamp, new_sources_map ) for source, (data, timestamp) in sources.items() ] diff --git a/src/calng/utils.py b/src/calng/utils.py index d0f6ceb4..83289230 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out): return tuple(axis_order[axis] for axis in axes_out) +def stacking_buffer_shape(array_shape, stack_num, axis=0): + """Figures out the shape you would need for np.stack""" + if axis > len(array_shape) or axis < -len(array_shape) - 1: + # complain when np.stack would + raise np.AxisError( + f"axis {axis} is out of bounds " + f"for array of dimension {len(array_shape) + 1}" + ) + if axis < 0: + axis += len(array_shape) + 1 + return array_shape[:axis] + (stack_num,) + array_shape[axis:] + + +def set_on_axis(array, vals, index, axis): + """set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x""" + if axis >= len(array): + raise IndexError( + f"too many indices for array: array is {len(array.shape)}-dimensional, " + f"but {axis+1} were indexed" + ) + # TODO: maybe support negative axis with wraparound + indices = np.index_exp[:] * axis + np.index_exp[index] + array[indices] = vals + + _np_typechar_to_c_typestring = { "?": "bool", "B": "unsigned char", -- GitLab