From ad4606b850f435e8f29a9bdff350b64c788b3ca4 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Tue, 17 Oct 2023 08:57:21 +0200
Subject: [PATCH] Finished overhaul of source stacking, added test

---
 src/calng/stacking_utils.py | 241 +++++++++++++++++++-----------------
 src/tests/test_stacking.py  |  64 ++++++++++
 2 files changed, 188 insertions(+), 117 deletions(-)
 create mode 100644 src/tests/test_stacking.py

diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py
index 049480d4..8149ceea 100644
--- a/src/calng/stacking_utils.py
+++ b/src/calng/stacking_utils.py
@@ -1,5 +1,7 @@
 import collections
+import concurrent.futures
 import enum
+import functools
 import re
 
 from karabo.bound import (
@@ -139,41 +141,42 @@ class StackingFriend:
         )
 
     def __init__(self, device, source_config, merge_config):
+        # (old source, new source, key) -> (index (maybe index_exp), axis)
         self._source_stacking_indices = {}
-        self._source_stacking_sources = collections.defaultdict(list)
-        self._source_stacking_group_sizes = {}
-        # (new source name, key) -> {original sources used}
+        # list of (old source, key, new source)
+        self._source_stacking_parts = []
+        # (new source, key, method, group size, axis, missing) -> {original sources}
         self._new_sources_inputs = collections.defaultdict(set)
         self._key_stacking_sources = collections.defaultdict(list)
+
         self._merge_config = None
         self._source_config = None
         self._device = device
-        self.reconfigure(merge_config, source_config)
+        self.reconfigure(source_config, merge_config)
 
-    def reconfigure(self, merge_config, source_config):
+    def reconfigure(self, source_config, merge_config):
         print("merge_config", type(merge_config))
         print("source_config", type(source_config))
-        if merge_config is not None:
-            self._merge_config = merge_config
         if source_config is not None:
             self._source_config = source_config
+        if merge_config is not None:
+            self._merge_config = merge_config
         self._source_stacking_indices.clear()
-        self._source_stacking_sources.clear()
-        self._source_stacking_group_sizes.clear()
-        self._key_stacking_sources.clear()
+        self._source_stacking_parts.clear()
         self._new_sources_inputs.clear()
+        self._key_stacking_sources.clear()
 
         # not filtering by row["select"] to allow unselected sources to create gaps
         source_names = [row["source"].partition("@")[0] for row in self._source_config]
         source_stacking_groups = [
             row
             for row in self._merge_config
-            if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name
+            if row["select"] and row["groupType"] == GroupType.MULTISOURCE.value
         ]
         key_stacking_groups = [
             row
             for row in self._merge_config
-            if row["select"] and row["groupType"] == GroupType.MULTIKEY.name
+            if row["select"] and row["groupType"] == GroupType.MULTIKEY.value
         ]
 
         for row in source_stacking_groups:
@@ -181,6 +184,7 @@ class StackingFriend:
             merge_method = MergeMethod(row["mergeMethod"])
             key = row["keyPattern"]
             new_source = row["replacement"]
+            axis = row["axis"]
             merge_sources = [
                 source for source in source_names if source_re.match(source)
             ]
@@ -189,11 +193,7 @@ class StackingFriend:
                     f"Group pattern {source_re} did not match any known sources"
                 )
                 continue
-            self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources)
             for i, source in enumerate(merge_sources):
-                self._source_stacking_sources[source].append(
-                    (key, new_source, merge_method, row["axis"])
-                )
                 self._source_stacking_indices[(source, new_source, key)] = (
                     i
                     if merge_method is MergeMethod.STACK
@@ -204,7 +204,18 @@ class StackingFriend:
                             len(merge_sources),
                         )
                     ]
-                )
+                ), axis
+                self._new_sources_inputs[
+                    (
+                        new_source,
+                        key,
+                        merge_method,
+                        len(merge_sources),
+                        axis,
+                        row["missingValue"],
+                    )
+                ].add(source)
+                self._source_stacking_parts.append((source, key, new_source))
 
         for row in key_stacking_groups:
             key_re = re.compile(row["keyPattern"])
@@ -214,25 +225,27 @@ class StackingFriend:
             )
 
     def process(self, sources, thread_pool=None):
-        stacking_data_shapes = {}
-        stacking_buffers = {}
-        new_source_map = collections.defaultdict(Hash)
+        merge_buffers = {}
+        new_source_map = {}
         missing_value_defaults = {}
 
         # prepare for source stacking where sources are present
-        source_set = set(sources.keys())
         for (
             new_source,
             data_key,
             merge_method,
             group_size,
             axis,
-            missing_value,
+            missing_value_str,
         ), original_sources in self._new_sources_inputs.items():
-            for present_source in source_set & original_sources:
-                data = sources[present_source].get(data_key)[0]
+            for source in original_sources:
+                if source not in sources:
+                    continue
+                data_hash, timestamp = sources[source]
+                data = data_hash.get(data_key)
                 if data is None:
                     continue
+                # didn't hit continue => first source in group present and with data
                 if merge_method is MergeMethod.STACK:
                     expected_shape = utils.stacking_buffer_shape(
                         data.shape, group_size, axis=axis
@@ -241,119 +254,113 @@ class StackingFriend:
                     expected_shape = utils.interleaving_buffer_shape(
                         data.shape, group_size, axis=axis
                     )
-                stacking_buffer = np.empty(
-                    expected_shape, dtype=data.dtype
-                )
-                stacking_buffers[(new_source, data_key)] = stacking_buffer
-                new_source_map[new_source][data_key] = stacking_buffer
+                merge_buffer = np.empty(expected_shape, dtype=data.dtype)
+                merge_buffers[(new_source, data_key)] = merge_buffer
+                if new_source not in new_source_map:
+                    new_source_map[new_source] = (Hash(), timestamp)
+                new_source_map[new_source][0][data_key] = merge_buffer
                 try:
-                    missing_value_defaults[(new_source, data_key)] = data.dtype.type(
-                        missing_value
-                    )
+                    missing_value = data.dtype.type(missing_value_str)
                 except ValueError:
                     self._device.log.WARN(
-                        f"Invalid missing data value for {new_source}.{data_key}, using 0"
+                        f"Invalid missing data value for {new_source}.{data_key} "
+                        f"(tried making a {data.dtype} out of "
+                        f" '{missing_value_str}', using 0"
                     )
+                    missing_value = 0
+                missing_value_defaults[(new_source, data_key)] = missing_value
+                # now we have set up the buffer for this new source, so break
                 break
             else:
-                # in this case: no present_source (if any) had data_key
+                # in this case: no source (if any) had data_key
                 self._device.log.WARN(
                     f"No sources needed for {new_source}.{data_key} were present"
                 )
 
-        for (new_source, key), attr in stacking_data_shapes.items():
-            merge_data_shape, merge_method, axis, dtype, timestamp = attr
-            group_size = parent._source_stacking_group_sizes[(new_source, key)]
-            merge_buffer = self._stacking_buffers.get((new_source, key))
-            if merge_buffer is None or merge_buffer.shape != expected_shape:
-                merge_buffer = np.empty(shape=expected_shape, dtype=dtype)
-                self._stacking_buffers[(new_source, key)] = merge_buffer
-
-            for merge_index in self._fill_missed_data.get((new_source, key), []):
-                utils.set_on_axis(
-                    merge_buffer,
-                    0,
-                    merge_index
-                    if merge_method is MergeMethod.STACK
-                    else np.index_exp[
-                        slice(
-                            merge_index,
-                            None,
-                            parent._source_stacking_group_sizes[(new_source, key)],
-                        )
-                    ],
-                    axis,
-                )
-            if new_source not in self.new_source_map:
-                self.new_source_map[new_source] = (Hash(), timestamp)
-            new_source_hash = self.new_source_map[new_source][0]
-            if not new_source_hash.has(key):
-                new_source_hash[key] = merge_buffer
-
         # now actually do some work
-        fun = functools.partial(self._handle_source, ...)
+        fun = functools.partial(
+            self._handle_source_stacking_part,
+            merge_buffers,
+            missing_value_defaults,
+            sources,
+        )
         if thread_pool is None:
-            for _ in map(fun, ...):
+            for _ in map(fun, self._source_stacking_parts):
                 pass
         else:
-            concurrent.futures.wait(thread_pool.map(fun, ...))
+            concurrent.futures.wait(thread_pool.map(fun, self._source_stacking_parts))
+
+        # note: new source names for group stacking may not match existing source names
+        sources.update(new_source_map)
 
-    def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source):
+    def _handle_source_stacking_part(
+        self,
+        merge_buffers,
+        missing_values,
+        actual_sources,
+        stacking_triple,
+    ):
         """Helper function used in processing. Note that it should be called for each
-        source that was supposed to be there - so it can decide whether to move data in
-        for stacking, fill missing data in case a source is missing, or skip in case no
-        buffer was created (none of the necessary sources were present)."""
+        (original source, data key, new source) triple expected in the stacking scheme
+        regardless of whether that original source is present - so it can decide whether
+        to move data in for stacking, fill missing data in case a source is missing,
+        or skip in case no buffer was created (none of the necessary sources were
+        present)."""
 
-        if expected_source not in actual_sources:
-            if ex
-        for (
-            key,
-            new_source,
-            merge_method,
-            axis,
-        ) in self._source_stacking_sources.get(source, ()):
-            if (new_source, key) in self._ignore_stacking:
-                continue
-            merge_data = data_hash.get(key)
-            merge_index = self._source_stacking_indices[(source, new_source, key)]
-            merge_buffer = self._stacking_buffers[(new_source, key)]
+        old_source, data_key, new_source = stacking_triple
+
+        if (new_source, data_key) not in merge_buffers:
+            # preparatory step did not create buffer
+            # (i.e. no sources from this group were present)
+            return
+        merge_buffer = merge_buffers[(new_source, data_key)]
+        if old_source in actual_sources:
+            data_hash = actual_sources[old_source][0]
+        else:
+            data_hash = None
+        if data_hash is None or (merge_data := data_hash.get(data_key)) is None:
+            merge_data = missing_values[(new_source, data_key)]
+        try:
             utils.set_on_axis(
                 merge_buffer,
                 merge_data,
-                merge_index,
-                axis,
+                *self._source_stacking_indices[(old_source, new_source, data_key)],
+            )
+            if data_hash is not None:
+                data_hash.erase(data_key)
+        except Exception as ex:
+            self._device.log.WARN(
+                f"Failed to stack {data_key} from {old_source} into {new_source}: {ex}"
             )
-            data_hash.erase(key)
 
-        # stack keys (multiple keys within this source)
-        for key_re, new_key, merge_method, axis in self._key_stacking_sources.get(
-            source, ()
-        ):
-            # note: please no overlap between different key_re
-            # note: if later key_re match earlier new_key, this gets spicy
-            keys = [key for key in data_hash.paths() if key_re.match(key)]
-            try:
-                # note: maybe we could reuse buffers here, too?
-                if merge_method is MergeMethod.STACK:
-                    stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
-                else:
-                    first = data_hash.get(keys[0])
-                    stacked = np.empty(
-                        shape=utils.stacking_buffer_shape(
-                            first.shape, len(keys), axis=axis
-                        ),
-                        dtype=first.dtype,
-                    )
-                    for i, key in enumerate(keys):
-                        utils.set_on_axis(
-                            stacked,
-                            data_hash.get(key),
-                            np.index_exp[slice(i, None, len(keys))],
-                            axis,
-                        )
-            except Exception as e:
-                self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
+    def _handle_key_stacking_entry(
+        self, source, data_hash, key_re, new_key, merge_method, axis
+    ):
+        # note: please no overlap between different key_re
+        # note: if later key_re match earlier new_key, this gets spicy
+        keys = [key for key in data_hash.paths() if key_re.match(key)]
+        try:
+            # note: maybe we could reuse buffers here, too?
+            if merge_method is MergeMethod.STACK:
+                stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
             else:
-                for key in keys:
-                    data_hash.erase(key)
-                data_hash[new_key] = stacked
+                first = data_hash.get(keys[0])
+                stacked = np.empty(
+                    shape=utils.stacking_buffer_shape(
+                        first.shape, len(keys), axis=axis
+                    ),
+                    dtype=first.dtype,
+                )
+                for i, key in enumerate(keys):
+                    utils.set_on_axis(
+                        stacked,
+                        data_hash.get(key),
+                        np.index_exp[slice(i, None, len(keys))],
+                        axis,
+                    )
+        except Exception as e:
+            self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
+        else:
+            for key in keys:
+                data_hash.erase(key)
+            data_hash[new_key] = stacked
diff --git a/src/tests/test_stacking.py b/src/tests/test_stacking.py
new file mode 100644
index 00000000..0c2ff00e
--- /dev/null
+++ b/src/tests/test_stacking.py
@@ -0,0 +1,64 @@
+from karabo.bound import Hash
+import numpy as np
+
+from calng import stacking_utils
+
+
+class NotADevice:
+    class log:
+        @staticmethod
+        def WARN(s):
+            print(f"Warning: {s}")
+
+
+source_table = [
+    {
+        "select": True,
+        "source": f"source{i}@device{i}:channel",
+    }
+    for i in range(3)
+]
+
+merge_rules = [
+    {
+        "select": True,
+        "sourcePattern": "source\\d+",
+        "keyPattern": "keyToStack",
+        "replacement": "newSource",
+        "groupType": "sources",
+        "mergeMethod": "stack",
+        "axis": 1,
+        "missingValue": "0",
+    }
+]
+
+friend = stacking_utils.StackingFriend(NotADevice(), source_table, merge_rules)
+
+datas = [np.arange(i * 100, i * 100 + 8).reshape(4, 2) for i in range(3)]
+
+
+def test_simple_source_stacking():
+    sources = {
+        f"source{i}": (Hash("keyToStack", data), None) for i, data in enumerate(datas)
+    }
+
+    friend.process(sources)
+    assert "newSource" in sources
+    stacked = sources["newSource"][0]["keyToStack"]
+    assert stacked.shape == (4, 3, 2)
+    for i, data in enumerate(datas):
+        assert np.array_equal(stacked[:, i], data, equal_nan=True)
+
+
+def test_source_stacking_one_missing():
+    sources = {
+        f"source{i}": (Hash("keyToStack", data), None) for i, data in enumerate(datas)
+    }
+    del sources["source0"]
+    friend.process(sources)
+    assert "newSource" in sources
+    stacked = sources["newSource"][0]["keyToStack"]
+    print(stacked)
+    assert np.all(stacked[:, 0] == 0)
+    for i, data in enumerate(datas[1:], start=1):
+        assert np.array_equal(stacked[:, i], data, equal_nan=True)
-- 
GitLab