From 4c1cf2dfd54f9d0f1cd14287ec3f042140058996 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Mon, 28 Oct 2019 05:22:22 +0100
Subject: [PATCH] Fixes some bugs

---
 xgm.py | 125 ++++++++++++++++++++++++++++++++++++++++-----------------
 1 file changed, 89 insertions(+), 36 deletions(-)

diff --git a/xgm.py b/xgm.py
index d7522cb..fecea7d 100644
--- a/xgm.py
+++ b/xgm.py
@@ -133,33 +133,42 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
     '''
     dropList = []
     mergeList = []
-    if ("XTD10_SA3" not in data and "XTD10_XGM" in data) or (
-        "SCS_SA3" not in data and "SCS_XGM" in data):
+    dedicated = False
+    if 'sase3' in data:
+        if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0,
+                                              drop=True) == 0):
+            dedicated = True
+            print('Dedicated trains, skip loading SASE 1 data.')
+    keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1", 
+            "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma",
+            "SCS_XGM", "SCS_SA3", "SCS_SA1",
+            "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
+    
+    if ("SCS_SA3" not in data and "SCS_XGM" in data):
         #no SASE-resolved arrays available
-        if 'SCS_XGM' in data:
-            sa3 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase3', npulses=npulses,
-                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename('SCS_SA3')
-            mergeList.append(sa3)
-            sa1 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase1', npulses=npulses,
-                   sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename('SCS_SA1')
+        sa3 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase3', npulses=npulses,
+               sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'})
+        mergeList.append(sa3)
+        if not dedicated:
+            sa1 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase1',
+                                  npulses=npulses, sase3First=sase3First).rename(
+                                  {'XGMbunchId':'sa1_pId'})
             mergeList.append(sa1)
-            dropList.append('SCS_XGM')
-
-        if 'XTD10_XGM' in data:
-            sa3 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase3', npulses=npulses,
-                       sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename('XTD10_SA3')
-            mergeList.append(sa3)
-            sa1 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase1', npulses=npulses,
-                       sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename('XTD10_SA1')
+        dropList.append('SCS_XGM')
+        keys.remove('SCS_XGM')
+
+    if ("XTD10_SA3" not in data and "XTD10_XGM" in data):
+        #no SASE-resolved arrays available
+        sa3 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase3', npulses=npulses,
+                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'})
+        mergeList.append(sa3)
+        if not dedicated:
+            sa1 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase1', 
+                                  npulses=npulses, sase3First=sase3First).rename(
+                                  {'XGMbunchId':'sa1_pId'})
             mergeList.append(sa1)
-            dropList.append('XTD10_XGM')
-        keys = []
-        
-    else:
-        keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1",
-                "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma"]
-        keys += ["SCS_XGM", "SCS_SA3", "SCS_SA1",
-                 "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
+        dropList.append('XTD10_XGM')
+        keys.remove('XTD10_XGM')
         
     for key in keys:
         if key not in data:
@@ -168,6 +177,9 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
             sase = 'sa3_'
         elif "sa1" in key.lower():
             sase = 'sa1_'
+            if dedicated:
+                dropList.append(key)
+                continue
         else:
             dropList.append(key)
             continue
@@ -176,7 +188,7 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
         dropList.append(key)
         mergeList.append(res)
     mergeList.append(data.drop(dropList))
-    subset = xr.merge(mergeList, join='outer')
+    subset = xr.merge(mergeList, join='inner')
     for k in data.attrs.keys():
         subset.attrs[k] = data.attrs[k]
     return subset
@@ -237,13 +249,51 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
                           dims=['trainId', 'bunchId'],
                           coords={'trainId':data.trainId},
                           name='sase_all')
-    mask_sa1 = sa_all.sel(trainId=sa1.trainId).isin(sa1_val).rename({'bunchId':'XGMbunchId'})
-    mask_sa3 = sa_all.sel(trainId=sa3.trainId).isin(sa3_val).rename({'bunchId':'XGMbunchId'})
+    idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
+    idxList3 = np.unique(sa3)
+    idxList1 = np.unique(sa1)
+
+    if sase=='sase3':
+        big_sa3 = []
+        for i,idxXGM in enumerate(idxListAll):
+            idxXGM = np.isin(idxXGM, idxList3)
+            idxTid = invAll==i
+            mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
+                            dims=['trainId', 'XGMbunchId'],
+                            coords={'trainId':data.trainId, 
+                                    'XGMbunchId':sa_all['bunchId'].values}, 
+                            name='mask')
+            mask[idxTid, idxXGM] = True
+            sa3 = xgm_arr.where(mask, drop=True)
+            if sa3.trainId.size > 0:
+                sa3 = sa3.assign_coords(XGMbunchId=np.arange(sa3.XGMbunchId.size))
+                big_sa3.append(sa3)
+        if len(big_sa3) > 0:
+            da_sa3 = xr.concat(big_sa3, dim='trainId').rename('{}_SA3'.format(xgm.split('_')[0]))
+        else:
+            da_sa3 = xr.DataArray([], dims=['trainId'], name='{}_SA3'.format(xgm.split('_')[0]))
+        return da_sa3
+
     if sase=='sase1':
-        return xgm_arr.where(mask_sa1, drop=True).rename('{}_SA1'.format(xgm.split('_')[0]))
-    else:
-        return xgm_arr.where(mask_sa3, drop=True).rename('{}_SA3'.format(xgm.split('_')[0]))
-    
+        big_sa1 = []
+        for i,idxXGM in enumerate(idxListAll):
+            idxXGM = np.isin(idxXGM, idxList1)
+            idxTid = invAll==i
+            mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
+                            dims=['trainId', 'XGMbunchId'],
+                            coords={'trainId':data.trainId, 
+                                    'XGMbunchId':sa_all['bunchId'].values}, 
+                            name='mask')
+            mask[idxTid, idxXGM] = True
+            sa1 = xgm_arr.where(mask, drop=True)
+            if sa1.trainId.size > 0:
+                sa1 = sa1.assign_coords(XGMbunchId=np.arange(sa1.XGMbunchId.size))
+                big_sa1.append(sa1)
+        if len(big_sa1) > 0:
+            da_sa1 = xr.concat(big_sa1, dim='trainId').rename('{}_SA1'.format(xgm.split('_')[0]))
+        else:
+            da_sa1 = xr.DataArray([], dims=['trainId'], name='{}_SA1'.format(xgm.split('_')[0]))  
+        return da_sa1
 
 def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
     ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
@@ -539,14 +589,17 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
     
     #2.1 case where apd is used:
     if use_apd:
-        return data[f'MCP{mcp}apd'].where(mask, drop=True)
-
+        res = data[f'MCP{mcp}apd'].where(mask, drop=True)
+        res = res.assign_coords(apdId=np.arange(res['apdId'].shape[0]))
+        return res
     #2.2 case where integration is performed on raw trace:
     else:
-        peaks = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, 
-                               t_offset=period, npulses=data.dims['apdId'])
+        peaks = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, t_offset=period,
+                         npulses=data.dims['apdId'])
         mask = mask.rename({'apdId':f'MCP{mcp}fromRaw'})
-        return peaks.where(mask, drop=True)
+        res = peaks.where(mask, drop=True)
+        res = res.assign_coords({f'MCP{mcp}fromRaw':np.arange(res[f'MCP{mcp}fromRaw'].shape[0])})
+        return res
 
 
 def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None,
-- 
GitLab