From 816d83328e562d95d7ecf3a2468ed473408abe5d Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Wed, 22 Apr 2020 23:01:28 +0200
Subject: [PATCH] improve cleanXGMdata for missing bunch pattern and XGM sigma
 arrays

---
 xgm.py | 86 ++++++++++++++++++++++++++++------------------------------
 1 file changed, 41 insertions(+), 45 deletions(-)

diff --git a/xgm.py b/xgm.py
index a886696..edade0b 100644
--- a/xgm.py
+++ b/xgm.py
@@ -17,7 +17,7 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
         The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0
         when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses.
         For DAQ runs after April 2019, sase-resolved arrays can be used. For older runs,
-        the function selectSASEinXGM can be used to extract sase-resolved pulses.
+        the function selectSASEinXGM is used to extract sase-resolved pulses.
         Inputs:
             data: xarray Dataset containing XGM TD arrays.
             npulses: number of pulses, needed if pulse pattern not available.
@@ -33,51 +33,51 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
     if 'sase3' in data:
         if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0,
                                               drop=True) == 0):
-            load_sa1 = False
             print('Dedicated trains, skip loading SASE 1 data.')
+            load_sa1 = False
         npulses_sa1 = None
+    else:
+        print('Missing bunch pattern info!')
+        if npulses is None:
+            raise TypeError('npulses argument is required when bunch pattern ' +
+                             'info is missing.')
     #pulse-resolved signals from XGMs
     keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1", 
             "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma",
             "SCS_XGM", "SCS_SA3", "SCS_SA1",
             "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
     
-    if ("SCS_SA3" not in data and "SCS_XGM" in data):
-        #no SASE-resolved arrays available
-        if not 'sase3' in data:
-            npulses_xgm = data['SCS_XGM'].where(data['SCS_XGM'], drop=True).shape[1]
-            npulses_sa1 = npulses_xgm - npulses
-            if npulses_sa1<=0:
-                load_sa1 = False
-        sa3 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase3', npulses=npulses,
-               sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename("SCS_SA3")
-        mergeList.append(sa3)
-        if load_sa1:
-            sa1 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase1',
-                                  npulses=npulses_sa1, sase3First=sase3First).rename(
-                                  {'XGMbunchId':'sa1_pId'}).rename("SCS_SA1")
-            mergeList.append(sa1)
-        dropList.append('SCS_XGM')
-        keys.remove('SCS_XGM')
-
-    load_sa1 = True
-    if ("XTD10_SA3" not in data and "XTD10_XGM" in data):
-        #no SASE-resolved arrays available
-        if not 'sase3' in data:
-            npulses_xgm = data['XTD10_XGM'].where(data['XTD10_XGM'], drop=True).shape[1]
-            npulses_sa1 = npulses_xgm - npulses
-            if npulses_sa1<=0:
-                load_sa1 = False
-        sa3 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase3', npulses=npulses,
-                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename("XTD10_SA3")
-        mergeList.append(sa3)
-        if load_sa1:
-            sa1 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase1', 
-                                  npulses=npulses_sa1, sase3First=sase3First).rename(
-                                  {'XGMbunchId':'sa1_pId'}).rename("XTD10_SA1")
-            mergeList.append(sa1)
-        dropList.append('XTD10_XGM')
-        keys.remove('XTD10_XGM')
+    for whichXgm in ['SCS', 'XTD10']:
+        load_sa1 = True
+        if (f"{whichXgm}_SA3" not in data and f"{whichXgm}_XGM" in data):
+            #no SASE-resolved arrays available
+            if not 'sase3' in data:
+                npulses_xgm = data[f'{whichXgm}_XGM'].where(data[f'{whichXgm}_XGM'], drop=True).shape[1]
+                npulses_sa1 = npulses_xgm - npulses
+                if npulses_sa1==0:
+                    load_sa1 = False
+                if npulses_sa1 < 0:
+                    raise ValueError(f'npulses = {npulses} is larger than the total number'
+                                     +f' of pulses per train = {npulses_xgm}')
+            sa3 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase3', npulses=npulses,
+                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3")
+            mergeList.append(sa3)
+            if f"{whichXgm}_XGM_sigma" in data:
+                sa3_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase3', npulses=npulses,
+                       sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3_sigma")
+                mergeList.append(sa3_sigma)
+                dropList.append(f'{whichXgm}_XGM_sigma')
+            if load_sa1:
+                sa1 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase1',
+                                      npulses=npulses_sa1, sase3First=sase3First).rename(
+                                      {'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1")
+                mergeList.append(sa1)
+                if f"{whichXgm}_XGM_sigma" in data:
+                    sa1_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase1', npulses=npulses_sa1,
+                           sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1_sigma")
+                    mergeList.append(sa1_sigma)
+            dropList.append(f'{whichXgm}_XGM')
+            keys.remove(f'{whichXgm}_XGM')
         
     for key in keys:
         if key not in data:
@@ -115,7 +115,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
         Inputs:
             data: xarray Dataset containing xgm data
             sase: key of sase to select: {'sase1', 'sase3'}
-            xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'}
+            xgm: key of xgm to select: {'XTD10_XGM[_sigma]', 'SCS_XGM[_sigma]'}
             sase3First: bool, optional. Used in case no bunch pattern was recorded
             npulses: int, optional. Required in case no bunch pattern was recorded.
             
@@ -127,24 +127,20 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
     '''
     #1. case where bunch pattern is missing:
     if sase not in data:
-        print('Missing bunch pattern info!')
-        if npulses is None:
-            raise TypeError('npulses argument is required when bunch pattern ' +
-                             'info is missing.')
         print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4])
               +'SASE {} pulses come first.'.format('3' if sase3First else '1'))
         #in older version of DAQ, non-data numbers were filled with 0.0.
         xgmData = data[xgm].where(data[xgm]!=0.0, drop=True)
         xgmData = xgmData.fillna(0.0).where(xgmData!=1.0, drop=True)
         if (sase3First and sase=='sase3') or (not sase3First and sase=='sase1'):
-            return xgmData[:,:npulses]
+            return xgmData[:,:npulses].assign_coords(XGMbunchId=np.arange(npulses))
         else:
             if xr.ufuncs.isnan(xgmData).any():
                 raise Exception('The number of pulses changed during the run. '
                       'This is not supported yet.')
             else:
                 start=xgmData.shape[1]-npulses
-                return xgmData[:,start:start+npulses]
+                return xgmData[:,start:start+npulses].assign_coords(XGMbunchId=np.arange(npulses))
     
     #2. case where bunch pattern is provided
     #2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses
-- 
GitLab