From 0229c1b9aa0cf887153e999f1629f30bc7fb9235 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Mon, 12 Dec 2022 11:34:29 +0100
Subject: [PATCH] Move load_bpt and get_unique_sase_pId to bunchpattern.py, add
 npulses_has_changed() and get_sase_pId() functions

---
 src/toolbox_scs/load.py               |  61 +++-------
 src/toolbox_scs/misc/bunch_pattern.py | 157 +++++++++++++++++++++++++-
 2 files changed, 173 insertions(+), 45 deletions(-)

diff --git a/src/toolbox_scs/load.py b/src/toolbox_scs/load.py
index a5bec21..0ce3c1c 100644
--- a/src/toolbox_scs/load.py
+++ b/src/toolbox_scs/load.py
@@ -20,6 +20,8 @@ from .constants import mnemonics as _mnemonics
 from .mnemonics_machinery import mnemonics_for_run
 from .util.exceptions import ToolBoxValueError
 import toolbox_scs.detectors as tbdet
+from .misc.bunch_pattern import (npulses_has_changed,
+                                 get_unique_sase_pId, load_bpt)
 
 __all__ = [
     'concatenateRuns',
@@ -55,7 +57,7 @@ def load(proposalNB=None, runNB=None,
          laser_bp=None,
          ):
     """
-    Load a run and extract the data. Output is an xarray with aligned    
+    Load a run and extract the data. Output is an xarray with aligned
     trainIds.
 
     Parameters
@@ -97,9 +99,9 @@ def load(proposalNB=None, runNB=None,
         'FastADC3peaks') and aligns the pulse Id according to the fadc_bp bunch
         pattern.
     extract_fadc2: bool
-        If True, extracts the peaks from FastADC2 variables (e.g. 'FastADC2_5raw',
-        'FastADC2_3peaks') and aligns the pulse Id according to the fadc2_bp bunch
-        pattern.
+        If True, extracts the peaks from FastADC2 variables (e.g.
+        'FastADC2_5raw', 'FastADC2_3peaks') and aligns the pulse Id according
+        to the fadc2_bp bunch pattern.
     extract_xgm: bool
         If True, extracts the values from XGM variables (e.g. 'SCS_SA3',
         'XTD10_XGM') and aligns the pulse Id with the sase1 / sase3 bunch
@@ -153,8 +155,10 @@ def load(proposalNB=None, runNB=None,
     data_arrays = []
     run_mnemonics = mnemonics_for_run(run)
     # load pulse pattern info only if number of sase 3 pulses changed
-    sase3, sase3_changed = get_sase_pId(run)
-    if sase3_changed:
+    sase3 = None
+    if npulses_has_changed(run, run_mnemonics=run_mnemonics) is False:
+        sase3 = get_unique_sase_pId(run, run_mnemonics=run_mnemonics)
+    else:
         log.warning('Number of pulses changed during the run. '
                     'Loading bunch pattern table.')
         bpt = load_bpt(run, run_mnemonics=run_mnemonics)
@@ -220,7 +224,7 @@ def load(proposalNB=None, runNB=None,
 
     data = xr.merge(data_arrays, join='inner')
     data.attrs['runFolder'] = runFolder
-    
+
     # backward compatibility with old-defined variables:
     if extract_tim is not None:
         extract_adq412 = extract_tim
@@ -230,13 +234,14 @@ def load(proposalNB=None, runNB=None,
         adq412_bp = tim_bp
     if laser_bp is not None:
         fadc_bp = laser_bp
-    
+
     adq412 = [k for k in run_mnemonics if 'MCP' in k and k in data]
     if extract_adq412 and len(adq412) > 0:
-        data = tbdet.get_digitizer_peaks(run, mnemonics=adq412, merge_with=data,
-                                        bunchPattern=adq412_bp)
+        data = tbdet.get_digitizer_peaks(run, mnemonics=adq412,
+                                         merge_with=data,
+                                         bunchPattern=adq412_bp)
 
-    fadc = [k for k in run_mnemonics if ('FastADC' in k) 
+    fadc = [k for k in run_mnemonics if ('FastADC' in k)
             and ('FastADC2_' not in k) and (k in data)]
     if extract_fadc and len(fadc) > 0:
         data = tbdet.get_digitizer_peaks(run, mnemonics=fadc, merge_with=data,
@@ -251,7 +256,8 @@ def load(proposalNB=None, runNB=None,
            'SCS_SA1', 'SCS_SA1_sigma', 'SCS_SA3', 'SCS_SA3_sigma']
     xgm = [k for k in xgm if k in data]
     if extract_xgm and len(xgm) > 0:
-        data = tbdet.get_xgm(run, mnemonics=xgm, merge_with=data)
+        data = tbdet.get_xgm(run, mnemonics=xgm, merge_with=data,
+                             sase3=sase3)
 
     bam = [k for k in run_mnemonics if 'BAM' in k and k in data]
     if extract_bam and len(bam) > 0:
@@ -492,34 +498,3 @@ def concatenateRuns(runs):
     for k in orderedRuns[0].attrs.keys():
         result.attrs[k] = [run.attrs[k] for run in orderedRuns]
     return result
-
-
-def load_bpt(run, merge_with=None, run_mnemonics=None):
-    if run_mnemonics is None:
-        run_mnemonics = mnemonics_for_run(run)
-
-    for key in ['bunchPatternTable', 'bunchPatternTable_SA3']:
-        if bool(merge_with) and key in merge_with:
-            log.debug(f'Using {key} from merge_with dataset.')
-            return merge_with[key]
-        if key in run_mnemonics:
-            bpt = run.get_array(*run_mnemonics[key].values(),
-                                name='bunchPatternTable')
-            log.debug(f'Loaded {key} from DataCollection.')
-            return bpt
-    log.debug('Could not find bunch pattern table.')
-    return None
-
-
-def get_sase_pId(run, sase='sase3'):
-    mnemonics = mnemonics_for_run(run)
-    if sase not in mnemonics:
-        # bunch pattern not recorded
-        return [], True
-    npulse_sase = np.unique(get_array(run, 'npulses_' + sase))
-    if len(npulse_sase) == 1:
-        return np.unique(load_run_values(run)[sase])[1:], False
-    # number of pulses changed during the run
-    return np.unique(get_array(run, sase))[1:], True
-
-
diff --git a/src/toolbox_scs/misc/bunch_pattern.py b/src/toolbox_scs/misc/bunch_pattern.py
index 90deb34..34ba18c 100644
--- a/src/toolbox_scs/misc/bunch_pattern.py
+++ b/src/toolbox_scs/misc/bunch_pattern.py
@@ -1,12 +1,14 @@
 # -*- coding: utf-8 -*-
 """ Toolbox for SCS.
 
-    Various utilities function to quickly process data measured at the SCS instruments.
+    Various utilities function to quickly process data
+    measured at the SCS instruments.
 
     Copyright (2019) SCS Team.
 """
 
 import os
+import logging
 
 import numpy as np
 import xarray as xr
@@ -15,13 +17,164 @@ from extra_data import RunDirectory
 
 # import and hide variable, such that it does not alter namespace.
 from ..constants import mnemonics as _mnemonics_bp
+from ..mnemonics_machinery import mnemonics_for_run
+from .bunch_pattern_external import is_pulse_at
 
 __all__ = [
     'extractBunchPattern',
+    'get_sase_pId',
+    'npulses_has_changed',
     'pulsePatternInfo',
-    'repRate'
+    'repRate',
 ]
 
+log = logging.getLogger(__name__)
+
+
+def npulses_has_changed(run, sase='sase3', run_mnemonics=None):
+    """
+    Checks if the number of pulses has changed during the run for
+    a specific location `sase` (='sase1', 'sase3', 'scs_ppl' or 'laser')
+    If the source is not found, returns True.
+
+    Parameters
+    ----------
+    run: extra_data.DataCollection
+        DataCollection containing the data.
+    sase: str
+        The location where to check: {'sase1', 'sase3', 'scs_ppl'}
+    run_mnemonics: dict
+        the mnemonics for the run (see `menonics_for_run`)
+
+    Returns
+    -------
+    ret: bool
+        True if the number of pulses has changed or the source was not
+        found, False if the number of pulses did not change.
+    """
+    if run_mnemonics is None:
+        run_mnemonics = mnemonics_for_run(run)
+    if sase == 'scs_ppl':
+        sase = 'laser'
+    if sase not in run_mnemonics:
+        return True
+    npulses = run.get_array(*run_mnemonics['npulses_'+sase].values())
+    if len(np.unique(npulses)) == 1:
+        return False
+    return True
+
+
+def get_unique_sase_pId(run, sase='sase3', run_mnemonics=None):
+    """
+    Assuming that the number of pulses did not change during the run,
+    returns the pulse Ids as the run value of the sase mnemonic.
+
+    Parameters
+    ----------
+    run: extra_data.DataCollection
+        DataCollection containing the data.
+    sase: str
+        The location where to check: {'sase1', 'sase3', 'scs_ppl'}
+    run_mnemonics: dict
+        the mnemonics for the run (see `menonics_for_run`)
+
+    Returns
+    -------
+    pulseIds: np.array
+        the pulse ids at the specified location. Returns None if the
+        mnemonic is not in the run.
+    """
+    if run_mnemonics is None:
+        run_mnemonics = mnemonics_for_run(run)
+    if sase == 'scs_ppl':
+        sase = 'laser'
+    if sase not in run_mnemonics:
+        # bunch pattern not recorded
+        return None
+    npulses = run.get_run_value(run_mnemonics['npulses_'+sase]['source'],
+                                run_mnemonics['npulses_'+sase]['key'])
+    pulseIds = run.get_run_value(run_mnemonics[sase]['source'],
+                                 run_mnemonics[sase]['key'])[:npulses]
+    return pulseIds
+
+
+def get_sase_pId(run, sase='sase3', run_mnemonics=None,
+                 bpt=None, merge_with=None):
+    """
+    Returns the pulse Ids of the specified `sase` during a run.
+    If the number of pulses has changed during the run, it loads the
+    bunch pattern table and extract all pulse Ids used
+    Parameters
+    ----------
+    run: extra_data.DataCollection
+        DataCollection containing the data.
+    sase: str
+        The location where to check: {'sase1', 'sase3', 'scs_ppl'}
+    run_mnemonics: dict
+        the mnemonics for the run (see `menonics_for_run`)
+    bpt: 2D-array
+        The bunch pattern table. Used only if the number of pulses
+        has changed. If None, it is loaded on the fly.
+    merge_with: xarray.Dataset
+        dataset that may contain the bunch pattern table to use in
+        case the number of pulses has changed. If merge_with does
+        not contain the bunch pattern table, it is loaded and added
+        as a variable 'bunchPatternTable' to merge_with.
+
+    Returns
+    -------
+    pulseIds: np.array
+        the pulse ids at the specified location. Returns None if the
+        mnemonic is not in the run.
+    """
+    if npulses_has_changed(run, sase, run_mnemonics) is False:
+        return get_unique_sase_pId(run, sase, run_mnemonics)
+    if bpt is None:
+        bpt = load_bpt(run, merge_with, run_mnemonics)
+    if bpt is not None:
+        mask = is_pulse_at(bpt, sase)
+        return np.unique(np.nonzero(mask.values)[1])
+    return None
+
+
+def load_bpt(run, merge_with=None, run_mnemonics=None):
+    """
+    Load the bunch pattern table. It returns the one contained in
+    merge_with if possible. Or, it adds it to merge_with once it is
+    loaded.
+    
+    Parameters
+    ----------
+    run: extra_data.DataCollection
+        DataCollection containing the data.
+    merge_with: xarray.Dataset
+        dataset that may contain the bunch pattern table or to which
+        add the bunch pattern table once loaded.
+    run_mnemonics: dict
+        the mnemonics for the run (see `menonics_for_run`)
+
+    Returns
+    -------
+    bpt: xarray.Dataset
+        the bunch pattern table as specified by the mnemonics
+        'bunchPatternTable'
+    """
+    if run_mnemonics is None:
+        run_mnemonics = mnemonics_for_run(run)
+    for key in ['bunchPatternTable', 'bunchPatternTable_SA3']:
+        if merge_with is not None and key in merge_with:
+            log.debug(f'Using {key} from merge_with dataset.')
+            return merge_with[key]
+        if key in run_mnemonics:
+            bpt = run.get_array(*run_mnemonics[key].values(),
+                                name='bunchPatternTable')
+            log.debug(f'Loaded {key} from DataCollection.')
+            if merge_with is not None:
+                merge_with = merge_with.merge(bpt, join='inner')
+            return bpt
+    log.debug('Could not find bunch pattern table.')
+    return None
+
 
 def extractBunchPattern(bp_table=None, key='sase3', runDir=None):
     ''' generate the bunch pattern and number of pulses of a source directly from the
-- 
GitLab