From dbb705b5501647b44e3e863d916f3b3411f4f1bb Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Sun, 17 Nov 2019 06:43:56 +0100
Subject: [PATCH] Generalized extractSaseBunchPattern() and renamed it to
 extractBunchPattern()

---
 Load.py | 97 ++++++++++++++++++++++++++++++++++-----------------------
 1 file changed, 58 insertions(+), 39 deletions(-)

diff --git a/Load.py b/Load.py
index d41edee..7b54417 100644
--- a/Load.py
+++ b/Load.py
@@ -375,55 +375,71 @@ mnemonics = {
                     'dim': ['gott_pId','pixelId']}
 }
 
-def extractSaseBunchPattern(runDir, sase=3):
-    ''' generate the "saseX" and "npulse_saseX" arrays directly from the bunch pattern
-        table and not using the MDL device BUNCH_DECODER. This is inspired from the 
-        euxfel_bunch_pattern project, 
-        https://git.xfel.eu/gitlab/karaboDevices/euxfel_bunch_pattern.git
+def extractBunchPattern(bp_table=None, key='sase3', runDir=None):
+    ''' generate the bunch pattern and number of pulses of a source directly from the
+        bunch pattern table and not using the MDL device BUNCH_DECODER. This is 
+        inspired by the euxfel_bunch_pattern package, 
+        https://git.xfel.eu/gitlab/karaboDevices/euxfel_bunch_pattern
         Inputs:
-            runDir: run directory obtained by karabo_data.runDirectory()
-            sase: int, sase number between 1 and 3
+            bp_table: DataArray corresponding to the mnemonics "bunchPatternTable".
+                      If None, the bunch pattern table is loaded using runDir.
+            key: str, ['sase1', 'sase2', 'sase3', 'scs_ppl']
+            runDir: karabo_data run directory. Required only if bp_table is None.
             
         Outputs:
-            sase: DataArray containing indices of the sase pulses for each train
-            npulses_sase: DataArray containing the number of pulses for each train
+            bunchPattern: DataArray containing indices of the sase/laser pulses for 
+            each train
+            npulses: DataArray containing the number of pulses for each train
                   
     '''
-    if not (1 <= sase <= 3):
-        raise ValueError("Invalid SASE value {!r}, expected 1-3")
-    # define relevant masks, see euxfel_bunch_pattern project for details
+    keys=['sase1', 'sase2', 'sase3', 'scs_ppl']
+    if key not in keys:
+        raise ValueError(f'Invalid key "{key}", possible values are {keys}')
+    if bp_table is None:
+        if runDir is None:
+            raise ValueError('bp_table and runDir cannot both be None')
+        bp_mnemo = mnemonics['bunchPatternTable']
+        if bp_mnemo['source'] not in runDir.all_sources:
+            raise ValueError('Source {} not found in run'.format(
+                                mnemonics['bunchPatternTable']['source']))
+        else:
+            bp_table = runDir.get_array(bp_mnemo['source'],bp_mnemo['key'], 
+                                        extra_dims=bp_mnemo['dim'])
+    # define relevant masks, see euxfel_bunch_pattern package for details
     DESTINATION_MASK = 0xf << 18
     DESTINATION_T4D = 4 << 18   # SASE1/3 dump
     DESTINATION_T5D = 2 << 18  # SASE2 dump
     PHOTON_LINE_DEFLECTION = 1 << 27  # Soft kick (e.g. SA3)
-    bp_mnemo = mnemonics['bunchPatternTable']
-    bp_table = runDir.get_array(bp_mnemo['source'],bp_mnemo['key'], 
-                                extra_dims=bp_mnemo['dim'])
-    destination = DESTINATION_T5D if (sase == 2) else DESTINATION_T4D
-    matched = (bp_table & DESTINATION_MASK) == destination
-
-    if sase == 1:
-        # Pulses to SASE 1 when soft kick is off
-        matched &= (bp_table & PHOTON_LINE_DEFLECTION) == 0
-    elif sase == 3:
-        # Pulses to SASE 3 when soft kick is on
-        matched &= (bp_table & PHOTON_LINE_DEFLECTION) != 0
-
+    LASER_SEED6 = 1 << 13
+    if 'sase' in key:
+        sase = int(key[4])
+        destination = DESTINATION_T5D if (sase == 2) else DESTINATION_T4D
+        matched = (bp_table & DESTINATION_MASK) == destination
+        if sase == 1:
+            # Pulses to SASE 1 when soft kick is off
+            matched &= (bp_table & PHOTON_LINE_DEFLECTION) == 0
+        elif sase == 3:
+            # Pulses to SASE 3 when soft kick is on
+            matched &= (bp_table & PHOTON_LINE_DEFLECTION) != 0
+    elif key=='scs_ppl':
+        matched = (bp_table & LASER_SEED6) != 0
+    
+    # create table of indices where bunch pattern and mask match
     nz = np.nonzero(matched.values)
     dim_pId = matched.shape[1]
-    sase_array = np.ones(matched.shape, dtype=np.uint64)*dim_pId
-    sase_array[nz] = nz[1]
-    sase_array = np.sort(sase_array)
-    sase_array[sase_array == dim_pId] = 0
+    bunchPattern = np.ones(matched.shape, dtype=np.uint64)*dim_pId
+    bunchPattern[nz] = nz[1]
+    bunchPattern = np.sort(bunchPattern)
+    npulses = np.count_nonzero(bunchPattern<dim_pId, axis=1)
+    bunchPattern[bunchPattern == dim_pId] = 0
 
-    sase_da = xr.DataArray(sase_array[:,:1000], dims=['trainId', 'bunchId'],
+    bunchPattern = xr.DataArray(bunchPattern[:,:1000], dims=['trainId', 'bunchId'],
                           coords={'trainId':matched.trainId}, 
-                          name=f'sase{sase}')
-    npulses_sase = xr.DataArray(np.count_nonzero(sase_da, axis=1), dims=['trainId'],
+                          name=key)
+    npulses = xr.DataArray(npulses, dims=['trainId'],
                                 coords={'trainId':matched.trainId}, 
-                                name=f'npulses_sase{sase}')
-    return sase_da, npulses_sase
-
+                                name=f'npulses_{key}')
+    return bunchPattern, npulses
 
 def load(fields, runNB, proposalNB, subFolder='raw', display=False, validate=False,
          subset=by_index[:], rois={}, useBPTable=True):
@@ -469,14 +485,17 @@ def load(fields, runNB, proposalNB, subFolder='raw', display=False, validate=Fal
     keys = []
     vals = []
 
-    # load pulse pattern infos
+    # load pulse pattern info
     if useBPTable:
-        if mnemonics['bunchPatternTable']['source'] not in run.all_sources:
+        bp_mnemo = mnemonics['bunchPatternTable']
+        if bp_mnemo['source'] not in run.all_sources:
             print('Source {} not found in run. Skipping!'.format(
                                 mnemonics['bunchPatternTable']['source']))
         else:
-            sase1, npulses_sase1 = extractSaseBunchPattern(run, 1)
-            sase3, npulses_sase3 = extractSaseBunchPattern(run, 3)
+            bp_table = run.get_array(bp_mnemo['source'],bp_mnemo['key'], 
+                                        extra_dims=bp_mnemo['dim'])
+            sase1, npulses_sase1 = extractBunchPattern(bp_table, 'sase1')
+            sase3, npulses_sase3 = extractBunchPattern(bp_table, 'sase3')
             keys += ["sase1", "npulses_sase1", "sase3", "npulses_sase3"]
             vals += [sase1, npulses_sase1, sase3, npulses_sase3]
     else:
-- 
GitLab