Skip to content
Snippets Groups Projects
Commit 02a4ccd9 authored by David Hammer's avatar David Hammer
Browse files

Allow stacking on other axes, fix thread pool

parent b2b5aca1
No related branches found
No related tags found
2 merge requests!10DetectorAssembler: assemble with extra shape (multiple frames),!9Stacking shmem matcher
...@@ -5,6 +5,7 @@ import re ...@@ -5,6 +5,7 @@ import re
import numpy as np import numpy as np
from karabo.bound import ( from karabo.bound import (
BOOL_ELEMENT, BOOL_ELEMENT,
INT32_ELEMENT,
KARABO_CLASSINFO, KARABO_CLASSINFO,
STRING_ELEMENT, STRING_ELEMENT,
TABLE_ELEMENT, TABLE_ELEMENT,
...@@ -14,7 +15,7 @@ from karabo.bound import ( ...@@ -14,7 +15,7 @@ from karabo.bound import (
) )
from TrainMatcher import TrainMatcher from TrainMatcher import TrainMatcher
from . import shmem_utils from . import shmem_utils, utils
from ._version import version as deviceVersion from ._version import version as deviceVersion
...@@ -35,7 +36,7 @@ def merge_schema(): ...@@ -35,7 +36,7 @@ def merge_schema():
.commit(), .commit(),
STRING_ELEMENT(schema) STRING_ELEMENT(schema)
.key("source_pattern") .key("sourcePattern")
.displayedName("Source pattern") .displayedName("Source pattern")
.assignmentOptional() .assignmentOptional()
.defaultValue("") .defaultValue("")
...@@ -43,7 +44,7 @@ def merge_schema(): ...@@ -43,7 +44,7 @@ def merge_schema():
.commit(), .commit(),
STRING_ELEMENT(schema) STRING_ELEMENT(schema)
.key("key_pattern") .key("keyPattern")
.displayedName("Key pattern") .displayedName("Key pattern")
.assignmentOptional() .assignmentOptional()
.defaultValue("") .defaultValue("")
...@@ -66,6 +67,14 @@ def merge_schema(): ...@@ -66,6 +67,14 @@ def merge_schema():
.defaultValue(MergeGroupType.MULTISOURCE.value) .defaultValue(MergeGroupType.MULTISOURCE.value)
.reconfigurable() .reconfigurable()
.commit(), .commit(),
INT32_ELEMENT(schema)
.key("stackingAxis")
.displayedName("Axis")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
) )
return schema return schema
...@@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
( (
TABLE_ELEMENT(expected) TABLE_ELEMENT(expected)
.key("merge") .key("merge")
.displayedName("Array merging") .displayedName("Array stacking")
.allowedStates(State.PASSIVE) .allowedStates(State.PASSIVE)
.description( .description(
"List source or key patterns to merge their data arrays, e.g. to " "List source or key patterns to merge their data arrays, e.g. to "
...@@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
BOOL_ELEMENT(expected) BOOL_ELEMENT(expected)
.key("useThreadPool") .key("useThreadPool")
.displayedName("Use thread pool")
.allowedStates(State.PASSIVE)
.assignmentOptional() .assignmentOptional()
.defaultValue(False) .defaultValue(False)
.reconfigurable()
.commit(), .commit(),
) )
...@@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
super().preReconfigure(conf) super().preReconfigure(conf)
if conf.has("merge") or conf.has("sources"): if conf.has("merge") or conf.has("sources"):
self._prepare_merge_groups(conf["merge"]) self._prepare_merge_groups(conf["merge"])
if conf.has("useThreadPool"):
if self._thread_pool is not None:
self._thread_pool.shutdown()
self._thread_pool = None
if conf["useThreadPool"]:
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
def _prepare_merge_groups(self, merge): def _prepare_merge_groups(self, merge):
source_group_patterns = [] source_group_patterns = []
...@@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
if group_type is MergeGroupType.MULTISOURCE: if group_type is MergeGroupType.MULTISOURCE:
source_group_patterns.append( source_group_patterns.append(
( (
re.compile(row["source_pattern"]), re.compile(row["sourcePattern"]),
row["key_pattern"], row["keyPattern"],
row["replacement"], row["replacement"],
row["stackingAxis"],
) )
) )
else: else:
key_group_patterns.append( key_group_patterns.append(
( (
re.compile(row["source_pattern"]), re.compile(row["sourcePattern"]),
re.compile(row["key_pattern"]), re.compile(row["keyPattern"]),
row["replacement"], row["replacement"],
) )
) )
...@@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
# handle source stacking groups # handle source stacking groups
self._source_stacking_indices.clear() self._source_stacking_indices.clear()
self._source_stacking_sources.clear() self._source_stacking_sources.clear()
for source_re, key, new_source in source_group_patterns: for source_re, key, new_source, stack_axis in source_group_patterns:
merge_sources = [ merge_sources = [
source for source in source_names if source_re.match(source) source for source in source_names if source_re.match(source)
] ]
...@@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
continue continue
for (i, source) in enumerate(merge_sources): for (i, source) in enumerate(merge_sources):
self._source_stacking_sources.setdefault(source, []).append( self._source_stacking_sources.setdefault(source, []).append(
(key, new_source) (key, new_source, stack_axis)
) )
self._source_stacking_indices[(source, key)] = i self._source_stacking_indices[(source, new_source, key)] = i
# handle key stacking groups # handle key stacking groups
self._key_stacking_sources.clear() self._key_stacking_sources.clear()
...@@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
(key_re, new_key) (key_re, new_key)
) )
def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
# TODO: handle ValueError for max of empty sequence
stack_num = (
max(
index
for (
_,
new_source_,
key_,
), index in self._source_stacking_indices.items()
if new_source_ == new_source and key_ == key
)
+ 1
)
self._stacking_buffers[(new_source, key)] = np.empty(
shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis),
dtype=dtype,
)
def _handle_source(self, source, data_hash, timestamp, new_sources_map): def _handle_source(self, source, data_hash, timestamp, new_sources_map):
# dereference calng shmem handles # dereference calng shmem handles
self._shmem_handler.dereference_shmem_handles(data_hash) self._shmem_handler.dereference_shmem_handles(data_hash)
# stack across sources (many sources, same key) # stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get # could probably save ~100 ns by "if ... in" instead of get
for (stack_key, new_source) in self._source_stacking_sources.get(source, ()): for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get(
source, ()
):
this_data = data_hash.get(stack_key) this_data = data_hash.get(stack_key)
try: try:
this_buffer = self._stacking_buffers[(new_source, stack_key)] this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)] stack_index = self._source_stacking_indices[
this_buffer[stack_index] = this_data (source, new_source, stack_key)
]
utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
except (ValueError, IndexError, KeyError): except (ValueError, IndexError, KeyError):
# ValueError: wrong shape # ValueError: wrong shape (react to this_data.shape)
# KeyError: buffer doesn't exist yet # KeyError: buffer doesn't exist yet
# IndexError: new source? (buffer not long enough)
# either way, create appropriate buffer now # either way, create appropriate buffer now
# TODO: complain if shape varies between sources # TODO: complain if shape varies between sources within train
self._stacking_buffers[(new_source, stack_key)] = np.empty( self._update_stacking_buffer(
shape=( new_source,
max( stack_key,
index_ this_data.shape,
for ( axis=stack_axis,
source_,
key_,
), index_ in self._source_stacking_indices.items()
if source_ == source and key_ == stack_key
)
+ 1,
)
+ this_data.shape,
dtype=this_data.dtype, dtype=this_data.dtype,
) )
# and then try again # and then try again
this_buffer = self._stacking_buffers[(new_source, stack_key)] this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[(source, stack_key)] stack_index = self._source_stacking_indices[
this_buffer[stack_index] = this_data (source, new_source, stack_key)
# TODO: zero out unfilled buffer entries ]
utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
# TODO: zero out unfilled buffer entries
data_hash.erase(stack_key) data_hash.erase(stack_key)
if new_source not in new_sources_map: if new_source not in new_sources_map:
...@@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ...@@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
concurrent.futures.wait( concurrent.futures.wait(
[ [
self._thread_pool.submit( self._thread_pool.submit(
self._handle_source, data, timestamp, new_sources_map self._handle_source, source, data, timestamp, new_sources_map
) )
for source, (data, timestamp) in sources.items() for source, (data, timestamp) in sources.items()
] ]
......
...@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out): ...@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
return tuple(axis_order[axis] for axis in axes_out) return tuple(axis_order[axis] for axis in axes_out)
def stacking_buffer_shape(array_shape, stack_num, axis=0):
"""Figures out the shape you would need for np.stack"""
if axis > len(array_shape) or axis < -len(array_shape) - 1:
# complain when np.stack would
raise np.AxisError(
f"axis {axis} is out of bounds "
f"for array of dimension {len(array_shape) + 1}"
)
if axis < 0:
axis += len(array_shape) + 1
return array_shape[:axis] + (stack_num,) + array_shape[axis:]
def set_on_axis(array, vals, index, axis):
"""set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x"""
if axis >= len(array):
raise IndexError(
f"too many indices for array: array is {len(array.shape)}-dimensional, "
f"but {axis+1} were indexed"
)
# TODO: maybe support negative axis with wraparound
indices = np.index_exp[:] * axis + np.index_exp[index]
array[indices] = vals
_np_typechar_to_c_typestring = { _np_typechar_to_c_typestring = {
"?": "bool", "?": "bool",
"B": "unsigned char", "B": "unsigned char",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment