diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py index e5204e4d1009df1f775ebbcdfe56554ff3baba1c..ae36803076211b8afa40c55cbe882a9c3ee962de 100644 --- a/src/calng/DetectorAssembler.py +++ b/src/calng/DetectorAssembler.py @@ -1,4 +1,5 @@ import enum +from timeit import default_timer import functools import re @@ -20,7 +21,7 @@ from karabo.bound import ( ) from TrainMatcher import TrainMatcher -from . import geom_utils, preview_utils, scenes, schemas +from . import geom_utils, preview_utils, scenes, schemas, utils from ._version import version as deviceVersion @@ -56,6 +57,15 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): .needsAcknowledging(False) .commit(), + DOUBLE_ELEMENT(expected) + .key("processingTime") + .displayedName("Processing time") + .unit(Unit.SECOND) + .metricPrefix(MetricPrefix.MILLI) + .readOnly() + .initialValue(0) + .commit(), + STRING_ELEMENT(expected) .key("pathToStack") .assignmentOptional() @@ -108,7 +118,9 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): def __init__(self, conf): super().__init__(conf) - self.info.merge(Hash("timeOfFlight", 0)) + self.info.merge(Hash("timeOfFlight", 0, "processingTime", 0)) + self._tof_tracker = utils.ExponentialMovingAverage(alpha=0.3) + self._processing_time_tracker = utils.ExponentialMovingAverage(alpha=0.3) self.registerSlot(self.slotReceiveGeometry) def initialization(self): @@ -176,6 +188,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): ) def on_matched_data(self, train_id, sources): + ts_start = default_timer() if self._geometry is None: self.log.WARN("Have not received a geometry yet, will not send anything") return @@ -255,9 +268,11 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): assembled, ) - self.info["timeOfFlight"] = ( - Timestamp().toTimestamp() - earliest_source_timestamp - ) * 1000 + self._processing_time_tracker.update(default_timer() - ts_start) + self._tof_tracker.update(Timestamp().toTimestamp() - earliest_source_timestamp) + + self.info["processingTime"] = self._processing_time_tracker.get() * 1000 + self.info["timeOfFlight"] = self._tof_tracker.get() * 1000 self.info["sent"] += 1 self.info["trainId"] = train_id self.rate_out.update() diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 19722de729a7cb0c55371c179c66be771f421503..8ebaebfa4a2e0a2de435f5490705c329c92d5a5a 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -1,140 +1,31 @@ import concurrent.futures -import enum -import re +from timeit import default_timer -import numpy as np from karabo.bound import ( BOOL_ELEMENT, - INT32_ELEMENT, + DOUBLE_ELEMENT, KARABO_CLASSINFO, - NODE_ELEMENT, OVERWRITE_ELEMENT, - STRING_ELEMENT, - TABLE_ELEMENT, - VECTOR_STRING_ELEMENT, + UINT32_ELEMENT, ChannelMetaData, Hash, - Schema, + MetricPrefix, State, + Unit, ) from TrainMatcher import TrainMatcher from . import shmem_utils, utils +from .stacking_utils import StackingFriend +from .frameselection_utils import FrameselectionFriend from ._version import version as deviceVersion -class GroupType(enum.Enum): - MULTISOURCE = "sources" # same key stacked from multiple sources in new source - MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key - - -class MergeMethod(enum.Enum): - STACK = "stack" - INTERLEAVE = "interleave" - - -def merge_schema(): - schema = Schema() - ( - BOOL_ELEMENT(schema) - .key("select") - .displayedName("Select") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("sourcePattern") - .displayedName("Source pattern") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("keyPattern") - .displayedName("Key pattern") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("replacement") - .displayedName("Replacement") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("groupType") - .displayedName("Group type") - .options(",".join(option.value for option in GroupType)) - .assignmentOptional() - .defaultValue(GroupType.MULTISOURCE.value) - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key("mergeMethod") - .displayedName("Merge method") - .options(",".join(option.value for option in MergeMethod)) - .assignmentOptional() - .defaultValue(MergeMethod.STACK.value) - .reconfigurable() - .commit(), - - INT32_ELEMENT(schema) - .key("axis") - .displayedName("Axis") - .assignmentOptional() - .defaultValue(0) - .reconfigurable() - .commit(), - ) - - return schema - - @KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion) class ShmemTrainMatcher(TrainMatcher.TrainMatcher): @staticmethod def expectedParameters(expected): ( - TABLE_ELEMENT(expected) - .key("merge") - .displayedName("Array stacking") - .allowedStates(State.PASSIVE) - .description( - "Specify which source(s) or key(s) to stack or interleave." - "When stacking sources, the 'Source pattern' is interpreted as a " - "regular expression and the 'Key pattern' is interpreted as an " - "ordinary string. From all sources matching the source pattern, the " - "data under this key (should be array with same dimensions across all " - "stacked sources) is stacked in the same order as the sources are " - "listed in 'Data sources' and the result is under the same key name in " - "a new source named by 'Replacement'. " - "When stacking keys, both the 'Source pattern' and the 'Key pattern' " - "are regular expressions. Within each source matching the source " - "pattern, all keys matching the key pattern are stacked and the result " - "is put under the key named by 'Replacement'. " - "While source stacking is optimized and can use thread pool, key " - "stacking will iterate over all paths in matched sources and naively " - "call np.stack for each key pattern. In either case, data that is used " - "for stacking is removed from its original location (e.g. key is " - "erased from hash). " - "In both cases, the data can alternatively be interleaved. This is " - "essentially equivalent to stacking except followed by a reshape such " - "that the output is shaped like concatenation." - ) - .setColumns(merge_schema()) - .assignmentOptional() - .defaultValue([]) - .reconfigurable() - .commit(), - # order is important for stacking, disable sorting OVERWRITE_ELEMENT(expected) .key("sortSources") @@ -143,20 +34,28 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .commit(), BOOL_ELEMENT(expected) - .key("useThreadPool") - .displayedName("Use thread pool") + .key("enableKaraboOutput") + .displayedName("Enable Karabo channel") .allowedStates(State.PASSIVE) .assignmentOptional() - .defaultValue(False) + .defaultValue(True) .reconfigurable() .commit(), - BOOL_ELEMENT(expected) - .key("enableKaraboOutput") - .displayedName("Enable Karabo channel") + DOUBLE_ELEMENT(expected) + .key("processingTime") + .displayedName("Processing time") + .unit(Unit.SECOND) + .metricPrefix(MetricPrefix.MILLI) + .readOnly() + .initialValue(0) + .commit(), + + UINT32_ELEMENT(expected) + .key("processingThreads") .allowedStates(State.PASSIVE) .assignmentOptional() - .defaultValue(True) + .defaultValue(16) .reconfigurable() .commit(), @@ -170,374 +69,70 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .assignmentOptional() .defaultValue(True) .commit(), - - NODE_ELEMENT(expected) - .key("frameSelector") - .displayedName("Frame selection") - .commit(), - - BOOL_ELEMENT(expected) - .key("frameSelector.enable") - .displayedName("Enabled") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - - STRING_ELEMENT(expected) - .key("frameSelector.arbiterSource") - .displayedName("Arbiter source") - .description( - "Source name to pull the frame selection pattern from, must be part of " - "matched sources." - ) - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(expected) - .key("frameSelector.dataSourcePattern") - .displayedName("Data source pattern") - .description( - "Source name pattern to apply frame selection to. Should match " - "subset of matched sources." - ) - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - VECTOR_STRING_ELEMENT(expected) - .key("frameSelector.dataKeys") - .displayedName("Data keys") - .description("Keys in data sources to apply frame selection to.") - .assignmentOptional() - .defaultValue([]) - .reconfigurable() - .commit(), ) + FrameselectionFriend.add_schema(expected) + StackingFriend.add_schema(expected) def __init__(self, config): if config.get("useInfiniband", default=True): from PipeToZeroMQ.utils import find_infiniband_ip config["output.hostname"] = find_infiniband_ip() + self._processing_time_tracker = utils.ExponentialMovingAverage(alpha=0.3) super().__init__(config) + self.info.merge(Hash("processingTime", 0)) def initialization(self): - self._stacking_buffers = {} - self._source_stacking_indices = {} - self._source_stacking_sources = {} - self._source_stacking_group_sizes = {} - self._key_stacking_sources = {} - self._have_prepared_merge_groups = False - self._prepare_merge_groups() super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() - if self.get("useThreadPool"): - self._thread_pool = concurrent.futures.ThreadPoolExecutor() - else: - self._thread_pool = None + self._stacking_friend = StackingFriend( + self, self.get("sources"), self.get("merge") + ) + self._frameselection_friend = FrameselectionFriend(self.get("frameSelector")) + self._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=self.get("processingThreads") + ) if not self.get("enableKaraboOutput"): # it is set already by super by default, so only need to turn off self.output = None - self._frame_selection_enabled = False - self._frame_selection_arbiter = "" - self._frame_selection_data_pattern = "" - self._frame_selection_data_keys = [] - self._have_prepared_frame_selection = False - self._prepare_frame_selection() - self.start() # Auto-start this type of matcher. def preReconfigure(self, conf): super().preReconfigure(conf) if conf.has("merge") or conf.has("sources"): - self._have_prepared_merge_groups = False + self._stacking_friend.reconfigure(conf.get("sources"), conf.get("merge")) # re-prepare in postReconfigure after sources *and* merge are in self - if conf.has("useThreadPool"): - if self._thread_pool is not None: - self._thread_pool.shutdown() - self._thread_pool = None - if conf["useThreadPool"]: - self._thread_pool = concurrent.futures.ThreadPoolExecutor() + if conf.has("frameSelector"): + self._frameselection_friend.reconfigure(conf.get("frameSelector")) + if conf.has("processingThreads"): + self._thread_pool.shutdown() + self._thread_pool = PriorityThreadPoolExecutor( + max_workers=conf.get("processingThreads") + ) if conf.has("enableKaraboOutput"): if conf["enableKaraboOutput"]: self.output = self._ss.getOutputChannel("output") else: self.output = None - if conf.has("frameSelector"): - self._have_prepared_frame_selection = False - - def postReconfigure(self): - super().postReconfigure() - if not self._have_prepared_merge_groups: - self._prepare_merge_groups() - if not self._have_prepared_frame_selection: - self._prepare_frame_selection() - - def _prepare_frame_selection(self): - self._frame_selection_enabled = self.get("frameSelector.enable") - self._frame_selection_arbiter = self.get("frameSelector.arbiterSource") - self._frame_selection_source_pattern = re.compile( - self.get("frameSelector.dataSourcePattern") - ) - self._frame_selection_data_keys = list(self.get("frameSelector.dataKeys")) - - def _prepare_merge_groups(self): - # not filtering by row["select"] to allow unselected sources to create gaps - source_names = [row["source"].partition("@")[0] for row in self.get("sources")] - self._stacking_buffers.clear() - self._source_stacking_indices.clear() - self._source_stacking_sources.clear() - self._source_stacking_group_sizes.clear() - self._key_stacking_sources.clear() - # split by type, prepare regexes - for row in self.get("merge"): - if not row["select"]: - continue - group_type = GroupType(row["groupType"]) - source_re = re.compile(row["sourcePattern"]) - merge_method = MergeMethod(row["mergeMethod"]) - axis = row["axis"] - if group_type is GroupType.MULTISOURCE: - key = row["keyPattern"] - new_source = row["replacement"] - merge_sources = [ - source for source in source_names if source_re.match(source) - ] - if len(merge_sources) == 0: - self.log.WARN( - f"Group pattern {source_re} did not match any known sources" - ) - continue - for (i, source) in enumerate(merge_sources): - self._source_stacking_sources.setdefault(source, []).append( - (key, new_source, merge_method, axis) - ) - self._source_stacking_indices[(source, new_source, key)] = i - self._source_stacking_group_sizes[(new_source, key)] = i + 1 - else: - key_re = re.compile(row["keyPattern"]) - new_key = row["replacement"] - self._key_stacking_sources.setdefault(source, []).append( - (key_re, new_key, merge_method, axis) - ) - - self._have_prepared_merge_groups = True - - def _check_stacking_data(self, sources, frame_selection_mask): - if frame_selection_mask is not None: - orig_size = len(frame_selection_mask) - result_size = np.sum(frame_selection_mask) - stacking_data_shapes = {} - ignore_stacking = {} - fill_missed_data = {} - for source, keys in self._source_stacking_sources.items(): - if source not in sources: - for key, new_source, _, _ in keys: - missed_sources = fill_missed_data.setdefault((new_source, key), []) - merge_index = self._source_stacking_indices[ - (source, new_source, key) - ] - missed_sources.append(merge_index) - continue - data_hash, timestamp = sources[source] - filtering = ( - frame_selection_mask is not None and - self._frame_selection_source_pattern.match(source) - ) - for key, new_source, merge_method, axis in keys: - merge_data_shape = None - if key in data_hash: - merge_data = data_hash[key] - merge_data_shape = merge_data.shape - else: - ignore_stacking[(new_source, key)] = "Some data is missed" - continue - - if filtering and key in self._frame_selection_data_keys: - # !!! stacking is not expected to be used with filtering - if merge_data_shape[0] == orig_size: - merge_data_shape = (result_size,) + merge_data.shape[1:] - - expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault( - (new_source, key), - (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp) - ) - if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype: - ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent" - del stacking_data_shapes[(new_source, key)] - - return stacking_data_shapes, ignore_stacking, fill_missed_data - - def _maybe_update_stacking_buffers( - self, stacking_data_shapes, fill_missed_data, new_sources_map - ): - for (new_source, key), attr in stacking_data_shapes.items(): - merge_data_shape, merge_method, axis, dtype, timestamp = attr - group_size = self._source_stacking_group_sizes[(new_source, key)] - if merge_method is MergeMethod.STACK: - expected_shape = utils.stacking_buffer_shape( - merge_data_shape, group_size, axis=axis) - else: - expected_shape = utils.interleaving_buffer_shape( - merge_data_shape, group_size, axis=axis) - - merge_buffer = self._stacking_buffers.get((new_source, key)) - if merge_buffer is None or merge_buffer.shape != expected_shape: - merge_buffer = np.empty( - shape=expected_shape, dtype=dtype) - self._stacking_buffers[(new_source, key)] = merge_buffer - for merge_index in fill_missed_data.get((new_source, key), []): - utils.set_on_axis( - merge_buffer, - 0, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, + def on_matched_data(self, train_id, sources): + ts_start = default_timer() + frame_selection_mask = self._frameselection_friend.get_mask(sources) + concurrent.futures.wait( + [ + self._thread_pool.submit( + self._handle_source, + source, + data, + timestamp, + frame_selection_mask, ) - if new_source not in new_sources_map: - new_sources_map[new_source] = (Hash(), timestamp) - new_source_hash = new_sources_map[new_source][0] - if not new_source_hash.has(key): - new_source_hash[key] = merge_buffer - - def _handle_source( - self, source, data_hash, timestamp, new_sources_map, frame_selection_mask, - ignore_stacking - ): - # dereference calng shmem handles - self._shmem_handler.dereference_shmem_handles(data_hash) - - # apply frame_selection - if frame_selection_mask is not None and self._frame_selection_source_pattern.match( - source - ): - for key in self._frame_selection_data_keys: - if not data_hash.has(key): - continue - if data_hash[key].shape[0] != frame_selection_mask.size: - self.log.WARN("Frame selection mask does not match the data size") - continue - - data_hash[key] = data_hash[key][frame_selection_mask] - - # stack across sources (many sources, same key) - for ( - key, - new_source, - merge_method, - axis, - ) in self._source_stacking_sources.get(source, ()): - if (new_source, key) in ignore_stacking: - continue - merge_data = data_hash.get(key) - merge_index = self._source_stacking_indices[ - (source, new_source, key) + for source, (data, timestamp) in sources.items() ] - merge_buffer = self._stacking_buffers[(new_source, key)] - utils.set_on_axis( - merge_buffer, - merge_data, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - data_hash.erase(key) - - # stack keys (multiple keys within this source) - for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get( - source, () - ): - # note: please no overlap between different key_re - # note: if later key_re match earlier new_key, this gets spicy - keys = [key for key in data_hash.paths() if key_re.match(key)] - try: - # note: maybe we could reuse buffers here, too? - if merge_method is MergeMethod.STACK: - stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) - else: - first = data_hash.get(keys[0]) - stacked = np.empty( - shape=utils.stacking_buffer_shape( - first.shape, len(keys), axis=axis - ), - dtype=first.dtype, - ) - for i, key in enumerate(keys): - utils.set_on_axis( - stacked, - data_hash.get(key), - np.index_exp[slice(i, None, len(keys))], - axis, - ) - except Exception as e: - self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") - else: - for key in keys: - data_hash.erase(key) - data_hash[new_key] = stacked - - def on_matched_data(self, train_id, sources): - new_sources_map = {} - frame_selection_mask = None - if self._frame_selection_enabled and self._frame_selection_arbiter in sources: - frame_selection_mask = np.array( - sources[self._frame_selection_arbiter][0][ - "data.dataFramePattern" - ], copy=False - ).astype(np.bool, copy=False), - - # prepare stacking - stacking_data_shapes, ignore_stacking, fill_missed_data = ( - self._check_stacking_data(sources, frame_selection_mask) ) - self._maybe_update_stacking_buffers( - stacking_data_shapes, fill_missed_data, new_sources_map) - for (new_source, key), msg in ignore_stacking.items(): - self.log.WARN(f"Failed to stack {new_source}.{key}: {msg}") - - if self._thread_pool is None: - for source, (data, timestamp) in sources.items(): - self._handle_source( - source, data, timestamp, new_sources_map, frame_selection_mask, - ignore_stacking - ) - else: - concurrent.futures.wait( - [ - self._thread_pool.submit( - self._handle_source, - source, - data, - timestamp, - new_sources_map, - frame_selection_mask, - ignore_stacking, - ) - for source, (data, timestamp) in sources.items() - ] - ) - sources.update(new_sources_map) + self._stacking_friend.process(sources, self._thread_pool) # karabo output if self.output is not None: @@ -552,6 +147,18 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self.zmq_output.write(source, data, timestamp) self.zmq_output.update() + self._processing_time_tracker.update(default_timer() - ts_start) + self.info["processingTime"] = self._processing_time_tracker.get() * 1000 self.info["sent"] += 1 self.info["trainId"] = train_id self.rate_out.update() + + def _handle_source( + self, + source, + data_hash, + timestamp, + frame_selection_mask, + ): + self._shmem_handler.dereference_shmem_handles(data_hash) + self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask) diff --git a/src/calng/frameselection_utils.py b/src/calng/frameselection_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5531c60cbb8937b71682ec8d22960c344906beb0 --- /dev/null +++ b/src/calng/frameselection_utils.py @@ -0,0 +1,98 @@ +import re + +from karabo.bound import ( + BOOL_ELEMENT, + NODE_ELEMENT, + STRING_ELEMENT, + VECTOR_STRING_ELEMENT, + Hash, +) +import numpy as np + + +class FrameselectionFriend: + @staticmethod + def add_schema(schema): + ( + NODE_ELEMENT(schema) + .key("frameSelector") + .displayedName("Frame selection") + .commit(), + + BOOL_ELEMENT(schema) + .key("frameSelector.enable") + .displayedName("Enabled") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("frameSelector.arbiterSource") + .displayedName("Arbiter source") + .description( + "Source name to pull the frame selection pattern from, must be part of " + "matched sources." + ) + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("frameSelector.dataSourcePattern") + .displayedName("Data source pattern") + .description( + "Source name pattern to apply frame selection to. Should match " + "subset of matched sources." + ) + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + VECTOR_STRING_ELEMENT(schema) + .key("frameSelector.dataKeys") + .displayedName("Data keys") + .description("Keys in data sources to apply frame selection to.") + .assignmentOptional() + .defaultValue([]) + .reconfigurable() + .commit(), + ) + + def __init__(self, config): + self._config = Hash() + self._enabled = False + self._arbiter = "" + self._data_pattern = "" + self._data_keys = [] + self.reconfigure(config) + + def reconfigure(self, config): + self._config.merge(config) + self._enabled = self._config.get("enable") + self._arbiter = self._config.get("arbiterSource") + self._source_pattern = re.compile(self._config.get("dataSourcePattern")) + self._data_keys = list(self._config.get("dataKeys")) + + def get_mask(self, sources): + if self._enabled and self._arbiter in sources: + return np.array( + sources[self._arbiter][0]["data.dataFramePattern"], copy=False + ).astype(np.bool, copy=False) + else: + return None + + def apply_mask(self, source, data_hash, mask): + if mask is not None and self._source_pattern.match(source): + for key in self._data_keys: + if not data_hash.has(key): + continue + if data_hash[key].shape[0] != mask.size: + self._device.log.WARN( + "Frame selection mask does not match the data size" + ) + continue + + data_hash[key] = data_hash[key][mask] diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b41c4de52f0f4f44d0146f44459c0ad0903e8536 --- /dev/null +++ b/src/calng/stacking_utils.py @@ -0,0 +1,400 @@ +import collections +import concurrent.futures +import enum +import re + +from karabo.bound import ( + BOOL_ELEMENT, + INT32_ELEMENT, + STRING_ELEMENT, + TABLE_ELEMENT, + Hash, + Schema, + State, +) +import numpy as np + +from . import utils + + +class GroupType(enum.Enum): + MULTISOURCE = "sources" # same key stacked from multiple sources in new source + MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key + + +class MergeMethod(enum.Enum): + STACK = "stack" + INTERLEAVE = "interleave" + + +def merge_schema(): + schema = Schema() + ( + BOOL_ELEMENT(schema) + .key("select") + .displayedName("Select") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("sourcePattern") + .displayedName("Source pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("keyPattern") + .displayedName("Key pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("replacement") + .displayedName("Replacement") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("groupType") + .displayedName("Group type") + .options(",".join(option.value for option in GroupType)) + .assignmentOptional() + .defaultValue(GroupType.MULTISOURCE.value) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("mergeMethod") + .displayedName("Merge method") + .options(",".join(option.value for option in MergeMethod)) + .assignmentOptional() + .defaultValue(MergeMethod.STACK.value) + .reconfigurable() + .commit(), + + INT32_ELEMENT(schema) + .key("axis") + .displayedName("Axis") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("missingValue") + .displayedName("Missing value default") + .description( + "If some sources are missing within one group in multi-source stacking*, " + "the corresponding parts of the resulting stacked array will be set to " + "this value. Note that if no sources are present, the array is not created " + "at all. This field is a string to allow special values like float nan / " + "inf; it is your responsibility to make sure that data types match (i.e. " + "if this is 'nan', the stacked data better be floats or doubles). *Missing " + "value handling is not yet implementedo for multi-key stacking." + ) + .assignmentOptional() + .defaultValue("0") + .reconfigurable() + .commit(), + ) + + return schema + + +class StackingFriend: + @staticmethod + def add_schema(schema): + ( + TABLE_ELEMENT(schema) + .key("merge") + .displayedName("Array stacking") + .allowedStates(State.PASSIVE) + .description( + "Specify which source(s) or key(s) to stack or interleave." + "When stacking sources, the 'Source pattern' is interpreted as a " + "regular expression and the 'Key pattern' is interpreted as an " + "ordinary string. From all sources matching the source pattern, the " + "data under this key (should be array with same dimensions across all " + "stacked sources) is stacked in the same order as the sources are " + "listed in 'Data sources' and the result is under the same key name in " + "a new source named by 'Replacement'. " + "When stacking keys, both the 'Source pattern' and the 'Key pattern' " + "are regular expressions. Within each source matching the source " + "pattern, all keys matching the key pattern are stacked and the result " + "is put under the key named by 'Replacement'. " + "While source stacking is optimized and can use thread pool, key " + "stacking will iterate over all paths in matched sources and naively " + "call np.stack for each key pattern. In either case, data that is used " + "for stacking is removed from its original location (e.g. key is " + "erased from hash). " + "In both cases, the data can alternatively be interleaved. This is " + "essentially equivalent to stacking except followed by a reshape such " + "that the output is shaped like concatenation." + ) + .setColumns(merge_schema()) + .assignmentOptional() + .defaultValue([]) + .reconfigurable() + .commit(), + ) + + def __init__(self, device, source_config, merge_config): + # used during pre-processing to set up buffers + # (new source, key, method, group size, axis, missing) -> {original sources} + self._new_sources_inputs = collections.defaultdict(set) + + # used for source stacking + # (old source, new source, key) -> (index (maybe index_exp), axis) + self._source_stacking_indices = {} + # list of (old source, key, new source) + self._source_stacking_parts = [] + + # used for key stacking + # list of entries used in _handle_key_stacking_entry + self._key_stacking_entries = [] + + self._merge_config = None + self._source_config = None + self._device = device + self.reconfigure(source_config, merge_config) + + def reconfigure(self, source_config, merge_config): + print("merge_config", type(merge_config)) + print("source_config", type(source_config)) + if source_config is not None: + self._source_config = source_config + if merge_config is not None: + self._merge_config = merge_config + self._source_stacking_indices.clear() + self._source_stacking_parts.clear() + self._new_sources_inputs.clear() + self._key_stacking_entries.clear() + + # not filtering by row["select"] to allow unselected sources to create gaps + source_names = [row["source"].partition("@")[0] for row in self._source_config] + + for row in self._merge_config: + if not row["select"]: + continue + source_re = re.compile(row["sourcePattern"]) + group_type = GroupType(row["groupType"]) + merge_method = MergeMethod(row["mergeMethod"]) + axis = row["axis"] + merge_sources = [ + source for source in source_names if source_re.match(source) + ] + if not merge_sources: + self._device.log.WARN( + f"Group pattern {source_re} did not match any known sources" + ) + continue + if group_type is GroupType.MULTISOURCE: + key = row["keyPattern"] + new_source = row["replacement"] + missing = row["missingValue"] + for i, source in enumerate(merge_sources): + self._source_stacking_indices[(source, new_source, key)] = ( + i + if merge_method is MergeMethod.STACK + else np.index_exp[ # interleaving + slice( + i, + None, + len(merge_sources), + ) + ] + ), axis + self._new_sources_inputs[ + ( + new_source, + key, + merge_method, + len(merge_sources), + axis, + missing, + ) + ].add(source) + self._source_stacking_parts.append((source, key, new_source)) + else: + key_re = re.compile(row["keyPattern"]) + new_key = row["replacement"] + for source in merge_sources: + self._key_stacking_entries.append( + (source, key_re, new_key, merge_method, axis) + ) + + def process(self, sources, thread_pool=None): + merge_buffers = {} + new_source_map = {} + missing_value_defaults = {} + + # prepare for source stacking where sources are present + for ( + new_source, + data_key, + merge_method, + group_size, + axis, + missing_value_str, + ), original_sources in self._new_sources_inputs.items(): + for source in original_sources: + if source not in sources: + continue + data_hash, timestamp = sources[source] + data = data_hash.get(data_key) + if not isinstance(data, np.ndarray): + continue + # didn't hit continue => first source in group present and with data + if merge_method is MergeMethod.STACK: + expected_shape = utils.stacking_buffer_shape( + data.shape, group_size, axis=axis + ) + else: + expected_shape = utils.interleaving_buffer_shape( + data.shape, group_size, axis=axis + ) + merge_buffer = np.empty(expected_shape, dtype=data.dtype) + merge_buffers[(new_source, data_key)] = merge_buffer + if new_source not in new_source_map: + new_source_map[new_source] = (Hash(), timestamp) + new_source_map[new_source][0][data_key] = merge_buffer + try: + missing_value = data.dtype.type(missing_value_str) + except ValueError: + self._device.log.WARN( + f"Invalid missing data value for {new_source}.{data_key} " + f"(tried making a {data.dtype} out of " + f" '{missing_value_str}', using 0" + ) + missing_value = 0 + missing_value_defaults[(new_source, data_key)] = missing_value + # now we have set up the buffer for this new source, so break + break + else: + # in this case: no source (if any) had data_key + self._device.log.WARN( + f"No sources needed for {new_source}.{data_key} were present" + ) + + # now actually do some work + if thread_pool is None: + for thing in self._source_stacking_parts: + self._handle_source_stacking_part( + merge_buffers, missing_value_defaults, sources, *thing + ) + for thing in self._key_stacking_entries: + self._handle_key_stacking_entry(sources, *thing) + else: + awaitables = [] + for thing in self._source_stacking_parts: + awaitables.append( + thread_pool.submit( + self._handle_source_stacking_part, + merge_buffers, + missing_value_defaults, + sources, + *thing, + ) + ) + for thing in self._key_stacking_entries: + awaitables.append( + thread_pool.submit(self._handle_key_stacking_entry, sources, *thing) + ) + concurrent.futures.wait(awaitables) + + # note: new source names for group stacking may not match existing source names + sources.update(new_source_map) + + def _handle_source_stacking_part( + self, + merge_buffers, + missing_values, + actual_sources, + old_source, + data_key, + new_source, + ): + """Helper function used in processing. Note that it should be called for each + (original source, data key, new source) triple expected in the stacking scheme + regardless of whether that original source is present - so it can decide whether + to move data in for stacking, fill missing data in case a source is missing, + or skip in case no buffer was created (none of the necessary sources were + present).""" + + if (new_source, data_key) not in merge_buffers: + # preparatory step did not create buffer + # (i.e. no sources from this group were present) + return + merge_buffer = merge_buffers[(new_source, data_key)] + if old_source in actual_sources: + data_hash = actual_sources[old_source][0] + else: + data_hash = None + if data_hash is None or (merge_data := data_hash.get(data_key)) is None: + merge_data = missing_values[(new_source, data_key)] + merge_index, merge_axis = self._source_stacking_indices[ + (old_source, new_source, data_key) + ] + try: + utils.set_on_axis(merge_buffer, merge_data, merge_index, merge_axis) + if data_hash is not None: + data_hash.erase(data_key) + except Exception as ex: + self._device.log.WARN( + f"Failed to stack {data_key} from {old_source} into {new_source}: {ex}" + ) + utils.set_on_axis( + merge_buffer, + missing_values[(new_source, data_key)], + merge_index, + merge_axis, + ) + + def _handle_key_stacking_entry( + self, sources, source, key_re, new_key, merge_method, axis + ): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + if source not in sources: + self._device.log.WARN(f"Source {source} not found for key stacking") + return + data_hash = sources[source][0] + keys = [key for key in data_hash.paths() if key_re.match(key)] + if not keys: + self._device.log.WARN( + f"Source {source} had no keys matching {key_re} for stacking" + ) + return + try: + # note: maybe we could reuse buffers here, too? + if merge_method is MergeMethod.STACK: + stacked = np.stack([data_hash.get(key) for key in keys], axis=axis) + else: + first = data_hash.get(keys[0]) + stacked = np.empty( + shape=utils.stacking_buffer_shape( + first.shape, len(keys), axis=axis + ), + dtype=first.dtype, + ) + for i, key in enumerate(keys): + utils.set_on_axis( + stacked, + data_hash.get(key), + np.index_exp[slice(i, None, len(keys))], + axis, + ) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in keys: + data_hash.erase(key) + data_hash[new_key] = stacked diff --git a/src/tests/test_stacking.py b/src/tests/test_stacking.py new file mode 100644 index 0000000000000000000000000000000000000000..70b56449413d24c2e64c7114cea2afddfba54f33 --- /dev/null +++ b/src/tests/test_stacking.py @@ -0,0 +1,159 @@ +import concurrent.futures + +from karabo.bound import Hash +import numpy as np +import pytest + +from calng import stacking_utils + + +class NotALog: + def __init__(self, parent): + self.parent = parent + + def WARN(self, s): + print(f"Warning: {s}") + self.parent.warnings.append(s) + + +class NotADevice: + def __init__(self): + self.log = NotALog(self) + self.warnings = [] + + +datas = [np.arange(i * 100, i * 100 + 8).reshape(4, 2) for i in range(3)] + + +class CommonTestFixtureGuy: + @pytest.fixture(params=[False, True]) + def thread_pool(self, request): + if request.param: + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + yield pool + else: + yield None + + +class TestSourceStacking(CommonTestFixtureGuy): + @pytest.fixture + def friend(self): + source_table = [ + { + "select": True, + "source": f"source{i}@device{i}:channel", + } + for i in range(3) + ] + + merge_rules = [ + { + "select": True, + "sourcePattern": "source\\d+", + "keyPattern": "keyToStack", + "replacement": "newSource", + "groupType": "sources", + "mergeMethod": "stack", + "axis": 1, + "missingValue": "0", + } + ] + device = NotADevice() + return stacking_utils.StackingFriend(device, source_table, merge_rules) + + @pytest.fixture + def sources(self): + return { + f"source{i}": (Hash("keyToStack", data), None) + for i, data in enumerate(datas) + } + + def test_simple(self, friend, sources, thread_pool): + friend.process(sources, thread_pool) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + assert stacked.shape == (4, 3, 2) + for i, data in enumerate(datas): + assert np.array_equal(stacked[:, i], data, equal_nan=True) + assert not friend._device.warnings + + def test_missing_source(self, friend, sources, thread_pool): + del sources["source0"] + friend.process(sources, thread_pool) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + assert np.all(stacked[:, 0] == 0) + for i, data in enumerate(datas[1:], start=1): + assert np.array_equal(stacked[:, i], data, equal_nan=True) + assert not friend._device.warnings + + def test_missing_data(self, friend, sources, thread_pool): + sources["source0"][0].erase("keyToStack") + friend.process(sources, thread_pool) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + assert np.all(stacked[:, 0] == 0) + for i, data in enumerate(datas[1:], start=1): + assert np.array_equal(stacked[:, i], data, equal_nan=True) + assert not friend._device.warnings + + def test_source_stacking_no_sources(self, friend, sources, thread_pool): + sources = {} + friend.process(sources, thread_pool) + assert friend._device.warnings + assert not sources + + def test_source_stacking_erroneous_data(self, friend, sources, thread_pool): + sources["source1"][0][ + "keyToStack" + ] = "and now for something completely different" + friend.process(sources, thread_pool) + assert "newSource" in sources + stacked = sources["newSource"][0]["keyToStack"] + for i, data in enumerate(datas): + if i == 1: + assert np.all(stacked[:, i] == 0) + else: + assert np.array_equal(stacked[:, i], data, equal_nan=True) + assert friend._device.warnings + + +class TestKeyStacking(CommonTestFixtureGuy): + @pytest.fixture + def friend(self): + source_table = [ + { + "select": True, + "source": "source@device:channel", + } + ] + + merge_rules = [ + { + "select": True, + "sourcePattern": "source", + "keyPattern": "key\\d+", + "replacement": "newKey", + "groupType": "keys", + "mergeMethod": "stack", + "axis": 1, + "missingValue": "0", + } + ] + device = NotADevice() + return stacking_utils.StackingFriend(device, source_table, merge_rules) + + @pytest.fixture + def sources(self): + h = Hash() + for i, data in enumerate(datas): + h[f"key{i}"] = data + return {"source": (h, None)} + + def test_simple(self, friend, sources, thread_pool): + friend.process(sources, thread_pool) + assert sources["source"][0].has("newKey") + stacked = sources["source"][0]["newKey"] + for i, data in enumerate(datas): + assert np.array_equal(stacked[:, i], data, equal_nan=True) + assert not friend._device.warnings