Skip to content
Snippets Groups Projects
Commit b01015c4 authored by David Hammer's avatar David Hammer
Browse files

Allow thread pool usage

parent e7f96bae
No related branches found
No related tags found
2 merge requests!10DetectorAssembler: assemble with extra shape (multiple frames),!9Stacking shmem matcher
import concurrent.futures
import enum
import re
......@@ -91,6 +92,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.defaultValue([])
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("useThreadPool")
.assignmentOptional()
.defaultValue(False)
.commit(),
)
def initialization(self):
......@@ -101,6 +108,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._prepare_merge_groups(self.get("merge"))
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
if self.get("useThreadPool"):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
else:
self._thread_pool = None
def preReconfigure(self, conf):
super().preReconfigure(conf)
......@@ -165,70 +176,81 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
(key_re, new_key)
)
def on_matched_data(self, train_id, sources):
def _handle_source(self, source, data_hash, timestamp, new_sources_map):
# dereference calng shmem handles
for (data, _) in sources.values():
self._shmem_handler.dereference_shmem_handles(data)
self._shmem_handler.dereference_shmem_handles(data_hash)
new_sources_map = {}
for source, (data, timestamp) in sources.items():
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
for (stack_key, new_source) in self._source_stacking_sources.get(
source, ()
):
this_data = data.get(stack_key)
try:
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)]
this_buffer[stack_index] = this_data
except (ValueError, KeyError):
# ValueError: wrong shape
# KeyError: buffer doesn't exist yet
# either way, create appropriate buffer now
# TODO: complain if shape varies between sources
self._stacking_buffers[(new_source, stack_key)] = np.empty(
shape=(
max(
index_
for (
source_,
key_,
), index_ in self._source_stacking_indices.items()
if source_ == source and key_ == stack_key
)
+ 1,
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
for (stack_key, new_source) in self._source_stacking_sources.get(source, ()):
this_data = data_hash.get(stack_key)
try:
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)]
this_buffer[stack_index] = this_data
except (ValueError, IndexError, KeyError):
# ValueError: wrong shape
# KeyError: buffer doesn't exist yet
# either way, create appropriate buffer now
# TODO: complain if shape varies between sources
self._stacking_buffers[(new_source, stack_key)] = np.empty(
shape=(
max(
index_
for (
source_,
key_,
), index_ in self._source_stacking_indices.items()
if source_ == source and key_ == stack_key
)
+ this_data.shape,
dtype=this_data.dtype,
+ 1,
)
# and then try again
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)]
this_buffer[stack_index] = this_data
# TODO: zero out unfilled buffer entries
data.erase(stack_key)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(stack_key):
new_source_hash[stack_key] = this_buffer
# stack keys (multiple keys within this source)
for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
stack_keys = [key for key in data.paths() if key_re.match(key)]
try:
# TODO: consider reusing buffers here, too
stacked = np.stack([data.get(key) for key in stack_keys], axis=0)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in stack_keys:
data.erase(key)
data[new_key] = stacked
+ this_data.shape,
dtype=this_data.dtype,
)
# and then try again
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)]
this_buffer[stack_index] = this_data
# TODO: zero out unfilled buffer entries
data_hash.erase(stack_key)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(stack_key):
new_source_hash[stack_key] = this_buffer
# stack keys (multiple keys within this source)
for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
stack_keys = [key for key in data_hash.paths() if key_re.match(key)]
try:
# TODO: consider reusing buffers here, too
stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in stack_keys:
data_hash.erase(key)
data_hash[new_key] = stacked
def on_matched_data(self, train_id, sources):
new_sources_map = {}
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(source, data, timestamp, new_sources_map)
else:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source, data, timestamp, new_sources_map
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(new_sources_map)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment