diff --git a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
index 55515eb7185ecd3666b5ef6b440b53e9685064be..35563c0bb8fdfbdc70ab84729564f36fc21668db 100644
--- a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
+++ b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
@@ -97,6 +97,18 @@
     "%matplotlib inline"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def finite_flattened_slice(array, slice_=np.s_[:]):\n",
+    "    \"\"\"Return flattened and finite values for a given slice.\"\"\"\n",
+    "    sliced_array = array[slice_]\n",
+    "    return sliced_array[np.isfinite(sliced_array)]"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -500,8 +512,7 @@
     "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_find_edges))\n",
     "threadpool_limits(limits=remi.get_num_workers(mt_avg_trace))\n",
     "\n",
-    "edges_by_det = {}\n",
-    "avg_traces_by_det = {}\n",
+    "det_data = {}\n",
     "\n",
     "for det_name, det in remi['detector'].items():\n",
     "    det_sourcekeys = remi.get_detector_sourcekeys(det_name)\n",
@@ -552,11 +563,14 @@
     "    with timing(f'find_edges, {det_name}'):\n",
     "        psh.map(find_edges, dc.select(det_sourcekeys))\n",
     "    \n",
-    "    edges_by_det[det_name] = edges\n",
-    "    avg_traces_by_det[det_name] = avg_traces.sum(axis=0) / len(dc.train_ids)\n",
     "    \n",
     "    with np.printoptions(precision=2, suppress=True):\n",
     "        print(edges[:5, :, :8])"
