From 99de26dda08d02c1c2b5efc1f1f462cfc4f0ce63 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Fri, 3 Jun 2022 12:54:50 +0200
Subject: [PATCH] Squashed commit of the following:

commit beb32df59a994f132618320d0d30be6d6ed8b8d4
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 13 13:25:10 2022 +0200

    Fix reconfigure bug and prevent sources from being sorted

commit a79b501a1eb4b37e9798a71cbb1659ebeb014165
Author: David Hammer <dhammer@mailbox.org>
Date:   Wed May 11 12:35:09 2022 +0200

    Update description of stacking patterns

commit 426d707ba866e333fdd7359c45f3c4433a0df79c
Merge: 07b6c97 e1b0b8d
Author: David Hammer <dhammer@mailbox.org>
Date:   Mon May 9 12:50:30 2022 +0200

    Merge branch 'stacking-shmem-matcher' of ssh://git.xfel.eu:10022/karaboDevices/calng into stacking-shmem-matcher

commit 07b6c9726e54461e50d48f03fd9833021686ecd3
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 14:53:31 2022 +0200

    Adding option to skip Karabo channel output

commit 02a4ccd97667c07abd1a57eb1d0b50db74c99b83
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 13:27:59 2022 +0200

    Allow stacking on other axes, fix thread pool

commit b2b5aca1199f372e63b116100cbd4573134b737d
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 16:30:14 2022 +0200

    Allow thread pool usage

commit f8f380ccf1e554b2a9ee468809f3d7d82dadac6b
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 14:49:01 2022 +0200

    Use sources list to get source order, cache buffers

commit ef8982115cc714011cebfbff062bb02ee80a4c27
Author: David Hammer <dhammer@mailbox.org>
Date:   Tue May 3 15:49:21 2022 +0200

    Adding Philipp's proposed changes to TrainMatcher for merging

commit e1b0b8d18d5a2a87b51b4f0b9c1593ce0cd1906f
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 14:53:31 2022 +0200

    Adding option to skip Karabo channel output

commit bbb7cee4f11415d1469c343ddde21b91a14e9248
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 13:27:59 2022 +0200

    Allow stacking on other axes, fix thread pool

commit b01015c4527c8ef365b92154def6f92690b5cad1
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 16:30:14 2022 +0200

    Allow thread pool usage

commit e7f96bae6b55acf37769b0ab0a88ae9f6791f4c0
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 14:49:01 2022 +0200

    Use sources list to get source order, cache buffers

commit 193264ed1165d5e54c3d88a0d62f32a0f24dc085
Author: David Hammer <dhammer@mailbox.org>
Date:   Tue May 3 15:49:21 2022 +0200

    Adding Philipp's proposed changes to TrainMatcher for merging
---
 src/calng/ShmemTrainMatcher.py | 378 +++++++++++++++++++++++++++++++--
 src/calng/shmem_utils.py       |  12 ++
 src/calng/utils.py             |  25 +++
 3 files changed, 402 insertions(+), 13 deletions(-)

diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index 3d481ccd..a3f919ee 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -1,26 +1,378 @@
-from karabo.bound import KARABO_CLASSINFO
+import concurrent.futures
+import enum
+import re
+
+import numpy as np
+from karabo.bound import (
+    BOOL_ELEMENT,
+    INT32_ELEMENT,
+    KARABO_CLASSINFO,
+    STRING_ELEMENT,
+    TABLE_ELEMENT,
+    ChannelMetaData,
+    Hash,
+    Schema,
+    State,
+    VectorHash,
+)
 from TrainMatcher import TrainMatcher
 
-from . import shmem_utils
+from . import shmem_utils, utils
 from ._version import version as deviceVersion
 
 
