From 660045ed0cc20fdcce528c56e36c6cdca4d7b8ab Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Tue, 29 Oct 2019 09:46:35 +0100
Subject: [PATCH] Adds documentation, simplifies selectSASEfromXGM()

---
 xgm.py | 97 +++++++++++++++++++++++++---------------------------------
 1 file changed, 41 insertions(+), 56 deletions(-)

diff --git a/xgm.py b/xgm.py
index 85e0058..61eba7e 100644
--- a/xgm.py
+++ b/xgm.py
@@ -199,11 +199,9 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
 
 def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None):
     ''' Given an array containing both SASE1 and SASE3 data, extracts SASE1-
-        or SASE3-only XGM data.
-        There are various cases depending on i) the mode of operation (10 Hz
-        with fresh bunch, dedicated trains to one SASE, pulse on demand), 
-        ii) the potential change of number of pulses per train in each SASE
-        and iii) the order (SASE1 first, SASE3 first, interleaved mode).
+        or SASE3-only XGM data. The function tracks the changes of bunch patterns
+        in sase 1 and sase 3 and applies a mask to the XGM array to extract the 
+        relevant pulses. This way, all complicated patterns are accounted for.
         
         Inputs:
             data: xarray Dataset containing xgm data
@@ -240,6 +238,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
                 return xgmData[:,start:start+npulses]
     
     #2. case where bunch pattern is provided
+    #2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses
     xgm_arr = data[xgm].where(data[xgm] != 1., drop=True)
     sa3 = data['sase3'].where(data['sase3'] > 1, drop=True)
     sa3_val=np.unique(sa3)
@@ -252,51 +251,38 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
                           dims=['trainId', 'bunchId'],
                           coords={'trainId':data.trainId},
                           name='sase_all')
-    idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
-    idxList3 = np.unique(sa3)
-    idxList1 = np.unique(sa1)
-
     if sase=='sase3':
-        big_sa3 = []
-        for i,idxXGM in enumerate(idxListAll):
-            idxXGM = np.isin(idxXGM, idxList3)
-            idxTid = invAll==i
-            mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
-                            dims=['trainId', 'XGMbunchId'],
-                            coords={'trainId':data.trainId, 
-                                    'XGMbunchId':sa_all['bunchId'].values}, 
-                            name='mask')
-            mask[idxTid, idxXGM] = True
-            sa3 = xgm_arr.where(mask, drop=True)
-            if sa3.trainId.size > 0:
-                sa3 = sa3.assign_coords(XGMbunchId=np.arange(sa3.XGMbunchId.size))
-                big_sa3.append(sa3)
-        if len(big_sa3) > 0:
-            da_sa3 = xr.concat(big_sa3, dim='trainId').rename('{}_SA3'.format(xgm.split('_')[0]))
-        else:
-            da_sa3 = xr.DataArray([], dims=['trainId'], name='{}_SA3'.format(xgm.split('_')[0]))
-        return da_sa3
-
-    if sase=='sase1':
-        big_sa1 = []
-        for i,idxXGM in enumerate(idxListAll):
-            idxXGM = np.isin(idxXGM, idxList1)
-            idxTid = invAll==i
-            mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
-                            dims=['trainId', 'XGMbunchId'],
-                            coords={'trainId':data.trainId, 
-                                    'XGMbunchId':sa_all['bunchId'].values}, 
-                            name='mask')
-            mask[idxTid, idxXGM] = True
-            sa1 = xgm_arr.where(mask, drop=True)
-            if sa1.trainId.size > 0:
-                sa1 = sa1.assign_coords(XGMbunchId=np.arange(sa1.XGMbunchId.size))
-                big_sa1.append(sa1)
-        if len(big_sa1) > 0:
-            da_sa1 = xr.concat(big_sa1, dim='trainId').rename('{}_SA1'.format(xgm.split('_')[0]))
-        else:
-            da_sa1 = xr.DataArray([], dims=['trainId'], name='{}_SA1'.format(xgm.split('_')[0]))  
-        return da_sa1
+        idxListSase = np.unique(sa3)
+        newName = xgm.split('_')[0] + '_SA3'
+    else:
+        idxListSase = np.unique(sa1)
+        newName = xgm.split('_')[0] + '_SA1'
+        
+    #2.2 track the changes of pulse patterns and the indices at which they occured (invAll)
+    idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
+    
+    #2.3 define a mask, loop over the different patterns and update the mask for each pattern
+    mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
+                    dims=['trainId', 'XGMbunchId'],
+                    coords={'trainId':data.trainId, 
+                            'XGMbunchId':sa_all['bunchId'].values}, 
+                    name='mask')
+
+    big_sase = []
+    for i,idxXGM in enumerate(idxListAll):
+        mask.values = np.zeros(mask.shape, dtype=bool)
+        idxXGM = np.isin(idxXGM, idxListSase)
+        idxTid = invAll==i
+        mask[idxTid, idxXGM] = True
+        sa_arr = xgm_arr.where(mask, drop=True)
+        if sa_arr.trainId.size > 0:
+            sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size))
+            big_sase.append(sa_arr)
+    if len(big_sase) > 0:
+        da_sase = xr.concat(big_sase, dim='trainId').rename(newName)
+    else:
+        da_sase = xr.DataArray([], dims=['trainId'], name=newName)
+    return da_sase
 
 def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
     ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
@@ -533,9 +519,8 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
     ''' Extract peak-integrated data from TIM where pulses are from SASE3 only.
         If use_apd is False it calculates integration from raw traces. 
         The missing values, in case of change of number of pulses, are filled
-        with NaNs.
-        If no bunch pattern info is available, the function assumes that
-        SASE 3 comes first and that the number of pulses is fixed in both
+        with NaNs. If no bunch pattern info is available, the function assumes
+        that SASE 3 comes first and that the number of pulses is fixed in both
         SASE 1 and 3.
         
         Inputs:
@@ -544,7 +529,7 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
             intstop: trace index of integration stop
             bkgstart: trace index of background start
             bkgstop: trace index of background stop
-            t_offset: index separation between two pulses
+            t_offset: number of ADC samples between two pulses
             mcp: MCP channel number
             npulses: int, optional. Number of pulses to compute. Required if
                 no bunch pattern info is available.
@@ -602,6 +587,7 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
                          npulses=pulseIdDim)
     
     sa3 = sa3/period
+    #2.3 track the changes of pulse patterns and the indices at which they occured (invAll)
     idxList, inv = np.unique(sa3, axis=0, return_inverse=True)
     mask = xr.DataArray(np.zeros((data.dims['trainId'], pulseIdDim), dtype=bool), 
                         dims=['trainId', pulseId],
@@ -993,11 +979,10 @@ def mergeFastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop,
             intstop: trace index of integration stop
             bkgstart: trace index of background start
             bkgstop: trace index of background stop
-            period: Number of samples separation between two pulses. Needed
+            period: Number of ADC samples between two pulses. Needed
                 if bunch pattern info is not available. If None, checks the 
                 pulse pattern and determine the period assuming a resolution
-                of 9.23 ns per sample which leads to 24 samples between
-                two bunches @ 4.5 MHz. 
+                of 9.23 ns per sample = 24 samples between two pulses @ 4.5 MHz. 
             npulses: number of pulses. If None, takes the maximum number of
                 pulses according to the bunch patter (field 'npulses_sase3')
             dim: name of the xr dataset dimension along the peaks
-- 
GitLab