+    "    \n",
+    "    det_data[det_name] = {\n",
+    "        'edges': edges,\n",
+    "        'avg_trace': avg_traces.sum(axis=0) / len(dc.train_ids)\n",
+    "    }"
    ]
   },
   {
@@ -578,7 +592,7 @@
     "    fig.text(0.02, 0.98, det_name.upper(), rotation=90, ha='left', va='top', size='x-large')\n",
     "\n",
     "    for edge_idx, edge_name in enumerate(['u1', 'u2', 'v1', 'v2', 'w1', 'w2', 'mcp']):\n",
-    "        axs[edge_idx].plot(avg_traces_by_det[det_name][edge_idx], lw=1)\n",
+    "        axs[edge_idx].plot(det_data[det_name]['avg_trace'][edge_idx], lw=1)\n",
     "        axs[edge_idx].tick_params(labelbottom=False)\n",
     "        axs[edge_idx].set_ylabel(edge_name)\n",
     "    \n",
@@ -590,7 +604,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Sample for digitized traces"
+    "### Sample for found edges"
    ]
   },
   {
@@ -600,7 +614,7 @@
    "outputs": [],
    "source": [
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
-    "    edges = edges_by_det[det_name]\n",
+    "    edges = det_data[det_name]['edges']\n",
     "    \n",
     "    fig = plt.figure(num=100+i, figsize=(9.5, 8))\n",
     "    grid = fig.add_gridspec(ncols=2, nrows=4, left=0.1, right=0.98, top=0.98, bottom=0.1)\n",
@@ -665,10 +679,10 @@
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
     "    fig = plt.figure(num=20+i, figsize=(9.5, 6))\n",
     "    \n",
-    "    edges = edges_by_det[det_name]\n",
+    "    edges = det_data[det_name]['edges']\n",
     "    \n",
-    "    min_edge = edges[np.isfinite(edges)].min()\n",
-    "    max_edge = edges[np.isfinite(edges)].max()\n",
+    "    min_edge = np.nanmin(edges)\n",
+    "    max_edge = np.nanmax(edges)\n",
     "\n",
     "    grid = fig.add_gridspec(ncols=3, nrows=3, left=0.08, right=0.98, top=0.95, hspace=0.4)\n",
     "\n",
@@ -695,8 +709,6 @@
     "        numx.plot(np.arange(len(num_edges)) * agg_window, num_edges, label=edge_name, **plot_kwargs)\n",
     "        max_num_edges = max(max_num_edges, num_edges.max())\n",
     "\n",
-    "        cur_edges = edges[:, edge_idx, :].flatten()\n",
-    "\n",
     "        if edge_idx < 6:\n",
     "            row = 1 + edge_idx % 2\n",
     "            col = edge_idx // 2\n",
@@ -706,8 +718,9 @@
     "\n",
     "        ax = fig.add_subplot(grid[row, col])\n",
     "        ax.set_title(f'TOF spectrum: {edge_name}')\n",
-    "        y, _, _ = ax.hist(cur_edges[np.isfinite(cur_edges)], bins=int((max_edge - min_edge) // 5),\n",
-    "                          range=(min_edge, max_edge), color=plot_kwargs['c'], histtype='step', linewidth=1)\n",
+    "        y, _, _ = ax.hist(finite_flattened_slice(edges, np.s_[:, edge_idx, :]),\n",
+    "                          bins=int((max_edge - min_edge) // 5), range=(min_edge, max_edge),\n",
+    "                          color=plot_kwargs['c'], histtype='step', linewidth=1)\n",
     "        hist_axs.append(ax)\n",
     "\n",
     "        max_spectral_intensity = max(max_spectral_intensity, y.max())\n",
@@ -737,7 +750,7 @@
    "outputs": [],
    "source": [
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
-    "    edges = edges_by_det[det_name]\n",
+    "    edges = det_data[det_name]['edges']\n",
     "    \n",
     "    sort = remi.get_dld_sorter(det_name)\n",
     "    \n",
@@ -781,12 +794,8 @@
    "source": [
     "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_rec_hits))\n",
     "\n",
-    "signals_by_det = {}\n",
-    "hits_by_det = {}\n",
-    "hit_counts_by_det = {}\n",
-    "\n",
     "for det_name, det in remi['detector'].items():\n",
-    "    edges = edges_by_det[det_name]\n",
+    "    edges = det_data[det_name]['edges']\n",
     "    \n",
     "    signals = psh.alloc(shape=(num_pulses, 50), dtype=signal_dt, fill=np.nan)\n",
     "    hits = psh.alloc(shape=(num_pulses, 50), dtype=hit_dt, fill=(np.nan, np.nan, np.nan, -1))\n",
@@ -804,9 +813,7 @@
     "    with timing(f'rec_hits, {det_name}'):\n",
     "        psh.map(reconstruct_hits, dc.train_ids)\n",
     "        \n",
-    "    signals_by_det[det_name] = signals\n",
-    "    hits_by_det[det_name] = hits\n",
-    "    hit_counts_by_det[det_name] = hit_counts"
+    "    det_data[det_name].update(signals=signals, hits=hits, hit_counts=hit_counts)"
    ]
   },
   {
@@ -823,7 +830,7 @@
     "for det_name in remi['detector'].keys():\n",
     "    agg_window = num_pulses // 1000\n",
     "    \n",
-    "    num_hits = np.isfinite(hits_by_det[det_name]['x']).sum(axis=1)\n",
+    "    num_hits = np.isfinite(det_data[det_name]['hits']['x']).sum(axis=1)\n",
     "    num_hits = num_hits[:(len(num_hits) // agg_window) * agg_window]\n",
     "    num_hits = num_hits.reshape(-1, agg_window).mean(axis=1)\n",
     "    max_num_hits = max(max_num_hits, num_hits.max())\n",
@@ -885,7 +892,7 @@
    "outputs": [],
    "source": [
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
-    "    hits = hits_by_det[det_name]\n",
+    "    hits = det_data[det_name]['hits']\n",
     "    \n",
     "    fig, ax = plt.subplots(num=60+i, figsize=(9.5, 5), ncols=1, clear=True,\n",
     "                           gridspec_kw=dict(left=0.08, right=0.91, top=0.8))\n",
@@ -960,7 +967,7 @@
    "outputs": [],
    "source": [
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
-    "    flat_hits = hits_by_det[det_name].reshape(-1)\n",
+    "    flat_hits = det_data[det_name]['hits'].reshape(-1)\n",
     "    flat_hits = flat_hits[np.isfinite(flat_hits[:]['x'])]\n",
     "    flat_hits = flat_hits[flat_hits['m'] < 10]\n",
     "\n",
@@ -1064,20 +1071,22 @@
     "            \n",
     "            cur_fast_data = outp.create_instrument_source(f'{cur_device_id}:{det_output_key}')\n",
     "            \n",
+    "            cur_data = det_data[det_name]\n",
+    "            \n",
     "            if save_raw_triggers:\n",
     "                cur_fast_data.create_key('raw.triggers', triggers[pulse_mask],\n",
     "                                         chunks=tuple(chunks_triggers), **dataset_kwargs)\n",
     "                \n",
     "            if save_raw_edges:\n",
-    "                cur_fast_data.create_key('raw.edges', edges_by_det[det_name][pulse_mask],\n",
+    "                cur_fast_data.create_key('raw.edges', cur_data['edges'][pulse_mask],\n",
     "                                         chunks=tuple(chunks_edges), **dataset_kwargs)\n",
     "                \n",
     "            if save_rec_signals:\n",
-    "                cur_fast_data.create_key('rec.signals', signals_by_det[det_name][pulse_mask],\n",
+    "                cur_fast_data.create_key('rec.signals', cur_data['signals'][pulse_mask],\n",
     "                                         chunks=tuple(chunks_signals), **dataset_kwargs)\n",
     "                \n",
     "            if save_rec_hits:\n",
-    "                cur_fast_data.create_key('rec.hits', hits_by_det[det_name][pulse_mask],\n",
+    "                cur_fast_data.create_key('rec.hits', cur_data['hits'][pulse_mask],\n",
     "                                         chunks=tuple(chunks_hits), **dataset_kwargs)\n",
     "                \n",
     "            cur_fast_data.create_index(raw=pulse_counts[train_mask], rec=pulse_counts[train_mask])\n",