Skip to content
Snippets Groups Projects
Commit e8a7fdd0 authored by David Hammer's avatar David Hammer
Browse files

Start adding thread pool for assembler, too

parent 713d3ee3
No related branches found
No related tags found
1 merge request!10DetectorAssembler: assemble with extra shape (multiple frames)
import concurrent.futures
import enum
import functools
import re
import numpy as np
from karabo.bound import (
BOOL_ELEMENT,
DOUBLE_ELEMENT,
FLOAT_ELEMENT,
IMAGEDATA_ELEMENT,
......@@ -23,6 +25,7 @@ from karabo.bound import (
ImageData,
MetricPrefix,
Schema,
State,
Timestamp,
Trainstamp,
Unit,
......@@ -185,6 +188,42 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
)
.assignmentMandatory()
.commit(),
BOOL_ELEMENT(expected)
.key("useThreadPool")
.displayedName("Use thread pool")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("enableKaraboOutput")
.displayedName("Enable Karabo channel for regular TrainMatcher output")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(True)
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("enableKaraboAssembledOutput")
.displayedName("Enable Karabo channel for assembled output")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(True)
.reconfigurable()
.commit(),
BOOL_ELEMENT(expected)
.key("enablePreview")
.displayedName("Enable preview output")
.allowedStates(State.PASSIVE)
.assignmentOptional()
.defaultValue(True)
.reconfigurable()
.commit(),
)
def __init__(self, conf):
......@@ -196,12 +235,15 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
# TODO: match inside device, fill multiple independent buffers
self._throttler = utils.SkippingThrottler(1 / self.get("preview.maxRate"))
self._preview_throttler = utils.SkippingThrottler(
1 / self.get("preview.maxRate")
)
self._path_to_stack = self.get("pathToStack")
self._geometry = None
self._stack_input_buffer = None
self._position_output_buffer = None
self._extra_shape = ()
self._module_indices_unfilled = set()
self.KARABO_SLOT(self.requestScene)
......@@ -222,8 +264,13 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
self.remote().registerDeviceMonitor(geometry_device, self._receive_geometry)
# regular self.output is gotten by superclass
self.assembled_output = self.signalSlotable.getOutputChannel("assembledOutput")
self.preview_output = self.signalSlotable.getOutputChannel("preview.output")
if self.get("useThreadPool"):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
else:
self._thread_pool = None
self.start()
def requestScene(self, params):
......@@ -257,30 +304,27 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
self.log.INFO("New geometry empty, will ignore update.")
return
self._geometry = geom_utils.deserialize_geometry(serialized_geometry)
# TODO: allow multiple memory cells (extra geom notion of extra dimensions)
def on_matched_data(self, train_id, sources):
if self._geometry is None:
self.log.WARN("Have not received a geometry yet, will not send anything")
return
self._module_indices_unfilled.clear()
my_timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))
my_source = self.getInstanceId()
bridge_output_choice = BridgeOutputOptions(
self.unsafe_get("outputForBridgeOutput")
)
def _stack_this_guy(self, source, data):
# TODO: handle failure to "parse" source, get data out
module_index = self._source_to_index(source)
self._stack_input_buffer[..., module_index, :, :] = data.get(
self._path_to_stack
).astype(np.float32, copy=False) # TODO: set dtype based on input?
self._module_indices_unfilled.discard(module_index)
# check and maybe update stacking, output buffers
input_shape = next(iter(sources.values()))[0].get(self._path_to_stack).shape
def _maybe_update_buffers(self, input_shape):
input_extra_shape = input_shape[:-2]
if self._stack_input_buffer is None or input_extra_shape != self._extra_shape:
self._extra_shape = input_extra_shape
self._stack_input_buffer = np.zeros(
self._extra_shape + self._geometry.expected_data_shape,
dtype=np.float32
self._extra_shape + self._geometry.expected_data_shape, dtype=np.float32
)
self._position_output_buffer = self._geometry.output_array_for_position_fast(
extra_shape=self._extra_shape, dtype=np.float32
self._position_output_buffer = (
self._geometry.output_array_for_position_fast(
extra_shape=self._extra_shape, dtype=np.float32
)
)
self.log.INFO(
f"Updating stacking buffer to shape: {self._stack_input_buffer.shape}"
......@@ -288,95 +332,137 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
self.log.INFO(
f"Updating output buffer to shape: {self._position_output_buffer.shape}"
)
self._module_indices_unfilled.clear()
module_indices_unfilled = set(range(self._geometry.n_modules))
earliest_source_timestamp = float("inf")
for source, (data, source_timestamp) in sources.items():
# regular TrainMatcher output
self.output.write(data, ChannelMetaData(source, source_timestamp))
if bridge_output_choice is BridgeOutputOptions.MATCHED:
self.zmq_output.write(source, data, source_timestamp)
# prepare for assembly
# TODO: handle failure to "parse" source, get data out
module_index = self._source_to_index(source)
self._stack_input_buffer[..., module_index, :, :] = data.get(
self._path_to_stack
).astype(np.float32, copy=False) # TODO: set dtype based on input?
module_indices_unfilled.discard(module_index)
earliest_source_timestamp = min(
earliest_source_timestamp, source_timestamp.toTimestamp()
)
self.output.update()
if bridge_output_choice is BridgeOutputOptions.MATCHED:
self.zmq_output.update()
def on_matched_data(self, train_id, sources):
my_timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))
my_source = self.getInstanceId()
bridge_output_choice = BridgeOutputOptions(
self.unsafe_get("outputForBridgeOutput")
)
for module_index in module_indices_unfilled:
self._stack_input_buffer[module_index].fill(0)
# TODO: configurable treatment of missing modules
# regular TrainMatcher output
with utils.Stopwatch(name="Write Karabo output", track=True):
if self.unsafe_get("enableKaraboOutput"):
for source, (data, source_timestamp) in sources.items():
# regular TrainMatcher output
self.output.write(
data, ChannelMetaData(source, source_timestamp), copyAllData=False
)
self.output.update()
with utils.Stopwatch(name="Write bridge output", track=True):
if bridge_output_choice is BridgeOutputOptions.MATCHED:
for source, (data, source_timestamp) in sources.items():
self.zmq_output.write(source, data, source_timestamp)
self.zmq_output.update()
# TODO: reusable output buffer to save on allocation
assembled, _ = self._geometry.position_modules_fast(
self._stack_input_buffer, out=self._position_output_buffer
)
# for the rest, we need geometry
if self._geometry is None:
self.log.WARN("Have not received a geometry yet, will not send anything")
return
# TODO: optionally include control data
output_hash = Hash(
"image.data",
assembled,
"trainId",
train_id,
)
output_metadata = ChannelMetaData(my_source, my_timestamp)
self.assembled_output.write(output_hash, output_metadata)
self.assembled_output.update()
if bridge_output_choice is BridgeOutputOptions.ASSEMBLED:
self.zmq_output.write(my_source, output_hash, my_timestamp)
self.zmq_output.update()
if self._throttler.test_and_set():
downsampling_factor = self.unsafe_get("preview.downsamplingFactor")
if downsampling_factor > 1:
assembled = downsample_2d(
assembled,
downsampling_factor,
reduction_fun=getattr(
np, self.unsafe_get("preview.downsamplingFunction")
),
with utils.Stopwatch(name="Get ready", track=True):
# get ready to stack and assemble
self._module_indices_unfilled.update(range(self._geometry.n_modules))
# TODO: selectable subset of source to stack
# check and maybe update stacking, output buffers
self._maybe_update_buffers(
next(iter(sources.values()))[0].get(self._path_to_stack).shape
)
with utils.Stopwatch(name="Fill buffer and assemble", track=True):
if self._thread_pool is None:
for source, (data, _) in sources.items():
self._stack_this_guy(source, data)
for module_index in self._module_indices_unfilled:
self._stack_input_buffer[..., module_index, :, :].fill(0)
# TODO: configurable treatment of missing modules
else:
concurrent.futures.wait(
[
self._thread_pool.submit(
self._stack_this_guy,
source,
data
)
for source, (data, _) in sources.items()
]
)
assembled[np.isnan(assembled)] = self.unsafe_get("preview.replaceNanWith")
self._thread_pool.map(
lambda module_index: self._stack_input_buffer[
..., module_index, :, :
].fill(0),
self._module_indices_unfilled,
)
assembled, _ = self._geometry.position_modules_fast(
self._stack_input_buffer,
out=self._position_output_buffer,
threadpool=self._thread_pool,
)
with utils.Stopwatch(name="Write assembled", track=True):
# TODO: optionally work like regular TrainMatcher for other sources
output_hash = Hash(
"image.data",
ImageData(
# TODO: get around this being mirrored...
assembled[..., ::-1, ::-1],
Dims(*assembled.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
assembled,
"trainId",
train_id,
)
self.preview_output.write(
output_hash,
output_metadata,
)
self.preview_output.update()
if bridge_output_choice is BridgeOutputOptions.PREVIEW:
self.zmq_output.write(
my_source,
output_hash,
my_timestamp,
)
output_metadata = ChannelMetaData(my_source, my_timestamp)
if self.unsafe_get("enableKaraboAssembledOutput"):
self.assembled_output.write(output_hash, output_metadata, copyAllData=False)
self.assembled_output.update()
if bridge_output_choice is BridgeOutputOptions.ASSEMBLED:
self.zmq_output.write(my_source, output_hash, my_timestamp)
self.zmq_output.update()
self.info["timeOfFlight"] = (
Timestamp().toTimestamp() - earliest_source_timestamp
) * 1000
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
if self.unsafe_get("enablePreview") and self._preview_throttler.test_and_set():
with utils.Stopwatch(name="Write preview", track=True):
downsampling_factor = self.unsafe_get("preview.downsamplingFactor")
if downsampling_factor > 1:
assembled = downsample_2d(
assembled,
downsampling_factor,
reduction_fun=getattr(
np, self.unsafe_get("preview.downsamplingFunction")
),
)
assembled[np.isnan(assembled)] = self.unsafe_get("preview.replaceNanWith")
output_hash = Hash(
"image.data",
ImageData(
# TODO: get around this being mirrored...
assembled[..., ::-1, ::-1],
Dims(*assembled.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"trainId",
train_id,
)
self.preview_output.write(
output_hash,
output_metadata,
copyAllData=False,
)
self.preview_output.update()
if bridge_output_choice is BridgeOutputOptions.PREVIEW:
self.zmq_output.write(
my_source,
output_hash,
my_timestamp,
)
self.zmq_output.update()
with utils.Stopwatch(name="Update trackers", track=True):
self.info["timeOfFlight"] = (
Timestamp().toTimestamp()
- min(ts.toTimestamp() for (_, ts) in sources.values())
) * 1000
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
print()
def on_new_data(self, channel, data, meta):
super().on_new_data(channel, data, meta)
......@@ -402,9 +488,15 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
def preReconfigure(self, conf):
super().preReconfigure(conf)
if conf.has("preview.maxRate"):
self._throttler = utils.SkippingThrottler(
self._preview_throttler = utils.SkippingThrottler(
1 / conf["preview.maxRate"]
)
if conf.has("useThreadPool"):
if self._thread_pool is not None:
self._thread_pool.shutdown()
self._thread_pool = None
if conf["useThreadPool"]:
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
def downsample_2d(arr, factor, reduction_fun=np.nanmax):
......
......@@ -166,7 +166,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._thread_pool = concurrent.futures.ThreadPoolExecutor()
if conf.has("enableKaraboOutput"):
if conf["enableKaraboOutput"]:
self.output = self._ss.getOutputChannel("output")
self.output = self.signalSlotable.getOutputChannel("output")
else:
self.output = None
......
......@@ -273,10 +273,14 @@ class Stopwatch:
name: if not None, will appear in string representation
also, if not None, will automatically print self when done
"""
_trackers = {}
def __init__(self, name=None):
def __init__(self, name=None, track=False):
self.stop_time = None
self.name = name
self.track = track
if name not in self._trackers:
self._trackers[name] = ExponentialMovingAverage(0.1)
def __enter__(self):
self.start_time = default_timer()
......@@ -286,6 +290,10 @@ class Stopwatch:
self.stop_time = default_timer()
if self.name is not None:
print(repr(self))
if self.track:
tracker = self._trackers[self.name]
tracker.update(self.elapsed)
print(tracker.get())
@property
def elapsed(self):
......@@ -299,9 +307,9 @@ class Stopwatch:
def __repr__(self):
if self.name is None:
return f"{self.elapsed():.3f} s"
return f"{self.elapsed:.3f} s"
else:
return f"{self.name}: {self.elapsed():.3f} s"
return f"{self.name}: {self.elapsed:.3f} s"
class StateContext:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment