From 99de26dda08d02c1c2b5efc1f1f462cfc4f0ce63 Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Fri, 3 Jun 2022 12:54:50 +0200 Subject: [PATCH] Squashed commit of the following: commit beb32df59a994f132618320d0d30be6d6ed8b8d4 Author: David Hammer <dhammer@mailbox.org> Date: Fri May 13 13:25:10 2022 +0200 Fix reconfigure bug and prevent sources from being sorted commit a79b501a1eb4b37e9798a71cbb1659ebeb014165 Author: David Hammer <dhammer@mailbox.org> Date: Wed May 11 12:35:09 2022 +0200 Update description of stacking patterns commit 426d707ba866e333fdd7359c45f3c4433a0df79c Merge: 07b6c97 e1b0b8d Author: David Hammer <dhammer@mailbox.org> Date: Mon May 9 12:50:30 2022 +0200 Merge branch 'stacking-shmem-matcher' of ssh://git.xfel.eu:10022/karaboDevices/calng into stacking-shmem-matcher commit 07b6c9726e54461e50d48f03fd9833021686ecd3 Author: David Hammer <dhammer@mailbox.org> Date: Fri May 6 14:53:31 2022 +0200 Adding option to skip Karabo channel output commit 02a4ccd97667c07abd1a57eb1d0b50db74c99b83 Author: David Hammer <dhammer@mailbox.org> Date: Fri May 6 13:27:59 2022 +0200 Allow stacking on other axes, fix thread pool commit b2b5aca1199f372e63b116100cbd4573134b737d Author: David Hammer <dhammer@mailbox.org> Date: Thu May 5 16:30:14 2022 +0200 Allow thread pool usage commit f8f380ccf1e554b2a9ee468809f3d7d82dadac6b Author: David Hammer <dhammer@mailbox.org> Date: Thu May 5 14:49:01 2022 +0200 Use sources list to get source order, cache buffers commit ef8982115cc714011cebfbff062bb02ee80a4c27 Author: David Hammer <dhammer@mailbox.org> Date: Tue May 3 15:49:21 2022 +0200 Adding Philipp's proposed changes to TrainMatcher for merging commit e1b0b8d18d5a2a87b51b4f0b9c1593ce0cd1906f Author: David Hammer <dhammer@mailbox.org> Date: Fri May 6 14:53:31 2022 +0200 Adding option to skip Karabo channel output commit bbb7cee4f11415d1469c343ddde21b91a14e9248 Author: David Hammer <dhammer@mailbox.org> Date: Fri May 6 13:27:59 2022 +0200 Allow stacking on other axes, fix thread pool commit b01015c4527c8ef365b92154def6f92690b5cad1 Author: David Hammer <dhammer@mailbox.org> Date: Thu May 5 16:30:14 2022 +0200 Allow thread pool usage commit e7f96bae6b55acf37769b0ab0a88ae9f6791f4c0 Author: David Hammer <dhammer@mailbox.org> Date: Thu May 5 14:49:01 2022 +0200 Use sources list to get source order, cache buffers commit 193264ed1165d5e54c3d88a0d62f32a0f24dc085 Author: David Hammer <dhammer@mailbox.org> Date: Tue May 3 15:49:21 2022 +0200 Adding Philipp's proposed changes to TrainMatcher for merging --- src/calng/ShmemTrainMatcher.py | 378 +++++++++++++++++++++++++++++++-- src/calng/shmem_utils.py | 12 ++ src/calng/utils.py | 25 +++ 3 files changed, 402 insertions(+), 13 deletions(-) diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 3d481ccd..a3f919ee 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -1,26 +1,378 @@ -from karabo.bound import KARABO_CLASSINFO +import concurrent.futures +import enum +import re + +import numpy as np +from karabo.bound import ( + BOOL_ELEMENT, + INT32_ELEMENT, + KARABO_CLASSINFO, + STRING_ELEMENT, + TABLE_ELEMENT, + ChannelMetaData, + Hash, + Schema, + State, + VectorHash, +) from TrainMatcher import TrainMatcher -from . import shmem_utils +from . import shmem_utils, utils from ._version import version as deviceVersion +class MergeGroupType(enum.Enum): + MULTISOURCE = "sources" # same key stacked from multiple sources in new source + MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key + + +def merge_schema(): + schema = Schema() + ( + BOOL_ELEMENT(schema) + .key("select") + .displayedName("Select") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("sourcePattern") + .displayedName("Source pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("keyPattern") + .displayedName("Key pattern") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("replacement") + .displayedName("Replacement") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("type") + .displayedName("Group type") + .options(",".join(option.value for option in MergeGroupType)) + .assignmentOptional() + .defaultValue(MergeGroupType.MULTISOURCE.value) + .reconfigurable() + .commit(), + + INT32_ELEMENT(schema) + .key("stackingAxis") + .displayedName("Axis") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + ) + + return schema + + @KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion) class ShmemTrainMatcher(TrainMatcher.TrainMatcher): + @staticmethod + def expectedParameters(expected): + ( + TABLE_ELEMENT(expected) + .key("merge") + .displayedName("Array stacking") + .allowedStates(State.PASSIVE) + .description( + "Specify which source(s) or key(s) to stack." + "When stacking sources, the 'Source pattern' is interpreted as a " + "regular expression and the 'Key pattern' is interpreted as an " + "ordinary string. From all sources matching the source pattern, the " + "data under this key (should be array with same dimensions across all " + "stacked sources) is stacked in the same order as the sources are " + "listed in 'Data sources' and the result is under the same key name in " + "a new source named by 'Replacement'. " + "When stacking keys, both the 'Source pattern' and the 'Key pattern' " + "are regular expressions. Within each source matching the source " + "pattern, all keys matching the key pattern are stacked and the result " + "is put under the key named by 'Replacement'. " + "While source stacking is optimized and can use thread pool, key " + "stacking will iterate over all paths in matched sources and naively " + "call np.stack for each key pattern. In either case, data that is used " + "for stacking is removed from its original location (e.g. key is " + "erased from hash)." + ) + .setColumns(merge_schema()) + .assignmentOptional() + .defaultValue([]) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(expected) + .key("useThreadPool") + .displayedName("Use thread pool") + .allowedStates(State.PASSIVE) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(expected) + .key("enableKaraboOutput") + .displayedName("Enable Karabo channel") + .allowedStates(State.PASSIVE) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) + def initialization(self): + self._stacking_buffers = {} + self._source_stacking_indices = {} + self._source_stacking_sources = {} + self._key_stacking_sources = {} + self._have_prepared_merge_groups = False + self._prepare_merge_groups() super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() + if self.get("useThreadPool"): + self._thread_pool = concurrent.futures.ThreadPoolExecutor() + else: + self._thread_pool = None + + if not self.get("enableKaraboOutput"): + # it is set already by super by default, so only need to turn off + self.output = None + + def preReconfigure(self, conf): + super().preReconfigure(conf) + if conf.has("merge") or conf.has("sources"): + self._have_prepared_merge_groups = False + # re-prepare in postReconfigure after sources *and* merge are in self + if conf.has("useThreadPool"): + if self._thread_pool is not None: + self._thread_pool.shutdown() + self._thread_pool = None + if conf["useThreadPool"]: + self._thread_pool = concurrent.futures.ThreadPoolExecutor() + if conf.has("enableKaraboOutput"): + if conf["enableKaraboOutput"]: + self.output = self._ss.getOutputChannel("output") + else: + self.output = None + + def postReconfigure(self): + super().postReconfigure() + if not self._have_prepared_merge_groups: + self._prepare_merge_groups() + + def _prepare_merge_groups(self): + source_group_patterns = [] + key_group_patterns = [] + # split by type, prepare regexes + for row in self.get("merge"): + if not row["select"]: + continue + group_type = MergeGroupType(row["type"]) + if group_type is MergeGroupType.MULTISOURCE: + source_group_patterns.append( + ( + re.compile(row["sourcePattern"]), + row["keyPattern"], + row["replacement"], + row["stackingAxis"], + ) + ) + else: + key_group_patterns.append( + ( + re.compile(row["sourcePattern"]), + re.compile(row["keyPattern"]), + row["replacement"], + ) + ) + + # not filtering by row["select"] to allow unselected sources to create gaps + source_names = [row["source"].partition("@")[0] for row in self.get("sources")] + + self._stacking_buffers.clear() + # handle source stacking groups + self._source_stacking_indices.clear() + self._source_stacking_sources.clear() + for source_re, key, new_source, stack_axis in source_group_patterns: + merge_sources = [ + source for source in source_names if source_re.match(source) + ] + if len(merge_sources) == 0: + self.log.WARN( + f"Group merge pattern {source_re} did not match any known sources" + ) + continue + for (i, source) in enumerate(merge_sources): + self._source_stacking_sources.setdefault(source, []).append( + (key, new_source, stack_axis) + ) + self._source_stacking_indices[(source, new_source, key)] = i + + # handle key stacking groups + self._key_stacking_sources.clear() + for source_re, key_re, new_key in key_group_patterns: + for source in source_names: + # TODO: maybe also warn if no matches here? + if not source_re.match(source): + continue + self._key_stacking_sources.setdefault(source, []).append( + (key_re, new_key) + ) + self._have_prepared_merge_groups = True + + def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype): + # TODO: handle ValueError for max of empty sequence + stack_num = ( + max( + index + for ( + _, + new_source_, + key_, + ), index in self._source_stacking_indices.items() + if new_source_ == new_source and key_ == key + ) + + 1 + ) + self._stacking_buffers[(new_source, key)] = np.empty( + shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis), + dtype=dtype, + ) + + def _handle_source(self, source, data_hash, timestamp, new_sources_map): + # dereference calng shmem handles + self._shmem_handler.dereference_shmem_handles(data_hash) + + # stack across sources (many sources, same key) + # could probably save ~100 ns by "if ... in" instead of get + for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get( + source, () + ): + this_data = data_hash.get(stack_key) + try: + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[ + (source, new_source, stack_key) + ] + utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis) + except (ValueError, IndexError, KeyError): + # ValueError: wrong shape (react to this_data.shape) + # KeyError: buffer doesn't exist yet + # IndexError: new source? (buffer not long enough) + # either way, create appropriate buffer now + # TODO: complain if shape varies between sources within train + self._update_stacking_buffer( + new_source, + stack_key, + this_data.shape, + axis=stack_axis, + dtype=this_data.dtype, + ) + # and then try again + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[ + (source, new_source, stack_key) + ] + utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis) + # TODO: zero out unfilled buffer entries + data_hash.erase(stack_key) + + if new_source not in new_sources_map: + new_sources_map[new_source] = (Hash(), timestamp) + new_source_hash = new_sources_map[new_source][0] + if not new_source_hash.has(stack_key): + new_source_hash[stack_key] = this_buffer + + # stack keys (multiple keys within this source) + for (key_re, new_key) in self._key_stacking_sources.get(source, ()): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + stack_keys = [key for key in data_hash.paths() if key_re.match(key)] + try: + # TODO: consider reusing buffers here, too + stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in stack_keys: + data_hash.erase(key) + data_hash[new_key] = stacked def on_matched_data(self, train_id, sources): + new_sources_map = {} + if self._thread_pool is None: + for source, (data, timestamp) in sources.items(): + self._handle_source(source, data, timestamp, new_sources_map) + else: + concurrent.futures.wait( + [ + self._thread_pool.submit( + self._handle_source, source, data, timestamp, new_sources_map + ) + for source, (data, timestamp) in sources.items() + ] + ) + sources.update(new_sources_map) + + # karabo output + if self.output is not None: + for source, (data, timestamp) in sources.items(): + self.output.write(data, ChannelMetaData(source, timestamp)) + self.output.update() + + # karabo bridge output for source, (data, timestamp) in sources.items(): - if data.has("calngShmemPaths"): - shmem_paths = list(data["calngShmemPaths"]) - data.erase("calngShmemPaths") - for shmem_path in shmem_paths: - if not data.has(shmem_path): - self.log.INFO(f"Hash from {source} did not have {shmem_path}") - continue - dereferenced = self._shmem_handler.get(data[shmem_path]) - data[shmem_path] = dereferenced - - super().on_matched_data(train_id, sources) + self.zmq_output.write(source, data, timestamp) + self.zmq_output.update() + + self.info["sent"] += 1 + self.info["trainId"] = train_id + self.rate_out.update() + + def _maybe_connect_data(self, conf, update=False, state=None): + """Temporary override on _maybe_connect_data to avoid sorting sources list (we + need it for stacking order)""" + if self["state"] not in (State.CHANGING, State.ACTIVE): + return + + last_state = self["state"] + self.updateState(State.CHANGING) + + # unwatch removed sources + def src_names(c): + # do not assign a lambda expression, use a def + return {s["source"] for s in c["sources"]} + + for source in src_names(self).difference(src_names(conf)): + self.monitor.unwatch_source(source) + + new_conf = VectorHash() + for src in conf["sources"]: + source = src["source"] + if src["select"]: + src["status"] = self.monitor.watch_source(source, src["offset"]) + else: + self.monitor.unwatch_source(source) + src["status"] = "" + new_conf.append(src) + + if update: + self.set("sources", new_conf) + else: + conf["sources"] = new_conf + self.updateState(state or last_state) diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py index 4c4838e2..e02e3fd0 100644 --- a/src/calng/shmem_utils.py +++ b/src/calng/shmem_utils.py @@ -46,6 +46,18 @@ class ShmemCircularBufferReceiver: return ary[index] + def dereference_shmem_handles(self, data_hash): + if data_hash.has("calngShmemPaths"): + shmem_paths = list(data_hash["calngShmemPaths"]) + data_hash.erase("calngShmemPaths") + for shmem_path in shmem_paths: + if not data_hash.has(shmem_path): + # TODO: proper warnings + print(f"Warning: hash did not contain {shmem_path}") + continue + dereferenced = self.get(data_hash[shmem_path]) + data_hash[shmem_path] = dereferenced + class ShmemCircularBuffer: """Convenience wrapper around posixshmem-backed ndarray buffers diff --git a/src/calng/utils.py b/src/calng/utils.py index 9737de73..54ed40fe 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out): return tuple(axis_order[axis] for axis in axes_out) +def stacking_buffer_shape(array_shape, stack_num, axis=0): + """Figures out the shape you would need for np.stack""" + if axis > len(array_shape) or axis < -len(array_shape) - 1: + # complain when np.stack would + raise np.AxisError( + f"axis {axis} is out of bounds " + f"for array of dimension {len(array_shape) + 1}" + ) + if axis < 0: + axis += len(array_shape) + 1 + return array_shape[:axis] + (stack_num,) + array_shape[axis:] + + +def set_on_axis(array, vals, index, axis): + """set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x""" + if axis >= len(array): + raise IndexError( + f"too many indices for array: array is {len(array.shape)}-dimensional, " + f"but {axis+1} were indexed" + ) + # TODO: maybe support negative axis with wraparound + indices = np.index_exp[:] * axis + np.index_exp[index] + array[indices] = vals + + _np_typechar_to_c_typestring = { "?": "bool", "B": "unsigned char", -- GitLab