Skip to content
Snippets Groups Projects
xgm.py 48.3 KiB
Newer Older
# -*- coding: utf-8 -*-
""" Toolbox for SCS.

    Various utilities function to quickly process data measured at the SCS instruments.

    Copyright (2019) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr


def pulsePatternInfo(data, plot=False):
    ''' display general information on the pulse patterns operated by SASE1 and SASE3.
        This is useful to track changes of number of pulses or mode of operation of
        SASE1 and SASE3. It also determines which SASE comes first in the train and
        the minimum separation between the two SASE sub-trains.
        
        Inputs:
            data: xarray Dataset containing pulse pattern info from the bunch decoder MDL: 
            {'sase1, sase3', 'npulses_sase1', 'npulses_sase3'}
            plot: bool enabling/disabling the plotting of the pulse patterns
            
        Outputs:
            print of pulse pattern info. If plot==True, plot of the pulse pattern.
    '''
    #Which SASE comes first?
    npulses_sa3 = data['npulses_sase3']       
    npulses_sa1 = data['npulses_sase1']  
    dedicated = False
    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
        dedicated = True
        print('No SASE 1 pulses during SASE 3 operation')
    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
        dedicated = True
        print('No SASE 3 pulses during SASE 1 operation')
    if dedicated==False:
        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
        #print(pulseIdmin_sa1, pulseIdmax_sa1, pulseIdmin_sa3, pulseIdmax_sa3)
        if pulseIdmin_sa1 > pulseIdmax_sa3:
            t = 0.220*(pulseIdmin_sa1 - pulseIdmax_sa3 + 1)
            print('SASE 3 pulses come before SASE 1 pulses (minimum separation %.1f µs)'%t)
        elif pulseIdmin_sa3 > pulseIdmax_sa1:
            t = 0.220*(pulseIdmin_sa3 - pulseIdmax_sa1 + 1)
            print('SASE 1 pulses come before SASE 3 pulses (minimum separation %.1f µs)'%t)
        else:
            print('Interleaved mode')
    
    #What is the pulse pattern of each SASE?
    for key in['sase3', 'sase1']:
        print('\n*** %s pulse pattern: ***'%key.upper())
        npulses = data['npulses_%s'%key]
        sase = data[key]
        if not np.all(npulses == npulses[0]):
            print('Warning: number of pulses per train changed during the run!')
        #take the derivative along the trainId to track changes in pulse number:
        diff = npulses.diff(dim='trainId')
        #only keep trainIds where a change occured:
        diff = diff.where(diff !=0, drop=True)
        #get a list of indices where a change occured:
        idx_change = np.argwhere(np.isin(npulses.trainId.values,
                                         diff.trainId.values, assume_unique=True))[:,0]
        #add index 0 to get the initial pulse number per train:
        idx_change = np.insert(idx_change, 0, 0)
        print('npulses\tindex From\tindex To\ttrainId From\ttrainId To\trep. rate [kHz]')
        for i,idx in enumerate(idx_change):
            n = npulses[idx]
            idxFrom = idx
            trainIdFrom = npulses.trainId[idx]
            if i < len(idx_change)-1:
                idxTo = idx_change[i+1]-1
            else:
                idxTo = npulses.shape[0]-1
            trainIdTo = npulses.trainId[idxTo]
            if n <= 1:
                print('%i\t%i\t\t%i\t\t%i\t%i'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo))
            else:
                f = 1/((sase[idxFrom,1] - sase[idxFrom,0])*222e-6)
                print('%i\t%i\t\t%i\t\t%i\t%i\t%.0f'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo, f))
    print('\n')
    if plot:
        plt.figure(figsize=(6,3))
Mercadier's avatar
Mercadier committed
        plt.plot(data['npulses_sase3'].trainId, data['npulses_sase3'], 'o-', 
                 ms=3, label='SASE 3')
        plt.xlabel('trainId')
        plt.ylabel('pulses per train')
Mercadier's avatar
Mercadier committed
        plt.plot(data['npulses_sase1'].trainId, data['npulses_sase1'], '^-',
                 ms=3, color='C2', label='SASE 1')
Mercadier's avatar
Mercadier committed

def repRate(data, sase='sase3'):
    ''' Calculates the pulse repetition rate in sase according
        to the bunch pattern and assuming a minimum pulse 
Mercadier's avatar
Mercadier committed
        Inputs:
            data: xarray Dataset containing pulse pattern
            sase: sase in which the repetition rate is
                  calculated (1,2 or 3)
        Output:
            f: repetition rate in kHz
    '''
    assert sase in data, 'key "{}" not found in data!'.format(sase)
    sase = data[sase].where(data['npulses_{}'.format(sase)]>1,
                            drop=True).values
    if len(sase)==0:
        print('Not enough pulses to extract repetition rate')
        return 0
    f = 1/((sase[0,1] - sase[0,0])*221.54e-6)
def cleanXGMdata(data, npulses=None, sase3First=True):
    ''' Cleans the XGM data arrays obtained from load() function.
        The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0
        when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses.
        For recent DAQ runs, sase-resolved arrays can be used. For older runs,
        the function selectSASEinXGM can be used to extract sase-resolved pulses.
        Inputs:
            data: xarray Dataset containing XGM TD arrays.
            npulses: number of pulses, needed if pulse pattern not available.
            sase3First: bool, needed if pulse pattern not available.
        
        Output:
            xarray Dataset containing sase- and pulse-resolved XGM data, with
                dimension names 'sa1_pId' and 'sa3_pId'                
    '''
    dropList = []
    mergeList = []
    if ("XTD10_SA3" not in data and "XTD10_XGM" in data) or (
        "SCS_SA3" not in data and "SCS_XGM" in data):
        #no SASE-resolved arrays available
        if 'SCS_XGM' in data:
            sa3 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase3', npulses=npulses,
                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename('SCS_SA3')
            mergeList.append(sa3)
            sa1 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase1', npulses=npulses,
                   sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename('SCS_SA1')
            mergeList.append(sa1)
            dropList.append('SCS_XGM')

        if 'XTD10_XGM' in data:
            sa3 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase3', npulses=npulses,
                       sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename('XTD10_SA3')
            mergeList.append(sa3)
            sa1 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase1', npulses=npulses,
                       sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename('XTD10_SA1')
            mergeList.append(sa1)
            dropList.append('XTD10_XGM')
        keys = []
        
    else:
        keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1",
                "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma"]
        keys += ["SCS_XGM", "SCS_SA3", "SCS_SA1",
                 "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
        
    for key in keys:
        if key not in data:
            continue
        if "sa3" in key.lower():
            sase = 'sa3_'
        elif "sa1" in key.lower():
            sase = 'sa1_'
        else:
            dropList.append(key)
            continue
        res = data[key].where(data[key] != 1.0, drop=True).rename(
                {'XGMbunchId':'{}pId'.format(sase)}).rename(key)
        dropList.append(key)
        mergeList.append(res)
    mergeList.append(data.drop(dropList))
    subset = xr.merge(mergeList, join='inner')
    for k in data.attrs.keys():
        subset.attrs[k] = data.attrs[k]
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None):
    ''' Extract SASE1- or SASE3-only XGM data.
        There are various cases depending on i) the mode of operation (10 Hz
        with fresh bunch, dedicated trains to one SASE, pulse on demand), 
        ii) the potential change of number of pulses per train in each SASE
        and iii) the order (SASE1 first, SASE3 first, interleaved mode).
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase to select: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'SA3_XGM', 'SCS_XGM'}
            sase3First: bool, optional. Used in case no bunch pattern was recorded
            npulses: int, optional. Required in case no bunch pattern was recorded.
            
        Output:
            DataArray that has all trainIds that contain a lasing
            train in sase, with dimension equal to the maximum number of pulses of 
            that sase in the run. The missing values, in case of change of number of pulses,
            are filled with NaNs.
    '''
        print('Missing bunch pattern info!')
        if npulses is None:
            raise TypeError('npulses argument is required when bunch pattern ' +
                             'info is missing.')
        print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4])
              +'SASE {} pulses come first.'.format('3' if sase3First else '1'))
        #in older version of DAQ, non-data numbers were filled with 0.0.
        xgmData = data[xgm].where(data[xgm]!=0.0, drop=True)
        xgmData = xgmData.fillna(0.0).where(xgmData!=1.0, drop=True)
        if (sase3First and sase=='sase3') or (not sase3First and sase=='sase1'):
            return xgmData[:,:npulses]
        else:
            if xr.ufuncs.isnan(xgmData).any():
                raise Exception('The number of pulses changed during the run. '
                      'This is not supported yet.')
            else:
                start=xgmData.shape[1]-npulses
                return xgmData[:,start:start+npulses]
    result = None
    npulses_sa3 = data['npulses_sase3']       
    npulses_sa1 = data['npulses_sase1']
    dedicated = 0
    if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
        dedicated += 1
        print('No SASE 1 pulses during SASE 3 operation')
    if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
        dedicated += 1
        print('No SASE 3 pulses during SASE 1 operation')
    #Alternating pattern with dedicated pulses in SASE1 and SASE3:
    if dedicated==2:
        if sase=='sase1':
            result = data[xgm].where(npulses_sa1>0, drop=True)[:,:npulses_sa1.max().values]
        else:
            result = data[xgm].where(npulses_sa3>0, drop=True)[:,:npulses_sa3.max().values]
        result = result.where(result != 1.0)
        return result
    # SASE1 and SASE3 bunches in a same train: find minimum indices of first and
    # maximum indices of last pulse per train
    else:
        pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
        pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
        pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
        pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
        if pulseIdmin_sa1 > pulseIdmax_sa3:
            sa3First = True
        elif pulseIdmin_sa3 > pulseIdmax_sa1:
            sa3First = False
        else:
            print('Interleaved mode, but no sase-dedicated XGM data loaded.')
            saseStr = 'SA{}'.format(sase[4])
            xgmStr = xgm.split('_')[0]
            print('Loading {}_{} data...'.format(xgmStr, saseStr))
            try:
                if npulses == None:
                    npulses = data['npulses_sase{}'.format(sase[4])].max().values
                if xgmStr == 'XTD10':
                    source = 'SA3_XTD10_XGM/XGM/DOOCS:output'
                if xgmStr == 'SCS':
                    source = 'SCS_BLU_XGM/XGM/DOOCS:output'
                key = 'data.intensitySa{}TD'.format(sase[4])
                result = data.attrs['run'].get_array(source, key, extra_dims=['XGMbunchId'])
                result = result.isel(XGMbunchId=slice(0, npulses))                
                return result
            except:
                print('Could not load {}_{} data. '.format(xgmStr, saseStr) +
                  'Interleaved mode and no sase-dedicated data is not yet supported.')

    #take the derivative along the trainId to track changes in pulse number:
    diff = npulses_sa3.diff(dim='trainId')
    #only keep trainIds where a change occured:
    diff = diff.where(diff != 0, drop=True)
    #get a list of indices where a change occured:
    idx_change_sa3 = np.argwhere(np.isin(npulses_sa3.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]

    #Same for SASE 1:
    diff = npulses_sa1.diff(dim='trainId')
    diff = diff.where(diff !=0, drop=True)
    idx_change_sa1 = np.argwhere(np.isin(npulses_sa1.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]

    #create index that locates all changes of pulse number in both SASE1 and 3:
    #add index 0 to get the initial pulse number per train:
    idx_change = np.unique(np.concatenate(([0], idx_change_sa3, idx_change_sa1))).astype(int)
    if sase=='sase1':
        npulses = npulses_sa1
        maxpulses = int(npulses_sa1.max().values)
    else:
        npulses = npulses_sa3
        maxpulses = int(npulses_sa3.max().values)
    for i,k in enumerate(idx_change):    
        #skip if no pulses after the change:
        if npulses[idx_change[i]]==0:
            continue
        #calculate indices
        if sa3First:
            a = 0
            b = int(npulses_sa3[k].values)
            c = b
            d = int(c + npulses_sa1[k].values)
        else:
            a = int(npulses_sa1[k].values)
            b = int(a + npulses_sa3[k].values)
            c = 0
            d = a
        if sase=='sase1':
            a = c
            b = d
        if i==len(idx_change)-1:
            l = None
        else:
            l = idx_change[i+1]
        temp = data[xgm][k:l,a:a+maxpulses].copy()
        temp[:,b:] = np.NaN
        if result is None:
            result = temp
        else:
            result = xr.concat([result, temp], dim='trainId')
    return result

def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
    ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
        for each train in the run. Supports fresh bunch, dedicated trains
        and pulse on demand modes.
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase for which the contribution is computed: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'}
            
        Output:
            1D DataArray equal to sum(sase)/sum(sase1+sase3)

    '''
    xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm)
    xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm)
    #Fill missing train ids with 0
    r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['XGMbunchId'])
    xgm_sa1 = r[0].fillna(0)
    xgm_sa3 = r[1].fillna(0)

    contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1))
    if sase=='sase1':
        return contrib
    else:
        return 1 - contrib

def calibrateXGMs(data, rollingWindow=200, plot=False):
    ''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM 
        (read in intensityTD property) to the respective slow ion signal 
        (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
        If the sase-resolved signal (introduced in May 2019) are recorded, the
        calibration is defined as the mean ratio between the photocurrent and
        the low-pass slowTrain signal. Otherwise, calibrateXGMsFromAllPulses()
        is called.
            rollingWindow: length of running average to calculate E_fast_avg
            plot: boolean, plot the calibration output

            factors: numpy ndarray of shape 1 x 2 containing 
                     [XTD10 calibration factor, SCS calibration factor]
    XTD10_factor = np.nan
    SCS_factor = np.nan
    if "XTD10_slowTrain" in data or "SCS_slowTrain" in data:
        if "XTD10_slowTrain" in data:
            XTD10_factor = np.mean(data.XTD10_photonFlux/data.XTD10_slowTrain)
            
        else:
            print('no XTD10 XGM data. Skipping calibration for XTD10 XGM')
        if "SCS_slowTrain" in data:
            #XTD10_SA3_contrib = data.XTD10_slowTrain_SA3 * data.npulses_sase3 / (
            #                data.XTD10_slowTrain * (data.npulses_sase3+data.npulses_sase1))
            #SCS_SA3_SLOW = data.SCS_photonFlux*(data.npulses_sase3+
            #                                    data.npulses_sase1)*XTD10_SA3_contrib/data.npulses_sase3
            #SCS_factor = np.mean(SCS_SA3_SLOW/data.SCS_slowTrain_SA3)
            SCS_factor = np.mean(data.SCS_photonFlux/data.SCS_slowTrain)
        else:
            print('no SCS XGM data. Skipping calibration for SCS XGM')
            
        #TODO: plot the results of calibration
        return np.array([XTD10_factor, SCS_factor])
    else:
        return calibrateXGMsFromAllPulses(data, rollingWindow, plot)
        
        
def calibrateXGMsFromAllPulses(data, rollingWindow=200, plot=False):
    ''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM 
        (read in intensityTD property) to the respective slow ion signal 
        (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
        One has to take into account the possible signal created by SASE1 pulses. In the
        tunnel, this signal is usually large enough to be read by the XGM and the relative
        contribution C of SASE3 pulses to the overall signal is computed.
        In the tunnel, the calibration F is defined as:
            F = E_slow / E_fast_avg, where
        E_fast_avg is the rolling average (with window rollingWindow) of the fast signal.
        In SCS XGM, the signal from SASE1 is usually in the noise, so we calculate the 
        average over the pulse-resolved signal of SASE3 pulses only and calibrate it to the 
        slow signal modulated by the SASE3 contribution:
            F = (N1+N3) * E_avg * C/(N3 * E_fast_avg_sase3), where N1 and N3 are the number 
        of pulses in SASE1 and SASE3, E_fast_avg_sase3 is the rolling average (with window
        rollingWindow) of the SASE3-only fast signal.
        Inputs:
            data: xarray Dataset
            rollingWindow: length of running average to calculate E_fast_avg
            plot: boolean, plot the calibration output
        Output:
            factors: numpy ndarray of shape 1 x 2 containing 
                     [XTD10 calibration factor, SCS calibration factor]
    '''
    XTD10_factor = np.nan
    SCS_factor = np.nan
    noSCS = noXTD10 = False
    if 'SCS_XGM' not in data:
        print('no SCS XGM data. Skipping calibration for SCS XGM')
        noSCS = True
    if 'XTD10_XGM' not in data:
        print('no XTD10 XGM data. Skipping calibration for XTD10 XGM')
        noXTD10 = True
    if noSCS and noXTD10:
        return np.array([XTD10_factor, SCS_factor])
    if not noSCS and noXTD10:
        print('XTD10 data is needed to calibrate SCS XGM.')
        return np.array([XTD10_factor, SCS_factor])
    start = 0
    stop = None
    npulses = data['npulses_sase3']
    ntrains = npulses.shape[0]
    # First, in case of change in number of pulses, locate a region where
    # the number of pulses is maximum.
    if not np.all(npulses == npulses[0]):
        print('Warning: Number of pulses per train changed during the run!')
        start = np.argmax(npulses.values)
        stop = ntrains + np.argmax(npulses.values[::-1]) - 1
        if stop - start < rollingWindow:
            print('not enough consecutive data points with the largest number of pulses per train')
        start += rollingWindow
        stop = np.min((ntrains, stop+rollingWindow))

    # Calculate SASE3 slow data
    sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM')
    SA3_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3']
    SA1_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*(1-sa3contrib)/data['npulses_sase1']
    # Calibrate XTD10 XGM with all signal from SASE1 and SASE3
    if not noXTD10:
        xgm_avg = selectSASEinXGM(data, 'sase3', 'XTD10_XGM').mean(axis=1)
        rolling_sa3_xgm = xgm_avg.rolling(trainId=rollingWindow).mean()
        ratio = SA3_SLOW/rolling_sa3_xgm
        XTD10_factor = ratio[start:stop].mean().values
        print('calibration factor XTD10 XGM: %f'%XTD10_factor)

    # Calibrate SCS XGM with SASE3-only contribution
    if not noSCS:
        SCS_SLOW = data['SCS_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3'] 
        scs_sase3_fast = selectSASEinXGM(data, 'sase3', 'SCS_XGM').mean(axis=1)
        meanFast = scs_sase3_fast.rolling(trainId=rollingWindow).mean()
        ratio = SCS_SLOW/meanFast
        SCS_factor = ratio[start:stop].median().values
        print('calibration factor SCS XGM: %f'%SCS_factor)
        if noSCS ^ noXTD10:
            plt.figure(figsize=(8,4))
        else:
            plt.figure(figsize=(8,8))
        plt.subplot(211)
        plt.title('E[uJ] = %.2f x IntensityTD' %(XTD10_factor))
        plt.plot(SA3_SLOW, label='SA3 slow', color='C1')
                 label='SA3 fast signal rolling avg', color='C4')
        plt.plot(xgm_avg*XTD10_factor, label='SA3 fast signal train avg', alpha=0.2, color='C4')
        plt.ylabel('Energy [uJ]')
        plt.xlabel('train in run')
        plt.legend(loc='upper left', fontsize=10)
        plt.twinx()
        plt.plot(SA1_SLOW, label='SA1 slow', alpha=0.2, color='C2')
        plt.ylabel('SA1 slow signal [uJ]')
        plt.legend(loc='lower right', fontsize=10)

        plt.subplot(212)
        plt.title('E[uJ] = %.2g x HAMP' %SCS_factor)
        plt.plot(SCS_SLOW, label='SCS slow', color='C1')
        plt.plot(meanFast*SCS_factor, label='SCS HAMP rolling avg', color='C2')
        plt.ylabel('Energy [uJ]')
        plt.xlabel('train in run')
        plt.plot(scs_sase3_fast*SCS_factor, label='SCS HAMP train avg', alpha=0.2, color='C2')
        plt.legend(loc='upper left', fontsize=10)
        plt.tight_layout()

    return np.array([XTD10_factor, SCS_factor])
def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None, npulses=None):
    ''' Computes peak integration from raw MCP traces.
    
        Inputs:
            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            mcp: MCP channel number
            t_offset: index separation between two pulses. Needed if bunch
                pattern info is not available. If None, checks the pulse
                pattern and determine the t_offset assuming mininum pulse 
                separation of 220 ns and digitizer resolution of 2 GHz.
            npulses: number of pulses. If None, takes the maximum number of
                pulses according to the bunch patter (field 'npulses_sase3')
            results: DataArray with dims trainId x max(sase3 pulses) 
            
    '''
    keyraw = 'MCP{}raw'.format(mcp)
    if keyraw not in data:
        raise ValueError("Source not found: {}!".format(keyraw))
    if npulses is None:
        npulses = int(data['npulses_sase3'].max().values)
    if t_offset is None:
        sa3 = data['sase3'].where(data['sase3']>1)
        if npulses > 1:
            #Calculate the number of pulses between two lasing pulses (step)
            step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
            step = int(step[1] - step[0])
            #multiply by elementary samples length (220 ns @ 2 GHz = 440)
    results = xr.DataArray(np.empty((data.trainId.shape[0], npulses)), coords=data[keyraw].coords,
                           dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
    for i in range(npulses):
        a = intstart + t_offset*i
        b = intstop + t_offset*i
        bkga = bkgstart + t_offset*i
        bkgb = bkgstop + t_offset*i
        bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a))
        results[:,i] = np.trapz(data[keyraw][:,a:b] - bg, axis=1)
    return results

def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
              bkgstart=None, bkgstop=None, t_offset=None, npulses=None, 
              stride=1):
    ''' Extract peak-integrated data from TIM where pulses are from SASE3 only.
        If use_apd is False it calculates integration from raw traces. 
        The missing values, in case of change of number of pulses, are filled
        with NaNs.
        If no bunch pattern info is available, the function assumes that
        SASE 3 comes first and that the number of pulses is fixed in both
        SASE 1 and 3.
            data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            mcp: MCP channel number
            npulses: int, optional. Number of pulses to compute. Required if
                no bunch pattern info is available.
            stride: int, optional. Used to select pulses in the APD array if
                no bunch pattern info is available.
        Output:
            tim: DataArray of shape trainId only for SASE3 pulses x N 
                 with N=max(number of pulses per train)
    '''
        print('Missing bunch pattern info!\n')
        if npulses is None:
            raise TypeError('npulses argument is required when bunch pattern ' +
                             'info is missing.')
        print('Retrieving {} SASE 3 pulses assuming that '.format(npulses) +
               'SASE 3 pulses come first.')
        if use_apd:
            tim = data['MCP{}apd'.format(mcp)][:,:npulses:stride]
        else:
            tim = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, 
                       t_offset=t_offset, npulses=npulses)
        return tim
    
    sa3 = data['sase3'].where(data['sase3']>1, drop=True)
    npulses_sa3 = data['npulses_sase3']
    maxpulses = int(npulses_sa3.max().values)
    if npulses is not None:
        maxpulses = np.min([npulses, maxpulses])
        #Calculate the number of non-lasing pulses between two lasing pulses (step)
        step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
        step = int(step[1] - step[0])
        apd = data['MCP{}apd'.format(mcp)]
        initialDelay = data.attrs['run'].get_array(
                        'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.initialDelay.value')[0].values
        upperLimit = data.attrs['run'].get_array(
                        'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.upperLimit.value')[0].values
        nsamples = upperLimit - initialDelay
        npulses_per_apd = int(nsamples/440)
        sa3 /= npulses_per_apd
        apd = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, 
                       t_offset=t_offset, npulses=npulses)
        sa3 /= step
    sa3 -= sa3[:,0]
    sa3 = sa3.astype(int)
    if np.all(npulses_sa3 == npulses_sa3[0]):
    stride = 1
    if use_apd:
        stride = np.max([stride,int(step/npulses_per_apd)])
    diff = npulses_sa3.diff(dim='trainId')
    #only keep trainIds where a change occured:
    diff = diff.where(diff != 0, drop=True)
    #get a list of indices where a change occured:
    idx_change = np.argwhere(np.isin(npulses_sa3.trainId.values,
                                     diff.trainId.values, assume_unique=True))[:,0]
    #add index 0 to get the initial pulse number per train:
    idx_change = np.insert(idx_change, 0, 0)
    tim = None
    for i,idx in enumerate(idx_change):
        if npulses_sa3[idx]==0:
            continue
        if i==len(idx_change)-1:
            l = None
        else:
            l = idx_change[i+1]
        b = npulses_sa3[idx].values
        temp = apd[idx:l,:maxpulses*stride:stride].copy()
        temp[:,b:] = np.NaN
        if tim is None:
            tim = temp
        else:
            tim = xr.concat([tim, temp], dim='trainId')
    return tim
def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None,
                 intstop=None, bkgstart=None, bkgstop=None, t_offset=None, npulses_apd=None):
    ''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM
        (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
        The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to
        match the SASE3-only average TIM pulse peak per train (TIM_avg) to the slow XGM 
        signal E_slow.
        Since E_slow is the average energy per pulse over all SASE1 and SASE3 
        pulses (N1 and N3), we first extract the relative contribution C of the SASE3 pulses
        by looking at the pulse-resolved signals of the SA3_XGM in the tunnel.
        There, the signal of SASE1 is usually strong enough to be above noise level.
        Let TIM_avg be the average of the TIM pulses (SASE3 only).
        The calibration factor is then defined as: F = E_slow * C * (N1+N3) / ( N3 * TIM_avg ).
        If N3 changes during the run, we locate the indices for which N3 is maximum and define
        a window where to apply calibration (indices start/stop).
        
        Warning: the calibration does not include the transmission by the KB mirrors!
        
        Inputs:
            data: xarray Dataset
            rollingWindow: length of running average to calculate TIM_avg
            plot: boolean. If True, plot calibration results.
            use_apd: boolean. If False, the TIM pulse peaks are extract from raw traces using
                     getTIMapd
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            npulses_apd: number of pulses
            
        Output:
            F: float, TIM calibration factor.
        
    '''
    start = 0
    stop = None
    npulses = data['npulses_sase3']
    ntrains = npulses.shape[0]
    if not np.all(npulses == npulses[0]):
        start = np.argmax(npulses.values)
        stop = ntrains + np.argmax(npulses.values[::-1]) - 1
        if stop - start < rollingWindow:
            print('not enough consecutive data points with the largest number of pulses per train')
        start += rollingWindow
        stop = np.min((ntrains, stop+rollingWindow))
    filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd)
    sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM')
    avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean()
    ratio = ((data['npulses_sase3']+data['npulses_sase1']) *
             data['SCS_photonFlux'] * sa3contrib) / (avgFast*data['npulses_sase3'])
    F = float(ratio[start:stop].median().values)

    if plot:
        fig = plt.figure(figsize=(8,5))
        ax = plt.subplot(211)
        ax.set_title('E[uJ] = {:2e} x TIM (MCP{})'.format(F, mcp))
        ax.plot(data['SCS_photonflux'], label='SCS XGM slow (all SASE)', color='C0')
        slow_avg_sase3 = data['SCS_photonflux']*(data['npulses_sase1']
                                                    +data['npulses_sase3'])*sa3contrib/data['npulses_sase3']
        ax.plot(slow_avg_sase3, label='SCS XGM slow (SASE3 only)', color='C1')
        ax.plot(avgFast*F, label='Calibrated TIM rolling avg', color='C2')
        ax.legend(loc='upper left', fontsize=8)
        ax.set_ylabel('Energy [$\mu$J]', size=10)
        ax.plot(filteredTIM.mean(axis=1)*F, label='Calibrated TIM train avg', alpha=0.2, color='C9')
        ax.legend(loc='best', fontsize=8, ncol=2)
        plt.xlabel('train in run')
        
        ax = plt.subplot(234)
        xgm_fast = selectSASEinXGM(data)
        ax.scatter(filteredTIM, xgm_fast, s=5, alpha=0.1, rasterized=True)
        fit, cov = np.polyfit(filteredTIM.values.flatten(),xgm_fast.values.flatten(),1, cov=True)
        y=np.poly1d(fit)
        x=np.linspace(filteredTIM.min(), filteredTIM.max(), 10)
        ax.plot(x, y(x), lw=2, color='r')
        ax.set_ylabel('Raw HAMP [$\mu$J]', size=10)
        ax.set_xlabel('TIM (MCP{}) signal'.format(mcp), size=10)
        ax.annotate(s='y(x) = F x + A\n'+
                    'F = %.3e\n$\Delta$F/F = %.2e\n'%(fit[0],np.abs(np.sqrt(cov[0,0])/fit[0]))+
                    'A = %.3e'%fit[1],
                    xy=(0.5,0.6), xycoords='axes fraction', fontsize=10, color='r')
        print('TIM calibration factor: %e'%(F))
        
        ax = plt.subplot(235)
        ax.hist(filteredTIM.values.flatten()*F, bins=50, rwidth=0.8)
        ax.set_ylabel('number of pulses', size=10)
        ax.set_xlabel('Pulse energy MCP{} [uJ]'.format(mcp), size=10)
        ax.set_yscale('log')
        
        ax = plt.subplot(236)
        if not use_apd:
            pulseStart = intstart
            pulseStop = intstop
        else:
            pulseStart = data.attrs['run'].get_array(
                'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStart.value')[0].values
            pulseStop = data.attrs['run'].get_array(
                'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStop.value')[0].values
            
        if 'MCP{}raw'.format(mcp) not in data:
            tid, data = data.attrs['run'].train_from_index(0)
            trace = data['SCS_UTC1_ADQ/ADC/1:network']['digitizers.channel_1_D.raw.samples']
            print('no raw data for MCP{}. Loading trace from MCP1'.format(mcp))
            label_trace='MCP1 Voltage [V]'
        else:
            trace = data['MCP{}raw'.format(mcp)][0]
            label_trace='MCP{} Voltage [V]'.format(mcp)
        ax.plot(trace[:pulseStop+25], 'o-', ms=2, label='trace')
        ax.axvspan(pulseStart, pulseStop, color='C2', alpha=0.2, label='APD region')
        ax.axvline(pulseStart, color='gray', ls='--')
        ax.axvline(pulseStop, color='gray', ls='--')
        ax.set_xlim(pulseStart - 25, pulseStop + 25)
        ax.set_ylabel(label_trace, size=10)
        ax.set_xlabel('sample #', size=10)
        ax.legend(fontsize=8)
        plt.tight_layout()

    return F
''' TIM calibration table
    Dict with key= photon energy and value= array of polynomial coefficients for each MCP (1,2,3).
    The polynomials correspond to a fit of the logarithm of the calibration factor as a function
    of MCP voltage. If P is a polynomial and V the MCP voltage, the calibration factor (in microjoule
    per APD signal) is given by -exp(P(V)).
    This table was generated from the calibration of March 2019, proposal 900074, semester 201930, 
    runs 69 - 111 (Ni edge):  https://in.xfel.eu/elog/SCS+Beamline/2323
    runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334
    runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349
'''
tim_calibration_table = {
    705.5: np.array([
        [-6.85344690e-12,  5.00931986e-08, -1.27206912e-04, 1.15596821e-01, -3.15215367e+01],
        [ 1.25613942e-11, -5.41566381e-08,  8.28161004e-05, -7.27230153e-02,  3.10984925e+01],
        [ 1.14094964e-12,  7.72658935e-09, -4.27504907e-05, 4.07253378e-02, -7.00773062e+00]]),
    779: np.array([
        [ 4.57610777e-12, -2.33282497e-08,  4.65978738e-05, -6.43305156e-02,  3.73958623e+01],
        [ 2.96325102e-11, -1.61393276e-07,  3.32600044e-04, -3.28468195e-01,  1.28328844e+02],
        [ 1.14521506e-11, -5.81980336e-08,  1.12518434e-04, -1.19072484e-01,  5.37601559e+01]]),
    851: np.array([
        [ 3.15774215e-11, -1.71452934e-07,  3.50316512e-04, -3.40098861e-01,  1.31064501e+02],
        [5.36341958e-11, -2.92533156e-07,  6.00574534e-04, -5.71083140e-01,  2.10547161e+02],
        [ 3.69445588e-11, -1.97731342e-07,  3.98203522e-04, -3.78338599e-01,  1.41894119e+02]])
}

def timFactorFromTable(voltage, photonEnergy, mcp=1):
    ''' Returns an energy calibration factor for TIM integrated peak signal (APD)
        according to calibration from March 2019, proposal 900074, semester 201930, 
        runs 69 - 111 (Ni edge):  https://in.xfel.eu/elog/SCS+Beamline/2323
        runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334
        runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349
        Uses the tim_calibration_table declared above.
        
        Inputs:
            voltage: MCP voltage in volts.
            photonEnergy: FEL photon energy in eV. Calibration factor is linearly
                interpolated between the known values from the calibration table. 
            mcp: MCP channel (1, 2, or 3).
            
        Output:
            f: calibration factor in microjoule per APD signal
    '''
    energies = np.sort([key for key in tim_calibration_table])
    if photonEnergy not in energies:
        if photonEnergy > energies.max():
            photonEnergy = energies.max()
        elif photonEnergy < energies.min():
            photonEnergy = energies.min()
        else:
            idx = np.searchsorted(energies, photonEnergy) - 1
            polyA = np.poly1d(tim_calibration_table[energies[idx]][mcp-1])
            polyB = np.poly1d(tim_calibration_table[energies[idx+1]][mcp-1])
            fA = -np.exp(polyA(voltage))
            fB = -np.exp(polyB(voltage))
            f = fA + (fB-fA)/(energies[idx+1]-energies[idx])*(photonEnergy - energies[idx])
            return f
    poly = np.poly1d(tim_calibration_table[photonEnergy][mcp-1])
    f = -np.exp(poly(voltage))
    return f


def checkTimApdWindow(data, mcp=1, use_apd=True, intstart=None, intstop=None):
    ''' Plot the first and last pulses in MCP trace together with 
        the window of integration to check if the pulse integration
        is properly calculated. If the number of pulses changed during
        the run, it selects a train where the number of pulses was 
        maximum.
        
        Inputs:
            data: xarray Dataset
            mcp: MCP channel (1, 2, 3 or 4)
            use_apd: if True, gets the APD parameters from the digitizer
                device. If False, uses intstart and intstop as boundaries
                and uses the bunch pattern to determine the separation
                between two pulses.
            intstart: trace index of integration start of the first pulse
            intstop: trace index of integration stop of the first pulse
            
        Output:
            Plot    
    '''
    mcpToChannel={1:'D', 2:'C', 3:'B', 4:'A'}
    apdChannels={1:3, 2:2, 3:1, 4:0}
    npulses_max = data['npulses_sase3'].max().values
    tid = data['npulses_sase3'].where(data['npulses_sase3'] == npulses_max,
    if 'MCP{}raw'.format(mcp) not in data:
        print('no raw data for MCP{}. Loading average trace from MCP{}'.format(mcp, mcp))
        trace = data.attrs['run'].get_array(
                'SCS_UTC1_ADQ/ADC/1:network',
                'digitizers.channel_1_{}.raw.samples'.format(mcpToChannel[mcp])
                ).sel({'trainId':tid}).mean(dim='trainId')
        trace = data['MCP{}raw'.format(mcp)].sel({'trainId':tid}).mean(dim='trainId')
    if use_apd:
        pulseStart = data.attrs['run'].get_array(
            'SCS_UTC1_ADQ/ADC/1', 
            'board1.apd.channel_{}.pulseStart.value'.format(apdChannels[mcp]))[0].values
        pulseStop = data.attrs['run'].get_array(
            'SCS_UTC1_ADQ/ADC/1', 
            'board1.apd.channel_{}.pulseStop.value'.format(apdChannels[mcp]))[0].values
        initialDelay = data.attrs['run'].get_array(
            'SCS_UTC1_ADQ/ADC/1', 
            'board1.apd.channel_{}.initialDelay.value'.format(apdChannels[mcp]))[0].values
        upperLimit = data.attrs['run'].get_array(
            'SCS_UTC1_ADQ/ADC/1', 
            'board1.apd.channel_{}.upperLimit.value'.format(apdChannels[mcp]))[0].values
    else:
        pulseStart = intstart
        pulseStop = intstop
    if npulses_max > 1:
        sa3 = data['sase3'].where(data['sase3']>1)
        step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
        step = int(step[1] - step[0])
        nsamples = 440 * step
    else:
        nsamples = 0

    fig, ax = plt.subplots(figsize=(5,3))
    ax.plot(trace[:pulseStop+25], color='C1', label='first pulse')
    ax.axvspan(pulseStart, pulseStop, color='k', alpha=0.1, label='APD region')
    ax.axvline(pulseStart, color='gray', ls='--')
    ax.axvline(pulseStop, color='gray', ls='--')
    ax.set_xlim(pulseStart-25, pulseStop+25)
    ax.locator_params(axis='x', nbins=4)
    ax.set_ylabel('MCP{} Voltage [V]'.format(mcp))
    ax.set_xlabel('First pulse sample #')
    if npulses_max > 1:
        pulseStart = pulseStart + nsamples*(npulses_max-1)
        pulseStop = pulseStop + nsamples*(npulses_max-1)
        ax2 = ax.twiny()
        ax2.plot(range(pulseStart-25,pulseStop+25), trace[pulseStart-25:pulseStop+25],
                color='C4', label='last pulse')
        ax2.locator_params(axis='x', nbins=4)
        ax2.set_xlabel('Last pulse sample #')
        lines, labels = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines + lines2, labels + labels2, loc=0)
    else:
        ax.legend(loc='lower left')
    plt.tight_layout()
    
def matchXgmTimPulseId(data, use_apd=True, intstart=None, intstop=None,
                       bkgstart=None, bkgstop=None, t_offset=None, 
                       npulses=None, sase3First=True, stride=1):
    ''' Function to match XGM pulse Id with TIM pulse Id.
        Inputs:
            data: xarray Dataset containing XGM and TIM data
            use_apd: bool. If True, uses the digitizer APD ('MCP[1,2,3,4]apd').
                     If False, peak integration is performed from raw traces.
                     All following parameters are needed in this case.
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            t_offset: index separation between two pulses
            npulses: number of pulses to compute. Required if no bunch
                pattern info is available
            sase3First: bool, needed if bunch pattern is missing.
            stride: int, used to select pulses in the TIM APD array if
                no bunch pattern info is available.
            xr DataSet containing XGM and TIM signals with the share d
            dimension 'sa3_pId'. Raw traces, raw XGM and raw APD are dropped.
    
    dropList = []
    mergeList = []
    ndata = cleanXGMdata(data, npulses, sase3First)
    for mcp in range(1,5):
        if 'MCP{}apd'.format(mcp) in data or 'MCP{}raw'.format(mcp) in data:
            MCPapd = getTIMapd(ndata, mcp=mcp, use_apd=use_apd, intstart=intstart,
                               intstop=intstop,bkgstart=bkgstart, bkgstop=bkgstop,
                               t_offset=t_offset, npulses=npulses,
                               stride=stride).rename('MCP{}apd'.format(mcp))
                MCPapd = MCPapd.rename({'apdId':'sa3_pId'})
                MCPapd = MCPapd.rename({'MCP{}fromRaw'.format(mcp):'sa3_pId'})
            mergeList.append(MCPapd)
                dropList.append('MCP{}raw'.format(mcp))
            if 'MCP{}apd'.format(mcp) in data:
                dropList.append('MCP{}apd'.format(mcp))
    subset = xr.merge(mergeList, join='inner')
    for k in ndata.attrs.keys():
        subset.attrs[k] = ndata.attrs[k]
# Fast ADC
def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, period=None, npulses=None):
    ''' Computes peak integration from raw FastADC traces.
    
        Inputs:
            data: xarray Dataset containing FastADC raw traces (e.g. 'FastADC1raw')
            channel: FastADC channel number
            intstart: trace index of integration start
            intstop: trace index of integration stop
            bkgstart: trace index of background start
            bkgstop: trace index of background stop
            period: number of samples between two pulses. Needed if bunch
                pattern info is not available. If None, checks the pulse
                pattern and determine the period assuming a resolution of
                9.23 ns per sample which leads to 24 samples between
                two bunches @ 4.5 MHz. 
            npulses: number of pulses. If None, takes the maximum number of
                pulses according to the bunch patter (field 'npulses_sase3')
            
        Output:
            results: DataArray with dims trainId x max(sase3 pulses) 
            
    '''
    keyraw = 'FastADC{}raw'.format(channel)
    if keyraw not in data:
        raise ValueError("Source not found: {}!".format(keyraw))
    if npulses is None:
        npulses = int(data['npulses_sase3'].max().values)
    if period is None:
        sa3 = data['sase3'].where(data['sase3']>1)
        if npulses > 1:
            #Calculate the number of pulses between two lasing pulses (step)
            step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
            step = int(step[1] - step[0])
            #multiply by elementary pulse length (221.5 ns / 9.23 ns = 24 samples)