From bb0c4ba17c0d874871b7f77613f9e2a7d98966aa Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Thu, 22 Dec 2022 10:58:17 +0100
Subject: [PATCH] rename sase to loc and make additional check in
 npulse_has_changed

---
 src/toolbox_scs/misc/bunch_pattern.py | 57 +++++++++++++++------------
 1 file changed, 32 insertions(+), 25 deletions(-)

diff --git a/src/toolbox_scs/misc/bunch_pattern.py b/src/toolbox_scs/misc/bunch_pattern.py
index 57a5f40..6431fd1 100644
--- a/src/toolbox_scs/misc/bunch_pattern.py
+++ b/src/toolbox_scs/misc/bunch_pattern.py
@@ -31,17 +31,17 @@ __all__ = [
 log = logging.getLogger(__name__)
 
 
-def npulses_has_changed(run, sase='sase3', run_mnemonics=None):
+def npulses_has_changed(run, loc='sase3', run_mnemonics=None):
     """
     Checks if the number of pulses has changed during the run for
-    a specific location `sase` (='sase1', 'sase3', 'scs_ppl' or 'laser')
+    a specific location `loc` (='sase1', 'sase3', 'scs_ppl' or 'laser')
     If the source is not found in the run, returns True.
 
     Parameters
     ----------
     run: extra_data.DataCollection
         DataCollection containing the data.
-    sase: str
+    loc: str
         The location where to check: {'sase1', 'sase3', 'scs_ppl'}
     run_mnemonics: dict
         the mnemonics for the run (see `menonics_for_run`)
@@ -53,22 +53,25 @@ def npulses_has_changed(run, sase='sase3', run_mnemonics=None):
         found, False if the number of pulses did not change.
     """
     sase_list = ['sase1', 'sase3', 'laser', 'scs_ppl']
-    if sase not in sase_list:
-        raise ValueError(f"Unknow sase location '{sase}'. Expected one in"
+    if loc not in sase_list:
+        raise ValueError(f"Unknow sase location '{loc}'. Expected one in"
                          f"{sase_list}")
     if run_mnemonics is None:
         run_mnemonics = mnemonics_for_run(run)
-    if sase == 'scs_ppl':
-        sase = 'laser'
-    if sase not in run_mnemonics:
+    if loc == 'scs_ppl':
+        loc = 'laser'
+    if loc not in run_mnemonics:
         return True
-    npulses = run.get_array(*run_mnemonics['npulses_'+sase].values())
+    if run_mnemonics[loc]['key'] not in run[run_mnemonics[loc]['source']].keys():
+        log.info(f'Mnemonic {loc} not found in run.')
+        return True
+    npulses = run.get_array(*run_mnemonics['npulses_'+loc].values())
     if len(np.unique(npulses)) == 1:
         return False
     return True
 
 
-def get_unique_sase_pId(run, sase='sase3', run_mnemonics=None):
+def get_unique_sase_pId(run, loc='sase3', run_mnemonics=None):
     """
     Assuming that the number of pulses did not change during the run,
     returns the pulse Ids as the run value of the sase mnemonic.
@@ -77,7 +80,7 @@ def get_unique_sase_pId(run, sase='sase3', run_mnemonics=None):
     ----------
     run: extra_data.DataCollection
         DataCollection containing the data.
-    sase: str
+    loc: str
         The location where to check: {'sase1', 'sase3', 'scs_ppl'}
     run_mnemonics: dict
         the mnemonics for the run (see `menonics_for_run`)
@@ -90,29 +93,30 @@ def get_unique_sase_pId(run, sase='sase3', run_mnemonics=None):
     """
     if run_mnemonics is None:
         run_mnemonics = mnemonics_for_run(run)
-    if sase == 'scs_ppl':
-        sase = 'laser'
-    if sase not in run_mnemonics:
+    if loc == 'scs_ppl':
+        loc = 'laser'
+    if loc not in run_mnemonics:
         # bunch pattern not recorded
         return None
-    npulses = run.get_run_value(run_mnemonics['npulses_'+sase]['source'],
-                                run_mnemonics['npulses_'+sase]['key'])
-    pulseIds = run.get_run_value(run_mnemonics[sase]['source'],
-                                 run_mnemonics[sase]['key'])[:npulses]
+    npulses = run.get_run_value(run_mnemonics['npulses_'+loc]['source'],
+                                run_mnemonics['npulses_'+loc]['key'])
+    pulseIds = run.get_run_value(run_mnemonics[loc]['source'],
+                                 run_mnemonics[loc]['key'])[:npulses]
     return pulseIds
 
 
-def get_sase_pId(run, sase='sase3', run_mnemonics=None,
+def get_sase_pId(run, loc='sase3', run_mnemonics=None,
                  bpt=None, merge_with=None):
     """
-    Returns the pulse Ids of the specified `sase` during a run.
+    Returns the pulse Ids of the specified `loc` during a run.
     If the number of pulses has changed during the run, it loads the
-    bunch pattern table and extract all pulse Ids used
+    bunch pattern table and extract all pulse Ids used.
+    
     Parameters
     ----------
     run: extra_data.DataCollection
         DataCollection containing the data.
-    sase: str
+    loc: str
         The location where to check: {'sase1', 'sase3', 'scs_ppl'}
     run_mnemonics: dict
         the mnemonics for the run (see `menonics_for_run`)
@@ -131,12 +135,12 @@ def get_sase_pId(run, sase='sase3', run_mnemonics=None,
         the pulse ids at the specified location. Returns None if the
         mnemonic is not in the run.
     """
-    if npulses_has_changed(run, sase, run_mnemonics) is False:
-        return get_unique_sase_pId(run, sase, run_mnemonics)
+    if npulses_has_changed(run, loc, run_mnemonics) is False:
+        return get_unique_sase_pId(run, loc, run_mnemonics)
     if bpt is None:
         bpt = load_bpt(run, merge_with, run_mnemonics)
     if bpt is not None:
-        mask = is_pulse_at(bpt, sase)
+        mask = is_pulse_at(bpt, loc)
         return np.unique(np.nonzero(mask.values)[1])
     return None
 
@@ -334,13 +338,16 @@ def repRate(data=None, runNB=None, proposalNB=None, key='sase3'):
     ''' Calculates the pulse repetition rate (in kHz) in sase
         according to the bunch pattern and assuming a grid of
         4.5 MHz.
+        
         Inputs:
+        -------
             data: xarray Dataset containing pulse pattern, needed if runNB is none
             runNB: int or str, run number. Needed if data is None
             proposal: int or str, proposal where to find the run. Needed if data is None
             key: str in [sase1, sase2, sase3, scs_ppl], source for which the
                  repetition rate is calculated
         Output:
+        -------
             f: repetition rate in kHz
     '''
     if runNB is None and data is None:
-- 
GitLab