Skip to content
Snippets Groups Projects

Draft: Overlap train processing in ShmemTrainMatcher

Open David Hammer requested to merge overlapping-processing-matcher into refactor-stacking
2 files
+ 59
20
Compare changes
  • Side-by-side
  • Inline
Files
2
import concurrent.futures
import time
from timeit import default_timer
import threading
import queue
from karabo.bound import (
BOOL_ELEMENT,
@@ -92,6 +95,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self.get("processingThreads")
)
# TODO: set capacity limit
self._future_output_queue = queue.Queue(maxsize=100)
self._output_writer_thread = threading.Thread(
target=self._output_writer, daemon=True
)
self._output_writer_thread.start()
if not self.get("enableKaraboOutput"):
# it is set already by super by default, so only need to turn off
@@ -118,6 +127,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self.output = None
def on_matched_data(self, train_id, sources):
self._future_output_queue.put(
(
train_id,
self._thread_pool.submit(
self._match_handler, train_id, sources
)
),
block=True,
)
def _match_handler(self, train_id, sources):
ts_start = default_timer()
frame_selection_mask = self._frameselection_friend.get_mask(sources)
concurrent.futures.wait(
@@ -132,26 +152,39 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
for source, (data, timestamp) in sources.items()
]
)
self._stacking_friend.process(sources, self._thread_pool)
# karabo output
if self.output is not None:
for source, (data, timestamp) in sources.items():
self.output.write(
data, ChannelMetaData(source, timestamp), copyAllData=False
)
self.output.update(safeNDArray=True)
# karabo bridge output
for source, (data, timestamp) in sources.items():
self.zmq_output.write(source, data, timestamp)
self.zmq_output.update()
self._stacking_friend.process(train_id, sources)
time.sleep(200 / 1000)
self._processing_time_tracker.update(default_timer() - ts_start)
self.info["processingTime"] = self._processing_time_tracker.get() * 1000
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
return sources
def _output_writer(self):
while True:
train_id, future = self._future_output_queue.get()
try:
sources = future.result(timeout=60)
except Exception as ex:
self.log.WARN(f"Processing failed for {train_id}: {ex}")
continue
try:
# karabo output
if self.output is not None:
for source, (data, timestamp) in sources.items():
self.output.write(
data, ChannelMetaData(source, timestamp), copyAllData=False
)
self.output.update(safeNDArray=True)
# karabo bridge output
for source, (data, timestamp) in sources.items():
self.zmq_output.write(source, data, timestamp)
self.zmq_output.update()
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
except Exception as ex:
self.log.WARN(f"Failed to write result for {train_id}: {ex}")
def _handle_source(
self,
Loading