From 75762388cf4bf8944b36a5b3a347d57685ea2a00 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Fri, 15 Mar 2019 22:20:53 +0100
Subject: [PATCH] Add new file with pulsePatternInfo, xgm and TIM functions

---
 xgm.py | 330 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 330 insertions(+)
 create mode 100644 xgm.py

diff --git a/xgm.py b/xgm.py
new file mode 100644
index 0000000..97d8518
--- /dev/null
+++ b/xgm.py
@@ -0,0 +1,330 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import xarray as xr
+
+
+def pulsePatternInfo(data, plot=False):
+    ''' display general information on the pulse patterns operated by SASE1 and SASE3.
+        This is useful to track changes of number of pulses or mode of operation of
+        SASE1 and SASE3. It also determines which SASE comes first in the train and
+        the minimum separation between the two SASE sub-trains.
+        
+        Inputs:
+            data: xarray Dataset containing pulse pattern info from the bunch decoder MDL: 
+            {'sase1, sase3', 'npulses_sase1', 'npulses_sase3'}
+            plot: bool enabling/disabling the plotting of the pulse patterns
+            
+        Outputs:
+            print of pulse pattern info. If plot==True, plot of the pulse pattern.
+    '''
+    #Which SASE comes first?
+    npulses_sa3 = data['npulses_sase3']       
+    npulses_sa1 = data['npulses_sase1']  
+    dedicated = False
+    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
+        dedicated = True
+        print('No SASE 1 pulses during SASE 3 operation')
+    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
+        dedicated = True
+        print('No SASE 3 pulses during SASE 1 operation')
+    if dedicated==False:
+        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
+        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
+        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
+        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
+        #print(pulseIdmin_sa1, pulseIdmax_sa1, pulseIdmin_sa3, pulseIdmax_sa3)
+        if pulseIdmin_sa1 > pulseIdmax_sa3:
+            t = 0.220*(pulseIdmin_sa1 - pulseIdmax_sa3 + 1)
+            print('SASE 3 pulses come before SASE 1 pulses (minimum separation %.1f µs)'%t)
+        elif pulseIdmin_sa3 > pulseIdmax_sa1:
+            t = 0.220*(pulseIdmin_sa3 - pulseIdmax_sa1 + 1)
+            print('SASE 1 pulses come before SASE 3 pulses (minimum separation %.1f µs)'%t)
+        else:
+            print('Interleaved mode')
+    
+    #What is the pulse pattern of each SASE?
+    for key in['sase3', 'sase1']:
+        print('\n*** %s pulse pattern: ***'%key.upper())
+        npulses = data['npulses_%s'%key]
+        sase = data[key]
+        if not np.all(npulses == npulses[0]):
+            print('Warning: number of pulses per train changed during the run!')
+        #take the derivative along the trainId to track changes in pulse number:
+        diff = npulses.diff(dim='trainId')
+        #only keep trainIds where a change occured:
+        diff = diff.where(diff !=0, drop=True)
+        #get a list of indices where a change occured:
+        idx_change = np.argwhere(np.isin(npulses.trainId.values,
+                                         diff.trainId.values, assume_unique=True))[:,0]
+        #add index 0 to get the initial pulse number per train:
+        idx_change = np.insert(idx_change, 0, 0)
+        print('npulses\tindex From\tindex To\ttrainId From\ttrainId To\trep. rate [kHz]')
+        for i,idx in enumerate(idx_change):
+            n = npulses[idx]
+            idxFrom = idx
+            trainIdFrom = npulses.trainId[idx]
+            if i < len(idx_change)-1:
+                idxTo = idx_change[i+1]-1
+            else:
+                idxTo = npulses.shape[0]-1
+            trainIdTo = npulses.trainId[idxTo]
+            if n <= 1:
+                print('%i\t%i\t\t%i\t\t%i\t%i'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo))
+            else:
+                f = 1/((sase[idxFrom,1] - sase[idxFrom,0])*222e-6)
+                print('%i\t%i\t\t%i\t\t%i\t%i\t%.0f'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo, f))
+    print('\n')
+    if plot:
+        plt.figure(figsize=(6,3))
+        plt.plot(data['npulses_sase3'].trainId, data['npulses_sase3'], 'o-', ms=3, label='SASE 3')
+        plt.xlabel('trainId')
+        plt.ylabel('pulses per train')
+        plt.plot(data['npulses_sase1'].trainId, data['npulses_sase1'], '^-', ms=3, color='C2', label='SASE 1')
+        plt.legend()
+        plt.tight_layout()
+        
+        
+def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM'):
+    ''' Extract SASE1- or SASE3-only XGM data.
+        There are various cases depending on i) the mode of operation (10 Hz
+        with fresh bunch, dedicated trains to one SASE, pulse on demand), 
+        ii) the potential change of number of pulses per train in each SASE
+        and iii) the order (SASE1 first, SASE3 first, interleaved mode).
+        
+        Inputs:
+            data: xarray Dataset containing xgm data
+            sase: key of sase to select: {'sase1', 'sase3'}
+            xgm: key of xgm to select: {'SA3_XGM', 'SCS_XGM'}
+            
+        Output:
+            DataArray that has all trainIds that contain a lasing
+            train in sase, with dimension equal to the maximum number of pulses of 
+            that sase in the run. The missing values, in case of change of number of pulses,
+            are filled with NaNs.
+    '''
+    result = None
+    npulses_sa3 = data['npulses_sase3']       
+    npulses_sa1 = data['npulses_sase1']
+    dedicated = 0
+    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
+        dedicated += 1
+        print('No SASE 1 pulses during SASE 3 operation')
+    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
+        dedicated += 1
+        print('No SASE 3 pulses during SASE 1 operation')
+    #Alternating pattern with dedicated pulses in SASE1 and SASE3:
+    if dedicated==2:
+        if sase=='sase1':
+            result = data[xgm].where(npulses_sa1>0, drop=True)[:,:npulses_sa1.max().values]
+        else:
+            result = data[xgm].where(npulses_sa3>0, drop=True)[:,:npulses_sa3.max().values]
+        result = result.where(result != 1.0)
+        return result
+    # SASE1 and SASE3 bunches in a same train: find minimum indices of first and
+    # maximum indices of last pulse per train
+    else:
+        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
+        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
+        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
+        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
+        if pulseIdmin_sa1 > pulseIdmax_sa3:
+            sa3First = True
+        elif pulseIdmin_sa3 > pulseIdmax_sa1:
+            sa3First = False
+        else:
+            print('Interleaved mode')
+
+    #take the derivative along the trainId to track changes in pulse number:
+    diff = npulses_sa3.diff(dim='trainId')
+    #only keep trainIds where a change occured:
+    diff = diff.where(diff != 0, drop=True)
+    #get a list of indices where a change occured:
+    idx_change_sa3 = np.argwhere(np.isin(npulses_sa3.trainId.values,
+                                     diff.trainId.values, assume_unique=True))[:,0]
+    #add index 0 to get the initial pulse number per train:
+    idx_change_sa3 = np.insert(idx_change_sa3, 0, 0)
+
+    #Same for SASE 1:
+    diff = npulses_sa1.diff(dim='trainId')
+    diff = diff.where(diff !=0, drop=True)
+    idx_change_sa1 = np.argwhere(np.isin(npulses_sa1.trainId.values,
+                                     diff.trainId.values, assume_unique=True))[:,0]
+    idx_change_sa1 = np.insert(idx_change_sa1, 0, 0)
+
+    #create index that locates all changes of pulse number in both SASE1 and 3:
+    idx_change = np.unique(np.concatenate(([0], idx_change_sa3, idx_change_sa1))).astype(int)
+    if sase=='sase1':
+        npulses = npulses_sa1
+        maxpulses = int(npulses_sa1.max().values)
+    else:
+        npulses = npulses_sa3
+        maxpulses = int(npulses_sa3.max().values)
+    for i,k in enumerate(idx_change):    
+        #skip if no pulses after the change:
+        if npulses[idx_change[i]]==0:
+            continue
+        #calculate indices
+        if sa3First:
+            a = 0
+            b = int(npulses_sa3[k].values)
+            c = b
+            d = int(c + npulses_sa1[k].values)
+        else:
+            a = int(npulses_sa1[k].values)
+            b = int(a + npulses_sa3[k].values)
+            c = 0
+            d = a
+        if sase=='sase1':
+            a = c
+            b = d
+        if i==len(idx_change)-1:
+            l = None
+        else:
+            l = idx_change[i+1]
+        temp = data[xgm][k:l,a:a+maxpulses].copy()
+        temp[:,b:] = np.NaN
+        if result is None:
+            result = temp
+        else:
+            result = xr.concat([result, temp], dim='trainId')
+    return result
+
+def calcContribSASE(data, sase='sase1', xgm='SA3_XGM'):
+    ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
+        for each train in the run. Supports fresh bunch, dedicated trains
+        and pulse on demand modes.
+        
+        Inputs:
+            data: xarray Dataset containing xgm data
+            sase: key of sase for which the contribution is computed: {'sase1', 'sase3'}
+            xgm: key of xgm to select: {'SA3_XGM', 'SCS_XGM'}
+            
+        Output:
+            1D DataArray equal to sum(sase)/sum(sase1+sase3)
+
+    '''
+    xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm)
+    xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm)
+    if np.all(xgm_sa1.trainId.isin(xgm_sa3.trainId).values) == False:
+        print('Dedicated mode')
+        r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['SA3_XGM_dim', 'SA1_XGM_dim'])
+        xgm_sa1 = r[0]
+        xgm_sa1.fillna(0)
+        xgm_sa3 = r[1]
+        xgm_sa3.fillna(0)
+    contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1))
+    if sase=='sase1':
+        return contrib
+    else:
+        return 1 - contrib
+
+def filterOnTrains(data, key='sase3'):
+    ''' Removes train ids for which there was no pulse in sase='sase1' or 'sase3' branch
+        
+        Inputs:
+            data: xarray Dataset
+            sase: SASE onwhich to filter: {'sase1', 'sase3'}
+            
+        Output:
+            filtered xarray Dataset
+    '''
+    key = 'npulses_' + key
+    res = {}
+    for d in data:
+        res[d] = data[d].where(data[key]>0, drop=True)
+    return xr.Dataset(res)
+
+
+def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset=1760, mcp=1, npulses=None):
+    ''' Computes peak integration from raw MCP traces.
+    
+        Inputs:
+            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
+            npulses: number of pulses
+            intstart: trace index of integration start
+            intstop: trace index of integration stop
+            bkgstart: trace index of background start
+            bkgstop: trace index of background stop
+            t_offset: index separation between two pulses
+            mcp: MCP channel number
+            
+        Output:
+            results: DataArray with dims trainId x max(sase3 pulses)*1MHz/intra-train rep.rate 
+            
+    '''
+    keyraw = 'MCP{}raw'.format(mcp)
+    if keyraw not in data:
+        raise ValueError("Source not found: {}!".format(keyraw))
+    if npulses is None:
+        npulses = int((data['sase3'].max().values + 1)/4)
+        sa3 = data['sase3'].where(data['sase3']>1)/4
+        sa3 -= sa3[:,0]
+    results = xr.DataArray(np.empty((sa3.shape[0], npulses)), coords=sa3.coords,
+                           dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
+    for i in range(npulses):
+        a = intstart + t_offset*i
+        b = intstop + t_offset*i
+        bkga = bkgstart + t_offset*i
+        bkgb = bkgstop + t_offset*i
+        bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a))
+        results[:,i] = np.trapz(data[keyraw][:,a:b] - bg, axis=1)
+    return results
+
+def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
+              bkgstart=None, bkgstop=None, t_offset=1760, npulses=None):
+    ''' Extract peak-integrated data from TIM where pulses are from SASE3 only.
+        If use_apd is False it calculates integration from raw traces. 
+        The missing values, in case of change of number of pulses, are filled
+        with NaNs.
+        
+            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
+            intstart: trace index of integration start
+            intstop: trace index of integration stop
+            bkgstart: trace index of background start
+            bkgstop: trace index of background stop
+            t_offset: index separation between two pulses
+            mcp: MCP channel number
+            npulses: number of pulses to compute
+
+        Output:
+            tim: DataArray of shape trainId only for SASE3 pulses x N 
+                 with N=max(number of pulses per train)
+    '''
+    key = 'MCP{}apd'.format(mcp)
+    if use_apd:
+        apd = data[key]
+    else:
+        apd = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset, mcp, npulses)
+    npulses_sa3 = data['npulses_sase3']
+    sa3 = data['sase3'].where(data['sase3']>1, drop=True)/4
+    sa3 -= sa3[:,0]
+    sa3 = sa3.astype(int)
+    if np.all(npulses_sa3 == npulses_sa3[0]):
+        tim = apd[:, sa3[0].values]
+        return tim
+    maxpulses = int(npulses_sa3.max().values)
+    diff = npulses_sa3.diff(dim='trainId')
+    #only keep trainIds where a change occured:
+    diff = diff.where(diff != 0, drop=True)
+    #get a list of indices where a change occured:
+    idx_change = np.argwhere(np.isin(npulses_sa3.trainId.values,
+                                     diff.trainId.values, assume_unique=True))[:,0]
+    #add index 0 to get the initial pulse number per train:
+    idx_change = np.insert(idx_change, 0, 0)
+    tim = None
+    for i,idx in enumerate(idx_change):
+        if npulses_sa3[idx]==0:
+            continue
+        if i==len(idx_change)-1:
+            l = None
+        else:
+            l = idx_change[i+1]
+        b = npulses_sa3[idx].values
+        temp = apd[idx:l,:maxpulses].copy()
+        temp[:,b:] = np.NaN
+        if tim is None:
+            tim = temp
+        else:
+            tim = xr.concat([tim, temp], dim='trainId')
+    return tim
+    
-- 
GitLab