diff --git a/src/calng/FrameSelectionArbiter.py b/src/calng/FrameSelectionArbiter.py index 1a077fa78fd673b97f6809acf0637a8c3e2eea93..da20600505d635c5c1b713e9bfe87c1a5b65d28a 100644 --- a/src/calng/FrameSelectionArbiter.py +++ b/src/calng/FrameSelectionArbiter.py @@ -1,40 +1,116 @@ from karabo.bound import ( + DOUBLE_ELEMENT, KARABO_CLASSINFO, NDARRAY_ELEMENT, NODE_ELEMENT, OUTPUT_CHANNEL, - SLOT_ELEMENT, STRING_ELEMENT, + VECTOR_BOOL_ELEMENT, ChannelMetaData, Epochstamp, Hash, Schema, Timestamp, Trainstamp, + Unit, ) import numpy as np from TrainMatcher import TrainMatcher from ._version import version as deviceVersion +from . import utils my_schema = Schema() ( - NDARRAY_ELEMENT(my_schema) - .key("veto") - .dtype("INT64").commit(), - - NDARRAY_ELEMENT(my_schema) - .key("vetoAsMask") - .dtype("UINT8").commit(), + NODE_ELEMENT(my_schema) + .key("data") + .commit(), + + VECTOR_BOOL_ELEMENT(my_schema) + .key("data.dataFramePattern") + .assignmentOptional() + .defaultValue([]) + .commit(), ) class BaseArbiterKernel: + _node = None # just for schema purposes, set in subclass pass + +class RandomSampler(BaseArbiterKernel): + _node = "randomSampler" + + def __init__(self, config, prefix="frameSelection.kernels"): + self._threshold = ( + 100 - config.get(f"{prefix}.{self._node}.probability") + ) / 100 + + @staticmethod + def add_schema(schema, prefix="frameSelection.kernels"): + node = RandomSampler._node + ( + NODE_ELEMENT(schema) + .key(f"{prefix}.{node}") + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{prefix}.{node}.probability") + .unit(Unit.PERCENT) + .assignmentOptional() + .defaultValue(50) + .reconfigurable() + .commit(), + ) + + def consider(self, train_id, sources, num_frames): + return (np.random.random(num_frames) > self._threshold).astype( + np.uint8, copy=False + ) + + class IntegratedIntensityArbiter(BaseArbiterKernel): - ... + _node = "integratedIntensity" + + def __init__(self, config, prefix="frameSelection.kernels"): + self._threshold = config.get(f"{prefix}.{self._node}.threshold") + + @staticmethod + def add_schema(schema, prefix="frameSelection.kernels"): + node = IntegratedIntensityArbiter._node + ( + NODE_ELEMENT(schema) + .key(f"{prefix}.{node}") + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{prefix}.{node}.threshold") + .assignmentOptional() + .defaultValue(1e8) + .reconfigurable() + .commit(), + ) + + def consider(self, train_id, sources, num_frames): + return ( + np.sum( + [ + data["image.integratedIntensity"] + for (data, _) in sources.values() + if data.has("image.integratedIntensity") + ], + axis=0, + ) + > self._threshold + ).astype(np.uint8, copy=False) + + +kernel_options = { + cls.__name__: cls for cls in (RandomSampler, IntegratedIntensityArbiter) +} + @KARABO_CLASSINFO("FrameSelectionArbiter", deviceVersion) class FrameSelectionArbiter(TrainMatcher.TrainMatcher): @@ -43,70 +119,47 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher): ( NODE_ELEMENT(expected) .key("frameSelection") - .displayedName("Soft veto") + .displayedName("Frame selection") .commit(), STRING_ELEMENT(expected) - .key("frameSelection.contextFile") - .displayedName("Context file") - .assignmentOptional() - .defaultValue("") - .commit(), - - STRING_ELEMENT(expected) - .key("frameSelection.kernelClass") - .displayedName("Kernel class name") - .assignmentOptional() - .defaultValue("") - .commit(), - - SLOT_ELEMENT(expected) - .key("frameSelection.reconfigure") - .displayedName("Reconfigure") + .key("frameSelection.kernelChoice") + .displayedName("Kernel to use") + .assignmentMandatory() + .options(",".join(kernel_options.keys())) + .reconfigurable() .commit(), OUTPUT_CHANNEL(expected) .key("output") .dataSchema(my_schema) .commit(), - ) - def __init__(self, config): - super().__init__(config) - - self.kernel = None - self.KARABO_SLOT( - self.frameSelection_reconfigure, slotName="frameSelection_reconfigure" + NODE_ELEMENT(expected) + .key("frameSelection.kernels") + .displayedName("Kernels") + .commit(), ) - self.registerInitialFunction(self.frameSelection_reconfigure) + for cls in kernel_options.values(): + cls.add_schema(expected) - def frameSelection_reconfigure(self): - spec = importlib.util.spec_from_file_location( - self.get("frameSelection.kernelClass"), - self.get("frameSelection.contextFile"), - ) - module = importlib.util.module_from_spec(spec) - # sys.modules[confi.get("frameSelection.kernelClass")] = module - spec.loader.exec_module(module) - self.kernel = getattr(module, self.get("frameSelection.kernelClass"))( - self._parameters - ) + def initialization(self): + super().initialization() + self._kernel_class = kernel_options[self.get("frameSelection.kernelChoice")] + self._kernel = self._kernel_class(self._parameters) def on_matched_data(self, train_id, sources): - if self.kernel is None: - return - + # TODO: robust frame deduction for (data, _) in sources.values(): - if not data.has("peakfinding"): + if not data.has("image.cellId"): continue - num_frames = data.get("peakfinding.numPeaks").size + num_frames = data.get("image.cellId").size break + decision = self._kernel.consider(train_id, sources, num_frames) result = Hash() - self.kernel(train_id, sources, result) - mask = np.zeros(num_frames, dtype=np.uint8) - mask[result["veto"]] = 1 - result["vetoAsMask"] = mask + # TODO: avoid recasting + result["data.dataFramePattern"] = list(map(bool, decision)) self.output.write( result, ChannelMetaData( @@ -116,3 +169,16 @@ class FrameSelectionArbiter(TrainMatcher.TrainMatcher): ) self.output.update() self.rate_out.update() + + def preReconfigure(self, conf): + super().preReconfigure(conf) + merged_conf = utils.ChainHash(conf, self._parameters) + if conf.has("frameSelection.kernelChoice"): + self.log.INFO("Switching frame selection kernel") + self._kernel_class = kernel_options[conf.get("frameSelection.kernelChoice")] + self._kernel = self._kernel_class(merged_conf) + elif conf.has(f"frameSelection.kernels.{self._kernel_class._node}"): + self.log.INFO("Reconfiguring frame selection kernel") + # TODO: update instead of rebuild for complex kernels + # (decide based on reconfigurability?) + self._kernel = self._kernel_class(merged_conf) diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index e66ea8bef2f212f38abdae278678ed9b0f84edb7..814c488860fd2c1ee8b5a9465a1f7a3aa81e6546 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -188,8 +188,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .key("frameSelector.arbiterSource") .displayedName("Arbiter source") .description( - "Source name to pull the veto pattern from, must be part of matched " - "sources." + "Source name to pull the frame selection pattern from, must be part of " + "matched sources." ) .assignmentOptional() .defaultValue("") @@ -200,8 +200,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .key("frameSelector.dataSourcePattern") .displayedName("Data source pattern") .description( - "Source name pattern to apply veto pattern to, must be part of matched " - "sources." + "Source name pattern to apply frame selection to. Should match " + "subset of matched sources." ) .assignmentOptional() .defaultValue("") @@ -211,7 +211,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): VECTOR_STRING_ELEMENT(expected) .key("frameSelector.dataKeys") .displayedName("Data keys") - .description("Keys in data sources to apply veto pattern to.") + .description("Keys in data sources to apply frame selection to.") .assignmentOptional() .defaultValue([]) .reconfigurable() @@ -350,20 +350,20 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ) def _handle_source( - self, source, data_hash, timestamp, new_sources_map, frame_selection_frames + self, source, data_hash, timestamp, new_sources_map, frame_selection_mask ): # dereference calng shmem handles self._shmem_handler.dereference_shmem_handles(data_hash) # apply frame_selection - if frame_selection_frames is not None and self._frame_selection_source_pattern.match( + if frame_selection_mask is not None and self._frame_selection_source_pattern.match( source ): for key in self._frame_selection_data_keys: if not data_hash.has(key): continue - data_hash[key] = data_hash[key][frame_selection_frames] + data_hash[key] = data_hash[key][frame_selection_mask] # stack across sources (many sources, same key) # could probably save ~100 ns by "if ... in" instead of get @@ -467,14 +467,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): def on_matched_data(self, train_id, sources): new_sources_map = {} - frame_selection_frames = None + frame_selection_mask = None if self._frame_selection_enabled and self._frame_selection_arbiter in sources: - frame_selection_frames = sources[self._frame_selection_arbiter][0]["veto"] + frame_selection_mask = sources[self._frame_selection_arbiter][0][ + "data.dataFramePattern" + ] if self._thread_pool is None: for source, (data, timestamp) in sources.items(): self._handle_source( - source, data, timestamp, new_sources_map, frame_selection_frames + source, data, timestamp, new_sources_map, frame_selection_mask ) else: concurrent.futures.wait( @@ -485,7 +487,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): data, timestamp, new_sources_map, - frame_selection_frames, + frame_selection_mask, ) for source, (data, timestamp) in sources.items() ] diff --git a/src/calng/utils.py b/src/calng/utils.py index 48e98452261ab7c4eec8a5057a0490ebfbd77740..b27aff2aa0a7dfc55aeca4a4de6d8848d5d60b5d 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -475,6 +475,9 @@ class ChainHash: return h[key] raise KeyError() + def get(self, key): + return self[key] + class SkippingThrottler: def __init__(self, min_period):