From f94ab9e11b5d16b826eea39ae3f6a6ad8e34619a Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Thu, 13 Oct 2022 09:38:17 +0200
Subject: [PATCH] Generalize FEL/PPL pattern support and remove edge trigger
 detection

---
 .../REMI/REMI_Digitize_and_Transform.ipynb    | 249 ++++++++----------
 1 file changed, 106 insertions(+), 143 deletions(-)

diff --git a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
index 363ee1ae3..ae1f68e7b 100644
--- a/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
+++ b/notebooks/REMI/REMI_Digitize_and_Transform.ipynb
@@ -38,10 +38,14 @@
     "dataset_compression_opts = 3  # HDF GZIP compression level.\n",
     "\n",
     "# Trigger parameters.\n",
-    "ppl_offset = 0  # In units of the FEL pulses.\n",
+    "ppt_source = 'SQS_RR_UTC/TSYS/TIMESERVER:outputBunchPattern'\n",
+    "ppl_offset = 0  # In units of the PPT.\n",
+    "laser_ppt_mask = -1  # Bit mask for used laser, negative to auto-detect from instrument. \n",
+    "instrument_sase = 3\n",
+    "first_pulse_offset = 1000\n",
+    "single_pulse_length = 25000\n",
     "\n",
     "# Parallelization parameters.\n",
-    "mp_pulse_info = 8  # Parallelization for pulse statistics.\n",
     "mp_find_triggers = 0.5  # Parallelization for finding triggers.\n",
     "mp_find_edges = 0.5  # Parallelization for digitizing analog signal.\n",
     "mt_avg_trace = 2  # Parallelization for trace averaging.\n",
@@ -65,6 +69,7 @@
     "from datetime import datetime\n",
     "\n",
     "import pasha as psh\n",
+    "from euxfel_bunch_pattern import indices_at_sase, indices_at_laser\n",
     "from extra_data import RunDirectory\n",
     "from extra_remi import Analysis, trigger_dt\n",
     "from extra_remi.util import timing\n",
@@ -147,8 +152,54 @@
     "# * `pulse_offsets [int32: len(dc.train_ids)]` containing the global offset for the first pulse of each train.\n",
     "# * `num_pulses = pulse_counts.sum(axis=0)`\n",
     "\n",
+    "ppt_data = dc[ppt_source, 'data.bunchPatternTable']\n",
+    "\n",
+    "def get_pulse_positions(ppt, sase, laser, ppl_offset):\n",
+    "    # Combine FEL and PPL positions.\n",
+    "\n",
+    "    fel_pos = indices_at_sase(ppt, sase)\n",
+    "\n",
+    "    if len(fel_pos) > 0:\n",
+    "        ppl_pos = indices_at_laser(ppt, laser) + fel_pos[0] + ppl_offset\n",
+    "    else:\n",
+    "        # Just take them as they are\n",
+    "        ppl_pos = indices_at_laser(ppt, laser)\n",
+    "\n",
+    "    return np.union1d(fel_pos, ppl_pos), fel_pos, ppl_pos\n",
+    "\n",
+    "if laser_ppt_mask < 0:\n",
+    "    # If laser PPT mask is not specified, try to figure it out from device IDs.\n",
+    "    from euxfel_bunch_pattern import PPL_BITS\n",
+    "    \n",
+    "    instrument = karabo_id[:karabo_id.index('_')]\n",
+    "    \n",
+    "    try:\n",
+    "        laser_ppt_mask = PPL_BITS[f'LP_{instrument}']\n",
+    "    except KeyError:\n",
+    "        raise ValueError(f'Laser PPT mask unknown for instrument `{instrument}`')\n",
+    "\n",
     "with timing('pulse_info'):\n",
-    "    pulse_counts, pulse_offsets, num_pulses = remi.get_pulse_info(dc, ppl_offset, mp_pulse_info)"
+    "    psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_find_triggers))\n",
+    "    \n",
+    "    # Build the pulse index\n",
+    "    pulse_counts = psh.alloc(shape=len(dc.train_ids), dtype=np.uint64)\n",
+    "    has_ppt = psh.alloc(shape=len(dc.train_ids), dtype=bool, fill=False)\n",
+    "    \n",
+    "    def count_pulses(wid, index, tid, ppt):\n",
+    "        pulse_counts[index] = len(get_pulse_positions(ppt, instrument_sase, laser_ppt_mask, ppl_offset)[0])\n",
+    "        has_ppt[index] = True\n",
+    "    \n",
+    "    psh.map(count_pulses, ppt_data)\n",
+    "\n",
+    "    # Fill any missing values with the highest.\n",
+    "    pulse_counts[has_ppt == False] = pulse_counts.max()\n",
+    "\n",
+    "    # Compute offsets based on pulse counts.\n",
+    "    pulse_offsets = np.zeros_like(pulse_counts)\n",
+    "    np.cumsum(pulse_counts[:-1], out=pulse_offsets[1:])\n",
+    "\n",
+    "    # Total number of pulses.\n",
+    "    num_pulses = int(pulse_counts.sum())"
    ]
   },
   {
@@ -190,156 +241,52 @@
     "# * `triggers [(start: int32, stop: int32, offset: float64, fel: bool, ppl: bool): num_pulses]`\n",
     "#   containing the triggers for each pulse.\n",
     "# \n",
-    "# Triggers may be obtained through two different methods:\n",
-    "# \n",
-    "# * `ppt` uses the pulse puttern table to locate the pulse positions on the trace. Only number of pulses and\n",
-    "#   their distance can be drawn this way, leaving the absolute offset for the very first pulse to be\n",
-    "#   configured via `trigger/ppt/first_pulse_offset`. If a PPL laser is used, it will be included in the\n",
-    "#   trigger pattern. The ppt_offset parameter allows taking into account an offset betwen PPL and FEL assuming\n",
-    "#   the same rep rate.\n",
-    "# \n",
-    "# * `edge` uses the digitizer channel `trigger/edge/channel` and builds triggers around the edges found on it.\n",
-    "#   The boundaries relative to this edge may be configured with the `group_start`, `group_end` and `dead_time`\n",
-    "#   parameters. There is no support for PPL with this method.\n",
+    "# This uses the pulse puttern table to locate the pulse positions on the trace. Only number of pulses and\n",
+    "# their distance can be drawn this way, leaving the absolute offset for the very first pulse to be\n",
+    "# configured via `trigger/ppt/first_pulse_offset`. If a PPL is used, it will be included in the trigger\n",
+    "# pattern. The ppt_offset parameter allows taking into account an offset betwen PPL and FEL.\n",
     "\n",
     "psh.set_default_context('processes', num_workers=remi.get_num_workers(mp_find_triggers))\n",
     "\n",
     "triggers = psh.alloc(shape=(num_pulses,), dtype=trigger_dt, fill=(-1, -1, np.nan, 0, 0))\n",
     "\n",
-    "if remi['trigger']['method'] == 'ppt':\n",
-    "    from euxfel_bunch_pattern import indices_at_sase, indices_at_laser, PPL_BITS\n",
-    "    \n",
-    "    pptc = remi['trigger']['ppt']\n",
-    "    \n",
-    "    keydata = dc[remi.get_ppt_sourcekey()]\n",
-    "    sase = remi['instrument']['timeserver']['sase']\n",
-    "    first_pulse_offset = pptc['first_pulse_offset']\n",
-    "    single_pulse_length = pptc['single_pulse_length']\n",
-    "    clock_factor = remi['digitizer']['clock_factor']\n",
-    "    \n",
-    "    def trigger_by_ppt(worker_id, index, train_id, ppt):\n",
-    "        abs_pos = indices_at_sase(ppt, sase)\n",
-    "        num_pulses = len(abs_pos)\n",
-    "        \n",
-    "        if num_pulses > 1:\n",
-    "            rel_pos = (abs_pos - abs_pos[0])\n",
-    "            pulse_len = rel_pos[1] - rel_pos[0]\n",
-    "        elif num_pulses == 1:\n",
-    "            rel_pos = np.zeros(1)\n",
-    "            pulse_len = single_pulse_length\n",
-    "        elif num_pulses == 0:\n",
-    "            return\n",
-    "        \n",
-    "        pulse_offset = pulse_offsets[index]\n",
-    "        pulse_count = pulse_counts[index]\n",
-    "        \n",
-    "        if ppl_offset != 0:\n",
-    "            # Special support for asymmetric pump-probe patterns.\n",
-    "            # For now only offsets as a integer multiple of the FEL pulse length are supported.\n",
-    "            # The relative position array is extended by the number of required elements and\n",
-    "            # the FEL/PPL flags set later on.\n",
-    "            \n",
-    "            abs_offset = abs(ppl_offset)\n",
-    "            rel_pos = np.append(rel_pos, rel_pos[-1] + np.arange(abs_offset) * pulse_len)\n",
-    "        \n",
-    "        train_triggers = triggers[pulse_offset:pulse_offset+pulse_count]\n",
-    "        \n",
-    "        start_frac = first_pulse_offset + rel_pos * 2 * clock_factor\n",
-    "        start = start_frac.astype(int)\n",
-    "        \n",
-    "        if start.shape != train_triggers.shape:\n",
-    "            print(f'pulse number mismatch in train {index} / {train_id}, SKIPPING')\n",
-    "            return\n",
-    "        \n",
-    "        train_triggers['start'] = start\n",
-    "        train_triggers['stop'] = start + int(pulse_len * 2 * clock_factor) - 1\n",
-    "        train_triggers['offset'] = start_frac - start\n",
-    "        \n",
-    "        has_ppl = len(indices_at_laser(ppt, PPL_BITS.LP_SQS)) > 0\n",
-    "        \n",
-    "        if ppl_offset == 0:\n",
-    "            train_triggers['fel'] = True\n",
-    "            train_triggers['ppl'] = has_ppl\n",
-    "        elif ppl_offset < 0:\n",
-    "            ## There are PPL-only pulses before FEL pulses.\n",
-    "            train_triggers[abs_offset:]['fel'] = True\n",
-    "            train_triggers[:-abs_offset]['ppl'] = has_ppl\n",
-    "        elif ppl_offset > 0:\n",
-    "            # There are PPL-only pulses after FEL pulses.\n",
-    "            train_triggers[:-abs_offset]['fel'] = True\n",
-    "            train_triggers[-abs_offset:]['ppl'] = has_ppl\n",
-    "    \n",
-    "    with timing('find_triggers'):\n",
-    "        psh.map(trigger_by_ppt, keydata)\n",
-    "    \n",
-    "elif remi['trigger']['method'] == 'edge':\n",
-    "    edgec = remi['trigger']['edge']\n",
-    "    keydata = dc[remi.get_channel_sourcekey(edgec['channel'])]\n",
-    "\n",
-    "    trace_len = keydata.entry_shape[0]\n",
-    "    group_start = edgec['group_start']\n",
-    "    group_end = edgec['group_end']\n",
-    "    dead_time = edgec['dead_time']\n",
-    "    \n",
-    "    def prepare_trigger_edge_worker(worker_id):\n",
-    "        correct_func = remi.get_baseline_corrector()\n",
-    "        discr_func, discr_params = remi.get_discriminator([edgec['channel']])\n",
+    "clock_factor = remi['digitizer']['clock_factor']\n",
     "\n",
-    "        bl_start, bl_stop, _ = remi.get_baseline_limits(trace_len)\n",
-    "        bl_sym = remi['digitizer']['baseline_symmetry']\n",
-    "        \n",
-    "        edge_pos = np.empty(10000, dtype=np.float64)\n",
-    "        trace_corr = np.empty(trace_len, dtype=np.float64)\n",
-    "        baselines = np.empty(bl_sym, dtype=np.float64)\n",
-    "        yield\n",
-    "        \n",
-    "    def group_boundaries(trigger_edges):\n",
-    "        cur_edge = trigger_edges[0]\n",
+    "def trigger_by_ppt(worker_id, index, train_id, ppt):\n",
+    "    all_pos, fel_pos, ppl_pos = get_pulse_positions(ppt, instrument_sase, laser_ppt_mask, ppl_offset)\n",
     "\n",
-    "        for i in range(1, len(trigger_edges)):\n",
-    "            next_edge = trigger_edges[i]\n",
-    "            edge_diff = int(next_edge) - int(cur_edge)\n",
+    "    rel_pos = all_pos - all_pos[0]\n",
+    "    num_pulses = len(all_pos)\n",
     "\n",
-    "            if edge_diff <= dead_time:\n",
-    "                pass\n",
+    "    if num_pulses > 1:\n",
+    "        pulse_lengths = np.unique(rel_pos[1:] - rel_pos[:-1])\n",
     "\n",
-    "            elif edge_diff > dead_time and edge_diff >= group_end:\n",
-    "                yield cur_edge, int(cur_edge) + group_start, int(cur_edge) + group_end\n",
-    "                cur_edge = trigger_edges[i]\n",
+    "        if len(pulse_lengths) > 1:\n",
+    "            print('WARNING: Differing pulse lengths encountered, minimum is used!')\n",
     "\n",
-    "            elif edge_diff > dead_time and edge_diff < group_end:\n",
-    "                yield cur_edge, int(cur_edge) + group_start, int(next_edge)\n",
-    "                cur_edge = trigger_edges[i]\n",
+    "        pulse_len = pulse_lengths.min()\n",
     "\n",
-    "            elif edge_diff < group_end:\n",
-    "                pass\n",
+    "    elif num_pulses == 1:\n",
+    "        pulse_len = single_pulse_length\n",
     "\n",
-    "        yield cur_edge, int(cur_edge) + group_start, int(cur_edge) + group_end\n",
-    "    \n",
-    "    @psh.with_init(prepare_trigger_edge_worker)\n",
-    "    def trigger_by_edge(worker_id, index, train_id, trace_raw):\n",
-    "        correct_func(trace_raw, trace_corr, baselines, bl_start, bl_stop)\n",
+    "    elif num_pulses == 0:\n",
+    "        return\n",
     "\n",
-    "        pulse_offset = pulse_offsets[index]\n",
-    "        pulse_count = pulse_counts[index]\n",
-    "        num_triggers = discr_func(trace_corr, edge_pos, **discr_params[0])\n",
+    "    start_frac = first_pulse_offset + rel_pos * 2 * clock_factor\n",
+    "    start_int = start_frac.astype(int)\n",
     "\n",
-    "        groups = group_boundaries(edge_pos[:num_triggers])\n",
-    "        train_triggers = triggers[pulse_offset:pulse_offset+pulse_count]\n",
+    "    pulse_offset = pulse_offsets[index]\n",
+    "    pulse_count = pulse_counts[index]\n",
     "        \n",
-    "        if num_triggers == 0 or num_triggers != pulse_count:\n",
-    "            print(f'index={index}, train_id={train_id}: Unexpected '\n",
-    "                  f'num_triggers={num_triggers} for pulse_count={pulse_count}')\n",
-    "            return\n",
+    "    train_triggers = triggers[pulse_offset:pulse_offset+pulse_count]\n",
+    "    train_triggers['start'] = start_int\n",
+    "    train_triggers['stop'] = start_int + int(pulse_len * 2 * clock_factor) - 1\n",
+    "    train_triggers['offset'] = start_frac - start_int\n",
+    "    train_triggers['fel'] = [pos in fel_pos for pos in all_pos]\n",
+    "    train_triggers['ppl'] = [pos in ppl_pos for pos in all_pos]\n",
     "\n",
-    "        for (edge, start, stop), pulse_trigger in zip(groups, train_triggers):\n",
-    "            pulse_trigger['start'] = start\n",
-    "            pulse_trigger['stop'] = stop\n",
-    "            pulse_trigger['offset'] = start - edge\n",
-    "            pulse_trigger['fel'] = True\n",
-    "    \n",
-    "    with timing('find_triggers'):\n",
-    "        psh.map(trigger_by_edge, keydata)"
+    "with timing('find_triggers'):\n",
+    "    psh.map(trigger_by_ppt, ppt_data)"
    ]
   },
   {
@@ -349,22 +296,35 @@
    "outputs": [],
    "source": [
     "fig, (lx, rx) = plt.subplots(num=2, ncols=2, nrows=1, figsize=(9, 4), clear=True,\n",
-    "                       gridspec_kw=dict(top=0.75))\n",
+    "                             gridspec_kw=dict(top=0.75))\n",
     "\n",
     "# Display ~400 pulses or 10 trains, whatever is lower\n",
-    "n_trains = max(abs(pulse_offsets - 400).argmin(), 10)\n",
+    "n_trains = max(abs(pulse_offsets - 200).argmin(), 5)\n",
+    "\n",
+    "visible_triggers = triggers[:pulse_offsets[n_trains]] \n",
     "\n",
-    "visible_trigger_starts = triggers['start'][:pulse_offsets[n_trains]]\n",
+    "pulse_index = np.arange(len(visible_triggers))\n",
+    "pumped = visible_triggers['fel'] & visible_triggers['ppl']\n",
+    "fel_only = visible_triggers['fel'] & ~pumped\n",
+    "ppl_only = visible_triggers['ppl'] & ~pumped\n",
     "\n",
-    "lx.plot(visible_trigger_starts, '.', ms=2)\n",
-    "lx.vlines(pulse_offsets[:n_trains], 0, visible_trigger_starts.max(), color='grey', linewidth=1, alpha=0.2)\n",
+    "lx.plot(pulse_index[pumped], visible_triggers[pumped]['start'], ' .', ms=3, c='C0', label='FEL+PPL')\n",
+    "lx.plot(pulse_index[fel_only], visible_triggers[fel_only]['start'], '.', ms=3, c='C1', label='FEL-only')\n",
+    "lx.plot(pulse_index[ppl_only], visible_triggers[ppl_only]['start'], '.', ms=2, c='C2', label='PPL-only')\n",
+    "\n",
+    "max_start = visible_triggers['start'].max()\n",
+    "\n",
+    "lx.vlines(pulse_offsets[:n_trains], 0, max_start, color='grey', linewidth=1, alpha=0.2)\n",
     "lx.tick_params(right=True)\n",
+    "lx.set_ylim(-1, max_start+1)\n",
     "\n",
     "lx.set_xlabel('Pulse index')\n",
     "lx.set_xlim(-15, pulse_offsets[n_trains]+15)\n",
     "\n",
     "lx.set_ylabel('Trigger position')\n",
-    "lx.set_ylim(0, visible_trigger_starts.max())\n",
+    "lx.set_ylim(0, max_start)\n",
+    "\n",
+    "lx.legend(fontsize='small', loc='lower right')\n",
     "\n",
     "train_lx = lx.twiny()\n",
     "train_lx.set_xlabel('Train ID', labelpad=8)\n",
@@ -374,7 +334,10 @@
     "                         rotation=-45, fontsize='x-small')\n",
     "\n",
     "rx.plot(triggers['start'], lw=0.2)\n",
+    "\n",
+    "rx.set_xlabel('Pulse index')\n",
     "rx.tick_params(left=False, labelleft=False, right=True, labelright=True)\n",
+    "\n",
     "pass"
    ]
   },
-- 
GitLab