Skip to content
Snippets Groups Projects
Commit 9495dda8 authored by David Hammer's avatar David Hammer
Browse files

Squashed commit of the following:

commit beb32df5
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 13 13:25:10 2022 +0200

    Fix reconfigure bug and prevent sources from being sorted

commit a79b501a
Author: David Hammer <dhammer@mailbox.org>
Date:   Wed May 11 12:35:09 2022 +0200

    Update description of stacking patterns

commit 426d707b
Merge: 07b6c972 e1b0b8d1
Author: David Hammer <dhammer@mailbox.org>
Date:   Mon May 9 12:50:30 2022 +0200

    Merge branch 'stacking-shmem-matcher' of ssh://git.xfel.eu:10022/karaboDevices/calng into stacking-shmem-matcher

commit 07b6c972
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 14:53:31 2022 +0200

    Adding option to skip Karabo channel output

commit 02a4ccd9
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 13:27:59 2022 +0200

    Allow stacking on other axes, fix thread pool

commit b2b5aca1
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 16:30:14 2022 +0200

    Allow thread pool usage

commit f8f380cc
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 14:49:01 2022 +0200

    Use sources list to get source order, cache buffers

commit ef898211
Author: David Hammer <dhammer@mailbox.org>
Date:   Tue May 3 15:49:21 2022 +0200

    Adding Philipp's proposed changes to TrainMatcher for merging

commit e1b0b8d1
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 14:53:31 2022 +0200

    Adding option to skip Karabo channel output

commit bbb7cee4
Author: David Hammer <dhammer@mailbox.org>
Date:   Fri May 6 13:27:59 2022 +0200

    Allow stacking on other axes, fix thread pool

commit b01015c4
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 16:30:14 2022 +0200

    Allow thread pool usage

commit e7f96bae
Author: David Hammer <dhammer@mailbox.org>
Date:   Thu May 5 14:49:01 2022 +0200

    Use sources list to get source order, cache buffers

commit 193264ed
Author: David Hammer <dhammer@mailbox.org>
Date:   Tue May 3 15:49:21 2022 +0200

    Adding Philipp's proposed changes to TrainMatcher for merging
