From 5556a6351ab1a3b83c3afec2469f76a559f40ea2 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Fri, 13 Dec 2024 10:58:48 +0100
Subject: [PATCH] Add support to look behind the trigger for anode channels in
 REMI edge finding

---
 .../REMI/REMI_Digitize_and_Transform.ipynb    | 30 +++++++++++++------
 1 file changed, 21 insertions(+), 9 deletions(-)

diff --git a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
index c17e0fb58..b4b4b8d1f 100644
--- a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
+++ b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
@@ -36,6 +36,7 @@
     "ignore_fel = False  # Ignore any FEL entries in the PPT.\n",
     "ignore_ppl = False  # Ignore any PPL entries in the PPT.\n",
     "trailing_trigger = False  # Add a trigger after all regular pulses with the remaining trace.\n",
+    "anode_lookbehind = True  # Whether to include an entire wire runtime behind the pulse trigger. \n",
     "ppl_offset = 0  # In units of the PPT.\n",
     "laser_ppt_mask = -1  # Bit mask for used laser, negative to auto-detect from instrument. \n",
     "instrument_sase = 3  # Which SASE we're running at for PPT decoding.\n",
@@ -77,6 +78,7 @@
    "source": [
     "from datetime import datetime\n",
     "from logging import warning\n",
+    "from math import ceil\n",
     "from pathlib import Path\n",
     "import re\n",
     "\n",
@@ -531,7 +533,7 @@
     "\n",
     "        train_triggers = triggers[pulse_offsets[index]:int(pulse_offsets[index]+num_pulses)]\n",
     "        train_triggers['start'] = start_int + pulse_start_offset\n",
-    "        train_triggers['stop'] = start_int + int(pulse_len * 2 * clock_factor) - 1 + pulse_end_offset\n",
+    "        train_triggers['stop'] = start_int + int(pulse_len * 2 * clock_factor) + pulse_end_offset\n",
     "        train_triggers['offset'] = start_frac - start_int\n",
     "        train_triggers['pulse'] = all_pos.astype(np.int16)\n",
     "        train_triggers['fel'] = [pos in fel_pos for pos in all_pos]\n",
@@ -685,7 +687,16 @@
     "        bl_sym = remi['digitizer']['baseline_symmetry']\n",
     "        \n",
     "        time_cal = 1e9 / (2 * remi['digitizer']['clock_factor'] * (1.3e9 / 288))\n",
-    "        \n",
+    "\n",
+    "        if anode_lookbehind:\n",
+    "            # Include one additional wire runtime in the time window for the anode channels\n",
+    "            # (while none for the MCP channel) to successfully reconstruct hits at the boundary.\n",
+    "            # As the MCP signal always comes first, only a lookbehind is needed.\n",
+    "            lookbehinds = [ceil(remi['detector'][det_name]['runtimes'][i // 2] / time_cal)\n",
+    "                           for i in remi['detector'][det_name]['indices'][:6]] + [0]\n",
+    "        else:\n",
+    "            lookbehinds = [0] * 7\n",
+    "                               \n",
     "        traces_corr = np.empty((7, trace_len), dtype=np.float64)\n",
     "        baselines = np.empty(bl_sym, dtype=np.float64)\n",
     "        yield\n",
@@ -708,13 +719,11 @@
     "        for trigger, pulse_edges, pulse_amplitudes in zip(\n",
     "            triggers[pulses_slice], edges[pulses_slice], amplitudes[pulses_slice]\n",
     "        ):\n",
-    "            trigger_slice = np.s_[trigger['start']:trigger['stop']]\n",
-    "                                                 \n",
-    "            for trace, channel_params, channel_edges, channel_amplitudes in zip(\n",
-    "                traces_corr, discr_params, pulse_edges, pulse_amplitudes\n",
+    "            for trace, channel_params, channel_lookbehind, channel_edges, channel_amplitudes in zip(\n",
+    "                traces_corr, discr_params, lookbehinds, pulse_edges, pulse_amplitudes\n",
     "            ):\n",
-    "                discr_func(trace[trigger_slice], edges=channel_edges,\n",
-    "                           amplitudes=channel_amplitudes, **channel_params)\n",
+    "                discr_func(trace[trigger['start']:trigger['stop']+channel_lookbehind],\n",
+    "                           edges=channel_edges, amplitudes=channel_amplitudes, **channel_params)\n",
     "\n",
     "            if np.isfinite(pulse_edges).sum(axis=1).max() == det['max_hits']:\n",
     "                warning(f'Maximum number of edges reached in train {train_id}, pulse: {trigger[\"pulse\"]}')\n",
@@ -1097,6 +1106,9 @@
    "source": [
     "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_rec_hits))\n",
     "\n",
+    "time_cal = 1e9 / (2 * remi['digitizer']['clock_factor'] * (1.3e9 / 288))\n",
+    "t_cutoffs = (triggers['stop'] - triggers['start']) * time_cal\n",
+    "\n",
     "for det_name, det in remi['detector'].items():\n",
     "    edges = det_data[det_name]['edges']\n",
     "    \n",
@@ -1111,7 +1123,7 @@
     "    @psh.with_init(prepare_hit_worker)\n",
     "    def reconstruct_hits(worker_id, index, train_id):\n",
     "        hit_counts[index] += sort.run_on_train(\n",
-    "            edges, signals, hits, pulse_offsets[index], pulse_offsets[index] + pulse_counts[index])\n",
+    "            edges, signals, hits, t_cutoffs, pulse_offsets[index], pulse_offsets[index] + pulse_counts[index])\n",
     "        \n",
     "    with timing(f'rec_hits, {det_name}'):\n",
     "        psh.map(reconstruct_hits, dc.train_ids)\n",
-- 
GitLab