diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index 607ab5f2cb400e9caddfee7a841fb8703c0f3017..788f328b2dca61ca8c92969c0349435c85bb606a 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -1,3 +1,4 @@ +import concurrent.futures import enum import re @@ -91,6 +92,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .defaultValue([]) .reconfigurable() .commit(), + + BOOL_ELEMENT(expected) + .key("useThreadPool") + .assignmentOptional() + .defaultValue(False) + .commit(), ) def initialization(self): @@ -101,6 +108,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self._prepare_merge_groups(self.get("merge")) super().initialization() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() + if self.get("useThreadPool"): + self._thread_pool = concurrent.futures.ThreadPoolExecutor() + else: + self._thread_pool = None def preReconfigure(self, conf): super().preReconfigure(conf) @@ -165,70 +176,81 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): (key_re, new_key) ) - def on_matched_data(self, train_id, sources): + def _handle_source(self, source, data_hash, timestamp, new_sources_map): # dereference calng shmem handles - for (data, _) in sources.values(): - self._shmem_handler.dereference_shmem_handles(data) + self._shmem_handler.dereference_shmem_handles(data_hash) - new_sources_map = {} - for source, (data, timestamp) in sources.items(): - # stack across sources (many sources, same key) - # could probably save ~100 ns by "if ... in" instead of get - for (stack_key, new_source) in self._source_stacking_sources.get( - source, () - ): - this_data = data.get(stack_key) - try: - this_buffer = self._stacking_buffers[(new_source, stack_key)] - stack_index = self._source_stacking_indices[(source, stack_key)] - this_buffer[stack_index] = this_data - except (ValueError, KeyError): - # ValueError: wrong shape - # KeyError: buffer doesn't exist yet - # either way, create appropriate buffer now - # TODO: complain if shape varies between sources - self._stacking_buffers[(new_source, stack_key)] = np.empty( - shape=( - max( - index_ - for ( - source_, - key_, - ), index_ in self._source_stacking_indices.items() - if source_ == source and key_ == stack_key - ) - + 1, + # stack across sources (many sources, same key) + # could probably save ~100 ns by "if ... in" instead of get + for (stack_key, new_source) in self._source_stacking_sources.get(source, ()): + this_data = data_hash.get(stack_key) + try: + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[(source, stack_key)] + this_buffer[stack_index] = this_data + except (ValueError, IndexError, KeyError): + # ValueError: wrong shape + # KeyError: buffer doesn't exist yet + # either way, create appropriate buffer now + # TODO: complain if shape varies between sources + self._stacking_buffers[(new_source, stack_key)] = np.empty( + shape=( + max( + index_ + for ( + source_, + key_, + ), index_ in self._source_stacking_indices.items() + if source_ == source and key_ == stack_key ) - + this_data.shape, - dtype=this_data.dtype, + + 1, ) - # and then try again - this_buffer = self._stacking_buffers[(new_source, stack_key)] - stack_index = self._source_stacking_indices[(source, stack_key)] - this_buffer[stack_index] = this_data - # TODO: zero out unfilled buffer entries - data.erase(stack_key) - - if new_source not in new_sources_map: - new_sources_map[new_source] = (Hash(), timestamp) - new_source_hash = new_sources_map[new_source][0] - if not new_source_hash.has(stack_key): - new_source_hash[stack_key] = this_buffer - - # stack keys (multiple keys within this source) - for (key_re, new_key) in self._key_stacking_sources.get(source, ()): - # note: please no overlap between different key_re - # note: if later key_re match earlier new_key, this gets spicy - stack_keys = [key for key in data.paths() if key_re.match(key)] - try: - # TODO: consider reusing buffers here, too - stacked = np.stack([data.get(key) for key in stack_keys], axis=0) - except Exception as e: - self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") - else: - for key in stack_keys: - data.erase(key) - data[new_key] = stacked + + this_data.shape, + dtype=this_data.dtype, + ) + # and then try again + this_buffer = self._stacking_buffers[(new_source, stack_key)] + stack_index = self._source_stacking_indices[(source, stack_key)] + this_buffer[stack_index] = this_data + # TODO: zero out unfilled buffer entries + data_hash.erase(stack_key) + + if new_source not in new_sources_map: + new_sources_map[new_source] = (Hash(), timestamp) + new_source_hash = new_sources_map[new_source][0] + if not new_source_hash.has(stack_key): + new_source_hash[stack_key] = this_buffer + + # stack keys (multiple keys within this source) + for (key_re, new_key) in self._key_stacking_sources.get(source, ()): + # note: please no overlap between different key_re + # note: if later key_re match earlier new_key, this gets spicy + stack_keys = [key for key in data_hash.paths() if key_re.match(key)] + try: + # TODO: consider reusing buffers here, too + stacked = np.stack([data_hash.get(key) for key in stack_keys], axis=0) + except Exception as e: + self.log.WARN(f"Failed to stack {key_re} for {source}: {e}") + else: + for key in stack_keys: + data_hash.erase(key) + data_hash[new_key] = stacked + + def on_matched_data(self, train_id, sources): + + new_sources_map = {} + if self._thread_pool is None: + for source, (data, timestamp) in sources.items(): + self._handle_source(source, data, timestamp, new_sources_map) + else: + concurrent.futures.wait( + [ + self._thread_pool.submit( + self._handle_source, data, timestamp, new_sources_map + ) + for source, (data, timestamp) in sources.items() + ] + ) sources.update(new_sources_map)