+class MergeGroupType(enum.Enum):
+    MULTISOURCE = "sources"  # same key stacked from multiple sources in new source
+    MULTIKEY = "keys"  # multiple keys within each matched source is stacked in new key
+
+
+def merge_schema():
+    schema = Schema()
+    (
+        BOOL_ELEMENT(schema)
+        .key("select")
+        .displayedName("Select")
+        .assignmentOptional()
+        .defaultValue(False)
+        .reconfigurable()
+        .commit(),
+
+        STRING_ELEMENT(schema)
+        .key("sourcePattern")
+        .displayedName("Source pattern")
+        .assignmentOptional()
+        .defaultValue("")
+        .reconfigurable()
+        .commit(),
+
+        STRING_ELEMENT(schema)
+        .key("keyPattern")
+        .displayedName("Key pattern")
+        .assignmentOptional()
+        .defaultValue("")
+        .reconfigurable()
+        .commit(),
+
+        STRING_ELEMENT(schema)
+        .key("replacement")
+        .displayedName("Replacement")
+        .assignmentOptional()
+        .defaultValue("")
+        .reconfigurable()
+        .commit(),
+
+        STRING_ELEMENT(schema)
+        .key("type")
+        .displayedName("Group type")
+        .options(",".join(option.value for option in MergeGroupType))
+        .assignmentOptional()
+        .defaultValue(MergeGroupType.MULTISOURCE.value)
+        .reconfigurable()
+        .commit(),
+
+        INT32_ELEMENT(schema)
+        .key("stackingAxis")
+        .displayedName("Axis")
+        .assignmentOptional()
+        .defaultValue(0)
+        .reconfigurable()
+        .commit(),
+    )
+
+    return schema
+
+
 @KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion)
 class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
+    @staticmethod
+    def expectedParameters(expected):
+        (
+            TABLE_ELEMENT(expected)
+            .key("merge")
+            .displayedName("Array stacking")
+            .allowedStates(State.PASSIVE)
+            .description(
+                "Specify which source(s) or key(s) to stack."
+                "When stacking sources, the 'Source pattern' is interpreted as a "
+                "regular expression and the 'Key pattern' is interpreted as an "
+                "ordinary string. From all sources matching the source pattern, the "
+                "data under this key (should be array with same dimensions across all "
+                "stacked sources) is stacked in the same order as the sources are "
+                "listed in 'Data sources' and the result is under the same key name in "
+                "a new source named by 'Replacement'. "
+                "When stacking keys, both the 'Source pattern' and the 'Key pattern' "
+                "are regular expressions. Within each source matching the source "
+                "pattern, all keys matching the key pattern are stacked and the result "
+                "is put under the key named by 'Replacement'. "
+                "While source stacking is optimized and can use thread pool, key "
+                "stacking will iterate over all paths in matched sources and naively "
+                "call np.stack for each key pattern. In either case, data that is used "
+                "for stacking is removed from its original location (e.g. key is "
+                "erased from hash)."
+            )
+            .setColumns(merge_schema())
+            .assignmentOptional()
+            .defaultValue([])
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(expected)
+            .key("useThreadPool")
+            .displayedName("Use thread pool")
+            .allowedStates(State.PASSIVE)
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(expected)
+            .key("enableKaraboOutput")
+            .displayedName("Enable Karabo channel")
+            .allowedStates(State.PASSIVE)
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
+
     def initialization(self):
+        self._stacking_buffers = {}
+        self._source_stacking_indices = {}
+        self._source_stacking_sources = {}
+        self._key_stacking_sources = {}
+        self._have_prepared_merge_groups = False
+        self._prepare_merge_groups()
         super().initialization()
         self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
+        if self.get("useThreadPool"):
+            self._thread_pool = concurrent.futures.ThreadPoolExecutor()
+        else:
+            self._thread_pool = None
+
+        if not self.get("enableKaraboOutput"):
+            # it is set already by super by default, so only need to turn off
+            self.output = None
+
+    def preReconfigure(self, conf):
+        super().preReconfigure(conf)
+        if conf.has("merge") or conf.has("sources"):
+            self._have_prepared_merge_groups = False
+            # re-prepare in postReconfigure after sources *and* merge are in self
+        if conf.has("useThreadPool"):
+            if self._thread_pool is not None:
+                self._thread_pool.shutdown()
+                self._thread_pool = None
+            if conf["useThreadPool"]:
+                self._thread_pool = concurrent.futures.ThreadPoolExecutor()
+        if conf.has("enableKaraboOutput"):
+            if conf["enableKaraboOutput"]:
+                self.output = self._ss.getOutputChannel("output")
+            else:
+                self.output = None
+
+    def postReconfigure(self):
+        super().postReconfigure()
+        if not self._have_prepared_merge_groups:
+            self._prepare_merge_groups()
+
+    def _prepare_merge_groups(self):
+        source_group_patterns = []
+        key_group_patterns = []
+        # split by type, prepare regexes
+        for row in self.get("merge"):
+            if not row["select"]:
+                continue
+            group_type = MergeGroupType(row["type"])
+            if group_type is MergeGroupType.MULTISOURCE:
+                source_group_patterns.append(
+                    (
+                        re.compile(row["sourcePattern"]),
+                        row["keyPattern"],
+                        row["replacement"],
+                        row["stackingAxis"],
+                    )
+                )
+            else:
+                key_group_patterns.append(
+                    (
+                        re.compile(row["sourcePattern"]),
+                        re.compile(row["keyPattern"]),
+                        row["replacement"],
+                    )
+                )
+
+        # not filtering by row["select"] to allow unselected sources to create gaps
+        source_names = [row["source"].partition("@")[0] for row in self.get("sources")]
+
+        self._stacking_buffers.clear()
+        # handle source stacking groups
+        self._source_stacking_indices.clear()
+        self._source_stacking_sources.clear()
+        for source_re, key, new_source, stack_axis in source_group_patterns:
+            merge_sources = [
+                source for source in source_names if source_re.match(source)
+            ]
+            if len(merge_sources) == 0:
+                self.log.WARN(
+                    f"Group merge pattern {source_re} did not match any known sources"
+                )
+                continue
+            for (i, source) in enumerate(merge_sources):
+                self._source_stacking_sources.setdefault(source, []).append(
+                    (key, new_source, stack_axis)
+                )
+                self._source_stacking_indices[(source, new_source, key)] = i
+
+        # handle key stacking groups
+        self._key_stacking_sources.clear()
+        for source_re, key_re, new_key in key_group_patterns:
+            for source in source_names:
+                # TODO: maybe also warn if no matches here?
+                if not source_re.match(source):
+                    continue
+                self._key_stacking_sources.setdefault(source, []).append(
+                    (key_re, new_key)
+                )
+        self._have_prepared_merge_groups = True
+
+    def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
+        # TODO: handle ValueError for max of empty sequence
+        stack_num = (
+            max(
+                index
+                for (
+                    _,
+                    new_source_,
+                    key_,
+                ), index in self._source_stacking_indices.items()
+                if new_source_ == new_source and key_ == key
+            )
+            + 1
+        )
+        self._stacking_buffers[(new_source, key)] = np.empty(
+            shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis),
+            dtype=dtype,
+        )
+
+    def _handle_source(self, source, data_hash, timestamp, new_sources_map):
+        # dereference calng shmem handles
+        self._shmem_handler.dereference_shmem_handles(data_hash)
+
+        # stack across sources (many sources, same key)
+        # could probably save ~100 ns by "if ... in" instead of get
+        for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get(
+            source, ()
+        ):
+            this_data = data_hash.get(stack_key)
+            try:
+                this_buffer = self._stacking_buffers[(new_source, stack_key)]
+                stack_index = self._source_stacking_indices[
+                    (source, new_source, stack_key)
+                ]
+                utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
+            except (ValueError, IndexError, KeyError):
+                # ValueError: wrong shape (react to this_data.shape)
+                # KeyError: buffer doesn't exist yet
+                # IndexError: new source? (buffer not long enough)
+                # either way, create appropriate buffer now
+                # TODO: complain if shape varies between sources within train
+                self._update_stacking_buffer(
+                    new_source,
+                    stack_key,
+                    this_data.shape,
+                    axis=stack_axis,
+                    dtype=this_data.dtype,
+                )
+                # and then try again
+                this_buffer = self._stacking_buffers[(new_source, stack_key)]
+                stack_index = self._source_stacking_indices[
+                    (source, new_source, stack_key)
+                ]
+                utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
+            # TODO: zero out unfilled buffer entries
+            data_hash.erase(stack_key)
+
+            if new_source not in new_sources_map:
+                new_sources_map[new_source] = (Hash(), timestamp)
+            new_source_hash = new_sources_map[new_source][0]
+            if not new_source_hash.has(stack_key):
+                new_source_hash[stack_key] = this_buffer
+
+        # stack keys (multiple keys within this source)
+        for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
+            # note: please no overlap between different key_re
+            # note: if later key_re match earlier new_key, this gets spicy
+            stack_keys = [key for key in data_hash.paths() if key_re.match(key)]
+            try:
+                # TODO: consider reusing buffers here, too
+                stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0)
+            except Exception as e:
+                self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
+            else:
+                for key in stack_keys:
+                    data_hash.erase(key)
+                data_hash[new_key] = stacked
 
     def on_matched_data(self, train_id, sources):
