diff --git a/setup.py b/setup.py index 6d8a692f7ed61795a4aabb19139fc194dde9a14a..5bd5110aef41384518e56f2673a0a2ac28a2b927 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup(name='toolbox_scs', packages=find_packages('src'), package_data={}, install_requires=[ - 'xarray', 'numpy', 'matplotlib', + 'xarray>=0.13.0', 'numpy', 'matplotlib', 'pandas', 'scipy', 'h5py', 'extra_data', 'euxfel_bunch_pattern', ], diff --git a/src/toolbox_scs/detectors/__init__.py b/src/toolbox_scs/detectors/__init__.py index e0f177e30f233c246ccadef14ec67b9342187ff9..22a9bc6c8ded4e1758c279fce50ea7cb983234f7 100644 --- a/src/toolbox_scs/detectors/__init__.py +++ b/src/toolbox_scs/detectors/__init__.py @@ -1,11 +1,13 @@ -from .detectors import ( - DSSC, process_one_module, - ) +from .xgm import ( + load_xgm, cleanXGMdata, matchXgmTimPulseId) +from .tim import ( + load_TIM,) __all__ = ( # Functions - "process_one_module", + "load_xgm", + "load_TIM", + "matchXgmTimPulseId", # Classes - "DSSC", # Variables ) diff --git a/src/toolbox_scs/detectors/dssc_process.py b/src/toolbox_scs/detectors/dssc_process.py index 4c0ef474d9d843a3b0ae3274b95241e73f7c88a3..4f3c56a5712740dc55e9f3741bc83e8491a31e0a 100644 --- a/src/toolbox_scs/detectors/dssc_process.py +++ b/src/toolbox_scs/detectors/dssc_process.py @@ -22,42 +22,6 @@ import pandas as pd import extra_data as ed -def load_xgm(run, print_info=False): - '''Returns the XGM data from loaded karabo_data.DataCollection''' - nbunches = run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', 'sase3.nPulses.value') - nbunches = np.unique(nbunches) - if len(nbunches) == 1: - nbunches = nbunches[0] - else: - warnings.warn('not all trains have same length DSSC data') - print('nbunches: ', nbunches) - nbunches = max(nbunches) - if print_info: - print('SASE3 bunches per train:', nbunches) - - xgm = run.get_array('SCS_BLU_XGM/XGM/DOOCS:output', 'data.intensitySa3TD', - roi=kd.by_index[:nbunches], extra_dims=['pulse']) - return xgm - - -def load_TIM(run, apd='MCP2apd'): - ''' - Load TIM traces and match them to SASE3 pulses. "run" is a karabo_data.RunDirectory instance. - Uses SCS ToolBox. - ''' - import ToolBox as tb - - fields = ["sase1", "sase3", "npulses_sase3", "npulses_sase1", apd, "SCS_SA3", "nrj"] - timdata = xr.Dataset() - for f in fields: - m = tb.mnemonics[f] - timdata[f] = run.get_array(m['source'], m['key'], extra_dims=m['dim']) - - timdata.attrs['run'] = run - timdata = tb.matchXgmTimPulseId(timdata) - return timdata.rename({'sa3_pId': 'pulse'})[apd] - - def prepare_module_empty(scan_variable, framepattern): '''Create empty (zero-valued) DataArray for a single DSSC module to iteratively add data to''' len_scan = len(np.unique(scan_variable)) diff --git a/src/toolbox_scs/detectors/tim.py b/src/toolbox_scs/detectors/tim.py new file mode 100644 index 0000000000000000000000000000000000000000..71aa54019b0019510bf1e7ff6ccba26f5ce5abab --- /dev/null +++ b/src/toolbox_scs/detectors/tim.py @@ -0,0 +1,54 @@ +""" Tim related sub-routines + + Copyright (2019) SCS Team. + contributions preferrably comply with pep8 code structure + guidelines. +""" + +import logging + +import xarray as xr + +from .xgm import matchXgmTimPulseId +from ..constants import mnemonics as _mnemonics_tim + + + +log = logging.getLogger(__name__) + + +def load_TIM(run, apd='MCP2apd'): + """ + Load TIM traces and match them to SASE3 pulses. + + Parameters + ---------- + run: extra_data.DataCollection, extra_data.RunDirectory + + Returns + ------- + timdata : xarray.DataArray + xarray DataArray containing the tim data + + Example + ------- + >>> import toolbox_scs as tb + >>> import toolbox_scs.detectors as tbdet + >>> run = tb.run_by_proposal(2212, 235) + >>> data = tbdet.load_TIM(run) + """ + + fields = ["sase1", "sase3", "npulses_sase3", + "npulses_sase1", apd, "SCS_SA3", "nrj"] + timdata = xr.Dataset() + + for f in fields: + m = _mnemonics_tim[f] + timdata[f] = run.get_array(m['source'], + m['key'], + extra_dims=m['dim']) + + timdata.attrs['run'] = run + timdata = matchXgmTimPulseId(timdata) + + return timdata.rename({'sa3_pId': 'pulse'})[apd] \ No newline at end of file diff --git a/src/toolbox_scs/detectors/xgm.py b/src/toolbox_scs/detectors/xgm.py new file mode 100644 index 0000000000000000000000000000000000000000..7d71e0fcb10c2303d54447d1457cc0e858d6049d --- /dev/null +++ b/src/toolbox_scs/detectors/xgm.py @@ -0,0 +1,1090 @@ +""" XGM related sub-routines + + Copyright (2019) SCS Team. + + (contributions preferrably comply with pep8 code structure + guidelines.) +""" + +import logging + +import numpy as np +import xarray as xr +import matplotlib.pyplot as plt +from scipy.signal import find_peaks +import extra_data as ed + + +log = logging.getLogger(__name__) + + +def load_xgm(run): + """ + Loads XGM data from karabo_data.DataCollection + + Parameters + ---------- + run: karabo_data.DataCollection + + Returns + ------- + xgm : xarray.DataArray + xarray DataArray containing the xgm data + + Example + ------- + >>> import toolbox_scs as tb + >>> import toolbox_scs.detectors as tbdet + >>> run = tb.run_by_proposal(2212, 235) + >>> xgm_data = tbdet.load_xgm(run) + """ + '''''' + nbunches = run.get_array('SCS_RR_UTC/MDL/BUNCH_DECODER', + 'sase3.nPulses.value') + nbunches = np.unique(nbunches) + if len(nbunches) == 1: + nbunches = nbunches[0] + else: + log.warning(f'not all trains have same length DSSC data ' + f'nbunches: {nbunches}') + nbunches = max(nbunches) + + log.info(f'SASE3 bunches per train: {nbunches}') + xgm = run.get_array('SCS_BLU_XGM/XGM/DOOCS:output', + 'data.intensitySa3TD', + roi=ed.by_index[:nbunches], + extra_dims=['pulse']) + return xgm + + + + +def cleanXGMdata(data, npulses=None, sase3First=True): + ''' Cleans the XGM data arrays obtained from load() function. + The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0 + when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses. + For DAQ runs after April 2019, sase-resolved arrays can be used. For older runs, + the function selectSASEinXGM is used to extract sase-resolved pulses. + Inputs: + data: xarray Dataset containing XGM TD arrays. + npulses: number of pulses, needed if pulse pattern not available. + sase3First: bool, needed if pulse pattern not available. + + Output: + xarray Dataset containing sase- and pulse-resolved XGM data, with + dimension names 'sa1_pId' and 'sa3_pId' + ''' + dropList = [] + mergeList = [] + load_sa1 = True + if 'sase3' in data: + if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0, + drop=True) == 0): + print('Dedicated trains, skip loading SASE 1 data.') + load_sa1 = False + npulses_sa1 = None + else: + print('Missing bunch pattern info!') + if npulses is None: + raise TypeError('npulses argument is required when bunch pattern ' + + 'info is missing.') + #pulse-resolved signals from XGMs + keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1", + "XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma", + "SCS_XGM", "SCS_SA3", "SCS_SA1", + "SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"] + + for whichXgm in ['SCS', 'XTD10']: + load_sa1 = True + if (f"{whichXgm}_SA3" not in data and f"{whichXgm}_XGM" in data): + #no SASE-resolved arrays available + if not 'sase3' in data: + npulses_xgm = data[f'{whichXgm}_XGM'].where(data[f'{whichXgm}_XGM'], drop=True).shape[1] + npulses_sa1 = npulses_xgm - npulses + if npulses_sa1==0: + load_sa1 = False + if npulses_sa1 < 0: + raise ValueError(f'npulses = {npulses} is larger than the total number' + +f' of pulses per train = {npulses_xgm}') + sa3 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase3', npulses=npulses, + sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3") + mergeList.append(sa3) + if f"{whichXgm}_XGM_sigma" in data: + sa3_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase3', npulses=npulses, + sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3_sigma") + mergeList.append(sa3_sigma) + dropList.append(f'{whichXgm}_XGM_sigma') + if load_sa1: + sa1 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase1', + npulses=npulses_sa1, sase3First=sase3First).rename( + {'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1") + mergeList.append(sa1) + if f"{whichXgm}_XGM_sigma" in data: + sa1_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase1', npulses=npulses_sa1, + sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1_sigma") + mergeList.append(sa1_sigma) + dropList.append(f'{whichXgm}_XGM') + keys.remove(f'{whichXgm}_XGM') + + for key in keys: + if key not in data: + continue + if "sa3" in key.lower(): + sase = 'sa3_' + elif "sa1" in key.lower(): + sase = 'sa1_' + if not load_sa1: + dropList.append(key) + continue + else: + dropList.append(key) + continue + res = data[key].where(data[key] != 1.0, drop=True).rename( + {'XGMbunchId':'{}pId'.format(sase)}).rename(key) + res = res.assign_coords( + {f'{sase}pId':np.arange(res[f'{sase}pId'].shape[0])}) + + dropList.append(key) + mergeList.append(res) + mergeList.append(data.drop(dropList)) + subset = xr.merge(mergeList, join='inner') + for k in data.attrs.keys(): + subset.attrs[k] = data.attrs[k] + return subset + + +def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None): + ''' Given an array containing both SASE1 and SASE3 data, extracts SASE1- + or SASE3-only XGM data. The function tracks the changes of bunch patterns + in sase 1 and sase 3 and applies a mask to the XGM array to extract the + relevant pulses. This way, all complicated patterns are accounted for. + + Inputs: + data: xarray Dataset containing xgm data + sase: key of sase to select: {'sase1', 'sase3'} + xgm: key of xgm to select: {'XTD10_XGM[_sigma]', 'SCS_XGM[_sigma]'} + sase3First: bool, optional. Used in case no bunch pattern was recorded + npulses: int, optional. Required in case no bunch pattern was recorded. + + Output: + DataArray that has all trainIds that contain a lasing + train in sase, with dimension equal to the maximum number of pulses of + that sase in the run. The missing values, in case of change of number of pulses, + are filled with NaNs. + ''' + #1. case where bunch pattern is missing: + if sase not in data: + print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4]) + +'SASE {} pulses come first.'.format('3' if sase3First else '1')) + #in older version of DAQ, non-data numbers were filled with 0.0. + xgmData = data[xgm].where(data[xgm]!=0.0, drop=True) + xgmData = xgmData.fillna(0.0).where(xgmData!=1.0, drop=True) + if (sase3First and sase=='sase3') or (not sase3First and sase=='sase1'): + return xgmData[:,:npulses].assign_coords(XGMbunchId=np.arange(npulses)) + else: + if xr.ufuncs.isnan(xgmData).any(): + raise Exception('The number of pulses changed during the run. ' + 'This is not supported yet.') + else: + start=xgmData.shape[1]-npulses + return xgmData[:,start:start+npulses].assign_coords(XGMbunchId=np.arange(npulses)) + + #2. case where bunch pattern is provided + #2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses + xgm_arr = data[xgm].where(data[xgm] != 1., drop=True) + sa3 = data['sase3'].where(data['sase3'] > 1, drop=True) + sa3_val=np.unique(sa3) + sa3_val = sa3_val[~np.isnan(sa3_val)] + sa1 = data['sase1'].where(data['sase1'] > 1, drop=True) + sa1_val=np.unique(sa1) + sa1_val = sa1_val[~np.isnan(sa1_val)] + sa_all = xr.concat([sa1, sa3], dim='bunchId').rename('sa_all') + sa_all = xr.DataArray(np.sort(sa_all)[:,:xgm_arr['XGMbunchId'].shape[0]], + dims=['trainId', 'bunchId'], + coords={'trainId':data.trainId}, + name='sase_all') + if sase=='sase3': + idxListSase = np.unique(sa3) + newName = xgm.split('_')[0] + '_SA3' + else: + idxListSase = np.unique(sa1) + newName = xgm.split('_')[0] + '_SA1' + + #2.2 track the changes of pulse patterns and the indices at which they occured (invAll) + idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True) + + #2.3 define a mask, loop over the different patterns and update the mask for each pattern + mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), + dims=['trainId', 'XGMbunchId'], + coords={'trainId':data.trainId, + 'XGMbunchId':sa_all['bunchId'].values}, + name='mask') + + big_sase = [] + for i,idxXGM in enumerate(idxListAll): + mask.values = np.zeros(mask.shape, dtype=bool) + idxXGM = np.isin(idxXGM, idxListSase) + idxTid = invAll==i + mask[idxTid, idxXGM] = True + sa_arr = xgm_arr.where(mask, drop=True) + if sa_arr.trainId.size > 0: + sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size)) + big_sase.append(sa_arr) + if len(big_sase) > 0: + da_sase = xr.concat(big_sase, dim='trainId').rename(newName) + else: + da_sase = xr.DataArray([], dims=['trainId'], name=newName) + return da_sase + +def saseContribution(data, sase='sase1', xgm='XTD10_XGM'): + ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses + for each train in the run. Supports fresh bunch, dedicated trains + and pulse on demand modes. + + Inputs: + data: xarray Dataset containing xgm data + sase: key of sase for which the contribution is computed: {'sase1', 'sase3'} + xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'} + + Output: + 1D DataArray equal to sum(sase)/sum(sase1+sase3) + + ''' + xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm) + xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm) + #Fill missing train ids with 0 + r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['XGMbunchId']) + xgm_sa1 = r[0].fillna(0) + xgm_sa3 = r[1].fillna(0) + + contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1)) + if sase=='sase1': + return contrib + else: + return 1 - contrib + +def calibrateXGMs(data, allPulses=False, plot=False, display=False): + ''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM + (read in intensityTD property) to the respective slow ion signal + (photocurrent read by Keithley, key 'pulseEnergy.photonFlux.value'). + If the sase-resolved averaged signals ("slowTrain", introduced in May + 2019) are recorded, the calibration is defined as the mean ratio + between the photonFlux and the slowTrain signal. Otherwise, the + averaged fast signals are computed using a rolling average. + + Inputs: + data: xarray Dataset + allPulses: if True, uses "XTD10_XGM" and "SCS_XGM" arrays and + computes the relative contributions of SASE1 and SASE3 to + photonFlux. This should be more accurate in cases where the + number of SASE1 pulses is large and/or the SASE1 pulse + intensity cannot be neglected. + plot: bool, plot the calibration output + display: bool, displays info if True + + Output: + ndarray with [XTD10 calibration factor, SCS calibration factor] + ''' + if allPulses: + return calibrateXGMsFromAllPulses(data, plot) + hasSlowTrain=[True,True] + results = np.array([np.nan, np.nan], dtype=float) + slowTrainData = [] + for i,whichXgm in enumerate(['XTD10', 'SCS']): + #1. Try to load fast data averages (in DAQ since May 2019) + if f'{whichXgm}_slowTrain' in data: + if display: + print(f'Using fast data averages (slowTrain) for {whichXgm}') + slowTrainData.append(data[f'{whichXgm}_slowTrain']) + else: + mnemo = tb.mnemonics[f'{whichXgm}_slowTrain'] + if mnemo['key'] in data.attrs['run'].keys_for_source(mnemo['source']): + if display: + print(f'Using fast data averages (slowTrain) for {whichXgm}') + slowTrainData.append(data.attrs['run'].get_array(mnemo['source'], mnemo['key'])) + #2. Calculate fast data average from fast data + else: + if display: + print(f'No averages of fast data (slowTrain) available for {whichXgm}.'+ + ' Attempting calibration from fast data.') + if f'{whichXgm}_SA3' in data: + if display: + print(f'Calculating slowTrain from SA3 for {whichXgm}') + slowTrainData.append(data[f'{whichXgm}_SA3'].rolling(trainId=200 + ).mean().mean(axis=1)) + elif f'{whichXgm}_XGM' in data: + sa3 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM') + slowTrainData.append(sa3.rolling(trainId=200).mean().mean(axis=1)) + else: + hasSlowTrain[i]=False + if hasSlowTrain[i]: + results[i] = np.mean(data[f'{whichXgm}_photonFlux']/slowTrainData[i]) + if display: + print(f'Calibration factor {whichXgm} XGM: {results[i]}') + if plot: + plt.figure(figsize=(8,4)) + plt.subplot(211) + plt.plot(data['XTD10_photonFlux'], label='XTD10 photon flux') + plt.plot(slowTrainData[0]*results[0], label='calibrated XTD10 fast signal') + plt.ylabel(r'Energy [$\mu$J]') + plt.legend(fontsize=8, loc='upper left') + plt.twinx() + plt.plot(slowTrainData[0], label='uncalibrated XTD10 fast signal', color='C4') + plt.ylabel(r'Uncalibrated energy') + plt.legend(fontsize=8, loc='upper right') + plt.subplot(212) + plt.plot(data['SCS_photonFlux'], label='SCS photon flux') + plt.plot(slowTrainData[1]*results[1], label='calibrated SCS fast signal') + plt.ylabel(r'Energy [$\mu$J]') + plt.xlabel('train Id') + plt.legend(fontsize=8, loc='upper left') + plt.twinx() + plt.plot(slowTrainData[1], label='uncalibrated SCS fast signal', color='C4') + plt.ylabel(r'Uncalibrated energy') + plt.legend(fontsize=8, loc='upper right') + plt.tight_layout() + return results + +def calibrateXGMsFromAllPulses(data, plot=False): + ''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM + (read in intensityTD property) to the respective slow ion signal + (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value'). + One has to take into account the possible signal created by SASE1 pulses. In the + tunnel, this signal is usually large enough to be read by the XGM and the relative + contribution C of SASE3 pulses to the overall signal is computed. + In the tunnel, the calibration F is defined as: + F = E_slow / E_fast_avg, where + E_fast_avg is the rolling average (with window rollingWindow) of the fast signal. + In SCS XGM, the signal from SASE1 is usually in the noise, so we calculate the + average over the pulse-resolved signal of SASE3 pulses only and calibrate it to the + slow signal modulated by the SASE3 contribution: + F = (N1+N3) * E_avg * C/(N3 * E_fast_avg_sase3), where N1 and N3 are the number + of pulses in SASE1 and SASE3, E_fast_avg_sase3 is the rolling average (with window + rollingWindow) of the SASE3-only fast signal. + + Inputs: + data: xarray Dataset + rollingWindow: length of running average to calculate E_fast_avg + plot: boolean, plot the calibration output + + Output: + factors: numpy ndarray of shape 1 x 2 containing + [XTD10 calibration factor, SCS calibration factor] + ''' + XTD10_factor = np.nan + SCS_factor = np.nan + noSCS = noXTD10 = False + if 'SCS_XGM' not in data: + print('no SCS XGM data. Skipping calibration for SCS XGM') + noSCS = True + if 'XTD10_XGM' not in data: + print('no XTD10 XGM data. Skipping calibration for XTD10 XGM') + noXTD10 = True + if noSCS and noXTD10: + return np.array([XTD10_factor, SCS_factor]) + if not noSCS and noXTD10: + print('XTD10 data is needed to calibrate SCS XGM.') + return np.array([XTD10_factor, SCS_factor]) + start = 0 + stop = None + npulses = data['npulses_sase3'] + ntrains = npulses.shape[0] + rollingWindow=200 + # First, in case of change in number of pulses, locate a region where + # the number of pulses is maximum. + if not np.all(npulses == npulses[0]): + print('Warning: Number of pulses per train changed during the run!') + start = np.argmax(npulses.values) + stop = ntrains + np.argmax(npulses.values[::-1]) - 1 + if stop - start < rollingWindow: + print('not enough consecutive data points with the largest number of pulses per train') + start += rollingWindow + stop = np.min((ntrains, stop+rollingWindow)) + + # Calculate SASE3 slow data + sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM') + SA3_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3'] + SA1_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*(1-sa3contrib)/data['npulses_sase1'] + + # Calibrate XTD10 XGM with all signal from SASE1 and SASE3 + if not noXTD10: + xgm_avg = selectSASEinXGM(data, 'sase3', 'XTD10_XGM').mean(axis=1) + rolling_sa3_xgm = xgm_avg.rolling(trainId=rollingWindow).mean() + ratio = SA3_SLOW/rolling_sa3_xgm + XTD10_factor = ratio[start:stop].mean().values + print('calibration factor XTD10 XGM: %f'%XTD10_factor) + + # Calibrate SCS XGM with SASE3-only contribution + if not noSCS: + SCS_SLOW = data['SCS_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3'] + scs_sase3_fast = selectSASEinXGM(data, 'sase3', 'SCS_XGM').mean(axis=1) + meanFast = scs_sase3_fast.rolling(trainId=rollingWindow).mean() + ratio = SCS_SLOW/meanFast + SCS_factor = ratio[start:stop].median().values + print('calibration factor SCS XGM: %f'%SCS_factor) + + if plot: + if noSCS ^ noXTD10: + plt.figure(figsize=(8,4)) + else: + plt.figure(figsize=(8,8)) + plt.subplot(211) + plt.title('E[uJ] = %.2f x IntensityTD' %(XTD10_factor)) + plt.plot(SA3_SLOW, label='SA3 slow', color='C1') + plt.plot(rolling_sa3_xgm*XTD10_factor, + label='SA3 fast signal rolling avg', color='C4') + plt.plot(xgm_avg*XTD10_factor, label='SA3 fast signal train avg', alpha=0.2, color='C4') + plt.ylabel('Energy [uJ]') + plt.xlabel('train in run') + plt.legend(loc='upper left', fontsize=10) + plt.twinx() + plt.plot(SA1_SLOW, label='SA1 slow', alpha=0.2, color='C2') + plt.ylabel('SA1 slow signal [uJ]') + plt.legend(loc='lower right', fontsize=10) + + plt.subplot(212) + plt.title('E[uJ] = %.2g x HAMP' %SCS_factor) + plt.plot(SCS_SLOW, label='SCS slow', color='C1') + plt.plot(meanFast*SCS_factor, label='SCS HAMP rolling avg', color='C2') + plt.ylabel('Energy [uJ]') + plt.xlabel('train in run') + plt.plot(scs_sase3_fast*SCS_factor, label='SCS HAMP train avg', alpha=0.2, color='C2') + plt.legend(loc='upper left', fontsize=10) + plt.tight_layout() + + return np.array([XTD10_factor, SCS_factor]) + + +# TIM +def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None, npulses=None): + ''' Computes peak integration from raw MCP traces. + + Inputs: + data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw') + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + mcp: MCP channel number + t_offset: index separation between two pulses. Needed if bunch + pattern info is not available. If None, checks the pulse + pattern and determine the t_offset assuming mininum pulse + separation of 220 ns and digitizer resolution of 2 GHz. + npulses: number of pulses. If None, takes the maximum number of + pulses according to the bunch pattern (field 'npulses_sase3') + + Output: + results: DataArray with dims trainId x max(sase3 pulses) + + ''' + keyraw = 'MCP{}raw'.format(mcp) + if keyraw not in data: + raise ValueError("Source not found: {}!".format(keyraw)) + if npulses is None: + npulses = int(data['npulses_sase3'].max().values) + if t_offset is None: + sa3 = data['sase3'].where(data['sase3']>1) + if npulses > 1: + #Calculate the number of pulses between two lasing pulses (step) + step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values + step = int(step[1] - step[0]) + #multiply by elementary samples length (220 ns @ 2 GHz = 440) + t_offset = 440 * step + else: + t_offset = 1 + results = xr.DataArray(np.zeros((data.trainId.shape[0], npulses)), coords=data[keyraw].coords, + dims=['trainId', 'MCP{}fromRaw'.format(mcp)]) + for i in range(npulses): + a = intstart + t_offset*i + b = intstop + t_offset*i + bkga = bkgstart + t_offset*i + bkgb = bkgstop + t_offset*i + if b > data.dims['samplesId']: + break + bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a)) + results[:,i] = np.trapz(data[keyraw][:,a:b] - bg, axis=1) + return results + + +def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None, + bkgstart=None, bkgstop=None, t_offset=None, npulses=None, + stride=1): + ''' Extract peak-integrated data from TIM where pulses are from SASE3 only. + If use_apd is False it calculates integration from raw traces. + The missing values, in case of change of number of pulses, are filled + with NaNs. If no bunch pattern info is available, the function assumes + that SASE 3 comes first and that the number of pulses is fixed in both + SASE 1 and 3. + + Inputs: + data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw') + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + t_offset: number of ADC samples between two pulses + mcp: MCP channel number + npulses: int, optional. Number of pulses to compute. Required if + no bunch pattern info is available. + stride: int, optional. Used to select pulses in the APD array if + no bunch pattern info is available. + Output: + tim: DataArray of shape trainId only for SASE3 pulses x N + with N=max(number of pulses per train) + ''' + #1. case where no bunch pattern is available: + if 'sase3' not in data: + print('Missing bunch pattern info!\n') + if npulses is None: + raise TypeError('npulses argument is required when bunch pattern ' + + 'info is missing.') + print('Retrieving {} SASE 3 pulses assuming that '.format(npulses) + + 'SASE 3 pulses come first.') + if use_apd: + tim = data[f'MCP{mcp}apd'][:,:npulses:stride] + else: + tim = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, + t_offset=t_offset, npulses=npulses) + return tim + + #2. If bunch pattern available, define a mask that corresponds to the SASE 3 pulses + sa3 = data['sase3'].where(data['sase3']>1, drop=True) + sa3 -= sa3[0,0] + #2.1 case where apd is used: + if use_apd: + pulseId = 'apdId' + pulseIdDim = data.dims['apdId'] + initialDelay = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.initialDelay.value')[0].values + upperLimit = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.upperLimit.value')[0].values + #440 = samples between two pulses @4.5 MHz with ADQ412 digitizer: + period = int((upperLimit - initialDelay)/440) + #display some warnings if apd parameters do not match pulse pattern: + period_from_bunch_pattern = int(np.nanmin(np.diff(sa3))) + if period > period_from_bunch_pattern: + print(f'Warning: apd parameter was set to record 1 pulse out of {period} @ 4.5 MHz ' + + f'but XFEL delivered 1 pulse out of {period_from_bunch_pattern}.') + maxPulses = data['npulses_sase3'].max().values + if period*pulseIdDim < period_from_bunch_pattern*(maxPulses-1): + print(f'Warning: Number of pulses and/or rep. rate in apd parameters were set ' + + f'too low ({pulseIdDim})to record the {maxPulses} SASE 3 pulses') + peaks = data[f'MCP{mcp}apd'] + + #2.2 case where integration is performed on raw trace: + else: + pulseId = f'MCP{mcp}fromRaw' + pulseIdDim = int(np.max(sa3).values) + 1 + period = int(np.nanmin(np.diff(sa3))) + peaks = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, t_offset=period*440, + npulses=pulseIdDim) + + sa3 = sa3/period + #2.3 track the changes of pulse patterns and the indices at which they occured (invAll) + idxList, inv = np.unique(sa3, axis=0, return_inverse=True) + mask = xr.DataArray(np.zeros((data.dims['trainId'], pulseIdDim), dtype=bool), + dims=['trainId', pulseId], + coords={'trainId':data.trainId, + pulseId:np.arange(pulseIdDim)}) + mask = mask.sel(trainId=sa3.trainId) + for i,idxApd in enumerate(idxList): + idxApd = idxApd[idxApd>=0].astype(int) + idxTid = inv==i + mask[idxTid, idxApd] = True + + peaks = peaks.where(mask, drop=True) + peaks = peaks.assign_coords({pulseId:np.arange(peaks[pulseId].shape[0])}) + return peaks + + +def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None, + intstop=None, bkgstart=None, bkgstop=None, t_offset=None, npulses_apd=None): + ''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM + (photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value'). + The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to + match the SASE3-only average TIM pulse peak per train (TIM_avg) to the slow XGM + signal E_slow. + Since E_slow is the average energy per pulse over all SASE1 and SASE3 + pulses (N1 and N3), we first extract the relative contribution C of the SASE3 pulses + by looking at the pulse-resolved signals of the SA3_XGM in the tunnel. + There, the signal of SASE1 is usually strong enough to be above noise level. + Let TIM_avg be the average of the TIM pulses (SASE3 only). + The calibration factor is then defined as: F = E_slow * C * (N1+N3) / ( N3 * TIM_avg ). + If N3 changes during the run, we locate the indices for which N3 is maximum and define + a window where to apply calibration (indices start/stop). + + Warning: the calibration does not include the transmission by the KB mirrors! + + Inputs: + data: xarray Dataset + rollingWindow: length of running average to calculate TIM_avg + mcp: MCP channel + plot: boolean. If True, plot calibration results. + use_apd: boolean. If False, the TIM pulse peaks are extract from raw traces using + getTIMapd + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + t_offset: index separation between two pulses + npulses_apd: number of pulses + + Output: + F: float, TIM calibration factor. + + ''' + start = 0 + stop = None + npulses = data['npulses_sase3'] + ntrains = npulses.shape[0] + if not np.all(npulses == npulses[0]): + start = np.argmax(npulses.values) + stop = ntrains + np.argmax(npulses.values[::-1]) - 1 + if stop - start < rollingWindow: + print('not enough consecutive data points with the largest number of pulses per train') + start += rollingWindow + stop = np.min((ntrains, stop+rollingWindow)) + filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd) + sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM') + avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean() + ratio = ((data['npulses_sase3']+data['npulses_sase1']) * + data['SCS_photonFlux'] * sa3contrib) / (avgFast*data['npulses_sase3']) + F = float(ratio[start:stop].median().values) + + if plot: + fig = plt.figure(figsize=(8,5)) + ax = plt.subplot(211) + ax.set_title('E[uJ] = {:2e} x TIM (MCP{})'.format(F, mcp)) + ax.plot(data['SCS_photonFlux'], label='SCS XGM slow (all SASE)', color='C0') + slow_avg_sase3 = data['SCS_photonFlux']*(data['npulses_sase1'] + +data['npulses_sase3'])*sa3contrib/data['npulses_sase3'] + ax.plot(slow_avg_sase3, label='SCS XGM slow (SASE3 only)', color='C1') + ax.plot(avgFast*F, label='Calibrated TIM rolling avg', color='C2') + ax.legend(loc='upper left', fontsize=8) + ax.set_ylabel('Energy [$\mu$J]', size=10) + ax.plot(filteredTIM.mean(axis=1)*F, label='Calibrated TIM train avg', alpha=0.2, color='C9') + ax.legend(loc='best', fontsize=8, ncol=2) + plt.xlabel('train in run') + + ax = plt.subplot(234) + xgm_fast = selectSASEinXGM(data) + ax.scatter(filteredTIM, xgm_fast, s=5, alpha=0.1, rasterized=True) + fit, cov = np.polyfit(filteredTIM.values.flatten(),xgm_fast.values.flatten(),1, cov=True) + y=np.poly1d(fit) + x=np.linspace(filteredTIM.min(), filteredTIM.max(), 10) + ax.plot(x, y(x), lw=2, color='r') + ax.set_ylabel('Raw HAMP [$\mu$J]', size=10) + ax.set_xlabel('TIM (MCP{}) signal'.format(mcp), size=10) + ax.annotate(s='y(x) = F x + A\n'+ + 'F = %.3e\n$\Delta$F/F = %.2e\n'%(fit[0],np.abs(np.sqrt(cov[0,0])/fit[0]))+ + 'A = %.3e'%fit[1], + xy=(0.5,0.6), xycoords='axes fraction', fontsize=10, color='r') + print('TIM calibration factor: %e'%(F)) + + ax = plt.subplot(235) + ax.hist(filteredTIM.values.flatten()*F, bins=50, rwidth=0.8) + ax.set_ylabel('number of pulses', size=10) + ax.set_xlabel('Pulse energy MCP{} [uJ]'.format(mcp), size=10) + ax.set_yscale('log') + + ax = plt.subplot(236) + if not use_apd: + pulseStart = intstart + pulseStop = intstop + else: + pulseStart = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStart.value')[0].values + pulseStop = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStop.value')[0].values + + if 'MCP{}raw'.format(mcp) not in data: + tid, data = data.attrs['run'].train_from_index(0) + trace = data['SCS_UTC1_ADQ/ADC/1:network']['digitizers.channel_1_D.raw.samples'] + print('no raw data for MCP{}. Loading trace from MCP1'.format(mcp)) + label_trace='MCP1 Voltage [V]' + else: + trace = data['MCP{}raw'.format(mcp)][0] + label_trace='MCP{} Voltage [V]'.format(mcp) + ax.plot(trace[:pulseStop+25], 'o-', ms=2, label='trace') + ax.axvspan(pulseStart, pulseStop, color='C2', alpha=0.2, label='APD region') + ax.axvline(pulseStart, color='gray', ls='--') + ax.axvline(pulseStop, color='gray', ls='--') + ax.set_xlim(pulseStart - 25, pulseStop + 25) + ax.set_ylabel(label_trace, size=10) + ax.set_xlabel('sample #', size=10) + ax.legend(fontsize=8) + plt.tight_layout() + + return F + + +''' TIM calibration table + Dict with key= photon energy and value= array of polynomial coefficients for each MCP (1,2,3). + The polynomials correspond to a fit of the logarithm of the calibration factor as a function + of MCP voltage. If P is a polynomial and V the MCP voltage, the calibration factor (in microjoule + per APD signal) is given by -exp(P(V)). + This table was generated from the calibration of March 2019, proposal 900074, semester 201930, + runs 69 - 111 (Ni edge): https://in.xfel.eu/elog/SCS+Beamline/2323 + runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334 + runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349 +''' +tim_calibration_table = { + 705.5: np.array([ + [-6.85344690e-12, 5.00931986e-08, -1.27206912e-04, 1.15596821e-01, -3.15215367e+01], + [ 1.25613942e-11, -5.41566381e-08, 8.28161004e-05, -7.27230153e-02, 3.10984925e+01], + [ 1.14094964e-12, 7.72658935e-09, -4.27504907e-05, 4.07253378e-02, -7.00773062e+00]]), + 779: np.array([ + [ 4.57610777e-12, -2.33282497e-08, 4.65978738e-05, -6.43305156e-02, 3.73958623e+01], + [ 2.96325102e-11, -1.61393276e-07, 3.32600044e-04, -3.28468195e-01, 1.28328844e+02], + [ 1.14521506e-11, -5.81980336e-08, 1.12518434e-04, -1.19072484e-01, 5.37601559e+01]]), + 851: np.array([ + [ 3.15774215e-11, -1.71452934e-07, 3.50316512e-04, -3.40098861e-01, 1.31064501e+02], + [5.36341958e-11, -2.92533156e-07, 6.00574534e-04, -5.71083140e-01, 2.10547161e+02], + [ 3.69445588e-11, -1.97731342e-07, 3.98203522e-04, -3.78338599e-01, 1.41894119e+02]]) +} + +def timFactorFromTable(voltage, photonEnergy, mcp=1): + ''' Returns an energy calibration factor for TIM integrated peak signal (APD) + according to calibration from March 2019, proposal 900074, semester 201930, + runs 69 - 111 (Ni edge): https://in.xfel.eu/elog/SCS+Beamline/2323 + runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334 + runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349 + Uses the tim_calibration_table declared above. + + Inputs: + voltage: MCP voltage in volts. + photonEnergy: FEL photon energy in eV. Calibration factor is linearly + interpolated between the known values from the calibration table. + mcp: MCP channel (1, 2, or 3). + + Output: + f: calibration factor in microjoule per APD signal + ''' + energies = np.sort([key for key in tim_calibration_table]) + if photonEnergy not in energies: + if photonEnergy > energies.max(): + photonEnergy = energies.max() + elif photonEnergy < energies.min(): + photonEnergy = energies.min() + else: + idx = np.searchsorted(energies, photonEnergy) - 1 + polyA = np.poly1d(tim_calibration_table[energies[idx]][mcp-1]) + polyB = np.poly1d(tim_calibration_table[energies[idx+1]][mcp-1]) + fA = -np.exp(polyA(voltage)) + fB = -np.exp(polyB(voltage)) + f = fA + (fB-fA)/(energies[idx+1]-energies[idx])*(photonEnergy - energies[idx]) + return f + poly = np.poly1d(tim_calibration_table[photonEnergy][mcp-1]) + f = -np.exp(poly(voltage)) + return f + + +def checkTimApdWindow(data, mcp=1, use_apd=True, intstart=None, intstop=None): + ''' Plot the first and last pulses in MCP trace together with + the window of integration to check if the pulse integration + is properly calculated. If the number of pulses changed during + the run, it selects a train where the number of pulses was + maximum. + + Inputs: + data: xarray Dataset + mcp: MCP channel (1, 2, 3 or 4) + use_apd: if True, gets the APD parameters from the digitizer + device. If False, uses intstart and intstop as boundaries + and uses the bunch pattern to determine the separation + between two pulses. + intstart: trace index of integration start of the first pulse + intstop: trace index of integration stop of the first pulse + + Output: + Plot + ''' + mcpToChannel={1:'D', 2:'C', 3:'B', 4:'A'} + apdChannels={1:3, 2:2, 3:1, 4:0} + npulses_max = data['npulses_sase3'].max().values + tid = data['npulses_sase3'].where(data['npulses_sase3'] == npulses_max, + drop=True).trainId.values + if 'MCP{}raw'.format(mcp) not in data: + print('no raw data for MCP{}. Loading average trace from MCP{}'.format(mcp, mcp)) + trace = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1:network', + 'digitizers.channel_1_{}.raw.samples'.format(mcpToChannel[mcp]) + ).sel({'trainId':tid}).mean(dim='trainId') + else: + trace = data['MCP{}raw'.format(mcp)].sel({'trainId':tid}).mean(dim='trainId') + if use_apd: + pulseStart = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', + 'board1.apd.channel_{}.pulseStart.value'.format(apdChannels[mcp]))[0].values + pulseStop = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', + 'board1.apd.channel_{}.pulseStop.value'.format(apdChannels[mcp]))[0].values + initialDelay = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', + 'board1.apd.channel_{}.initialDelay.value'.format(apdChannels[mcp]))[0].values + upperLimit = data.attrs['run'].get_array( + 'SCS_UTC1_ADQ/ADC/1', + 'board1.apd.channel_{}.upperLimit.value'.format(apdChannels[mcp]))[0].values + else: + pulseStart = intstart + pulseStop = intstop + if npulses_max > 1: + sa3 = data['sase3'].where(data['sase3']>1) + step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values + step = int(step[1] - step[0]) + nsamples = 440 * step + else: + nsamples = 0 + + fig, ax = plt.subplots(figsize=(5,3)) + ax.plot(trace[:pulseStop+25], color='C1', label='first pulse') + ax.axvspan(pulseStart, pulseStop, color='k', alpha=0.1, label='APD region') + ax.axvline(pulseStart, color='gray', ls='--') + ax.axvline(pulseStop, color='gray', ls='--') + ax.set_xlim(pulseStart-25, pulseStop+25) + ax.locator_params(axis='x', nbins=4) + ax.set_ylabel('MCP{} Voltage [V]'.format(mcp)) + ax.set_xlabel('First pulse sample #') + if npulses_max > 1: + pulseStart = pulseStart + nsamples*(npulses_max-1) + pulseStop = pulseStop + nsamples*(npulses_max-1) + ax2 = ax.twiny() + ax2.plot(range(pulseStart-25,pulseStop+25), trace[pulseStart-25:pulseStop+25], + color='C4', label='last pulse') + ax2.locator_params(axis='x', nbins=4) + ax2.set_xlabel('Last pulse sample #') + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2, loc=0) + else: + ax.legend(loc='lower left') + plt.tight_layout() + +def matchXgmTimPulseId(data, use_apd=True, intstart=None, intstop=None, + bkgstart=None, bkgstop=None, t_offset=None, + npulses=None, sase3First=True, stride=1): + ''' Function to match XGM pulse Id with TIM pulse Id. + Inputs: + data: xarray Dataset containing XGM and TIM data + use_apd: bool. If True, uses the digitizer APD ('MCP[1,2,3,4]apd'). + If False, peak integration is performed from raw traces. + All following parameters are needed in this case. + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + t_offset: index separation between two pulses + npulses: number of pulses to compute. Required if no bunch + pattern info is available + sase3First: bool, needed if bunch pattern is missing. + stride: int, used to select pulses in the TIM APD array if + no bunch pattern info is available. + + Output: + xr DataSet containing XGM and TIM signals with the share d + dimension 'sa3_pId'. Raw traces, raw XGM and raw APD are dropped. + ''' + + dropList = [] + mergeList = [] + ndata = cleanXGMdata(data, npulses, sase3First) + for mcp in range(1,5): + if 'MCP{}apd'.format(mcp) in data or 'MCP{}raw'.format(mcp) in data: + MCPapd = getTIMapd(data, mcp=mcp, use_apd=use_apd, intstart=intstart, + intstop=intstop,bkgstart=bkgstart, bkgstop=bkgstop, + t_offset=t_offset, npulses=npulses, + stride=stride).rename('MCP{}apd'.format(mcp)) + if use_apd: + MCPapd = MCPapd.rename({'apdId':'sa3_pId'}) + else: + MCPapd = MCPapd.rename({'MCP{}fromRaw'.format(mcp):'sa3_pId'}) + mergeList.append(MCPapd) + if 'MCP{}raw'.format(mcp) in ndata: + dropList.append('MCP{}raw'.format(mcp)) + if 'MCP{}apd'.format(mcp) in data: + dropList.append('MCP{}apd'.format(mcp)) + mergeList.append(ndata.drop(dropList)) + subset = xr.merge(mergeList, join='inner') + for k in ndata.attrs.keys(): + subset.attrs[k] = ndata.attrs[k] + return subset + + +# Fast ADC +def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, + period=None, npulses=None, source='scs_ppl', + usePeakValue=False, peakType='pos'): + ''' Computes peak integration from raw FastADC traces. + + Inputs: + data: xarray Dataset containing FastADC raw traces (e.g. 'FastADC1raw') + channel: FastADC channel number + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + period: number of samples between two pulses. Needed if bunch + pattern info is not available. If None, checks the pulse + pattern and determine the period assuming a resolution of + 9.23 ns per sample which leads to 24 samples between + two bunches @ 4.5 MHz. + npulses: number of pulses. If None, takes the maximum number of + pulses according to the bunch patter (field 'npulses_sase3') + source: str, nature of the pulses, 'sase[1,2 or 3]', or 'scs_ppl', + or any name. Used to give name to the peak Id dimension. + usePeakValue: bool, if True takes the peak value of the signal, + otherwise integrates over integration region. + peakType: str, 'pos' or 'neg'. Used if usePeakValue is True to + indicate if min or max value should be extracted. + + + Output: + results: DataArray with dims trainId x max(sase3 pulses) + + ''' + keyraw = 'FastADC{}raw'.format(channel) + if keyraw not in data: + raise ValueError("Source not found: {}!".format(keyraw)) + if npulses is None or period is None: + indices, npulses_bp, mask = tb.extractBunchPattern(runDir=data.attrs['run'], + key=source) + if npulses is None: + npulses = int(npulses_bp.max().values) + if period is None: + indices = indices_bp.where(indices_bp>1) + if npulses > 1: + #Calculate the number of pulses between two lasing pulses (step) + step = indices.where(npulses_bp>1, drop=True)[0,:2].values + step = int(step[1] - step[0]) + #multiply by elementary pulse length (221.5 ns / 9.23 ns = 24 samples) + period = 24 * step + else: + period = 1 + pulseId = source + if source=='scs_ppl': + pulseId = 'ol_pId' + if 'sase' in source: + pulseId = f'sa{source[4]}_pId' + results = xr.DataArray(np.empty((data.trainId.shape[0], npulses)), coords=data[keyraw].coords, + dims=['trainId', pulseId]) + for i in range(npulses): + a = intstart + period*i + b = intstop + period*i + bkga = bkgstart + period*i + bkgb = bkgstop + period*i + bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a)) + if usePeakValue: + if peakType=='pos': + val = np.max(data[keyraw][:,a:b] - bg, axis=1) + if peakType=='neg': + val = np.min(data[keyraw][:,a:b] - bg, axis=1) + else: + val = np.trapz(data[keyraw][:,a:b] - bg, axis=1) + results[:,i] = val + return results + +def autoFindFastAdcPeaks(data, channel=5, window='large', usePeakValue=False, + source='scs_ppl', display=False, plot=False): + ''' Automatically finds peaks in channel of Fast ADC trace, a minimum width of 4 + samples. The find_peaks function and determination of the peak integration + region and baseline subtraction is optimized for typical photodiode signals + of the SCS instrument (ILH, FFT reflectometer, FFT diag stage). + Inputs: + data: xarray Dataset containing Fast ADC traces + key: data key of the array of traces + window: 'small' or 'large': defines the width of the integration region + centered on the peak. + usePeakValue: bool, if True takes the peak value of the signal, + otherwise integrates over integration region. + display: bool, displays info on the pulses found + plot: bool, plots regions of integration of the first pulse in the trace + Output: + peaks: DataArray of the integrated peaks + ''' + + key = f'FastADC{channel}raw' + if key not in data: + raise ValueError(f'{key} not found in data set') + #average over the 100 first traces to get at least one train with signal + trace = data[key].isel(trainId=slice(0,100)).mean(dim='trainId').values + if plot: + trace_plot = np.copy(trace) + #subtract baseline and check if peaks are positive or negative + bl = np.median(trace) + trace_no_bl = trace - bl + if np.max(trace_no_bl) >= np.abs(np.min(trace_no_bl)): + posNeg = 'positive' + else: + posNeg = 'negative' + trace_no_bl *= -1 + trace = bl + trace_no_bl + threshold = bl + np.max(trace_no_bl) / 2 + #find peaks + centers, peaks = find_peaks(trace, height=threshold, width=(4, None)) + c = centers[0] + w = np.average(peaks['widths']).astype(int) + period = np.median(np.diff(centers)).astype(int) + npulses = centers.shape[0] + if window not in ['small', 'large']: + raise ValueError(f"'window argument should be either 'small' or 'large', not {window}") + if window=='small': + intstart = int(c - w/4) + 1 + intstop = int(c + w/4) + 1 + if window=='large': + intstart = int(peaks['left_ips'][0]) + intstop = int(peaks['right_ips'][0]) + w + bkgstop = int(peaks['left_ips'][0])-5 + bkgstart = bkgstop - 10 + if display: + print(f'Found {npulses} {posNeg} pulses, avg. width={w}, period={period} samples, ' + + f'rep. rate={1e6/(9.230769*period):.3f} kHz') + fAdcPeaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop, + bkgstart=bkgstart, bkgstop=bkgstop, period=period, npulses=npulses, + source=source, usePeakValue=usePeakValue, peakType=posNeg[:3]) + if plot: + plt.figure() + plt.plot(trace_plot, 'o-', ms=3) + for i in range(npulses): + plt.axvline(intstart+i*period, ls='--', color='g') + plt.axvline(intstop+i*period, ls='--', color='r') + plt.axvline(bkgstart+i*period, ls='--', color='lightgrey') + plt.axvline(bkgstop+i*period, ls='--', color='grey') + plt.title(f'Fast ADC {channel} trace') + plt.xlim(bkgstart-10, intstop + 50) + return fAdcPeaks + +def mergeFastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, + period=None, npulses=None, dim='lasPulseId'): + ''' Calculates the peaks from Fast ADC raw traces with fastAdcPeaks() + and merges the results in Dataset. + Inputs: + data: xr Dataset with 'FastADC[channel]raw' traces + channel: Fast ADC channel + intstart: trace index of integration start + intstop: trace index of integration stop + bkgstart: trace index of background start + bkgstop: trace index of background stop + period: Number of ADC samples between two pulses. Needed + if bunch pattern info is not available. If None, checks the + pulse pattern and determine the period assuming a resolution + of 9.23 ns per sample = 24 samples between two pulses @ 4.5 MHz. + npulses: number of pulses. If None, takes the maximum number of + pulses according to the bunch patter (field 'npulses_sase3') + dim: name of the xr dataset dimension along the peaks + + ''' + peaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop, + bkgstart=bkgstart, bkgstop=bkgstop, period=period, + npulses=npulses) + + key = 'FastADC{}peaks'.format(channel) + if key in data: + s = data.drop(key) + else: + s = data + peaks = peaks.rename(key).rename({'peakId':dim}) + subset = xr.merge([s, peaks], join='inner') + for k in data.attrs.keys(): + subset.attrs[k] = data.attrs[k] + return subset + diff --git a/src/toolbox_scs/load.py b/src/toolbox_scs/load.py index 172917d99bcb86efec00c0511538833ce2e6a935..74f726d0608151baa98f056f0134ef4fe4ab4746 100644 --- a/src/toolbox_scs/load.py +++ b/src/toolbox_scs/load.py @@ -125,30 +125,6 @@ def load(fields, runNB, proposalNB, subFolder='raw', display=False, validate=Fal return result -def concatenateRuns(runs): - """ Sorts and concatenate a list of runs with identical data variables along the - trainId dimension. - - Input: - runs: (list) the xarray Datasets to concatenate - Output: - a concatenated xarray Dataset - """ - firstTid = {i: int(run.trainId[0].values) for i,run in enumerate(runs)} - orderedDict = dict(sorted(firstTid.items(), key=lambda t: t[1])) - orderedRuns = [runs[i] for i in orderedDict] - keys = orderedRuns[0].keys() - for run in orderedRuns[1:]: - if run.keys() != keys: - print('data fields between different runs are not identical. Cannot combine runs.') - return - - result = xr.concat(orderedRuns, dim='trainId') - for k in orderedRuns[0].attrs.keys(): - result.attrs[k] = [run.attrs[k] for run in orderedRuns] - return result - - def run_by_proposal(proposal, run): """ Get run in given proposal @@ -193,6 +169,30 @@ def run_by_path(path): return RunDirectory(path) +def concatenateRuns(runs): + """ Sorts and concatenate a list of runs with identical data variables along the + trainId dimension. + + Input: + runs: (list) the xarray Datasets to concatenate + Output: + a concatenated xarray Dataset + """ + firstTid = {i: int(run.trainId[0].values) for i,run in enumerate(runs)} + orderedDict = dict(sorted(firstTid.items(), key=lambda t: t[1])) + orderedRuns = [runs[i] for i in orderedDict] + keys = orderedRuns[0].keys() + for run in orderedRuns[1:]: + if run.keys() != keys: + print('data fields between different runs are not identical. Cannot combine runs.') + return + + result = xr.concat(orderedRuns, dim='trainId') + for k in orderedRuns[0].attrs.keys(): + result.attrs[k] = [run.attrs[k] for run in orderedRuns] + return result + + def load_scan_variable(run, mnemonic, stepsize=None): """ Loads the given scan variable and rounds scan positions to integer diff --git a/src/toolbox_scs/misc/__init__.py b/src/toolbox_scs/misc/__init__.py index 155856b1b69b726b0bb43ba2e31b7cf034437dd1..dad19fbd3fc3288e71aae1f3823b54e4663d622f 100644 --- a/src/toolbox_scs/misc/__init__.py +++ b/src/toolbox_scs/misc/__init__.py @@ -1,16 +1,12 @@ -from .bunch_pattern import ( - extractBunchPattern, pulsePatternInfo, repRate, sortBAMdata, - ) -from .bunch_pattern_external import ( - is_sase_3, is_sase_1, is_ppl, get_index_ppl, get_index_sase1, - get_index_sase3, - ) -from .laser_utils import ( - positionToDelay, degToRelPower, - ) -from . azimuthal_integrator import ( - AzimutalIntegrator, - ) +from .bunch_pattern import (extractBunchPattern, pulsePatternInfo, + repRate, sortBAMdata, + ) +from .bunch_pattern_external import (is_sase_3, is_sase_1, is_ppl, + get_index_ppl, get_index_sase1, get_index_sase3, + ) +from .laser_utils import positionToDelay, degToRelPower +from .azimuthal_integrator import AzimutalIntegrator + __all__ = ( # Functions diff --git a/src/toolbox_scs/test/test_detectors.py b/src/toolbox_scs/test/test_detectors.py new file mode 100644 index 0000000000000000000000000000000000000000..404b3fc611017faa3668d722f7ed68c27a4c4f57 --- /dev/null +++ b/src/toolbox_scs/test/test_detectors.py @@ -0,0 +1,110 @@ +import unittest +import logging +import os +import sys +import argparse + + +import toolbox_scs as tb +import toolbox_scs.detectors as tbdet +from toolbox_scs.util.exceptions import * + +logging.basicConfig(level=logging.DEBUG) +log_root = logging.getLogger(__name__) + +suites = {"packaging": ( + "test_init", + ), + "xgm": ( + "test_loadxgm", + "test_cleanxgm", + "test_matchxgmtim", + ), + "tim": ( + "test_loadtim", + ) + } + + +def list_suites(): + print("""\nPossible test suites:\n-------------------------""") + for key in suites: + print(key) + print("-------------------------\n") + + +class TestDetectors(unittest.TestCase): + def setUp(self): + self.run = tb.run_by_proposal(2212, 235) + + fields = ["sase1", "sase3", "npulses_sase3", + "npulses_sase1", "MCP2apd", "SCS_SA3", "nrj"] + self.tb_data = tb.load(fields, 235, 2212) + + log_root.info("Finished setup, start tests.") + + def tearDown(self): + pass + + def test_init(self): + self.assertEqual(tbdet.__name__, "toolbox_scs.detectors") + + def test_loadxgm(self): + xgm_data = tbdet.load_xgm(self.run) + self.assertTrue(xgm_data.values[0][-1]) + + def test_cleanxgm(self): + data = tbdet.cleanXGMdata(self.tb_data) + self.assertEqual(data['sa3_pId'].values[-1], 19) + + def test_matchxgmtim(self): + data = tbdet.matchXgmTimPulseId(self.tb_data) + self.assertEqual(data['npulses_sase3'].values[0], 20) + + def test_loadtim(self): + data = tbdet.load_TIM(self.run) + self.assertEqual(data.name, 'MCP2apd') + + +def suite(*tests): + suite = unittest.TestSuite() + for test in tests: + suite.addTest(TestDetectors(test)) + return suite + + +def main(*cliargs): + try: + for test_suite in cliargs: + if test_suite in suites: + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite(*suites[test_suite])) + else: + log_root.warning( + "Unknown suite: '{}'".format(test_suite)) + pass + except Exception as err: + log_root.error("Unecpected error: {}".format(err), + exc_info=True) + pass + + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--list-suites', + action='store_true', + help='list possible test suites') + parser.add_argument('--run-suites', metavar='S', + nargs='+', action='store', + help='a list of valid test suites') + args = parser.parse_args() + + if args.list_suites: + list_suites() + + if args.run_suites: + main(*args.run_suites) diff --git a/src/toolbox_scs/test/test_top_level.py b/src/toolbox_scs/test/test_top_level.py index d696f4369f0bb056f1e3cde62a1e2d23590378e1..31e7f899a86d34d38f27a2b50582dcb3651fc463 100644 --- a/src/toolbox_scs/test/test_top_level.py +++ b/src/toolbox_scs/test/test_top_level.py @@ -11,7 +11,6 @@ import extra_data as ed suites = {"packaging": ( "test_constant", - ), "load": ( "test_load", diff --git a/src/toolbox_scs/test/test_utils.py b/src/toolbox_scs/test/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..451782adf3341f6ed5be7a0b7af83f6f3ec49e0f --- /dev/null +++ b/src/toolbox_scs/test/test_utils.py @@ -0,0 +1,102 @@ +import unittest +import logging +import os +import sys +import argparse + + +from toolbox_scs.util.data_access import ( + find_run_dir, + ) +from toolbox_scs.util.exceptions import ToolBoxPathError + +suites = {"ed-extensions": ( + "test_rundir1", + "test_rundir2", + "test_rundir3", + ) + } + + +def list_suites(): + print("""\nPossible test suites:\n-------------------------""") + for key in suites: + print(key) + print("-------------------------\n") + + +class TestDataAccess(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_rundir1(self): + Proposal = 2212 + Run = 235 + Dir = find_run_dir(Proposal, Run) + self.assertEqual(Dir, + "/gpfs/exfel/exp/SCS/201901/p002212/raw/r0235") + + def test_rundir2(self): + Proposal = 23678 + Run = 235 + Dir = find_run_dir(Proposal, Run) + self.assertEqual(Dir, None) + + def test_rundir3(self): + Proposal = 2212 + Run = 2325 + with self.assertRaises(ToolBoxPathError) as cm: + find_run_dir(Proposal, Run) + the_exception = cm.exception + path = '/gpfs/exfel/exp/SCS/201901/p002212/raw/r2325' + err_msg = f"The constructed path '{path}' does not exist" + self.assertEqual(the_exception.message, err_msg) + + +def suite(*tests): + suite = unittest.TestSuite() + for test in tests: + suite.addTest(TestDataAccess(test)) + return suite + + +def main(*cliargs): + logging.basicConfig(level=logging.DEBUG) + log_root = logging.getLogger(__name__) + try: + for test_suite in cliargs: + if test_suite in suites: + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite(*suites[test_suite])) + else: + log_root.warning( + "Unknown suite: '{}'".format(test_suite)) + pass + except Exception as err: + log_root.error("Unecpected error: {}".format(err), + exc_info=True) + pass + + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--list-suites', + action='store_true', + help='list possible test suites') + parser.add_argument('--run-suites', metavar='S', + nargs='+', action='store', + help='a list of valid test suites') + args = parser.parse_args() + + if args.list_suites: + list_suites() + + if args.run_suites: + main(*args.run_suites) diff --git a/src/toolbox_scs/util/data_access.py b/src/toolbox_scs/util/data_access.py new file mode 100644 index 0000000000000000000000000000000000000000..b088bac354f18ba9af69af09c93f126113661974 --- /dev/null +++ b/src/toolbox_scs/util/data_access.py @@ -0,0 +1,61 @@ +''' +Extensions to the extra_data package. + +contributions should comply with pep8 code structure guidelines. +''' + +import os +import logging + +import extra_data as ed +from extra_data.read_machinery import find_proposal + +from ..util.exceptions import ToolBoxPathError + +log = logging.getLogger(__name__) + +def find_run_dir(proposal, run): + """ + Get run directory for given run. + + This method is an extension to the extra_data method + 'find_proposal' and should eventually be transferred over. + + Parameters + ---------- + proposal: str, int + Proposal number + run: str, int + Run number + + Returns + ------- + rdir : str + Run directory as a string + + Raises + ------ + ToolBoxPathError: Exception + Error raised if the constructed path does not exist. This may + happen when entering a non-valid run number, or the folder has + been renamed/removed. + + """ + rdir = None + + try: + pdir = find_proposal(f'p{proposal:06d}') + rdir = os.path.join(pdir, f'raw/r{run:04d}') + if os.path.isdir(rdir) is False: + log.warning("Invalid directory: raise ToolBoxPathError.") + msg = f"The constructed path '{rdir}' does not exist" + raise ToolBoxPathError(msg, rdir) + + except ToolBoxPathError: + raise + except Exception as err: + log.error("Unexpected error:", exc_info=True) + log.warning("Unexpected error orrured, return None") + pass + + return rdir \ No newline at end of file