From aa3009190e45d651ee90a984c149eaacbdc72b60 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Fri, 7 Apr 2023 17:29:18 +0200
Subject: [PATCH] Make REMI reconstruction plots robust against no edges or
 hits in data

---
 .../REMI/REMI_Digitize_and_Transform.ipynb    | 71 +++++++++++++------
 1 file changed, 50 insertions(+), 21 deletions(-)

diff --git a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
index 0429805ce..ddb1c0c17 100644
--- a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
+++ b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
@@ -73,15 +73,17 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "from datetime import datetime\n",
+    "from logging import warning\n",
+    "from pathlib import Path\n",
+    "import re\n",
+    "\n",
     "import numpy as np\n",
     "import matplotlib.pyplot as plt\n",
     "from matplotlib.colors import LogNorm\n",
     "from threadpoolctl import threadpool_limits\n",
     "\n",
-    "import re\n",
     "import h5py\n",
-    "from pathlib import Path\n",
-    "from datetime import datetime\n",
     "\n",
     "import pasha as psh\n",
     "from euxfel_bunch_pattern import indices_at_sase, indices_at_laser\n",
@@ -432,10 +434,10 @@
     "\n",
     "    if len(pulse_deltas) > 1:\n",
     "        delta_str = ', '.join([str(x) for x in sorted(pulse_deltas)])\n",
-    "        print(f'WARNING: Different pulse lengths (PPT: {delta_str}) encountered within single trains, '\n",
-    "              f'separated pulse spectra may split up signals!')\n",
+    "        warning(f'Different pulse lengths (PPT: {delta_str}) encountered within single trains, '\n",
+    "                f'separated pulse spectra may split up signals!')\n",
     "    else:\n",
-    "        print('WARNING: Different pulse lengths encountered across trains, separation may be unstable!')"
+    "        warning('Different pulse lengths encountered across trains, separation may be unstable!')"
    ]
   },
   {
@@ -569,6 +571,9 @@
     "    with timing(f'find_edges, {det_name}'):\n",
     "        psh.map(find_edges, dc.select(det_sourcekeys))\n",
     "    \n",
+    "    if not np.isfinite(edges).any():\n",
+    "        warning(f'No edges found for {det_name}')\n",
+    "    \n",
     "    fig, (ux, bx) = plt.subplots(num=110+i, ncols=1, nrows=2, figsize=(9.5, 8), clear=True,\n",
     "                                 gridspec_kw=dict(left=0.1, right=0.98, top=0.98, bottom=0.1, hspace=0.25))\n",
     "    \n",
@@ -663,6 +668,7 @@
     "        \n",
     "        finite_edges = np.isfinite(edges[:, signal_idx, 0])\n",
     "        if not finite_edges.any():\n",
+    "            warning(f'No edges found for {det_name}/{signal_name}')\n",
     "            continue\n",
     "            \n",
     "        pulse_idx = np.uint64(finite_edges.nonzero()[0][0])  # Is combined with other uint64 values below.\n",
@@ -732,9 +738,23 @@
     "    hist_axs = []\n",
     "\n",
     "    for edge_idx, edge_name in enumerate(['u1', 'u2', 'v1', 'v2', 'w1', 'w2', 'mcp']):\n",
+    "        if edge_idx < 6:\n",
+    "            row = 1 + edge_idx % 2\n",
+    "            col = edge_idx // 2\n",
+    "        else:\n",
+    "            row = 0\n",
+    "            col = np.s_[1:3]\n",
+    "\n",
+    "        ax = fig.add_subplot(grid[row, col])\n",
+    "        ax.set_title(f'TOF spectrum: {edge_name}')\n",
+    "        \n",
     "        num_edges = np.isfinite(edges[:, edge_idx, :]).sum(axis=1)\n",
     "        num_edges = num_edges[:((len(num_edges) // agg_window) * agg_window)]\n",
     "        num_edges = num_edges.reshape(-1, agg_window).mean(axis=1)\n",
+    "        \n",
+    "        if (num_edges == 0).all():\n",
+    "            warning(f'No edges found for {det_name}/{edge_name}')\n",
+    "            continue\n",
     "\n",
     "        if edge_idx < 6:\n",
     "            plot_kwargs = dict(c=f'C{edge_idx}', ls='solid', lw=1.0)\n",
@@ -744,15 +764,6 @@
     "        numx.plot(np.arange(len(num_edges)) * agg_window, num_edges, label=edge_name, **plot_kwargs)\n",
     "        max_num_edges = max(max_num_edges, num_edges.max())\n",
     "\n",
-    "        if edge_idx < 6:\n",
-    "            row = 1 + edge_idx % 2\n",
-    "            col = edge_idx // 2\n",
-    "        else:\n",
-    "            row = 0\n",
-    "            col = np.s_[1:3]\n",
-    "\n",
-    "        ax = fig.add_subplot(grid[row, col])\n",
-    "        ax.set_title(f'TOF spectrum: {edge_name}')\n",
     "        y, _, _ = ax.hist(finite_flattened_slice(edges, np.s_[:, edge_idx, :]),\n",
     "                          bins=int((max_edge - min_edge) // 5), range=(min_edge, max_edge),\n",
     "                          color=plot_kwargs['c'], histtype='step', linewidth=1)\n",
@@ -794,6 +805,10 @@
     "    is_valid = remi.get_presort_mask(edges, edge_idx=0, w=not quad_anode,\n",
     "                                     sum_limit=max(sort.uncorrected_time_sum_half_widths),\n",
     "                                     sum_shifts=sum_shifts)\n",
+    "    \n",
+    "    if not is_valid.any():\n",
+    "        warning(f'No valid preliminary edge combinations found for {det_name}')\n",
+    "    \n",
     "    signals, sums = remi.get_signals_and_sums(edges, indices=sort.channel_indices, sum_shifts=sum_shifts,\n",
     "                                              mask=is_valid)\n",
     "    fig = plot_detector_diagnostics(signals=signals, sums=sums, fig_num=30+i, im_scale=1.5,\n",
@@ -933,6 +948,10 @@
     "                           gridspec_kw=dict(left=0.08, right=0.91, top=0.8))\n",
     "    \n",
     "    fig.text(0.02, 0.98, det_name.upper(), rotation=90, ha='left', va='top', size='x-large')\n",
+    "    \n",
+    "    if not (hits['m'] >= 0).any():\n",
+    "        warning(f'No hits found for {det_name}')\n",
+    "        continue\n",
     "\n",
     "    method_bins = np.bincount(hits['m'][hits['m'] >= 0], minlength=20)\n",
     "    ax.bar(np.arange(20), method_bins, width=0.5)\n",
@@ -1004,8 +1023,8 @@
     "for i, det_name in enumerate(remi['detector'].keys()):\n",
     "    flat_hits = det_data[det_name]['hits'].reshape(-1)\n",
     "    flat_hits = flat_hits[np.isfinite(flat_hits[:]['x'])]\n",
-    "    flat_hits = flat_hits[flat_hits['m'] < 10]\n",
-    "\n",
+    "    flat_hits = flat_hits[flat_hits['m'] <= 10]\n",
+    "    \n",
     "    fig = plt.figure(num=70+i, figsize=(9, 13.5))\n",
     "    \n",
     "    fig.text(0.02, 0.98, det_name.upper(), rotation=90, ha='left', va='top', size='x-large')\n",
@@ -1015,10 +1034,11 @@
     "    txp = fig.add_axes([0.1, 0.28, 0.85, 0.22])\n",
     "    typ = fig.add_axes([0.1, 0.04, 0.85, 0.22])\n",
     "    \n",
-    "    im_radius = remi['detector'][det_name]['mcp_radius']*1.1\n",
+    "    if flat_hits.size == 0:\n",
+    "        warning(f'No hits found for {det_name}')\n",
+    "        continue\n",
     "    \n",
-    "    min_tof = flat_hits['t'].min()\n",
-    "    max_tof = flat_hits['t'].max()\n",
+    "    im_radius = remi['detector'][det_name]['mcp_radius']*1.1\n",
     "    \n",
     "    imp.hist2d(flat_hits['x'], flat_hits['y'], bins=(256, 256),\n",
     "               range=[[-im_radius, im_radius], [-im_radius, im_radius]], norm=LogNorm())\n",
@@ -1027,9 +1047,18 @@
     "    imp.set_ylabel('Y / mm')\n",
     "    imp.tick_params(right=True, labelright=True, top=True, labeltop=True)\n",
     "    imp.grid()\n",
+    "    \n",
+    "    min_tof = flat_hits['t'].min()\n",
+    "    max_tof = flat_hits['t'].max()\n",
+    "    \n",
+    "    num_tof_bins = int((max_tof - min_tof) // 5)\n",
+    "    \n",
+    "    if num_tof_bins == 0:\n",
+    "        warning(f'All TOFs limited to single bin for {det_name}')\n",
+    "        continue\n",
     "\n",
     "    for ax, dim_label in zip([txp, typ], ['x', 'y']):\n",
-    "        ax.hist2d(flat_hits['t'], flat_hits[dim_label], bins=(int((max_tof - min_tof) // 5), 256),\n",
+    "        ax.hist2d(flat_hits['t'], flat_hits[dim_label], bins=(num_tof_bins, 256),\n",
     "                   range=[[min_tof, max_tof], [-im_radius, im_radius]], norm=LogNorm())\n",
     "        ax.set_ylabel(f'{dim_label.upper()} / mm')\n",
     "        \n",
-- 
GitLab