diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 4bd3554fdb861acf1ed61493e5bb2d661cf5a982..72bab416cef4b49724935e44815745644d665c7c 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -5,6 +5,7 @@ from karabo.bound import ( BOOL_ELEMENT, KARABO_CLASSINFO, OVERWRITE_ELEMENT, + UINT32_ELEMENT, ChannelMetaData, State, ) @@ -29,20 +30,18 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .commit(), BOOL_ELEMENT(expected) - .key("useThreadPool") - .displayedName("Use thread pool") + .key("enableKaraboOutput") + .displayedName("Enable Karabo channel") .allowedStates(State.PASSIVE) .assignmentOptional() - .defaultValue(False) + .defaultValue(True) .reconfigurable() .commit(), - BOOL_ELEMENT(expected) - .key("enableKaraboOutput") - .displayedName("Enable Karabo channel") + UINT32_ELEMENT(expected) + .key("processingThreads") .allowedStates(State.PASSIVE) - .assignmentOptional() - .defaultValue(True) + .defaultValue(16) .reconfigurable() .commit(), @@ -72,10 +71,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() self._stacking_friend = StackingFriend(self.get("merge"), self.get("sources")) self._frameselection_friend = FrameselectionFriend(self.get("frameSelector")) - if self.get("useThreadPool"): - self._thread_pool = concurrent.futures.ThreadPoolExecutor() - else: - self._thread_pool = None + self._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=self.get("processingThreads") + ) if not self.get("enableKaraboOutput"): # it is set already by super by default, so only need to turn off @@ -90,12 +88,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): # re-prepare in postReconfigure after sources *and* merge are in self if conf.has("frameSelector"): self._frameselection_friend.reconfigure(conf.get("frameSelector")) - if conf.has("useThreadPool"): - if self._thread_pool is not None: - self._thread_pool.shutdown() - self._thread_pool = None - if conf["useThreadPool"]: - self._thread_pool = concurrent.futures.ThreadPoolExecutor() + if conf.has("processingThreads"): + self._thread_pool.shutdown() + self._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=self.get("processingThreads") + ) if conf.has("enableKaraboOutput"): if conf["enableKaraboOutput"]: self.output = self._ss.getOutputChannel("output") @@ -103,31 +100,23 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self.output = None def on_matched_data(self, train_id, sources): - new_sources_map = {} frame_selection_mask = self._frameselection_friend.get_mask(sources) - self._stacking_friend.prepare_stacking_for_train( - sources, frame_selection_mask, new_sources_map - ) - - if self._thread_pool is None: - for source, (data, timestamp) in sources.items(): - self._handle_source( - source, data, timestamp, new_sources_map, frame_selection_mask, + # note: should not do stacking and frame selection for now! + self._stacking_friend.prepare_stacking_for_train(sources) + + concurrent.futures.wait( + [ + self._thread_pool.submit( + self._handle_source, + source, + data, + timestamp, + new_sources_map, + frame_selection_mask, ) - else: - concurrent.futures.wait( - [ - self._thread_pool.submit( - self._handle_source, - source, - data, - timestamp, - new_sources_map, - frame_selection_mask, - ) - for source, (data, timestamp) in sources.items() - ] - ) + for source, (data, timestamp) in sources.items() + ] + ) sources.update(new_sources_map) # karabo output