diff --git a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb index 55515eb7185ecd3666b5ef6b440b53e9685064be..35563c0bb8fdfbdc70ab84729564f36fc21668db 100644 --- a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb +++ b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb @@ -97,6 +97,18 @@ "%matplotlib inline" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def finite_flattened_slice(array, slice_=np.s_[:]):\n", + " \"\"\"Return flattened and finite values for a given slice.\"\"\"\n", + " sliced_array = array[slice_]\n", + " return sliced_array[np.isfinite(sliced_array)]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -500,8 +512,7 @@ "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_find_edges))\n", "threadpool_limits(limits=remi.get_num_workers(mt_avg_trace))\n", "\n", - "edges_by_det = {}\n", - "avg_traces_by_det = {}\n", + "det_data = {}\n", "\n", "for det_name, det in remi['detector'].items():\n", " det_sourcekeys = remi.get_detector_sourcekeys(det_name)\n", @@ -552,11 +563,14 @@ " with timing(f'find_edges, {det_name}'):\n", " psh.map(find_edges, dc.select(det_sourcekeys))\n", " \n", - " edges_by_det[det_name] = edges\n", - " avg_traces_by_det[det_name] = avg_traces.sum(axis=0) / len(dc.train_ids)\n", " \n", " with np.printoptions(precision=2, suppress=True):\n", " print(edges[:5, :, :8])" + " \n", + " det_data[det_name] = {\n", + " 'edges': edges,\n", + " 'avg_trace': avg_traces.sum(axis=0) / len(dc.train_ids)\n", + " }" ] }, { @@ -578,7 +592,7 @@ " fig.text(0.02, 0.98, det_name.upper(), rotation=90, ha='left', va='top', size='x-large')\n", "\n", " for edge_idx, edge_name in enumerate(['u1', 'u2', 'v1', 'v2', 'w1', 'w2', 'mcp']):\n", - " axs[edge_idx].plot(avg_traces_by_det[det_name][edge_idx], lw=1)\n", + " axs[edge_idx].plot(det_data[det_name]['avg_trace'][edge_idx], lw=1)\n", " axs[edge_idx].tick_params(labelbottom=False)\n", " axs[edge_idx].set_ylabel(edge_name)\n", " \n", @@ -590,7 +604,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Sample for digitized traces" + "### Sample for found edges" ] }, { @@ -600,7 +614,7 @@ "outputs": [], "source": [ "for i, det_name in enumerate(remi['detector'].keys()):\n", - " edges = edges_by_det[det_name]\n", + " edges = det_data[det_name]['edges']\n", " \n", " fig = plt.figure(num=100+i, figsize=(9.5, 8))\n", " grid = fig.add_gridspec(ncols=2, nrows=4, left=0.1, right=0.98, top=0.98, bottom=0.1)\n", @@ -665,10 +679,10 @@ "for i, det_name in enumerate(remi['detector'].keys()):\n", " fig = plt.figure(num=20+i, figsize=(9.5, 6))\n", " \n", - " edges = edges_by_det[det_name]\n", + " edges = det_data[det_name]['edges']\n", " \n", - " min_edge = edges[np.isfinite(edges)].min()\n", - " max_edge = edges[np.isfinite(edges)].max()\n", + " min_edge = np.nanmin(edges)\n", + " max_edge = np.nanmax(edges)\n", "\n", " grid = fig.add_gridspec(ncols=3, nrows=3, left=0.08, right=0.98, top=0.95, hspace=0.4)\n", "\n", @@ -695,8 +709,6 @@ " numx.plot(np.arange(len(num_edges)) * agg_window, num_edges, label=edge_name, **plot_kwargs)\n", " max_num_edges = max(max_num_edges, num_edges.max())\n", "\n", - " cur_edges = edges[:, edge_idx, :].flatten()\n", - "\n", " if edge_idx < 6:\n", " row = 1 + edge_idx % 2\n", " col = edge_idx // 2\n", @@ -706,8 +718,9 @@ "\n", " ax = fig.add_subplot(grid[row, col])\n", " ax.set_title(f'TOF spectrum: {edge_name}')\n", - " y, _, _ = ax.hist(cur_edges[np.isfinite(cur_edges)], bins=int((max_edge - min_edge) // 5),\n", - " range=(min_edge, max_edge), color=plot_kwargs['c'], histtype='step', linewidth=1)\n", + " y, _, _ = ax.hist(finite_flattened_slice(edges, np.s_[:, edge_idx, :]),\n", + " bins=int((max_edge - min_edge) // 5), range=(min_edge, max_edge),\n", + " color=plot_kwargs['c'], histtype='step', linewidth=1)\n", " hist_axs.append(ax)\n", "\n", " max_spectral_intensity = max(max_spectral_intensity, y.max())\n", @@ -737,7 +750,7 @@ "outputs": [], "source": [ "for i, det_name in enumerate(remi['detector'].keys()):\n", - " edges = edges_by_det[det_name]\n", + " edges = det_data[det_name]['edges']\n", " \n", " sort = remi.get_dld_sorter(det_name)\n", " \n", @@ -781,12 +794,8 @@ "source": [ "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_rec_hits))\n", "\n", - "signals_by_det = {}\n", - "hits_by_det = {}\n", - "hit_counts_by_det = {}\n", - "\n", "for det_name, det in remi['detector'].items():\n", - " edges = edges_by_det[det_name]\n", + " edges = det_data[det_name]['edges']\n", " \n", " signals = psh.alloc(shape=(num_pulses, 50), dtype=signal_dt, fill=np.nan)\n", " hits = psh.alloc(shape=(num_pulses, 50), dtype=hit_dt, fill=(np.nan, np.nan, np.nan, -1))\n", @@ -804,9 +813,7 @@ " with timing(f'rec_hits, {det_name}'):\n", " psh.map(reconstruct_hits, dc.train_ids)\n", " \n", - " signals_by_det[det_name] = signals\n", - " hits_by_det[det_name] = hits\n", - " hit_counts_by_det[det_name] = hit_counts" + " det_data[det_name].update(signals=signals, hits=hits, hit_counts=hit_counts)" ] }, { @@ -823,7 +830,7 @@ "for det_name in remi['detector'].keys():\n", " agg_window = num_pulses // 1000\n", " \n", - " num_hits = np.isfinite(hits_by_det[det_name]['x']).sum(axis=1)\n", + " num_hits = np.isfinite(det_data[det_name]['hits']['x']).sum(axis=1)\n", " num_hits = num_hits[:(len(num_hits) // agg_window) * agg_window]\n", " num_hits = num_hits.reshape(-1, agg_window).mean(axis=1)\n", " max_num_hits = max(max_num_hits, num_hits.max())\n", @@ -885,7 +892,7 @@ "outputs": [], "source": [ "for i, det_name in enumerate(remi['detector'].keys()):\n", - " hits = hits_by_det[det_name]\n", + " hits = det_data[det_name]['hits']\n", " \n", " fig, ax = plt.subplots(num=60+i, figsize=(9.5, 5), ncols=1, clear=True,\n", " gridspec_kw=dict(left=0.08, right=0.91, top=0.8))\n", @@ -960,7 +967,7 @@ "outputs": [], "source": [ "for i, det_name in enumerate(remi['detector'].keys()):\n", - " flat_hits = hits_by_det[det_name].reshape(-1)\n", + " flat_hits = det_data[det_name]['hits'].reshape(-1)\n", " flat_hits = flat_hits[np.isfinite(flat_hits[:]['x'])]\n", " flat_hits = flat_hits[flat_hits['m'] < 10]\n", "\n", @@ -1064,20 +1071,22 @@ " \n", " cur_fast_data = outp.create_instrument_source(f'{cur_device_id}:{det_output_key}')\n", " \n", + " cur_data = det_data[det_name]\n", + " \n", " if save_raw_triggers:\n", " cur_fast_data.create_key('raw.triggers', triggers[pulse_mask],\n", " chunks=tuple(chunks_triggers), **dataset_kwargs)\n", " \n", " if save_raw_edges:\n", - " cur_fast_data.create_key('raw.edges', edges_by_det[det_name][pulse_mask],\n", + " cur_fast_data.create_key('raw.edges', cur_data['edges'][pulse_mask],\n", " chunks=tuple(chunks_edges), **dataset_kwargs)\n", " \n", " if save_rec_signals:\n", - " cur_fast_data.create_key('rec.signals', signals_by_det[det_name][pulse_mask],\n", + " cur_fast_data.create_key('rec.signals', cur_data['signals'][pulse_mask],\n", " chunks=tuple(chunks_signals), **dataset_kwargs)\n", " \n", " if save_rec_hits:\n", - " cur_fast_data.create_key('rec.hits', hits_by_det[det_name][pulse_mask],\n", + " cur_fast_data.create_key('rec.hits', cur_data['hits'][pulse_mask],\n", " chunks=tuple(chunks_hits), **dataset_kwargs)\n", " \n", " cur_fast_data.create_index(raw=pulse_counts[train_mask], rec=pulse_counts[train_mask])\n",