From bbb7cee4f11415d1469c343ddde21b91a14e9248 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Fri, 6 May 2022 13:27:59 +0200
Subject: [PATCH] Allow stacking on other axes, fix thread pool

---
 src/calng/ShmemTrainMatcher.py | 103 ++++++++++++++++++++++-----------
 src/calng/utils.py             |  25 ++++++++
 2 files changed, 95 insertions(+), 33 deletions(-)

diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index 788f328b..780be449 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -5,6 +5,7 @@ import re
 import numpy as np
 from karabo.bound import (
     BOOL_ELEMENT,
+    INT32_ELEMENT,
     KARABO_CLASSINFO,
     STRING_ELEMENT,
     TABLE_ELEMENT,
@@ -14,7 +15,7 @@ from karabo.bound import (
 )
 from TrainMatcher import TrainMatcher
 
-from . import shmem_utils
+from . import shmem_utils, utils
 from ._version import version as deviceVersion
 
 
@@ -35,7 +36,7 @@ def merge_schema():
         .commit(),
 
         STRING_ELEMENT(schema)
-        .key("source_pattern")
+        .key("sourcePattern")
         .displayedName("Source pattern")
         .assignmentOptional()
         .defaultValue("")
@@ -43,7 +44,7 @@ def merge_schema():
         .commit(),
 
         STRING_ELEMENT(schema)
-        .key("key_pattern")
+        .key("keyPattern")
         .displayedName("Key pattern")
         .assignmentOptional()
         .defaultValue("")
@@ -66,6 +67,14 @@ def merge_schema():
         .defaultValue(MergeGroupType.MULTISOURCE.value)
         .reconfigurable()
         .commit(),
+
+        INT32_ELEMENT(schema)
+        .key("stackingAxis")
+        .displayedName("Axis")
+        .assignmentOptional()
+        .defaultValue(0)
+        .reconfigurable()
+        .commit(),
     )
 
     return schema
@@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         (
             TABLE_ELEMENT(expected)
             .key("merge")
-            .displayedName("Array merging")
+            .displayedName("Array stacking")
             .allowedStates(State.PASSIVE)
             .description(
                 "List source or key patterns to merge their data arrays, e.g. to "
@@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
 
             BOOL_ELEMENT(expected)
             .key("useThreadPool")
+            .displayedName("Use thread pool")
+            .allowedStates(State.PASSIVE)
             .assignmentOptional()
             .defaultValue(False)
+            .reconfigurable()
             .commit(),
         )
 
@@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         super().preReconfigure(conf)
         if conf.has("merge") or conf.has("sources"):
             self._prepare_merge_groups(conf["merge"])
+        if conf.has("useThreadPool"):
+            if self._thread_pool is not None:
+                self._thread_pool.shutdown()
+                self._thread_pool = None
+            if conf["useThreadPool"]:
+                self._thread_pool = concurrent.futures.ThreadPoolExecutor()
 
     def _prepare_merge_groups(self, merge):
         source_group_patterns = []
@@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
             if group_type is MergeGroupType.MULTISOURCE:
                 source_group_patterns.append(
                     (
-                        re.compile(row["source_pattern"]),
-                        row["key_pattern"],
+                        re.compile(row["sourcePattern"]),
+                        row["keyPattern"],
                         row["replacement"],
+                        row["stackingAxis"],
                     )
                 )
             else:
                 key_group_patterns.append(
                     (
-                        re.compile(row["source_pattern"]),
-                        re.compile(row["key_pattern"]),
+                        re.compile(row["sourcePattern"]),
+                        re.compile(row["keyPattern"]),
                         row["replacement"],
                     )
                 )
@@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         # handle source stacking groups
         self._source_stacking_indices.clear()
         self._source_stacking_sources.clear()
-        for source_re, key, new_source in source_group_patterns:
+        for source_re, key, new_source, stack_axis in source_group_patterns:
             merge_sources = [
                 source for source in source_names if source_re.match(source)
             ]
@@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                 continue
             for (i, source) in enumerate(merge_sources):
                 self._source_stacking_sources.setdefault(source, []).append(
-                    (key, new_source)
+                    (key, new_source, stack_axis)
                 )
-                self._source_stacking_indices[(source, key)] = i
+                self._source_stacking_indices[(source, new_source, key)] = i
 
         # handle key stacking groups
         self._key_stacking_sources.clear()
@@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                     (key_re, new_key)
                 )
 
+    def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
+        # TODO: handle ValueError for max of empty sequence
+        stack_num = (
+            max(
+                index
+                for (
+                    _,
+                    new_source_,
+                    key_,
+                ), index in self._source_stacking_indices.items()
+                if new_source_ == new_source and key_ == key
+            )
+            + 1
+        )
+        self._stacking_buffers[(new_source, key)] = np.empty(
+            shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis),
+            dtype=dtype,
+        )
+
     def _handle_source(self, source, data_hash, timestamp, new_sources_map):
         # dereference calng shmem handles
         self._shmem_handler.dereference_shmem_handles(data_hash)
 
         # stack across sources (many sources, same key)
         # could probably save ~100 ns by "if ... in" instead of get
-        for (stack_key, new_source) in self._source_stacking_sources.get(source, ()):
+        for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get(
+            source, ()
+        ):
             this_data = data_hash.get(stack_key)
             try:
                 this_buffer = self._stacking_buffers[(new_source, stack_key)]
-                stack_index = self._source_stacking_indices[(source, stack_key)]
-                this_buffer[stack_index] = this_data
+                stack_index = self._source_stacking_indices[
+                    (source, new_source, stack_key)
+                ]
+                utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
             except (ValueError, IndexError, KeyError):
-                # ValueError: wrong shape
+                # ValueError: wrong shape (react to this_data.shape)
                 # KeyError: buffer doesn't exist yet
+                # IndexError: new source? (buffer not long enough)
                 # either way, create appropriate buffer now
-                # TODO: complain if shape varies between sources
-                self._stacking_buffers[(new_source, stack_key)] = np.empty(
-                    shape=(
-                        max(
-                            index_
-                            for (
-                                source_,
-                                key_,
-                            ), index_ in self._source_stacking_indices.items()
-                            if source_ == source and key_ == stack_key
-                        )
-                        + 1,
-                    )
-                    + this_data.shape,
+                # TODO: complain if shape varies between sources within train
+                self._update_stacking_buffer(
+                    new_source,
+                    stack_key,
+                    this_data.shape,
+                    axis=stack_axis,
                     dtype=this_data.dtype,
                 )
                 # and then try again
                 this_buffer = self._stacking_buffers[(new_source, stack_key)]
-                stack_index = self._source_stacking_indices[(source, stack_key)]
-                this_buffer[stack_index] = this_data
-                # TODO: zero out unfilled buffer entries
+                stack_index = self._source_stacking_indices[
+                    (source, new_source, stack_key)
+                ]
+                utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
+            # TODO: zero out unfilled buffer entries
             data_hash.erase(stack_key)
 
             if new_source not in new_sources_map:
@@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
             concurrent.futures.wait(
                 [
                     self._thread_pool.submit(
-                        self._handle_source, data, timestamp, new_sources_map
+                        self._handle_source, source, data, timestamp, new_sources_map
                     )
                     for source, (data, timestamp) in sources.items()
                 ]
diff --git a/src/calng/utils.py b/src/calng/utils.py
index d0f6ceb4..83289230 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
     return tuple(axis_order[axis] for axis in axes_out)
 
 
+def stacking_buffer_shape(array_shape, stack_num, axis=0):
+    """Figures out the shape you would need for np.stack"""
+    if axis > len(array_shape) or axis < -len(array_shape) - 1:
+        # complain when np.stack would
+        raise np.AxisError(
+            f"axis {axis} is out of bounds "
+            f"for array of dimension {len(array_shape) + 1}"
+        )
+    if axis < 0:
+        axis += len(array_shape) + 1
+    return array_shape[:axis] + (stack_num,) + array_shape[axis:]
+
+
+def set_on_axis(array, vals, index, axis):
+    """set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x"""
+    if axis >= len(array):
+        raise IndexError(
+            f"too many indices for array: array is {len(array.shape)}-dimensional, "
+            f"but {axis+1} were indexed"
+        )
+    # TODO: maybe support negative axis with wraparound
+    indices = np.index_exp[:] * axis + np.index_exp[index]
+    array[indices] = vals
+
+
 _np_typechar_to_c_typestring = {
     "?": "bool",
     "B": "unsigned char",
-- 
GitLab