Skip to content
Snippets Groups Projects
Commit f32d6e7f authored by Philipp Schmidt's avatar Philipp Schmidt
Browse files

Keep result arrays in a single dict indexed by detector name in REMI reconstruction

parent 966273a1
No related branches found
No related tags found
1 merge request!822[REMI] Save pulse amplitudes during discrimination
......@@ -97,6 +97,18 @@
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def finite_flattened_slice(array, slice_=np.s_[:]):\n",
" \"\"\"Return flattened and finite values for a given slice.\"\"\"\n",
" sliced_array = array[slice_]\n",
" return sliced_array[np.isfinite(sliced_array)]"
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -500,8 +512,7 @@
"psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_find_edges))\n",
"threadpool_limits(limits=remi.get_num_workers(mt_avg_trace))\n",
"\n",
"edges_by_det = {}\n",
"avg_traces_by_det = {}\n",
"det_data = {}\n",
"\n",
"for det_name, det in remi['detector'].items():\n",
" det_sourcekeys = remi.get_detector_sourcekeys(det_name)\n",
......@@ -552,11 +563,14 @@
" with timing(f'find_edges, {det_name}'):\n",
" psh.map(find_edges, dc.select(det_sourcekeys))\n",
" \n",
" edges_by_det[det_name] = edges\n",
" avg_traces_by_det[det_name] = avg_traces.sum(axis=0) / len(dc.train_ids)\n",
" \n",
" with np.printoptions(precision=2, suppress=True):\n",
" print(edges[:5, :, :8])"
" \n",
" det_data[det_name] = {\n",
" 'edges': edges,\n",
" 'avg_trace': avg_traces.sum(axis=0) / len(dc.train_ids)\n",
" }"
]
},
{
......@@ -578,7 +592,7 @@
" fig.text(0.02, 0.98, det_name.upper(), rotation=90, ha='left', va='top', size='x-large')\n",
"\n",
" for edge_idx, edge_name in enumerate(['u1', 'u2', 'v1', 'v2', 'w1', 'w2', 'mcp']):\n",
" axs[edge_idx].plot(avg_traces_by_det[det_name][edge_idx], lw=1)\n",
" axs[edge_idx].plot(det_data[det_name]['avg_trace'][edge_idx], lw=1)\n",
" axs[edge_idx].tick_params(labelbottom=False)\n",
" axs[edge_idx].set_ylabel(edge_name)\n",
" \n",
......@@ -590,7 +604,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sample for digitized traces"
"### Sample for found edges"
]
},
{
......@@ -600,7 +614,7 @@
"outputs": [],
"source": [
"for i, det_name in enumerate(remi['detector'].keys()):\n",
" edges = edges_by_det[det_name]\n",
" edges = det_data[det_name]['edges']\n",
" \n",
" fig = plt.figure(num=100+i, figsize=(9.5, 8))\n",
" grid = fig.add_gridspec(ncols=2, nrows=4, left=0.1, right=0.98, top=0.98, bottom=0.1)\n",
......@@ -665,10 +679,10 @@
"for i, det_name in enumerate(remi['detector'].keys()):\n",
" fig = plt.figure(num=20+i, figsize=(9.5, 6))\n",
" \n",
" edges = edges_by_det[det_name]\n",
" edges = det_data[det_name]['edges']\n",
" \n",
" min_edge = edges[np.isfinite(edges)].min()\n",
" max_edge = edges[np.isfinite(edges)].max()\n",
" min_edge = np.nanmin(edges)\n",
" max_edge = np.nanmax(edges)\n",
"\n",
" grid = fig.add_gridspec(ncols=3, nrows=3, left=0.08, right=0.98, top=0.95, hspace=0.4)\n",
"\n",
......@@ -695,8 +709,6 @@
" numx.plot(np.arange(len(num_edges)) * agg_window, num_edges, label=edge_name, **plot_kwargs)\n",
" max_num_edges = max(max_num_edges, num_edges.max())\n",
"\n",
" cur_edges = edges[:, edge_idx, :].flatten()\n",
"\n",
" if edge_idx < 6:\n",
" row = 1 + edge_idx % 2\n",
" col = edge_idx // 2\n",
......@@ -706,8 +718,9 @@
"\n",
" ax = fig.add_subplot(grid[row, col])\n",
" ax.set_title(f'TOF spectrum: {edge_name}')\n",
" y, _, _ = ax.hist(cur_edges[np.isfinite(cur_edges)], bins=int((max_edge - min_edge) // 5),\n",
" range=(min_edge, max_edge), color=plot_kwargs['c'], histtype='step', linewidth=1)\n",
" y, _, _ = ax.hist(finite_flattened_slice(edges, np.s_[:, edge_idx, :]),\n",
" bins=int((max_edge - min_edge) // 5), range=(min_edge, max_edge),\n",
" color=plot_kwargs['c'], histtype='step', linewidth=1)\n",
" hist_axs.append(ax)\n",
"\n",
" max_spectral_intensity = max(max_spectral_intensity, y.max())\n",
......@@ -737,7 +750,7 @@
"outputs": [],
"source": [
"for i, det_name in enumerate(remi['detector'].keys()):\n",
" edges = edges_by_det[det_name]\n",
" edges = det_data[det_name]['edges']\n",
" \n",
" sort = remi.get_dld_sorter(det_name)\n",
" \n",
......@@ -781,12 +794,8 @@
"source": [
"psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_rec_hits))\n",
"\n",
"signals_by_det = {}\n",
"hits_by_det = {}\n",
"hit_counts_by_det = {}\n",
"\n",
"for det_name, det in remi['detector'].items():\n",
" edges = edges_by_det[det_name]\n",
" edges = det_data[det_name]['edges']\n",
" \n",
" signals = psh.alloc(shape=(num_pulses, 50), dtype=signal_dt, fill=np.nan)\n",
" hits = psh.alloc(shape=(num_pulses, 50), dtype=hit_dt, fill=(np.nan, np.nan, np.nan, -1))\n",
......@@ -804,9 +813,7 @@
" with timing(f'rec_hits, {det_name}'):\n",
" psh.map(reconstruct_hits, dc.train_ids)\n",
" \n",
" signals_by_det[det_name] = signals\n",
" hits_by_det[det_name] = hits\n",
" hit_counts_by_det[det_name] = hit_counts"
" det_data[det_name].update(signals=signals, hits=hits, hit_counts=hit_counts)"
]
},
{
......@@ -823,7 +830,7 @@
"for det_name in remi['detector'].keys():\n",
" agg_window = num_pulses // 1000\n",
" \n",
" num_hits = np.isfinite(hits_by_det[det_name]['x']).sum(axis=1)\n",
" num_hits = np.isfinite(det_data[det_name]['hits']['x']).sum(axis=1)\n",
" num_hits = num_hits[:(len(num_hits) // agg_window) * agg_window]\n",
" num_hits = num_hits.reshape(-1, agg_window).mean(axis=1)\n",
" max_num_hits = max(max_num_hits, num_hits.max())\n",
......@@ -885,7 +892,7 @@
"outputs": [],
"source": [
"for i, det_name in enumerate(remi['detector'].keys()):\n",
" hits = hits_by_det[det_name]\n",
" hits = det_data[det_name]['hits']\n",
" \n",
" fig, ax = plt.subplots(num=60+i, figsize=(9.5, 5), ncols=1, clear=True,\n",
" gridspec_kw=dict(left=0.08, right=0.91, top=0.8))\n",
......@@ -960,7 +967,7 @@
"outputs": [],
"source": [
"for i, det_name in enumerate(remi['detector'].keys()):\n",
" flat_hits = hits_by_det[det_name].reshape(-1)\n",
" flat_hits = det_data[det_name]['hits'].reshape(-1)\n",
" flat_hits = flat_hits[np.isfinite(flat_hits[:]['x'])]\n",
" flat_hits = flat_hits[flat_hits['m'] < 10]\n",
"\n",
......@@ -1064,20 +1071,22 @@
" \n",
" cur_fast_data = outp.create_instrument_source(f'{cur_device_id}:{det_output_key}')\n",
" \n",
" cur_data = det_data[det_name]\n",
" \n",
" if save_raw_triggers:\n",
" cur_fast_data.create_key('raw.triggers', triggers[pulse_mask],\n",
" chunks=tuple(chunks_triggers), **dataset_kwargs)\n",
" \n",
" if save_raw_edges:\n",
" cur_fast_data.create_key('raw.edges', edges_by_det[det_name][pulse_mask],\n",
" cur_fast_data.create_key('raw.edges', cur_data['edges'][pulse_mask],\n",
" chunks=tuple(chunks_edges), **dataset_kwargs)\n",
" \n",
" if save_rec_signals:\n",
" cur_fast_data.create_key('rec.signals', signals_by_det[det_name][pulse_mask],\n",
" cur_fast_data.create_key('rec.signals', cur_data['signals'][pulse_mask],\n",
" chunks=tuple(chunks_signals), **dataset_kwargs)\n",
" \n",
" if save_rec_hits:\n",
" cur_fast_data.create_key('rec.hits', hits_by_det[det_name][pulse_mask],\n",
" cur_fast_data.create_key('rec.hits', cur_data['hits'][pulse_mask],\n",
" chunks=tuple(chunks_hits), **dataset_kwargs)\n",
" \n",
" cur_fast_data.create_index(raw=pulse_counts[train_mask], rec=pulse_counts[train_mask])\n",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment