From c3e1c48f8e54c19dde8953a179666e30abe53dda Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Mon, 16 Oct 2023 19:56:41 +0200
Subject: [PATCH] WIP: restructure / simplify stacking execution

---
 src/calng/ShmemTrainMatcher.py |  40 ++---
 src/calng/stacking_utils.py    | 262 +++++++++++++++++----------------
 2 files changed, 159 insertions(+), 143 deletions(-)

diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index 72bab416..301b8d04 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
     def initialization(self):
         super().initialization()
         self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
-        self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources"))
+        self._stacking_friend = StackingFriend(
+            self, self.get("merge"), self.get("sources")
+        )
         self._frameselection_friend = FrameselectionFriend(self.get("frameSelector"))
         self._thread_pool = concurrent.futures.ThreadPoolExecutor(
             max_workers=self.get("processingThreads")
@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
     def on_matched_data(self, train_id, sources):
         frame_selection_mask = self._frameselection_friend.get_mask(sources)
         # note: should not do stacking and frame selection for now!
-        self._stacking_friend.prepare_stacking_for_train(sources)
-
-        concurrent.futures.wait(
-            [
-                self._thread_pool.submit(
-                    self._handle_source,
-                    source,
-                    data,
-                    timestamp,
-                    new_sources_map,
-                    frame_selection_mask,
-                )
-                for source, (data, timestamp) in sources.items()
-            ]
-        )
-        sources.update(new_sources_map)
+        with self._stacking_friend.stacking_context as stacker:
+            concurrent.futures.wait(
+                [
+                    self._thread_pool.submit(
+                        self._handle_source,
+                        source,
+                        data,
+                        timestamp,
+                        stacker,
+                        frame_selection_mask,
+                    )
+                    for source, (data, timestamp) in sources.items()
+                ]
+            )
+        sources.update(stacker.new_source_map)
 
         # karabo output
         if self.output is not None:
@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         source,
         data_hash,
         timestamp,
-        new_sources_map,
+        stacker,
         frame_selection_mask,
-        ignore_stacking,
     ):
         self._shmem_handler.dereference_shmem_handles(data_hash)
         self._frameselection_friend.apply_mask(source, data_hash, frame_selection_mask)
-        self._stacking_friend.handle_source(...)
+        stacker.process(source, data_hash)
diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py
index b6d21a3d..049480d4 100644
--- a/src/calng/stacking_utils.py
+++ b/src/calng/stacking_utils.py
@@ -1,3 +1,4 @@
+import collections
 import enum
 import re
 
@@ -79,6 +80,22 @@ def merge_schema():
         .defaultValue(0)
         .reconfigurable()
         .commit(),
+        STRING_ELEMENT(schema)
+        .key("missingValue")
+        .displayedName("Missing value default")
+        .description(
+            "If some sources are missing within one group in multi-source stacking*, "
+            "the corresponding parts of the resulting stacked array will be set to "
+            "this value. Note that if no sources are present, the array is not created "
+            "at all. This field is a string to allow special values like float nan / "
+            "inf; it is your responsibility to make sure that data types match (i.e. "
+            "if this is 'nan', the stacked data better be floats or doubles). *Missing "
+            "value handling is not yet implementedo for multi-key stacking."
+        )
+        .assignmentOptional()
+        .defaultValue("0")
+        .reconfigurable()
+        .commit(),
     )
 
     return schema
@@ -121,147 +138,132 @@ class StackingFriend:
             .commit(),
         )
 
-    def __init__(self, merge_config, source_config):
-        self._stacking_buffers = {}
+    def __init__(self, device, source_config, merge_config):
         self._source_stacking_indices = {}
-        self._source_stacking_sources = {}
+        self._source_stacking_sources = collections.defaultdict(list)
         self._source_stacking_group_sizes = {}
-        self._key_stacking_sources = {}
-        self._merge_config = Hash()
-        self._source_config = Hash()
+        # (new source name, key) -> {original sources used}
+        self._new_sources_inputs = collections.defaultdict(set)
+        self._key_stacking_sources = collections.defaultdict(list)
+        self._merge_config = None
+        self._source_config = None
+        self._device = device
         self.reconfigure(merge_config, source_config)
 
     def reconfigure(self, merge_config, source_config):
+        print("merge_config", type(merge_config))
+        print("source_config", type(source_config))
         if merge_config is not None:
-            self._merge_config.merge(merge_config)
+            self._merge_config = merge_config
         if source_config is not None:
-            self._source_config.merge(source_config)
-
-        # not filtering by row["select"] to allow unselected sources to create gaps
-        source_names = [row["source"].partition("@")[0] for row in self._source_config]
-        self._stacking_buffers.clear()
+            self._source_config = source_config
         self._source_stacking_indices.clear()
         self._source_stacking_sources.clear()
         self._source_stacking_group_sizes.clear()
         self._key_stacking_sources.clear()
-        # split by type, prepare regexes
-        for row in self._merge_config:
-            if not row["select"]:
-                continue
-            group_type = GroupType(row["groupType"])
+        self._new_sources_inputs.clear()
+
+        # not filtering by row["select"] to allow unselected sources to create gaps
+        source_names = [row["source"].partition("@")[0] for row in self._source_config]
+        source_stacking_groups = [
+            row
+            for row in self._merge_config
+            if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name
+        ]
+        key_stacking_groups = [
+            row
+            for row in self._merge_config
+            if row["select"] and row["groupType"] == GroupType.MULTIKEY.name
+        ]
+
+        for row in source_stacking_groups:
             source_re = re.compile(row["sourcePattern"])
             merge_method = MergeMethod(row["mergeMethod"])
-            axis = row["axis"]
-            if group_type is GroupType.MULTISOURCE:
-                key = row["keyPattern"]
-                new_source = row["replacement"]
-                merge_sources = [
-                    source for source in source_names if source_re.match(source)
-                ]
-                if len(merge_sources) == 0:
-                    self.log.WARN(
-                        f"Group pattern {source_re} did not match any known sources"
-                    )
-                    continue
-                self._source_stacking_group_sizes[(new_source, key)] = len(
-                    merge_sources
+            key = row["keyPattern"]
+            new_source = row["replacement"]
+            merge_sources = [
+                source for source in source_names if source_re.match(source)
+            ]
+            if len(merge_sources) == 0:
+                self._device.log.WARN(
+                    f"Group pattern {source_re} did not match any known sources"
                 )
-                for i, source in enumerate(merge_sources):
-                    self._source_stacking_sources.setdefault(source, []).append(
-                        (key, new_source, merge_method, axis)
-                    )
-                    self._source_stacking_indices[(source, new_source, key)] = (
-                        i
-                        if merge_method is MergeMethod.STACK
-                        else np.index_exp[
-                            slice(
-                                i,
-                                None,
-                                self._source_stacking_group_sizes[(new_source, key)],
-                            )
-                        ]
-                    )
-            else:
-                key_re = re.compile(row["keyPattern"])
-                new_key = row["replacement"]
-                self._key_stacking_sources.setdefault(source, []).append(
-                    (key_re, new_key, merge_method, axis)
+                continue
+            self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources)
+            for i, source in enumerate(merge_sources):
+                self._source_stacking_sources[source].append(
+                    (key, new_source, merge_method, row["axis"])
+                )
+                self._source_stacking_indices[(source, new_source, key)] = (
+                    i
+                    if merge_method is MergeMethod.STACK
+                    else np.index_exp[  # interleaving
+                        slice(
+                            i,
+                            None,
+                            len(merge_sources),
+                        )
+                    ]
                 )
 
-    def prepare_stacking_for_train(
-        self, sources, frame_selection_mask, new_sources_map
-    ):
-        if frame_selection_mask is not None:
-            orig_size = len(frame_selection_mask)
-            result_size = np.sum(frame_selection_mask)
+        for row in key_stacking_groups:
+            key_re = re.compile(row["keyPattern"])
+            new_key = row["replacement"]
+            self._key_stacking_sources[source].append(
+                (key_re, new_key, MergeMethod(row["mergeMethod"]), row["axis"])
+            )
 
+    def process(self, sources, thread_pool=None):
         stacking_data_shapes = {}
-        self._ignore_stacking = {}
-        self._fill_missed_data = {}
-        for source, keys in self._source_stacking_sources.items():
-            if source not in sources:
-                for key, new_source, _, _ in keys:
-                    missed_sources = self._fill_missed_data.setdefault(
-                        (new_source, key), []
+        stacking_buffers = {}
+        new_source_map = collections.defaultdict(Hash)
+        missing_value_defaults = {}
+
+        # prepare for source stacking where sources are present
+        source_set = set(sources.keys())
+        for (
+            new_source,
+            data_key,
+            merge_method,
+            group_size,
+            axis,
+            missing_value,
+        ), original_sources in self._new_sources_inputs.items():
+            for present_source in source_set & original_sources:
+                data = sources[present_source].get(data_key)[0]
+                if data is None:
+                    continue
+                if merge_method is MergeMethod.STACK:
+                    expected_shape = utils.stacking_buffer_shape(
+                        data.shape, group_size, axis=axis
                     )
-                    merge_index = self._source_stacking_indices[
-                        (source, new_source, key)
-                    ]
-                    missed_sources.append(merge_index)
-                continue
-            data_hash, timestamp = sources[source]
-            filtering = (
-                frame_selection_mask is not None
-                and self._frame_selection_source_pattern.match(source)
-            )
-            for key, new_source, merge_method, axis in keys:
-                merge_data_shape = None
-                if key in data_hash:
-                    merge_data = data_hash[key]
-                    merge_data_shape = merge_data.shape
                 else:
-                    self._ignore_stacking[(new_source, key)] = "Some data is missed"
-                    continue
-
-                if filtering and key in self._frame_selection_data_keys:
-                    # !!! stacking is not expected to be used with filtering
-                    if merge_data_shape[0] == orig_size:
-                        merge_data_shape = (result_size,) + merge_data.shape[1:]
-
-                (
-                    expected_shape,
-                    _,
-                    _,
-                    expected_dtype,
-                    _,
-                ) = stacking_data_shapes.setdefault(
-                    (new_source, key),
-                    (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp),
-                )
-                if (
-                    expected_shape != merge_data_shape
-                    or expected_dtype != merge_data.dtype
-                ):
-                    self._ignore_stacking[
-                        (new_source, key)
-                    ] = "Shape or dtype is inconsistent"
-                    del stacking_data_shapes[(new_source, key)]
-
-        for (new_source, key), msg in self._ignore_stacking.items():
-            self._device.log.WARN(f"Failed to stack {new_source}.{key}: {msg}")
-
-        for (new_source, key), attr in stacking_data_shapes.items():
-            merge_data_shape, merge_method, axis, dtype, timestamp = attr
-            group_size = self._source_stacking_group_sizes[(new_source, key)]
-            if merge_method is MergeMethod.STACK:
-                expected_shape = utils.stacking_buffer_shape(
-                    merge_data_shape, group_size, axis=axis
+                    expected_shape = utils.interleaving_buffer_shape(
+                        data.shape, group_size, axis=axis
+                    )
+                stacking_buffer = np.empty(
+                    expected_shape, dtype=data.dtype
                 )
+                stacking_buffers[(new_source, data_key)] = stacking_buffer
+                new_source_map[new_source][data_key] = stacking_buffer
+                try:
+                    missing_value_defaults[(new_source, data_key)] = data.dtype.type(
+                        missing_value
+                    )
+                except ValueError:
+                    self._device.log.WARN(
+                        f"Invalid missing data value for {new_source}.{data_key}, using 0"
+                    )
+                break
             else:
-                expected_shape = utils.interleaving_buffer_shape(
-                    merge_data_shape, group_size, axis=axis
+                # in this case: no present_source (if any) had data_key
+                self._device.log.WARN(
+                    f"No sources needed for {new_source}.{data_key} were present"
                 )
 
+        for (new_source, key), attr in stacking_data_shapes.items():
+            merge_data_shape, merge_method, axis, dtype, timestamp = attr
+            group_size = parent._source_stacking_group_sizes[(new_source, key)]
             merge_buffer = self._stacking_buffers.get((new_source, key))
             if merge_buffer is None or merge_buffer.shape != expected_shape:
                 merge_buffer = np.empty(shape=expected_shape, dtype=dtype)
@@ -277,19 +279,33 @@ class StackingFriend:
                         slice(
                             merge_index,
                             None,
-                            self._source_stacking_group_sizes[(new_source, key)],
+                            parent._source_stacking_group_sizes[(new_source, key)],
                         )
                     ],
                     axis,
                 )
-            if new_source not in new_sources_map:
-                new_sources_map[new_source] = (Hash(), timestamp)
-            new_source_hash = new_sources_map[new_source][0]
+            if new_source not in self.new_source_map:
+                self.new_source_map[new_source] = (Hash(), timestamp)
+            new_source_hash = self.new_source_map[new_source][0]
             if not new_source_hash.has(key):
                 new_source_hash[key] = merge_buffer
 
-    def handle_source(self, source, data_hash):
-        # stack across sources (many sources, same key)
+        # now actually do some work
+        fun = functools.partial(self._handle_source, ...)
+        if thread_pool is None:
+            for _ in map(fun, ...):
+                pass
+        else:
+            concurrent.futures.wait(thread_pool.map(fun, ...))
+
+    def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source):
+        """Helper function used in processing. Note that it should be called for each
+        source that was supposed to be there - so it can decide whether to move data in
+        for stacking, fill missing data in case a source is missing, or skip in case no
+        buffer was created (none of the necessary sources were present)."""
+
+        if expected_source not in actual_sources:
+            if ex
         for (
             key,
             new_source,
-- 
GitLab