From b2b5aca1199f372e63b116100cbd4573134b737d Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Thu, 5 May 2022 16:30:14 +0200
Subject: [PATCH] Allow thread pool usage

---
 src/calng/ShmemTrainMatcher.py | 142 +++++++++++++++++++--------------
 1 file changed, 82 insertions(+), 60 deletions(-)

diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index 607ab5f2..788f328b 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -1,3 +1,4 @@
+import concurrent.futures
 import enum
 import re
 
@@ -91,6 +92,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
             .defaultValue([])
             .reconfigurable()
             .commit(),
+
+            BOOL_ELEMENT(expected)
+            .key("useThreadPool")
+            .assignmentOptional()
+            .defaultValue(False)
+            .commit(),
         )
 
     def initialization(self):
@@ -101,6 +108,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
         self._prepare_merge_groups(self.get("merge"))
         super().initialization()
         self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
+        if self.get("useThreadPool"):
+            self._thread_pool = concurrent.futures.ThreadPoolExecutor()
+        else:
+            self._thread_pool = None
 
     def preReconfigure(self, conf):
         super().preReconfigure(conf)
@@ -165,70 +176,81 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                     (key_re, new_key)
                 )
 
-    def on_matched_data(self, train_id, sources):
+    def _handle_source(self, source, data_hash, timestamp, new_sources_map):
         # dereference calng shmem handles
-        for (data, _) in sources.values():
-            self._shmem_handler.dereference_shmem_handles(data)
+        self._shmem_handler.dereference_shmem_handles(data_hash)
 
-        new_sources_map = {}
-        for source, (data, timestamp) in sources.items():
-            # stack across sources (many sources, same key)
-            # could probably save ~100 ns by "if ... in" instead of get
-            for (stack_key, new_source) in self._source_stacking_sources.get(
-                source, ()
-            ):
-                this_data = data.get(stack_key)
-                try:
-                    this_buffer = self._stacking_buffers[(new_source, stack_key)]
-                    stack_index = self._source_stacking_indices[(source, stack_key)]
-                    this_buffer[stack_index] = this_data
-                except (ValueError, KeyError):
-                    # ValueError: wrong shape
-                    # KeyError: buffer doesn't exist yet
-                    # either way, create appropriate buffer now
-                    # TODO: complain if shape varies between sources
-                    self._stacking_buffers[(new_source, stack_key)] = np.empty(
-                        shape=(
-                            max(
-                                index_
-                                for (
-                                    source_,
-                                    key_,
-                                ), index_ in self._source_stacking_indices.items()
-                                if source_ == source and key_ == stack_key
-                            )
-                            + 1,
+        # stack across sources (many sources, same key)
+        # could probably save ~100 ns by "if ... in" instead of get
+        for (stack_key, new_source) in self._source_stacking_sources.get(source, ()):
+            this_data = data_hash.get(stack_key)
+            try:
+                this_buffer = self._stacking_buffers[(new_source, stack_key)]
+                stack_index = self._source_stacking_indices[(source, stack_key)]
+                this_buffer[stack_index] = this_data
+            except (ValueError, IndexError, KeyError):
+                # ValueError: wrong shape
+                # KeyError: buffer doesn't exist yet
+                # either way, create appropriate buffer now
+                # TODO: complain if shape varies between sources
+                self._stacking_buffers[(new_source, stack_key)] = np.empty(
+                    shape=(
+                        max(
+                            index_
+                            for (
+                                source_,
+                                key_,
+                            ), index_ in self._source_stacking_indices.items()
+                            if source_ == source and key_ == stack_key
                         )
-                        + this_data.shape,
-                        dtype=this_data.dtype,
+                        + 1,
                     )
-                    # and then try again
-                    this_buffer = self._stacking_buffers[(new_source, stack_key)]
-                    stack_index = self._source_stacking_indices[(source, stack_key)]
-                    this_buffer[stack_index] = this_data
-                    # TODO: zero out unfilled buffer entries
-                data.erase(stack_key)
-
-                if new_source not in new_sources_map:
-                    new_sources_map[new_source] = (Hash(), timestamp)
-                new_source_hash = new_sources_map[new_source][0]
-                if not new_source_hash.has(stack_key):
-                    new_source_hash[stack_key] = this_buffer
-
-            # stack keys (multiple keys within this source)
-            for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
-                # note: please no overlap between different key_re
-                # note: if later key_re match earlier new_key, this gets spicy
-                stack_keys = [key for key in data.paths() if key_re.match(key)]
-                try:
-                    # TODO: consider reusing buffers here, too
-                    stacked = np.stack([data.get(key) for key in stack_keys], axis=0)
-                except Exception as e:
-                    self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
-                else:
-                    for key in stack_keys:
-                        data.erase(key)
-                    data[new_key] = stacked
+                    + this_data.shape,
+                    dtype=this_data.dtype,
+                )
+                # and then try again
+                this_buffer = self._stacking_buffers[(new_source, stack_key)]
+                stack_index = self._source_stacking_indices[(source, stack_key)]
+                this_buffer[stack_index] = this_data
+                # TODO: zero out unfilled buffer entries
+            data_hash.erase(stack_key)
+
+            if new_source not in new_sources_map:
+                new_sources_map[new_source] = (Hash(), timestamp)
+            new_source_hash = new_sources_map[new_source][0]
+            if not new_source_hash.has(stack_key):
+                new_source_hash[stack_key] = this_buffer
+
+        # stack keys (multiple keys within this source)
+        for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
+            # note: please no overlap between different key_re
+            # note: if later key_re match earlier new_key, this gets spicy
+            stack_keys = [key for key in data_hash.paths() if key_re.match(key)]
+            try:
+                # TODO: consider reusing buffers here, too
+                stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0)
+            except Exception as e:
+                self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
+            else:
+                for key in stack_keys:
+                    data_hash.erase(key)
+                data_hash[new_key] = stacked
+
+    def on_matched_data(self, train_id, sources):
+
+        new_sources_map = {}
+        if self._thread_pool is None:
+            for source, (data, timestamp) in sources.items():
+                self._handle_source(source, data, timestamp, new_sources_map)
+        else:
+            concurrent.futures.wait(
+                [
+                    self._thread_pool.submit(
+                        self._handle_source, data, timestamp, new_sources_map
+                    )
+                    for source, (data, timestamp) in sources.items()
+                ]
+            )
 
         sources.update(new_sources_map)
 
-- 
GitLab