Skip to content
Snippets Groups Projects
xgm.py 16.6 KiB
Newer Older
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr


def pulsePatternInfo(data, plot=False):
    ''' display general information on the pulse patterns operated by SASE1 and SASE3.
        This is useful to track changes of number of pulses or mode of operation of
        SASE1 and SASE3. It also determines which SASE comes first in the train and
        the minimum separation between the two SASE sub-trains.
        
        Inputs:
            data: xarray Dataset containing pulse pattern info from the bunch decoder MDL: 
            {'sase1, sase3', 'npulses_sase1', 'npulses_sase3'}
            plot: bool enabling/disabling the plotting of the pulse patterns
            
        Outputs:
            print of pulse pattern info. If plot==True, plot of the pulse pattern.
    '''
    #Which SASE comes first?
    npulses_sa3 = data['npulses_sase3']       
    npulses_sa1 = data['npulses_sase1']  
    dedicated = False
    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
        dedicated = True
        print('No SASE 1 pulses during SASE 3 operation')
    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
        dedicated = True
        print('No SASE 3 pulses during SASE 1 operation')
    if dedicated==False:
        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
        #print(pulseIdmin_sa1, pulseIdmax_sa1, pulseIdmin_sa3, pulseIdmax_sa3)
        if pulseIdmin_sa1 > pulseIdmax_sa3:
            t = 0.220*(pulseIdmin_sa1 - pulseIdmax_sa3 + 1)
            print('SASE 3 pulses come before SASE 1 pulses (minimum separation %.1f µs)'%t)
        elif pulseIdmin_sa3 > pulseIdmax_sa1:
            t = 0.220*(pulseIdmin_sa3 - pulseIdmax_sa1 + 1)
            print('SASE 1 pulses come before SASE 3 pulses (minimum separation %.1f µs)'%t)
        else:
            print('Interleaved mode')
    
    #What is the pulse pattern of each SASE?
    for key in['sase3', 'sase1']:
        print('\n*** %s pulse pattern: ***'%key.upper())
        npulses = data['npulses_%s'%key]
        sase = data[key]
        if not np.all(npulses == npulses[0]):
            print('Warning: number of pulses per train changed during the run!')
        #take the derivative along the trainId to track changes in pulse number:
        diff = npulses.diff(dim='trainId')
        #only keep trainIds where a change occured:
        diff = diff.where(diff !=0, drop=True)
        #get a list of indices where a change occured:
        idx_change = np.argwhere(np.isin(npulses.trainId.values,
                                         diff.trainId.values, assume_unique=True))[:,0]
        #add index 0 to get the initial pulse number per train:
        idx_change = np.insert(idx_change, 0, 0)
        print('npulses\tindex From\tindex To\ttrainId From\ttrainId To\trep. rate [kHz]')
        for i,idx in enumerate(idx_change):
            n = npulses[idx]
            idxFrom = idx
            trainIdFrom = npulses.trainId[idx]
            if i < len(idx_change)-1:
                idxTo = idx_change[i+1]-1
            else:
                idxTo = npulses.shape[0]-1
            trainIdTo = npulses.trainId[idxTo]
            if n <= 1:
                print('%i\t%i\t\t%i\t\t%i\t%i'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo))
            else:
                f = 1/((sase[idxFrom,1] - sase[idxFrom,0])*222e-6)
                print('%i\t%i\t\t%i\t\t%i\t%i\t%.0f'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo, f))
    print('\n')
    if plot:
        plt.figure(figsize=(6,3))
        plt.plot(data['npulses_sase3'].trainId, data['npulses_sase3'], 'o-', ms=3, label='SASE 3')
        plt.xlabel('trainId')
        plt.ylabel('pulses per train')
        plt.plot(data['npulses_sase1'].trainId, data['npulses_sase1'], '^-', ms=3, color='C2', label='SASE 1')
        plt.legend()
        plt.tight_layout()
        
        
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM'):
    ''' Extract SASE1- or SASE3-only XGM data.
        There are various cases depending on i) the mode of operation (10 Hz
        with fresh bunch, dedicated trains to one SASE, pulse on demand), 
        ii) the potential change of number of pulses per train in each SASE
        and iii) the order (SASE1 first, SASE3 first, interleaved mode).
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase to select: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'SA3_XGM', 'SCS_XGM'}
            
        Output:
            DataArray that has all trainIds that contain a lasing
            train in sase, with dimension equal to the maximum number of pulses of 
            that sase in the run. The missing values, in case of change of number of pulses,
            are filled with NaNs.
    '''
    result = None
    npulses_sa3 = data['npulses_sase3']       
    npulses_sa1 = data['npulses_sase1']
    dedicated = 0
    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
        dedicated += 1
        print('No SASE 1 pulses during SASE 3 operation')
    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
        dedicated += 1
        print('No SASE 3 pulses during SASE 1 operation')
    #Alternating pattern with dedicated pulses in SASE1 and SASE3:
    if dedicated==2:
        if sase=='sase1':
            result = data[xgm].where(npulses_sa1>0, drop=True)[:,:npulses_sa1.max().values]
        else:
            result = data[xgm].where(npulses_sa3>0, drop=True)[:,:npulses_sa3.max().values]
        result = result.where(result != 1.0)
        return result
    # SASE1 and SASE3 bunches in a same train: find minimum indices of first and
    # maximum indices of last pulse per train
    else:
        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
        if pulseIdmin_sa1 > pulseIdmax_sa3:
            sa3First = True
        elif pulseIdmin_sa3 > pulseIdmax_sa1:
            sa3First = False
        else:
            print('Interleaved mode')

    #take the derivative along the trainId to track changes in pulse number:
    diff = npulses_sa3.diff(dim='trainId')
    #only keep trainIds where a change occured:
    diff = diff.where(diff != 0, drop=True)
    #get a list of indices where a change occured:
    idx_change_sa3 = np.argwhere(np.isin(npulses_sa3.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]

    #Same for SASE 1:
    diff = npulses_sa1.diff(dim='trainId')
    diff = diff.where(diff !=0, drop=True)
    idx_change_sa1 = np.argwhere(np.isin(npulses_sa1.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]

    #create index that locates all changes of pulse number in both SASE1 and 3:
    #add index 0 to get the initial pulse number per train:
    idx_change = np.unique(np.concatenate(([0], idx_change_sa3, idx_change_sa1))).astype(int)
    if sase=='sase1':
        npulses = npulses_sa1
        maxpulses = int(npulses_sa1.max().values)
    else:
        npulses = npulses_sa3
        maxpulses = int(npulses_sa3.max().values)
    for i,k in enumerate(idx_change):    
        #skip if no pulses after the change:
        if npulses[idx_change[i]]==0:
            continue
        #calculate indices
        if sa3First:
            a = 0
            b = int(npulses_sa3[k].values)
            c = b
            d = int(c + npulses_sa1[k].values)
        else:
            a = int(npulses_sa1[k].values)
            b = int(a + npulses_sa3[k].values)
            c = 0
            d = a
        if sase=='sase1':
            a = c
            b = d
        if i==len(idx_change)-1:
            l = None
        else:
            l = idx_change[i+1]
        temp = data[xgm][k:l,a:a+maxpulses].copy()
        temp[:,b:] = np.NaN
        if result is None:
            result = temp
        else:
            result = xr.concat([result, temp], dim='trainId')
    return result

def calcContribSASE(data, sase='sase1', xgm='SA3_XGM'):
    ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
        for each train in the run. Supports fresh bunch, dedicated trains
        and pulse on demand modes.
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase for which the contribution is computed: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'SA3_XGM', 'SCS_XGM'}
            
        Output:
            1D DataArray equal to sum(sase)/sum(sase1+sase3)

    '''
    xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm)
    xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm)
    if np.all(xgm_sa1.trainId.isin(xgm_sa3.trainId).values) == False:
        print('Dedicated mode')
        r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['SA3_XGM_dim', 'SA1_XGM_dim'])
        xgm_sa1 = r[0]
        xgm_sa1.fillna(0)
        xgm_sa3 = r[1]
        xgm_sa3.fillna(0)
    contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1))
    if sase=='sase1':
        return contrib
    else:
        return 1 - contrib

def filterOnTrains(data, key='sase3'):
    ''' Removes train ids for which there was no pulse in sase='sase1' or 'sase3' branch
        
        Inputs:
            data: xarray Dataset
            sase: SASE onwhich to filter: {'sase1', 'sase3'}
            
        Output:
            filtered xarray Dataset
    '''
    key = 'npulses_' + key
    res = data.where(data[key]>0, drop=True)
    return res

def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset=1760, mcp=1, npulses=None):
    ''' Computes peak integration from raw MCP traces.
    
        Inputs:
            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            mcp: MCP channel number
            
        Output:
            results: DataArray with dims trainId x max(sase3 pulses)*1MHz/intra-train rep.rate 
            
    '''
    keyraw = 'MCP{}raw'.format(mcp)
    if keyraw not in data:
        raise ValueError("Source not found: {}!".format(keyraw))
    if npulses is None:
        npulses = int((data['sase3'].max().values + 1)/4)
    sa3 = data['sase3'].where(data['sase3']>1)/4
    sa3 -= sa3[:,0]
    results = xr.DataArray(np.empty((sa3.shape[0], npulses)), coords=sa3.coords,
                           dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
    for i in range(npulses):
        a = intstart + t_offset*i
        b = intstop + t_offset*i
        bkga = bkgstart + t_offset*i
        bkgb = bkgstop + t_offset*i
        bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a))
        results[:,i] = np.trapz(data[keyraw][:,a:b] - bg, axis=1)
    return results

def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
              bkgstart=None, bkgstop=None, t_offset=1760, npulses=None):
    ''' Extract peak-integrated data from TIM where pulses are from SASE3 only.
        If use_apd is False it calculates integration from raw traces. 
        The missing values, in case of change of number of pulses, are filled
        with NaNs.
        
            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            mcp: MCP channel number
            npulses: number of pulses to compute

        Output:
            tim: DataArray of shape trainId only for SASE3 pulses x N 
                 with N=max(number of pulses per train)
    '''
    key = 'MCP{}apd'.format(mcp)
    if use_apd:
        apd = data[key]
    else:
        apd = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, t_offset, mcp, npulses)
    npulses_sa3 = data['npulses_sase3']
    sa3 = data['sase3'].where(data['sase3']>1, drop=True)/4
    sa3 -= sa3[:,0]
    sa3 = sa3.astype(int)
    if np.all(npulses_sa3 == npulses_sa3[0]):
        tim = apd[:, sa3[0].values]
        return tim
    maxpulses = int(npulses_sa3.max().values)
    diff = npulses_sa3.diff(dim='trainId')
    #only keep trainIds where a change occured:
    diff = diff.where(diff != 0, drop=True)
    #get a list of indices where a change occured:
    idx_change = np.argwhere(np.isin(npulses_sa3.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]
    #add index 0 to get the initial pulse number per train:
    idx_change = np.insert(idx_change, 0, 0)
    tim = None
    for i,idx in enumerate(idx_change):
        if npulses_sa3[idx]==0:
            continue
        if i==len(idx_change)-1:
            l = None
        else:
            l = idx_change[i+1]
        b = npulses_sa3[idx].values
        temp = apd[idx:l,:maxpulses].copy()
        temp[:,b:] = np.NaN
        if tim is None:
            tim = temp
        else:
            tim = xr.concat([tim, temp], dim='trainId')
    return tim
    
def calibrateTIM(data, rollingWindow=200, mcp=1, use_apd=True, intstart=None, intstop=None,
              bkgstart=None, bkgstop=None, t_offset=1760, npulses_apd=None):
    ''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM
        (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
        The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to
        match the SASE3-only average TIM pulse peak per train (TIM_avg) to the slow XGM 
        signal E_slow.
        Since E_slow is the average energy per pulse over all SASE1 and SASE3 
        pulses (N1 and N3), we first extract the relative contribution C of the SASE3 pulses
        by looking at the pulse-resolved signals of the SA3_XGM in the tunnel.
        There, the signal of SASE1 is usually strong enough to be above noise level.
        Let TIM_avg be the average of the TIM pulses (SASE3 only).
        The calibration factor is then defined as: F = E_slow * C * (N1+N3) / ( N3 * TIM_avg ).
        If N3 changes during the run, we locate the indices for which N3 is maximum and define
        a window where to apply calibration (indices start/stop).
        
        Warning: the calibration does not include the transmission by the KB mirrors!
        
        Inputs:
            data: xarray Dataset
            rolling window: number of trains to perform a running average on to match
                            TIM-avg and E_slow
            mcp: MCP channel
            use_apd: boolean. If False, the TIM pulse peaks are extract from raw traces using
                     getTIMapd
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            mcp: MCP channel number
            npulses_apd: number of pulses
            
        Output:
            F: float, TIM calibration factor.
        
    '''
    start = 0
    stop = None
    npulses = data['npulses_sase3']
    if not np.all(npulses == npulses[0]):
        start = np.argmax(npulses.values)
        stop = ntrains + np.argmax(npulses.values[::-1]) - 1
        if stop - start < rollingWindow:
            print('not enough consecutive data points with the largest number of pulses per train')
        start += rollingWindow
        stop = np.min((ntrains, stop+rollingWindow))
        #print(start, stop)
    filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd)
    sa3contrib = calcContribSASE(data, 'sase3', 'SA3_XGM')
    avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean()
    ratio = ((data['npulses_sase3']+data['npulses_sase1']) *
             data['SCS_XGM_SLOW'] * sa3contrib) / (avgFast*data['npulses_sase3'])
    F = float(ratio[start:stop].median().values)
    return F