parent 2e66e93e
No related branches found
No related tags found
No related merge requests found
from karabo.bound import KARABO_CLASSINFO
import concurrent.futures
import enum
import re
import numpy as np
from karabo.bound import (
BOOL_ELEMENT,
INT32_ELEMENT,
KARABO_CLASSINFO,
STRING_ELEMENT,
TABLE_ELEMENT,
ChannelMetaData,
Hash,
Schema,
State,
VectorHash,
)
from TrainMatcher import TrainMatcher
from . import shmem_utils
from . import shmem_utils, utils
from ._version import version as deviceVersion
class MergeGroupType(enum.Enum):
MULTISOURCE = "sources" # same key stacked from multiple sources in new source
MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key
def merge_schema():
schema = Schema()
(
BOOL_ELEMENT(schema)
.key("select")
.displayedName("Select")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("sourcePattern")
.displayedName("Source pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("keyPattern")
.displayedName("Key pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("replacement")
.displayedName("Replacement")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("type")
.displayedName("Group type")
.options(",".join(option.value for option in MergeGroupType))
.assignmentOptional()
.defaultValue(MergeGroupType.MULTISOURCE.value)
.reconfigurable()
.commit(),
INT32_ELEMENT(schema)
.key("stackingAxis")
.displayedName("Axis")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
)
return schema
@KARABO_CLASSINFO("ShmemTrainMatcher", deviceVersion)
class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
@staticmethod
def expectedParameters(expected):
(
TABLE_ELEMENT(expected)
.key("merge")
.displayedName("Array stacking")
.allowedStates(State.PASSIVE)
.description(
"Specify which source(s) or key(s) to stack."
"When stacking sources, the 'Source pattern' is interpreted as a "
"regular expression and the 'Key pattern' is interpreted as an "
"ordinary string. From all sources matching the source pattern, the "
"data under this key (should be array with same dimensions across all "
"stacked sources) is stacked in the same order as the sources are "
"listed in 'Data sources' and the result is under the same key name in "
"a new source named by 'Replacement'. "
"When stacking keys, both the 'Source pattern' and the 'Key pattern' "
"are regular expressions. Within each source matching the source "
"pattern, all keys matching the key pattern are stacked and the result "
"is put under the key named by 'Replacement'. "
"While source stacking is optimized and can use thread pool, key "
"stacking will iterate over all paths in matched sources and naively "
"call np.stack for each key pattern. In either case, data that is used "
"for stacking is removed from its original location (e.g. key is "
"erased from hash)."
)
.setColumns(merge_schema())
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("useThreadPool")
.displayedName("Use thread pool")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("enableKaraboOutput")
.displayedName("Enable Karabo channel")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(True)
.reconfigurable()
.commit(),
)
def initialization(self):
self._stacking_buffers = {}
self._source_stacking_indices = {}
self._source_stacking_sources = {}
self._key_stacking_sources = {}
self._have_prepared_merge_groups = False
self._prepare_merge_groups()
super().initialization()
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
if self.get("useThreadPool"):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
else:
self._thread_pool = None
if not self.get("enableKaraboOutput"):
# it is set already by super by default, so only need to turn off
self.output = None
def preReconfigure(self, conf):
super().preReconfigure(conf)
if conf.has("merge") or conf.has("sources"):
self._have_prepared_merge_groups = False
# re-prepare in postReconfigure after sources *and* merge are in self
if conf.has("useThreadPool"):
if self._thread_pool is not None:
self._thread_pool.shutdown()
self._thread_pool = None
if conf["useThreadPool"]:
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
if conf.has("enableKaraboOutput"):
if conf["enableKaraboOutput"]:
self.output = self._ss.getOutputChannel("output")
else:
self.output = None
def postReconfigure(self):
super().postReconfigure()
if not self._have_prepared_merge_groups:
self._prepare_merge_groups()
def _prepare_merge_groups(self):
source_group_patterns = []
key_group_patterns = []
# split by type, prepare regexes
for row in self.get("merge"):
if not row["select"]:
continue
group_type = MergeGroupType(row["type"])
if group_type is MergeGroupType.MULTISOURCE:
source_group_patterns.append(
(
re.compile(row["sourcePattern"]),
row["keyPattern"],
row["replacement"],
row["stackingAxis"],
)
)
else:
key_group_patterns.append(
(
re.compile(row["sourcePattern"]),
re.compile(row["keyPattern"]),
row["replacement"],
)
)
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self.get("sources")]
self._stacking_buffers.clear()
# handle source stacking groups
self._source_stacking_indices.clear()
self._source_stacking_sources.clear()
for source_re, key, new_source, stack_axis in source_group_patterns:
merge_sources = [
source for source in source_names if source_re.match(source)
]
if len(merge_sources) == 0:
self.log.WARN(
f"Group merge pattern {source_re} did not match any known sources"
)
continue
for (i, source) in enumerate(merge_sources):
self._source_stacking_sources.setdefault(source, []).append(
(key, new_source, stack_axis)
)
self._source_stacking_indices[(source, new_source, key)] = i
# handle key stacking groups
self._key_stacking_sources.clear()
for source_re, key_re, new_key in key_group_patterns:
for source in source_names:
# TODO: maybe also warn if no matches here?
if not source_re.match(source):
continue
self._key_stacking_sources.setdefault(source, []).append(
(key_re, new_key)
)
self._have_prepared_merge_groups = True
def _update_stacking_buffer(self, new_source, key, individual_shape, axis, dtype):
# TODO: handle ValueError for max of empty sequence
stack_num = (
max(
index
for (
_,
new_source_,
key_,
), index in self._source_stacking_indices.items()
if new_source_ == new_source and key_ == key
)
+ 1
)
self._stacking_buffers[(new_source, key)] = np.empty(
shape=utils.stacking_buffer_shape(individual_shape, stack_num, axis=axis),
dtype=dtype,
)
def _handle_source(self, source, data_hash, timestamp, new_sources_map):
# dereference calng shmem handles
self._shmem_handler.dereference_shmem_handles(data_hash)
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
for (stack_key, new_source, stack_axis) in self._source_stacking_sources.get(
source, ()
):
this_data = data_hash.get(stack_key)
try:
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[
(source, new_source, stack_key)
]
utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
except (ValueError, IndexError, KeyError):
# ValueError: wrong shape (react to this_data.shape)
# KeyError: buffer doesn't exist yet
# IndexError: new source? (buffer not long enough)
# either way, create appropriate buffer now
# TODO: complain if shape varies between sources within train
self._update_stacking_buffer(
new_source,
stack_key,
this_data.shape,
axis=stack_axis,
dtype=this_data.dtype,
)
# and then try again
this_buffer = self._stacking_buffers[(new_source, stack_key)]
stack_index = self._source_stacking_indices[
(source, new_source, stack_key)
]
utils.set_on_axis(this_buffer, this_data, stack_index, stack_axis)
# TODO: zero out unfilled buffer entries
data_hash.erase(stack_key)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(stack_key):
new_source_hash[stack_key] = this_buffer
# stack keys (multiple keys within this source)
for (key_re, new_key) in self._key_stacking_sources.get(source, ()):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
stack_keys = [key for key in data_hash.paths() if key_re.match(key)]
try:
# TODO: consider reusing buffers here, too
stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in stack_keys:
data_hash.erase(key)
data_hash[new_key] = stacked
def on_matched_data(self, train_id, sources):
new_sources_map = {}
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(source, data, timestamp, new_sources_map)
else:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._handle_source, source, data, timestamp, new_sources_map
)
for source, (data, timestamp) in sources.items()
]
)
sources.update(new_sources_map)
# karabo output
if self.output is not None:
for source, (data, timestamp) in sources.items():
self.output.write(data, ChannelMetaData(source, timestamp))
self.output.update()
# karabo bridge output
for source, (data, timestamp) in sources.items():
if data.has("calngShmemPaths"):
shmem_paths = list(data["calngShmemPaths"])
data.erase("calngShmemPaths")
for shmem_path in shmem_paths:
if not data.has(shmem_path):
self.log.INFO(f"Hash from {source} did not have {shmem_path}")
continue
dereferenced = self._shmem_handler.get(data[shmem_path])
data[shmem_path] = dereferenced
super().on_matched_data(train_id, sources)
self.zmq_output.write(source, data, timestamp)
self.zmq_output.update()
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _maybe_connect_data(self, conf, update=False, state=None):
"""Temporary override on _maybe_connect_data to avoid sorting sources list (we
need it for stacking order)"""
if self["state"] not in (State.CHANGING, State.ACTIVE):
return
last_state = self["state"]
self.updateState(State.CHANGING)
# unwatch removed sources
def src_names(c):
# do not assign a lambda expression, use a def
return {s["source"] for s in c["sources"]}
for source in src_names(self).difference(src_names(conf)):
self.monitor.unwatch_source(source)
new_conf = VectorHash()
for src in conf["sources"]:
source = src["source"]
if src["select"]:
src["status"] = self.monitor.watch_source(source, src["offset"])
else:
self.monitor.unwatch_source(source)
src["status"] = ""
new_conf.append(src)
if update:
self.set("sources", new_conf)
else:
conf["sources"] = new_conf
self.updateState(state or last_state)
......@@ -46,6 +46,18 @@ class ShmemCircularBufferReceiver:
return ary[index]
def dereference_shmem_handles(self, data_hash):
if data_hash.has("calngShmemPaths"):
shmem_paths = list(data_hash["calngShmemPaths"])
data_hash.erase("calngShmemPaths")
for shmem_path in shmem_paths:
if not data_hash.has(shmem_path):
# TODO: proper warnings
print(f"Warning: hash did not contain {shmem_path}")
continue
dereferenced = self.get(data_hash[shmem_path])
data_hash[shmem_path] = dereferenced
class ShmemCircularBuffer:
"""Convenience wrapper around posixshmem-backed ndarray buffers
......
......@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
return tuple(axis_order[axis] for axis in axes_out)
def stacking_buffer_shape(array_shape, stack_num, axis=0):
"""Figures out the shape you would need for np.stack"""
if axis > len(array_shape) or axis < -len(array_shape) - 1:
# complain when np.stack would
raise np.AxisError(
f"axis {axis} is out of bounds "
f"for array of dimension {len(array_shape) + 1}"
)
if axis < 0:
axis += len(array_shape) + 1
return array_shape[:axis] + (stack_num,) + array_shape[axis:]
def set_on_axis(array, vals, index, axis):
"""set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x"""
if axis >= len(array):
raise IndexError(
f"too many indices for array: array is {len(array.shape)}-dimensional, "
f"but {axis+1} were indexed"
)
# TODO: maybe support negative axis with wraparound
indices = np.index_exp[:] * axis + np.index_exp[index]
array[indices] = vals
_np_typechar_to_c_typestring = {
"?": "bool",
"B": "unsigned char",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment