Skip to content
Snippets Groups Projects
Commit b2b5aca1 authored by David Hammer's avatar David Hammer
Browse files

Allow thread pool usage

parent f8f380cc
No related branches found
No related tags found
2 merge requests!10DetectorAssembler: assemble with extra shape (multiple frames),!9Stacking shmem matcher
import concurrent.futures
import enum import enum
import re import re
...@@ -91,6 +92,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -91,6 +92,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.defaultValue([]) .defaultValue([])
.reconfigurable() .reconfigurable()
.commit(), .commit(),
BOOL_ELEMENT(expected)
.key("useThreadPool")
.assignmentOptional()
.defaultValue(False)
.commit(),
) )
def initialization(self): def initialization(self):
...@@ -101,6 +108,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -101,6 +108,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._prepare_merge_groups(self.get("merge")) self._prepare_merge_groups(self.get("merge"))
super().initialization() super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
if self.get("useThreadPool"):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
else:
self._thread_pool = None
def preReconfigure(self, conf): def preReconfigure(self, conf):
super().preReconfigure(conf) super().preReconfigure(conf)
...@@ -165,70 +176,81 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -165,70 +176,81 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
(key_re, new_key) (key_re, new_key)
) )
def on_matched_data(self, train_id, sources): def _handle_source(self, source, data_hash, timestamp, new_sources_map):
# dereference calng shmem handles # dereference calng shmem handles
for (data, _) in sources.values(): self._shmem_handler.dereference_shmem_handles(data_hash)
self._shmem_handler.dereference_shmem_handles(data)
new_sources_map = {} # stack across sources (many sources, same key)
for source, (data, timestamp) in sources.items(): # could probably save ~100 ns by "if ... in" instead of get
# stack across sources (many sources, same key) for (stack_key, new_source) in self._source_stacking_sources.get(source, ()):
# could probably save ~100 ns by "if ... in" instead of get this_data = data_hash.get(stack_key)
for (stack_key, new_source) in self._source_stacking_sources.get( try:
source, () this_buffer = self._stacking_buffers[(new_source, stack_key)]
): stack_index = self._source_stacking_indices[(source, stack_key)]
this_data = data.get(stack_key) this_buffer[stack_index] = this_data
try: except (ValueError, IndexError, KeyError):
this_buffer = self._stacking_buffers[(new_source, stack_key)] # ValueError: wrong shape
stack_index = self._source_stacking_indices[(source, stack_key)] # KeyError: buffer doesn't exist yet
this_buffer[stack_index] = this_data # either way, create appropriate buffer now
except (ValueError, KeyError): # TODO: complain if shape varies between sources
# ValueError: wrong shape self._stacking_buffers[(new_source, stack_key)] = np.empty(
# KeyError: buffer doesn't exist yet shape=(
# either way, create appropriate buffer now max(
# TODO: complain if shape varies between sources index_
self._stacking_buffers[(new_source, stack_key)] = np.empty( for (
shape=( source_,
max( key_,
index_ ), index_ in self._source_stacking_indices.items()
for ( if source_ == source and key_ == stack_key
source_,
key_,
), index_ in self._source_stacking_indices.items()
if source_ == source and key_ == stack_key
)
+ 1,
) )
+ this_data.shape, + 1,
dtype=this_data.dtype,
) )
# and then try again + this_data.shape,
this_buffer = self._stacking_buffers[(new_source, stack_key)] dtype=this_data.dtype,
stack_index = self._source_stacking_indices[(source, stack_key)] )
this_buffer[stack_index] = this_data # and then try again
# TODO: zero out unfilled buffer entries this_buffer = self._stacking_buffers[(new_source, stack_key)]
data.erase(stack_key) stack_index = self._source_stacking_indices[(source, stack_key)]
this_buffer[stack_index] = this_data
if new_source not in new_sources_map: # TODO: zero out unfilled buffer entries
new_sources_map[new_source] = (Hash(), timestamp) data_hash.erase(stack_key)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(stack_key): if new_source not in new_sources_map:
new_source_hash[stack_key] = this_buffer new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
# stack keys (multiple keys within this source) if not new_source_hash.has(stack_key):
for (key_re, new_key) in self._key_stacking_sources.get(source, ()): new_source_hash[stack_key] = this_buffer
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy # stack keys (multiple keys within this source)
stack_keys = [key for key in data.paths() if key_re.match(key)] for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
try: # note: please no overlap between different key_re
# TODO: consider reusing buffers here, too # note: if later key_re match earlier new_key, this gets spicy
stacked = np.stack([data.get(key) for key in stack_keys], axis=0) stack_keys = [key for key in data_hash.paths() if key_re.match(key)]
except Exception as e: try:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") # TODO: consider reusing buffers here, too
else: stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0)
for key in stack_keys: except Exception as e:
data.erase(key) self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
data[new_key] = stacked else:
for key in stack_keys:
data_hash.erase(key)
data_hash[new_key] = stacked
def on_matched_data(self, train_id, sources):
new_sources_map = {}
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(source, data, timestamp, new_sources_map)
else:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source, data, timestamp, new_sources_map
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(new_sources_map) sources.update(new_sources_map)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment