diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index c75f4c460f3cd2b6934b25affce717b5e8e68f70..19722de729a7cb0c55371c179c66be771f421503 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -329,30 +329,94 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self._have_prepared_merge_groups = True - def _update_stacking_buffer( - self, new_source, key, individual_shape, merge_method, axis, dtype - ): - if merge_method is MergeMethod.STACK: - self._stacking_buffers[(new_source, key)] = np.empty( - shape=utils.stacking_buffer_shape( - individual_shape, - self._source_stacking_group_sizes[(new_source, key)], - axis=axis, - ), - dtype=dtype, - ) - else: - self._stacking_buffers[(new_source, key)] = np.empty( - shape=utils.interleaving_buffer_shape( - individual_shape, - self._source_stacking_group_sizes[(new_source, key)], - axis=axis, - ), - dtype=dtype, + def _check_stacking_data(self, sources, frame_selection_mask): + if frame_selection_mask is not None: + orig_size = len(frame_selection_mask) + result_size = np.sum(frame_selection_mask) + stacking_data_shapes = {} + ignore_stacking = {} + fill_missed_data = {} + for source, keys in self._source_stacking_sources.items(): + if source not in sources: + for key, new_source, _, _ in keys: + missed_sources = fill_missed_data.setdefault((new_source, key), []) + merge_index = self._source_stacking_indices[ + (source, new_source, key) + ] + missed_sources.append(merge_index) + continue + data_hash, timestamp = sources[source] + filtering = ( + frame_selection_mask is not None and + self._frame_selection_source_pattern.match(source) ) + for key, new_source, merge_method, axis in keys: + merge_data_shape = None + if key in data_hash: + merge_data = data_hash[key] + merge_data_shape = merge_data.shape + else: + ignore_stacking[(new_source, key)] = "Some data is missed" + continue + + if filtering and key in self._frame_selection_data_keys: + # !!! stacking is not expected to be used with filtering + if merge_data_shape[0] == orig_size: + merge_data_shape = (result_size,) + merge_data.shape[1:] + + expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault( + (new_source, key), + (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp) + ) + if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype: + ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent" + del stacking_data_shapes[(new_source, key)] + + return stacking_data_shapes, ignore_stacking, fill_missed_data + + def _maybe_update_stacking_buffers( + self, stacking_data_shapes, fill_missed_data, new_sources_map + ): + for (new_source, key), attr in stacking_data_shapes.items(): + merge_data_shape, merge_method, axis, dtype, timestamp = attr + group_size = self._source_stacking_group_sizes[(new_source, key)] + if merge_method is MergeMethod.STACK: + expected_shape = utils.stacking_buffer_shape( + merge_data_shape, group_size, axis=axis) + else: + expected_shape = utils.interleaving_buffer_shape( + merge_data_shape, group_size, axis=axis) + + merge_buffer = self._stacking_buffers.get((new_source, key)) + if merge_buffer is None or merge_buffer.shape != expected_shape: + merge_buffer = np.empty( + shape=expected_shape, dtype=dtype) + self._stacking_buffers[(new_source, key)] = merge_buffer + + for merge_index in fill_missed_data.get((new_source, key), []): + utils.set_on_axis( + merge_buffer, + 0, + merge_index + if merge_method is MergeMethod.STACK + else np.index_exp[ + slice( + merge_index, + None, + self._source_stacking_group_sizes[(new_source, key)], + ) + ], + axis, + ) + if new_source not in new_sources_map: + new_sources_map[new_source] = (Hash(), timestamp) + new_source_hash = new_sources_map[new_source][0] + if not new_source_hash.has(key): + new_source_hash[key] = merge_buffer def _handle_source( - self, source, data_hash, timestamp, new_sources_map, frame_selection_mask + self, source, data_hash, timestamp, new_sources_map, frame_selection_mask, + ignore_stacking ): # dereference calng shmem handles self._shmem_handler.dereference_shmem_handles(data_hash) @@ -364,76 +428,42 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): for key in self._frame_selection_data_keys: if not data_hash.has(key): continue + if data_hash[key].shape[0] != frame_selection_mask.size: + self.log.WARN("Frame selection mask does not match the data size") + continue data_hash[key] = data_hash[key][frame_selection_mask] # stack across sources (many sources, same key) - # could probably save ~100 ns by "if ... in" instead of get for ( key, new_source, merge_method, axis, ) in self._source_stacking_sources.get(source, ()): + if (new_source, key) in ignore_stacking: + continue merge_data = data_hash.get(key) merge_index = self._source_stacking_indices[ (source, new_source, key) ] - try: - merge_buffer = self._stacking_buffers[(new_source, key)] - utils.set_on_axis( - merge_buffer, - merge_data, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - except (ValueError, IndexError, KeyError): - # ValueError: wrong shape (react to merge_data.shape) - # KeyError: buffer doesn't exist yet - # IndexError: new source? (TODO: re-run _prepare_merge_groups or ERROR) - # either way, create appropriate buffer now - # TODO: complain if shape varies between sources within train - self._update_stacking_buffer( - new_source, - key, - merge_data.shape, - merge_method, - axis=axis, - dtype=merge_data.dtype, - ) - # and then try again - merge_buffer = self._stacking_buffers[(new_source, key)] - utils.set_on_axis( - merge_buffer, - merge_data, - merge_index - if merge_method is MergeMethod.STACK - else np.index_exp[ - slice( - merge_index, - None, - self._source_stacking_group_sizes[(new_source, key)], - ) - ], - axis, - ) - # TODO: zero out unfilled buffer entries + merge_buffer = self._stacking_buffers[(new_source, key)] + utils.set_on_axis( + merge_buffer, + merge_data, + merge_index + if merge_method is MergeMethod.STACK + else np.index_exp[ + slice( + merge_index, + None, + self._source_stacking_group_sizes[(new_source, key)], + ) + ], + axis, + ) data_hash.erase(key) - if new_source not in new_sources_map: - new_sources_map[new_source] = (Hash(), timestamp) - new_source_hash = new_sources_map[new_source][0] - if not new_source_hash.has(key): - new_source_hash[key] = merge_buffer - # stack keys (multiple keys within this source) for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get( source, () @@ -477,10 +507,20 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): ], copy=False ).astype(np.bool, copy=False), + # prepare stacking + stacking_data_shapes, ignore_stacking, fill_missed_data = ( + self._check_stacking_data(sources, frame_selection_mask) + ) + self._maybe_update_stacking_buffers( + stacking_data_shapes, fill_missed_data, new_sources_map) + for (new_source, key), msg in ignore_stacking.items(): + self.log.WARN(f"Failed to stack {new_source}.{key}: {msg}") + if self._thread_pool is None: for source, (data, timestamp) in sources.items(): self._handle_source( - source, data, timestamp, new_sources_map, frame_selection_mask + source, data, timestamp, new_sources_map, frame_selection_mask, + ignore_stacking ) else: concurrent.futures.wait( @@ -492,6 +532,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): timestamp, new_sources_map, frame_selection_mask, + ignore_stacking, ) for source, (data, timestamp) in sources.items() ]