diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py
index c75f4c460f3cd2b6934b25affce717b5e8e68f70..19722de729a7cb0c55371c179c66be771f421503 100644
--- a/src/calng/ShmemTrainMatcher.py
+++ b/src/calng/ShmemTrainMatcher.py
@@ -329,30 +329,94 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
 
         self._have_prepared_merge_groups = True
 
-    def _update_stacking_buffer(
-        self, new_source, key, individual_shape, merge_method, axis, dtype
-    ):
-        if merge_method is MergeMethod.STACK:
-            self._stacking_buffers[(new_source, key)] = np.empty(
-                shape=utils.stacking_buffer_shape(
-                    individual_shape,
-                    self._source_stacking_group_sizes[(new_source, key)],
-                    axis=axis,
-                ),
-                dtype=dtype,
-            )
-        else:
-            self._stacking_buffers[(new_source, key)] = np.empty(
-                shape=utils.interleaving_buffer_shape(
-                    individual_shape,
-                    self._source_stacking_group_sizes[(new_source, key)],
-                    axis=axis,
-                ),
-                dtype=dtype,
+    def _check_stacking_data(self, sources, frame_selection_mask):
+        if frame_selection_mask is not None:
+            orig_size = len(frame_selection_mask)
+            result_size = np.sum(frame_selection_mask)
+        stacking_data_shapes = {}
+        ignore_stacking = {}
+        fill_missed_data = {}
+        for source, keys in self._source_stacking_sources.items():
+            if source not in sources:
+                for key, new_source, _, _ in keys:
+                    missed_sources = fill_missed_data.setdefault((new_source, key), [])
+                    merge_index = self._source_stacking_indices[
+                        (source, new_source, key)
+                    ]
+                    missed_sources.append(merge_index)
+                continue
+            data_hash, timestamp = sources[source]
+            filtering = (
+                frame_selection_mask is not None and
+                self._frame_selection_source_pattern.match(source)
             )
+            for key, new_source, merge_method, axis in keys:
+                merge_data_shape = None
+                if key in data_hash:
+                    merge_data = data_hash[key]
+                    merge_data_shape = merge_data.shape
+                else:
+                    ignore_stacking[(new_source, key)] = "Some data is missed"
+                    continue
+
+                if filtering and key in self._frame_selection_data_keys:
+                    # !!! stacking is not expected to be used with filtering
+                    if merge_data_shape[0] == orig_size:
+                        merge_data_shape = (result_size,) + merge_data.shape[1:]
+
+                expected_shape, _, _, expected_dtype, _ = stacking_data_shapes.setdefault(
+                    (new_source, key),
+                    (merge_data_shape, merge_method, axis, merge_data.dtype, timestamp)
+                )
+                if expected_shape != merge_data_shape or expected_dtype != merge_data.dtype:
+                    ignore_stacking[(new_source, key)] = "Shape or dtype is inconsistent"
+                    del stacking_data_shapes[(new_source, key)]
+
+        return stacking_data_shapes, ignore_stacking, fill_missed_data
+
+    def _maybe_update_stacking_buffers(
+        self, stacking_data_shapes, fill_missed_data, new_sources_map
+    ):
+        for (new_source, key), attr in stacking_data_shapes.items():
+            merge_data_shape, merge_method, axis, dtype, timestamp = attr
+            group_size = self._source_stacking_group_sizes[(new_source, key)]
+            if merge_method is MergeMethod.STACK:
+                expected_shape = utils.stacking_buffer_shape(
+                    merge_data_shape, group_size, axis=axis)
+            else:
+                expected_shape = utils.interleaving_buffer_shape(
+                    merge_data_shape, group_size, axis=axis)
+
+            merge_buffer = self._stacking_buffers.get((new_source, key))
+            if merge_buffer is None or merge_buffer.shape != expected_shape:
+                merge_buffer = np.empty(
+                    shape=expected_shape, dtype=dtype)
+                self._stacking_buffers[(new_source, key)] = merge_buffer
+
+            for merge_index in fill_missed_data.get((new_source, key), []):
+                utils.set_on_axis(
+                    merge_buffer,
+                    0,
+                    merge_index
+                    if merge_method is MergeMethod.STACK
+                    else np.index_exp[
+                        slice(
+                            merge_index,
+                            None,
+                            self._source_stacking_group_sizes[(new_source, key)],
+                        )
+                    ],
+                    axis,
+                )
+            if new_source not in new_sources_map:
+                new_sources_map[new_source] = (Hash(), timestamp)
+            new_source_hash = new_sources_map[new_source][0]
+            if not new_source_hash.has(key):
+                new_source_hash[key] = merge_buffer
 
     def _handle_source(
-        self, source, data_hash, timestamp, new_sources_map, frame_selection_mask
+        self, source, data_hash, timestamp, new_sources_map, frame_selection_mask,
+        ignore_stacking
     ):
         # dereference calng shmem handles
         self._shmem_handler.dereference_shmem_handles(data_hash)
@@ -364,76 +428,42 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
             for key in self._frame_selection_data_keys:
                 if not data_hash.has(key):
                     continue
+                if data_hash[key].shape[0] != frame_selection_mask.size:
+                    self.log.WARN("Frame selection mask does not match the data size")
+                    continue
 
                 data_hash[key] = data_hash[key][frame_selection_mask]
 
         # stack across sources (many sources, same key)
-        # could probably save ~100 ns by "if ... in" instead of get
         for (
             key,
             new_source,
             merge_method,
             axis,
         ) in self._source_stacking_sources.get(source, ()):
+            if (new_source, key) in ignore_stacking:
+                continue
             merge_data = data_hash.get(key)
             merge_index = self._source_stacking_indices[
                 (source, new_source, key)
             ]
-            try:
-                merge_buffer = self._stacking_buffers[(new_source, key)]
-                utils.set_on_axis(
-                    merge_buffer,
-                    merge_data,
-                    merge_index
-                    if merge_method is MergeMethod.STACK
-                    else np.index_exp[
-                        slice(
-                            merge_index,
-                            None,
-                            self._source_stacking_group_sizes[(new_source, key)],
-                        )
-                    ],
-                    axis,
-                )
-            except (ValueError, IndexError, KeyError):
-                # ValueError: wrong shape (react to merge_data.shape)
-                # KeyError: buffer doesn't exist yet
-                # IndexError: new source? (TODO: re-run _prepare_merge_groups or ERROR)
-                # either way, create appropriate buffer now
-                # TODO: complain if shape varies between sources within train
-                self._update_stacking_buffer(
-                    new_source,
-                    key,
-                    merge_data.shape,
-                    merge_method,
-                    axis=axis,
-                    dtype=merge_data.dtype,
-                )
-                # and then try again
-                merge_buffer = self._stacking_buffers[(new_source, key)]
-                utils.set_on_axis(
-                    merge_buffer,
-                    merge_data,
-                    merge_index
-                    if merge_method is MergeMethod.STACK
-                    else np.index_exp[
-                        slice(
-                            merge_index,
-                            None,
-                            self._source_stacking_group_sizes[(new_source, key)],
-                        )
-                    ],
-                    axis,
-                )
-            # TODO: zero out unfilled buffer entries
+            merge_buffer = self._stacking_buffers[(new_source, key)]
+            utils.set_on_axis(
+                merge_buffer,
+                merge_data,
+                merge_index
+                if merge_method is MergeMethod.STACK
+                else np.index_exp[
+                    slice(
+                        merge_index,
+                        None,
+                        self._source_stacking_group_sizes[(new_source, key)],
+                    )
+                ],
+                axis,
+            )
             data_hash.erase(key)
 
-            if new_source not in new_sources_map:
-                new_sources_map[new_source] = (Hash(), timestamp)
-            new_source_hash = new_sources_map[new_source][0]
-            if not new_source_hash.has(key):
-                new_source_hash[key] = merge_buffer
-
         # stack keys (multiple keys within this source)
         for (key_re, new_key, merge_method, axis) in self._key_stacking_sources.get(
             source, ()
@@ -477,10 +507,20 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                 ], copy=False
             ).astype(np.bool, copy=False),
 
+        # prepare stacking
+        stacking_data_shapes, ignore_stacking, fill_missed_data = (
+            self._check_stacking_data(sources, frame_selection_mask)
+        )
+        self._maybe_update_stacking_buffers(
+            stacking_data_shapes, fill_missed_data, new_sources_map)
+        for (new_source, key), msg in ignore_stacking.items():
+            self.log.WARN(f"Failed to stack {new_source}.{key}: {msg}")
+
         if self._thread_pool is None:
             for source, (data, timestamp) in sources.items():
                 self._handle_source(
-                    source, data, timestamp, new_sources_map, frame_selection_mask
+                    source, data, timestamp, new_sources_map, frame_selection_mask,
+                    ignore_stacking
                 )
         else:
             concurrent.futures.wait(
@@ -492,6 +532,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
                         timestamp,
                         new_sources_map,
                         frame_selection_mask,
+                        ignore_stacking,
                     )
                     for source, (data, timestamp) in sources.items()
                 ]