Skip to content
Snippets Groups Projects

Fix stacking buffer shape issue

Merged Egor Sobolev requested to merge fix/stacking-buffer-shape into master
3 unresolved threads
1 file
+ 29
8
Compare changes
  • Side-by-side
  • Inline
@@ -335,11 +335,15 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
result_size = np.sum(frame_selection_mask)
stacking_data_shapes = {}
ignore_stacking = {}
fill_missed_data = {}
for source, keys in self._source_stacking_sources.items():
if source not in sources:
for key, new_source, _, _ in keys:
ignore_stacking[(new_source, key)] = "Some source is missed"
continue
missed_sources = fill_missed_data.setdefault((new_source, key), [])
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
missed_sources.append(merge_index)
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None and
@@ -350,7 +354,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
if merge_data_shape is None:
ignore_stacking[(new_source, key)] = "Some data is missed"
continue
@@ -368,10 +372,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
return stacking_data_shapes, ignore_stacking
return stacking_data_shapes, ignore_stacking, fill_missed_data
def _maybe_update_stacking_buffers(
self, stacking_data_shapes, new_sources_map
self, stacking_data_shapes, fill_missed_data, new_sources_map
):
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
@@ -389,6 +393,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
shape=expected_shape, dtype=dtype)
self._stacking_buffers[(new_source, key)] = merge_buffer
for merge_index in fill_missed_data.get((new_source, key), []):
utils.set_on_axis(
merge_buffer,
0,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
@@ -486,9 +505,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
).astype(np.bool, copy=False),
# prepare stacking
stacking_data_shapes, ignore_stacking = self._check_stacking_data(
sources, frame_selection_mask)
self._maybe_update_stacking_buffers(stacking_data_shapes, new_sources_map)
stacking_data_shapes, ignore_stacking, fill_missed_data = (
self._check_stacking_data(sources, frame_selection_mask)
)
self._maybe_update_stacking_buffers(
stacking_data_shapes, fill_missed_data, new_sources_map)
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
Loading