+        new_sources_map = {}
+        if self._thread_pool is None:
+            for source, (data, timestamp) in sources.items():
+                self._handle_source(source, data, timestamp, new_sources_map)
+        else:
+            concurrent.futures.wait(
+                [
+                    self._thread_pool.submit(
+                        self._handle_source, source, data, timestamp, new_sources_map
+                    )
+                    for source, (data, timestamp) in sources.items()
+                ]
+            )
+        sources.update(new_sources_map)
+
+        # karabo output
+        if self.output is not None:
+            for source, (data, timestamp) in sources.items():
+                self.output.write(data, ChannelMetaData(source, timestamp))
+            self.output.update()
+
+        # karabo bridge output
         for source, (data, timestamp) in sources.items():
-            if data.has("calngShmemPaths"):
-                shmem_paths = list(data["calngShmemPaths"])
-                data.erase("calngShmemPaths")
-                for shmem_path in shmem_paths:
-                    if not data.has(shmem_path):
-                        self.log.INFO(f"Hash from {source} did not have {shmem_path}")
-                        continue
-                    dereferenced = self._shmem_handler.get(data[shmem_path])
-                    data[shmem_path] = dereferenced
-
-        super().on_matched_data(train_id, sources)
+            self.zmq_output.write(source, data, timestamp)
+        self.zmq_output.update()
+
+        self.info["sent"] += 1
+        self.info["trainId"] = train_id
+        self.rate_out.update()
+
+    def _maybe_connect_data(self, conf, update=False, state=None):
+        """Temporary override on _maybe_connect_data to avoid sorting sources list (we
+        need it for stacking order)"""
+        if self["state"] not in (State.CHANGING, State.ACTIVE):
+            return
+
+        last_state = self["state"]
+        self.updateState(State.CHANGING)
+
+        # unwatch removed sources
+        def src_names(c):
+            # do not assign a lambda expression, use a def
+            return {s["source"] for s in c["sources"]}
+
+        for source in src_names(self).difference(src_names(conf)):
+            self.monitor.unwatch_source(source)
+
+        new_conf = VectorHash()
+        for src in conf["sources"]:
+            source = src["source"]
+            if src["select"]:
+                src["status"] = self.monitor.watch_source(source, src["offset"])
+            else:
+                self.monitor.unwatch_source(source)
+                src["status"] = ""
+            new_conf.append(src)
+
+        if update:
+            self.set("sources", new_conf)
+        else:
+            conf["sources"] = new_conf
+        self.updateState(state or last_state)
diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py
index 4c4838e2..e02e3fd0 100644
--- a/src/calng/shmem_utils.py
+++ b/src/calng/shmem_utils.py
@@ -46,6 +46,18 @@ class ShmemCircularBufferReceiver:
 
         return ary[index]
 
+    def dereference_shmem_handles(self, data_hash):
+        if data_hash.has("calngShmemPaths"):
+            shmem_paths = list(data_hash["calngShmemPaths"])
+            data_hash.erase("calngShmemPaths")
+            for shmem_path in shmem_paths:
+                if not data_hash.has(shmem_path):
+                    # TODO: proper warnings
+                    print(f"Warning: hash did not contain {shmem_path}")
+                    continue
+                dereferenced = self.get(data_hash[shmem_path])
+                data_hash[shmem_path] = dereferenced
+
 
 class ShmemCircularBuffer:
     """Convenience wrapper around posixshmem-backed ndarray buffers
diff --git a/src/calng/utils.py b/src/calng/utils.py
index 9737de73..54ed40fe 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
     return tuple(axis_order[axis] for axis in axes_out)
 
 
+def stacking_buffer_shape(array_shape, stack_num, axis=0):
+    """Figures out the shape you would need for np.stack"""
+    if axis > len(array_shape) or axis < -len(array_shape) - 1:
+        # complain when np.stack would
+        raise np.AxisError(
+            f"axis {axis} is out of bounds "
+            f"for array of dimension {len(array_shape) + 1}"
+        )
+    if axis < 0:
+        axis += len(array_shape) + 1
+    return array_shape[:axis] + (stack_num,) + array_shape[axis:]
+
+
+def set_on_axis(array, vals, index, axis):
+    """set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x"""
+    if axis >= len(array):
+        raise IndexError(
+            f"too many indices for array: array is {len(array.shape)}-dimensional, "
+            f"but {axis+1} were indexed"
+        )
+    # TODO: maybe support negative axis with wraparound
+    indices = np.index_exp[:] * axis + np.index_exp[index]
+    array[indices] = vals
+
+
 _np_typechar_to_c_typestring = {
     "?": "bool",
     "B": "unsigned char",
-- 
GitLab