Skip to content
Snippets Groups Projects

Fix stacking buffer shape issue

Merged Egor Sobolev requested to merge fix/stacking-buffer-shape into master
1 file
+ 91
76
Compare changes
  • Side-by-side
  • Inline
@@ -329,30 +329,75 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
self._have_prepared_merge_groups = True
def _update_stacking_buffer(
self, new_source, key, individual_shape, merge_method, axis, dtype
):
if merge_method is MergeMethod.STACK:
self._stacking_buffers[(new_source, key)] = np.empty(
shape=utils.stacking_buffer_shape(
individual_shape,
self._source_stacking_group_sizes[(new_source, key)],
axis=axis,
),
dtype=dtype,
)
else:
self._stacking_buffers[(new_source, key)] = np.empty(
shape=utils.interleaving_buffer_shape(
individual_shape,
self._source_stacking_group_sizes[(new_source, key)],
axis=axis,
),
dtype=dtype,
def _check_stacking_data(self, sources, frame_selection_mask):
if frame_selection_mask is not None:
orig_size = len(frame_selection_mask)
result_size = np.sum(frame_selection_mask)
stacking_data_shapes = {}
ignore_stacking = {}
for source, keys in self._source_stacking_sources.items():
if source not in sources:
for key, new_source, _, _ in keys:
ignore_stacking[(new_source, key)] = "Some source is missed"
continue
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None and
self._frame_selection_source_pattern.match(source)
)
for key, new_source, merge_method, axis in keys:
merge_data_shape = None
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
if merge_data_shape is None:
ignore_stacking[(new_source, key)] = "Some data is missed"
continue
if filtering and key in self._frame_selection_data_keys:
# !!! stacking is not expected to be used with filtering
# merge_data_shape[0] == orig_size
merge_data_shape = (result_size,) + merge_data.shape[1:]
expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault(
(new_source, key),
(merge_data_shape, merge_method, axis, merge_data.dtype, timestamp)
)
if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype:
ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
return stacking_data_shapes, ignore_stacking
def _maybe_update_stacking_buffers(
self, stacking_data_shapes, new_sources_map
):
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = self._source_stacking_group_sizes[(new_source, key)]
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
merge_data_shape, group_size, axis=axis)
else:
expected_shape = utils.interleaving_buffer_shape(
merge_data_shape, group_size, axis=axis)
merge_buffer = self._stacking_buffers.get((new_source, key))
if merge_buffer is None or merge_buffer.shape != expected_shape:
merge_buffer = np.empty(
shape=expected_shape, dtype=dtype)
self._stacking_buffers[(new_source, key)] = merge_buffer
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer
def _handle_source(
self, source, data_hash, timestamp, new_sources_map, frame_selection_mask
self, source, data_hash, timestamp, new_sources_map, frame_selection_mask,
ignore_stacking
):
# dereference calng shmem handles
self._shmem_handler.dereference_shmem_handles(data_hash)
@@ -368,72 +413,35 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
data_hash[key] = data_hash[key][frame_selection_mask]
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
for (
key,
new_source,
merge_method,
axis,
) in self._source_stacking_sources.get(source, ()):
if (new_source, key) in ignore_stacking:
continue
merge_data = data_hash.get(key)
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
try:
merge_buffer = self._stacking_buffers[(new_source, key)]
utils.set_on_axis(
merge_buffer,
merge_data,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
except (ValueError, IndexError, KeyError):
# ValueError: wrong shape (react to merge_data.shape)
# KeyError: buffer doesn't exist yet
# IndexError: new source? (TODO: re-run _prepare_merge_groups or ERROR)
# either way, create appropriate buffer now
# TODO: complain if shape varies between sources within train
self._update_stacking_buffer(
new_source,
key,
merge_data.shape,
merge_method,
axis=axis,
dtype=merge_data.dtype,
)
# and then try again
merge_buffer = self._stacking_buffers[(new_source, key)]
utils.set_on_axis(
merge_buffer,
merge_data,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
# TODO: zero out unfilled buffer entries
merge_buffer = self._stacking_buffers[(new_source, key)]
utils.set_on_axis(
merge_buffer,
merge_data,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
data_hash.erase(key)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer
# stack keys (multiple keys within this source)
for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get(
source, ()
@@ -477,10 +485,16 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
], copy=False
).astype(np.bool, copy=False),
# prepare stacking
stacking_data_shapes, ignore_stacking = self._check_stacking_data(
sources, frame_selection_mask)
self._maybe_update_stacking_buffers(stacking_data_shapes, new_sources_map)
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
self._handle_source(
source, data, timestamp, new_sources_map, frame_selection_mask
source, data, timestamp, new_sources_map, frame_selection_mask,
ignore_stacking
)
else:
concurrent.futures.wait(
@@ -492,6 +506,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
timestamp,
new_sources_map,
frame_selection_mask,
ignore_stacking,
)
for source, (data, timestamp) in sources.items()
]
Loading