Skip to content
Snippets Groups Projects
Commit c3d2d437 authored by David Hammer's avatar David Hammer
Browse files

Prototype configurable frame selection arbiter

parent ba6eaa9f
No related branches found
No related tags found
5 merge requests!59Interface CrystFEL with Karabo and allow automatic parameter tunning with rcrystfel,!53Train picker arbiter kernel,!52SFX hitfinder and geometry for correction addons,!47Draft: CrystFEL integration,!35Frame selection
from karabo.bound import (
DOUBLE_ELEMENT,
KARABO_CLASSINFO,
NDARRAY_ELEMENT,
NODE_ELEMENT,
OUTPUT_CHANNEL,
SLOT_ELEMENT,
STRING_ELEMENT,
VECTOR_BOOL_ELEMENT,
ChannelMetaData,
Epochstamp,
Hash,
Schema,
Timestamp,
Trainstamp,
Unit,
)
import numpy as np
from TrainMatcher import TrainMatcher
from ._version import version as deviceVersion
from . import utils
my_schema = Schema()
(
NDARRAY_ELEMENT(my_schema)
.key("veto")
.dtype("INT64").commit(),
NDARRAY_ELEMENT(my_schema)
.key("vetoAsMask")
.dtype("UINT8").commit(),
NODE_ELEMENT(my_schema)
.key("data")
.commit(),
VECTOR_BOOL_ELEMENT(my_schema)
.key("data.dataFramePattern")
.assignmentOptional()
.defaultValue([])
.commit(),
)
class BaseArbiterKernel:
_node = None # just for schema purposes, set in subclass
pass
class RandomSampler(BaseArbiterKernel):
_node = "randomSampler"
def __init__(self, config, prefix="frameSelection.kernels"):
self._threshold = (
100 - config.get(f"{prefix}.{self._node}.probability")
) / 100
@staticmethod
def add_schema(schema, prefix="frameSelection.kernels"):
node = RandomSampler._node
(
NODE_ELEMENT(schema)
.key(f"{prefix}.{node}")
.commit(),
DOUBLE_ELEMENT(schema)
.key(f"{prefix}.{node}.probability")
.unit(Unit.PERCENT)
.assignmentOptional()
.defaultValue(50)
.reconfigurable()
.commit(),
)
def consider(self, train_id, sources, num_frames):
return (np.random.random(num_frames) > self._threshold).astype(
np.uint8, copy=False
)
class IntegratedIntensityArbiter(BaseArbiterKernel):
...
_node = "integratedIntensity"
def __init__(self, config, prefix="frameSelection.kernels"):
self._threshold = config.get(f"{prefix}.{self._node}.threshold")
@staticmethod
def add_schema(schema, prefix="frameSelection.kernels"):
node = IntegratedIntensityArbiter._node
(
NODE_ELEMENT(schema)
.key(f"{prefix}.{node}")
.commit(),
DOUBLE_ELEMENT(schema)
.key(f"{prefix}.{node}.threshold")
.assignmentOptional()
.defaultValue(1e8)
.reconfigurable()
.commit(),
)
def consider(self, train_id, sources, num_frames):
return (
np.sum(
[
data["image.integratedIntensity"]
for (data, _) in sources.values()
if data.has("image.integratedIntensity")
],
axis=0,
)
> self._threshold
).astype(np.uint8, copy=False)
kernel_options = {
cls.__name__: cls for cls in (RandomSampler, IntegratedIntensityArbiter)
}
@KARABO_CLASSINFO("FrameSelectionArbiter", deviceVersion)
class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
......@@ -43,70 +119,47 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
(
NODE_ELEMENT(expected)
.key("frameSelection")
.displayedName("Soft veto")
.displayedName("Frame selection")
.commit(),
STRING_ELEMENT(expected)
.key("frameSelection.contextFile")
.displayedName("Context file")
.assignmentOptional()
.defaultValue("")
.commit(),
STRING_ELEMENT(expected)
.key("frameSelection.kernelClass")
.displayedName("Kernel class name")
.assignmentOptional()
.defaultValue("")
.commit(),
SLOT_ELEMENT(expected)
.key("frameSelection.reconfigure")
.displayedName("Reconfigure")
.key("frameSelection.kernelChoice")
.displayedName("Kernel to use")
.assignmentMandatory()
.options(",".join(kernel_options.keys()))
.reconfigurable()
.commit(),
OUTPUT_CHANNEL(expected)
.key("output")
.dataSchema(my_schema)
.commit(),
)
def __init__(self, config):
super().__init__(config)
self.kernel = None
self.KARABO_SLOT(
self.frameSelection_reconfigure, slotName="frameSelection_reconfigure"
NODE_ELEMENT(expected)
.key("frameSelection.kernels")
.displayedName("Kernels")
.commit(),
)
self.registerInitialFunction(self.frameSelection_reconfigure)
for cls in kernel_options.values():
cls.add_schema(expected)
def frameSelection_reconfigure(self):
spec = importlib.util.spec_from_file_location(
self.get("frameSelection.kernelClass"),
self.get("frameSelection.contextFile"),
)
module = importlib.util.module_from_spec(spec)
# sys.modules[confi.get("frameSelection.kernelClass")] = module
spec.loader.exec_module(module)
self.kernel = getattr(module, self.get("frameSelection.kernelClass"))(
self._parameters
)
def initialization(self):
super().initialization()
self._kernel_class = kernel_options[self.get("frameSelection.kernelChoice")]
self._kernel = self._kernel_class(self._parameters)
def on_matched_data(self, train_id, sources):
if self.kernel is None:
return
# TODO: robust frame deduction
for (data, _) in sources.values():
if not data.has("peakfinding"):
if not data.has("image.cellId"):
continue
num_frames = data.get("peakfinding.numPeaks").size
num_frames = data.get("image.cellId").size
break
decision = self._kernel.consider(train_id, sources, num_frames)
result = Hash()
self.kernel(train_id, sources, result)
mask = np.zeros(num_frames, dtype=np.uint8)
mask[result["veto"]] = 1
result["vetoAsMask"] = mask
# TODO: avoid recasting
result["data.dataFramePattern"] = list(map(bool, decision))
self.output.write(
result,
ChannelMetaData(
......@@ -116,3 +169,16 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher):
)
self.output.update()
self.rate_out.update()
def preReconfigure(self, conf):
super().preReconfigure(conf)
merged_conf = utils.ChainHash(conf, self._parameters)
if conf.has("frameSelection.kernelChoice"):
self.log.INFO("Switching frame selection kernel")
self._kernel_class = kernel_options[conf.get("frameSelection.kernelChoice")]
self._kernel = self._kernel_class(merged_conf)
elif conf.has(f"frameSelection.kernels.{self._kernel_class._node}"):
self.log.INFO("Reconfiguring frame selection kernel")
# TODO: update instead of rebuild for complex kernels
# (decide based on reconfigurability?)
self._kernel = self._kernel_class(merged_conf)
......@@ -171,8 +171,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.key("frameSelector.arbiterSource")
.displayedName("Arbiter source")
.description(
"Source name to pull the veto pattern from, must be part of matched "
"sources."
"Source name to pull the frame selection pattern from, must be part of "
"matched sources."
)
.assignmentOptional()
.defaultValue("")
......@@ -183,8 +183,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.key("frameSelector.dataSourcePattern")
.displayedName("Data source pattern")
.description(
"Source name pattern to apply veto pattern to, must be part of matched "
"sources."
"Source name pattern to apply frame selection to. Should match "
"subset of matched sources."
)
.assignmentOptional()
.defaultValue("")
......@@ -194,7 +194,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
VECTOR_STRING_ELEMENT(expected)
.key("frameSelector.dataKeys")
.displayedName("Data keys")
.description("Keys in data sources to apply veto pattern to.")
.description("Keys in data sources to apply frame selection to.")
.assignmentOptional()
.defaultValue([])
.reconfigurable()
......@@ -347,20 +347,20 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
)
def _handle_source(
self, source, data_hash, timestamp, new_sources_map, frame_selection_frames
self, source, data_hash, timestamp, new_sources_map, frame_selection_mask
):
# dereference calng shmem handles
self._shmem_handler.dereference_shmem_handles(data_hash)
# apply frame_selection
if frame_selection_frames is not None and self._frame_selection_source_pattern.match(
if frame_selection_mask is not None and self._frame_selection_source_pattern.match(
source
):
for key in self._frame_selection_data_keys:
if not data_hash.has(key):
continue
data_hash[key] = data_hash[key][frame_selection_frames]
data_hash[key] = data_hash[key][frame_selection_mask]
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
......@@ -419,14 +419,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def on_matched_data(self, train_id, sources):
new_sources_map = {}
frame_selection_frames = None
frame_selection_mask = None
if self._frame_selection_enabled and self._frame_selection_arbiter in sources:
frame_selection_frames = sources[self._frame_selection_arbiter][0]["veto"]
frame_selection_mask = sources[self._frame_selection_arbiter][0][
"data.dataFramePattern"
]
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(
source, data, timestamp, new_sources_map, frame_selection_frames
source, data, timestamp, new_sources_map, frame_selection_mask
)
else:
concurrent.futures.wait(
......@@ -437,7 +439,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
data,
timestamp,
new_sources_map,
frame_selection_frames,
frame_selection_mask,
)
for source, (data, timestamp) in sources.items()
]
......
......@@ -463,6 +463,9 @@ class ChainHash:
return h[key]
raise KeyError()
def get(self, key):
return self[key]
class SkippingThrottler:
def __init__(self, min_period):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment