Skip to content
Snippets Groups Projects
Commit af4f1f23 authored by Egor Sobolev's avatar Egor Sobolev Committed by spbonc
Browse files

Add filling of places for missed sources in stacked data with zeros

parent 44059e17
No related branches found
No related tags found
2 merge requests!72Fix stacking buffer shape issue,!61Draft: analysis development branch (do not merge)
This commit is part of merge request !72. Comments created here will be created in the context of that merge request.
......@@ -335,11 +335,15 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
result_size = np.sum(frame_selection_mask)
stacking_data_shapes = {}
ignore_stacking = {}
fill_missed_data = {}
for source, keys in self._source_stacking_sources.items():
if source not in sources:
for key, new_source, _, _ in keys:
ignore_stacking[(new_source, key)] = "Some source is missed"
continue
missed_sources = fill_missed_data.setdefault((new_source, key), [])
merge_index = self._source_stacking_indices[
(source, new_source, key)
]
missed_sources.append(merge_index)
data_hash, timestamp = sources[source]
filtering = (
frame_selection_mask is not None and
......@@ -350,7 +354,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
if key in data_hash:
merge_data = data_hash[key]
merge_data_shape = merge_data.shape
if merge_data_shape is None:
ignore_stacking[(new_source, key)] = "Some data is missed"
continue
......@@ -368,10 +372,10 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent"
del stacking_data_shapes[(new_source, key)]
return stacking_data_shapes, ignore_stacking
return stacking_data_shapes, ignore_stacking, fill_missed_data
def _maybe_update_stacking_buffers(
self, stacking_data_shapes, new_sources_map
self, stacking_data_shapes, fill_missed_data, new_sources_map
):
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
......@@ -389,6 +393,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
shape=expected_shape, dtype=dtype)
self._stacking_buffers[(new_source, key)] = merge_buffer
for merge_index in fill_missed_data.get((new_source, key), []):
utils.set_on_axis(
merge_buffer,
0,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
self._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
if new_source not in new_sources_map:
new_sources_map[new_source] = (Hash(), timestamp)
new_source_hash = new_sources_map[new_source][0]
......@@ -486,9 +505,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
).astype(np.bool, copy=False),
# prepare stacking
stacking_data_shapes, ignore_stacking = self._check_stacking_data(
sources, frame_selection_mask)
self._maybe_update_stacking_buffers(stacking_data_shapes, new_sources_map)
stacking_data_shapes, ignore_stacking, fill_missed_data = (
self._check_stacking_data(sources, frame_selection_mask)
)
self._maybe_update_stacking_buffers(
stacking_data_shapes, fill_missed_data, new_sources_map)
if self._thread_pool is None:
for source, (data, timestamp) in sources.items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment