Skip to content
Snippets Groups Projects
xgm.py 49.5 KiB
Newer Older
# -*- coding: utf-8 -*-
""" Toolbox for SCS.

    Various utilities function to quickly process data measured at the SCS instruments.

    Copyright (2019) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from scipy.signal import find_peaks
def cleanXGMdata(data, npulses=None, sase3First=True):
    ''' Cleans the XGM data arrays obtained from load() function.
        The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0
        when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses.
        For DAQ runs after April 2019, sase-resolved arrays can be used. For older runs,
        the function selectSASEinXGM is used to extract sase-resolved pulses.
        Inputs:
            data: xarray Dataset containing XGM TD arrays.
            npulses: number of pulses, needed if pulse pattern not available.
            sase3First: bool, needed if pulse pattern not available.
        
        Output:
            xarray Dataset containing sase- and pulse-resolved XGM data, with
                dimension names 'sa1_pId' and 'sa3_pId'                
    '''
    dropList = []
    mergeList = []
Laurent Mercadier's avatar
Laurent Mercadier committed
    if 'sase3' in data:
        if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0,
                                              drop=True) == 0):
            print('Dedicated trains, skip loading SASE 1 data.')
    else:
        print('Missing bunch pattern info!')
        if npulses is None:
            raise TypeError('npulses argument is required when bunch pattern ' +
                             'info is missing.')
    #pulse-resolved signals from XGMs
Laurent Mercadier's avatar
Laurent Mercadier committed
    keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1", 
            "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma",
            "SCS_XGM", "SCS_SA3", "SCS_SA1",
            "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
    
    for whichXgm in ['SCS', 'XTD10']:
        load_sa1 = True
        if (f"{whichXgm}_SA3" not in data and f"{whichXgm}_XGM" in data):
            #no SASE-resolved arrays available
            if not 'sase3' in data:
                npulses_xgm = data[f'{whichXgm}_XGM'].where(data[f'{whichXgm}_XGM'], drop=True).shape[1]
                npulses_sa1 = npulses_xgm - npulses
                if npulses_sa1==0:
                    load_sa1 = False
                if npulses_sa1 < 0:
                    raise ValueError(f'npulses = {npulses} is larger than the total number'
                                     +f' of pulses per train = {npulses_xgm}')
            sa3 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase3', npulses=npulses,
                   sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3")
            mergeList.append(sa3)
            if f"{whichXgm}_XGM_sigma" in data:
                sa3_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase3', npulses=npulses,
                       sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3_sigma")
                mergeList.append(sa3_sigma)
                dropList.append(f'{whichXgm}_XGM_sigma')
            if load_sa1:
                sa1 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase1',
                                      npulses=npulses_sa1, sase3First=sase3First).rename(
                                      {'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1")
                mergeList.append(sa1)
                if f"{whichXgm}_XGM_sigma" in data:
                    sa1_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase1', npulses=npulses_sa1,
                           sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1_sigma")
                    mergeList.append(sa1_sigma)
            dropList.append(f'{whichXgm}_XGM')
            keys.remove(f'{whichXgm}_XGM')
        
    for key in keys:
        if key not in data:
            continue
        if "sa3" in key.lower():
            sase = 'sa3_'
        elif "sa1" in key.lower():
            sase = 'sa1_'
Laurent Mercadier's avatar
Laurent Mercadier committed
                dropList.append(key)
                continue
        else:
            dropList.append(key)
            continue
        res = data[key].where(data[key] != 1.0, drop=True).rename(
                {'XGMbunchId':'{}pId'.format(sase)}).rename(key)
        res = res.assign_coords(
                {f'{sase}pId':np.arange(res[f'{sase}pId'].shape[0])})
        
        dropList.append(key)
        mergeList.append(res)
    mergeList.append(data.drop(dropList))
Laurent Mercadier's avatar
Laurent Mercadier committed
    subset = xr.merge(mergeList, join='inner')
    for k in data.attrs.keys():
        subset.attrs[k] = data.attrs[k]
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None):
    ''' Given an array containing both SASE1 and SASE3 data, extracts SASE1-
        or SASE3-only XGM data. The function tracks the changes of bunch patterns
        in sase 1 and sase 3 and applies a mask to the XGM array to extract the 
        relevant pulses. This way, all complicated patterns are accounted for.
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase to select: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'XTD10_XGM[_sigma]', 'SCS_XGM[_sigma]'}
            sase3First: bool, optional. Used in case no bunch pattern was recorded
            npulses: int, optional. Required in case no bunch pattern was recorded.
            
        Output:
            DataArray that has all trainIds that contain a lasing
            train in sase, with dimension equal to the maximum number of pulses of 
            that sase in the run. The missing values, in case of change of number of pulses,
            are filled with NaNs.
    '''
    #1. case where bunch pattern is missing:
        print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4])
              +'SASE {} pulses come first.'.format('3' if sase3First else '1'))
        #in older version of DAQ, non-data numbers were filled with 0.0.
        xgmData = data[xgm].where(data[xgm]!=0.0, drop=True)
        xgmData = xgmData.fillna(0.0).where(xgmData!=1.0, drop=True)
        if (sase3First and sase=='sase3') or (not sase3First and sase=='sase1'):
            return xgmData[:,:npulses].assign_coords(XGMbunchId=np.arange(npulses))
        else:
            if xr.ufuncs.isnan(xgmData).any():
                raise Exception('The number of pulses changed during the run. '
                      'This is not supported yet.')
            else:
                start=xgmData.shape[1]-npulses
                return xgmData[:,start:start+npulses].assign_coords(XGMbunchId=np.arange(npulses))
    
    #2. case where bunch pattern is provided
    #2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses
    xgm_arr = data[xgm].where(data[xgm] != 1., drop=True)
    sa3 = data['sase3'].where(data['sase3'] > 1, drop=True)
    sa3_val=np.unique(sa3)
    sa3_val = sa3_val[~np.isnan(sa3_val)]
    sa1 = data['sase1'].where(data['sase1'] > 1, drop=True)
    sa1_val=np.unique(sa1)
    sa1_val = sa1_val[~np.isnan(sa1_val)]
    sa_all = xr.concat([sa1, sa3], dim='bunchId').rename('sa_all')
    sa_all = xr.DataArray(np.sort(sa_all)[:,:xgm_arr['XGMbunchId'].shape[0]],
                          dims=['trainId', 'bunchId'],
                          coords={'trainId':data.trainId},
                          name='sase_all')
Laurent Mercadier's avatar
Laurent Mercadier committed
    if sase=='sase3':
        idxListSase = np.unique(sa3)
        newName = xgm.split('_')[0] + '_SA3'
    else:
        idxListSase = np.unique(sa1)
        newName = xgm.split('_')[0] + '_SA1'
        
    #2.2 track the changes of pulse patterns and the indices at which they occured (invAll)
    idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
    
    #2.3 define a mask, loop over the different patterns and update the mask for each pattern
    mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), 
                    dims=['trainId', 'XGMbunchId'],
                    coords={'trainId':data.trainId, 
                            'XGMbunchId':sa_all['bunchId'].values}, 
                    name='mask')

    big_sase = []
    for i,idxXGM in enumerate(idxListAll):
        mask.values = np.zeros(mask.shape, dtype=bool)
        idxXGM = np.isin(idxXGM, idxListSase)
        idxTid = invAll==i
        mask[idxTid, idxXGM] = True
        sa_arr = xgm_arr.where(mask, drop=True)
        if sa_arr.trainId.size > 0:
            sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size))
            big_sase.append(sa_arr)
    if len(big_sase) > 0:
        da_sase = xr.concat(big_sase, dim='trainId').rename(newName)
    else:
        da_sase = xr.DataArray([], dims=['trainId'], name=newName)
    return da_sase
def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
    ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses 
        for each train in the run. Supports fresh bunch, dedicated trains
        and pulse on demand modes.
        
        Inputs:
            data: xarray Dataset containing xgm data
            sase: key of sase for which the contribution is computed: {'sase1', 'sase3'}
            xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'}
Loading
Loading full blame...