From 660045ed0cc20fdcce528c56e36c6cdca4d7b8ab Mon Sep 17 00:00:00 2001 From: Laurent Mercadier <laurent.mercadier@xfel.eu> Date: Tue, 29 Oct 2019 09:46:35 +0100 Subject: [PATCH] Adds documentation, simplifies selectSASEfromXGM() --- xgm.py | 97 +++++++++++++++++++++++++--------------------------------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/xgm.py b/xgm.py index 85e0058..61eba7e 100644 --- a/xgm.py +++ b/xgm.py @@ -199,11 +199,9 @@ def cleanXGMdata(data, npulses=None, sase3First=True): def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None): ''' Given an array containing both SASE1 and SASE3 data, extracts SASE1- - or SASE3-only XGM data. - There are various cases depending on i) the mode of operation (10 Hz - with fresh bunch, dedicated trains to one SASE, pulse on demand), - ii) the potential change of number of pulses per train in each SASE - and iii) the order (SASE1 first, SASE3 first, interleaved mode). + or SASE3-only XGM data. The function tracks the changes of bunch patterns + in sase 1 and sase 3 and applies a mask to the XGM array to extract the + relevant pulses. This way, all complicated patterns are accounted for. Inputs: data: xarray Dataset containing xgm data @@ -240,6 +238,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses= return xgmData[:,start:start+npulses] #2. case where bunch pattern is provided + #2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses xgm_arr = data[xgm].where(data[xgm] != 1., drop=True) sa3 = data['sase3'].where(data['sase3'] > 1, drop=True) sa3_val=np.unique(sa3) @@ -252,51 +251,38 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses= dims=['trainId', 'bunchId'], coords={'trainId':data.trainId}, name='sase_all') - idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True) - idxList3 = np.unique(sa3) - idxList1 = np.unique(sa1) - if sase=='sase3': - big_sa3 = [] - for i,idxXGM in enumerate(idxListAll): - idxXGM = np.isin(idxXGM, idxList3) - idxTid = invAll==i - mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), - dims=['trainId', 'XGMbunchId'], - coords={'trainId':data.trainId, - 'XGMbunchId':sa_all['bunchId'].values}, - name='mask') - mask[idxTid, idxXGM] = True - sa3 = xgm_arr.where(mask, drop=True) - if sa3.trainId.size > 0: - sa3 = sa3.assign_coords(XGMbunchId=np.arange(sa3.XGMbunchId.size)) - big_sa3.append(sa3) - if len(big_sa3) > 0: - da_sa3 = xr.concat(big_sa3, dim='trainId').rename('{}_SA3'.format(xgm.split('_')[0])) - else: - da_sa3 = xr.DataArray([], dims=['trainId'], name='{}_SA3'.format(xgm.split('_')[0])) - return da_sa3 - - if sase=='sase1': - big_sa1 = [] - for i,idxXGM in enumerate(idxListAll): - idxXGM = np.isin(idxXGM, idxList1) - idxTid = invAll==i - mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), - dims=['trainId', 'XGMbunchId'], - coords={'trainId':data.trainId, - 'XGMbunchId':sa_all['bunchId'].values}, - name='mask') - mask[idxTid, idxXGM] = True - sa1 = xgm_arr.where(mask, drop=True) - if sa1.trainId.size > 0: - sa1 = sa1.assign_coords(XGMbunchId=np.arange(sa1.XGMbunchId.size)) - big_sa1.append(sa1) - if len(big_sa1) > 0: - da_sa1 = xr.concat(big_sa1, dim='trainId').rename('{}_SA1'.format(xgm.split('_')[0])) - else: - da_sa1 = xr.DataArray([], dims=['trainId'], name='{}_SA1'.format(xgm.split('_')[0])) - return da_sa1 + idxListSase = np.unique(sa3) + newName = xgm.split('_')[0] + '_SA3' + else: + idxListSase = np.unique(sa1) + newName = xgm.split('_')[0] + '_SA1' + + #2.2 track the changes of pulse patterns and the indices at which they occured (invAll) + idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True) + + #2.3 define a mask, loop over the different patterns and update the mask for each pattern + mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool), + dims=['trainId', 'XGMbunchId'], + coords={'trainId':data.trainId, + 'XGMbunchId':sa_all['bunchId'].values}, + name='mask') + + big_sase = [] + for i,idxXGM in enumerate(idxListAll): + mask.values = np.zeros(mask.shape, dtype=bool) + idxXGM = np.isin(idxXGM, idxListSase) + idxTid = invAll==i + mask[idxTid, idxXGM] = True + sa_arr = xgm_arr.where(mask, drop=True) + if sa_arr.trainId.size > 0: + sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size)) + big_sase.append(sa_arr) + if len(big_sase) > 0: + da_sase = xr.concat(big_sase, dim='trainId').rename(newName) + else: + da_sase = xr.DataArray([], dims=['trainId'], name=newName) + return da_sase def saseContribution(data, sase='sase1', xgm='XTD10_XGM'): ''' Calculate the relative contribution of SASE 1 or SASE 3 pulses @@ -533,9 +519,8 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None, ''' Extract peak-integrated data from TIM where pulses are from SASE3 only. If use_apd is False it calculates integration from raw traces. The missing values, in case of change of number of pulses, are filled - with NaNs. - If no bunch pattern info is available, the function assumes that - SASE 3 comes first and that the number of pulses is fixed in both + with NaNs. If no bunch pattern info is available, the function assumes + that SASE 3 comes first and that the number of pulses is fixed in both SASE 1 and 3. Inputs: @@ -544,7 +529,7 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None, intstop: trace index of integration stop bkgstart: trace index of background start bkgstop: trace index of background stop - t_offset: index separation between two pulses + t_offset: number of ADC samples between two pulses mcp: MCP channel number npulses: int, optional. Number of pulses to compute. Required if no bunch pattern info is available. @@ -602,6 +587,7 @@ def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None, npulses=pulseIdDim) sa3 = sa3/period + #2.3 track the changes of pulse patterns and the indices at which they occured (invAll) idxList, inv = np.unique(sa3, axis=0, return_inverse=True) mask = xr.DataArray(np.zeros((data.dims['trainId'], pulseIdDim), dtype=bool), dims=['trainId', pulseId], @@ -993,11 +979,10 @@ def mergeFastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, intstop: trace index of integration stop bkgstart: trace index of background start bkgstop: trace index of background stop - period: Number of samples separation between two pulses. Needed + period: Number of ADC samples between two pulses. Needed if bunch pattern info is not available. If None, checks the pulse pattern and determine the period assuming a resolution - of 9.23 ns per sample which leads to 24 samples between - two bunches @ 4.5 MHz. + of 9.23 ns per sample = 24 samples between two pulses @ 4.5 MHz. npulses: number of pulses. If None, takes the maximum number of pulses according to the bunch patter (field 'npulses_sase3') dim: name of the xr dataset dimension along the peaks -- GitLab