diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py index 30e963ece4fa1bfe3c0fe33f9c22c0f9036250de..3168aa3d9ac3874d4754660f2aa545131b4ba0c3 100644 --- a/src/toolbox_scs/detectors/digitizers.py +++ b/src/toolbox_scs/detectors/digitizers.py @@ -318,6 +318,74 @@ def get_peaks(run, return peaks.assign_coords({extra_dim: pid}) +def get_peaks_from_list(run, dataList, digitizer='ADQ412', + bunchPattern='sase3', integParams=None, + keepAllSase=False): + """ + Automatically loads and computes digitizer peaks from ToolBox + mnemonics and/or data arrays (raw traces or peak-integrated values). + Sources can be raw traces (e.g. "MCP3raw") or peak- + integrated data (e.g. "MCP2apd"). The bunch pattern table is + used to assign the pulse id coordinates. + + Parameters + ---------- + run: extra_data.DataCollection + DataCollection containing the digitizer data. If provided, + data must be a str or list of str. If None, data must be + an xarray Dataset containing actual data. + key: list of str and/or DataArray + A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw". + bunchPattern: str + 'sase1' or 'sase3' or 'scs_ppl', bunch pattern + used to extract peaks. The pulse ID dimension will be named + 'sa1_pId', 'sa3_pId' or 'ol_pId', respectively. + integParams: dict + dictionnary for raw trace integration, e.g. + {'intstart':100, 'intstop':200, 'bkgstart':50, + 'bkgstop':99, 'period':440, 'npulses':500}. + If None, integration parameters are computed automatically. + keepAllSase: bool + Only relevant in case of sase-dedicated trains. If True, all + trains are kept, else only those of the bunchPattern are kept. + + Returns + ------- + xarray Dataset with all TIM variables substituted by + the peak caclulated values (e.g. "MCP2raw" becomes + "MCP2peaks"). + """ + autoFind = True + if integParams is not None: + autoFind = False + keys = [] + vals = [] + for d in dataList: + if isinstance(d, str): + key = d + log.debug(f'Loading peaks from {key}...') + if isinstance(d, xr.DataArray): + key = d.name + log.debug(f'Retrieving peaks from {key} array...') + useRaw = True if 'raw' in key else False + m = _mnemonics[key] + peaks = get_peaks(run, d, + source=m['source'], + key=m['key'], + digitizer=digitizer, + useRaw=useRaw, + autoFind=autoFind, + integParams=integParams, + bunchPattern=bunchPattern) + key = key.replace('raw', 'peaks').replace('apd', 'peaks') + keys.append(key) + vals.append(peaks) + join = 'outer' if keepAllSase else 'inner' + aligned_vals = xr.align(*vals, join=join) + ds = xr.Dataset(dict(zip(keys, aligned_vals))) + return ds + + def channel_peak_params(run, source, key=None, digitizer=None, channel=None, board=None): """ @@ -458,26 +526,6 @@ def find_integ_params(trace, min_distance=1, height=None, width=1): return result -def plotPeakIntegrationWindow(raw_trace, params): - xmin = np.max([0, params['baseStart']-100]) - xmax = np.min([params['pulseStop']+100, raw_trace.size]) - fig, ax = plt.subplots(figsize=(5, 3)) - ax.axvline(params['baseStart'], ls='--', color='k') - ax.axvline(params['baseStop'], ls='--', color='k') - ax.axvspan(params['baseStart'], params['baseStop'], - alpha=0.5, color='grey', label='baseline') - ax.axvline(params['pulseStart'], ls='--', color='r') - ax.axvline(params['pulseStop'], ls='--', color='r') - ax.axvspan(params['pulseStart'], params['pulseStop'], - alpha=0.2, color='r', label='peak') - ax.plot(np.arange(xmin, xmax), raw_trace[xmin:xmax], - label='raw trace') - ax.set_xlabel('digitizer samples') - ax.legend(fontsize=8) - fig.tight_layout() - return ax - - def get_peak_params(run, key, raw_trace=None, plot=False): """ Get the peak region and baseline region of a raw @@ -530,72 +578,24 @@ def get_peak_params(run, key, raw_trace=None, plot=False): return params -def get_peaks_from_list(run, dataList, digitizer='ADQ412', - bunchPattern='sase3', integParams=None, - keepAllSase=False): - """ - Automatically loads and computes digitizer peaks from ToolBox - mnemonics and/or data arrays (raw traces or peak-integrated values). - Sources can be raw traces (e.g. "MCP3raw") or peak- - integrated data (e.g. "MCP2apd"). The bunch pattern table is - used to assign the pulse id coordinates. - - Parameters - ---------- - run: extra_data.DataCollection - DataCollection containing the digitizer data. If provided, - data must be a str or list of str. If None, data must be - an xarray Dataset containing actual data. - key: list of str and/or DataArray - A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw". - bunchPattern: str - 'sase1' or 'sase3' or 'scs_ppl', bunch pattern - used to extract peaks. The pulse ID dimension will be named - 'sa1_pId', 'sa3_pId' or 'ol_pId', respectively. - integParams: dict - dictionnary for raw trace integration, e.g. - {'intstart':100, 'intstop':200, 'bkgstart':50, - 'bkgstop':99, 'period':440, 'npulses':500}. - If None, integration parameters are computed automatically. - keepAllSase: bool - Only relevant in case of sase-dedicated trains. If True, all - trains are kept, else only those of the bunchPattern are kept. - - Returns - ------- - xarray Dataset with all TIM variables substituted by - the peak caclulated values (e.g. "MCP2raw" becomes - "MCP2peaks"). - """ - autoFind = True - if integParams is not None: - autoFind = False - keys = [] - vals = [] - for d in dataList: - if isinstance(d, str): - key = d - log.debug(f'Loading peaks from {key}...') - if isinstance(d, xr.DataArray): - key = d.name - log.debug(f'Retrieving peaks from {key} array...') - useRaw = True if 'raw' in key else False - m = _mnemonics[key] - peaks = get_peaks(run, d, - source=m['source'], - key=m['key'], - digitizer=digitizer, - useRaw=useRaw, - autoFind=autoFind, - integParams=integParams, - bunchPattern=bunchPattern) - key = key.replace('raw', 'peaks').replace('apd', 'peaks') - keys.append(key) - vals.append(peaks) - join = 'outer' if keepAllSase else 'inner' - aligned_vals = xr.align(*vals, join=join) - ds = xr.Dataset(dict(zip(keys, aligned_vals))) - return ds +def plotPeakIntegrationWindow(raw_trace, params): + xmin = np.max([0, params['baseStart']-100]) + xmax = np.min([params['pulseStop']+100, raw_trace.size]) + fig, ax = plt.subplots(figsize=(5, 3)) + ax.axvline(params['baseStart'], ls='--', color='k') + ax.axvline(params['baseStop'], ls='--', color='k') + ax.axvspan(params['baseStart'], params['baseStop'], + alpha=0.5, color='grey', label='baseline') + ax.axvline(params['pulseStart'], ls='--', color='r') + ax.axvline(params['pulseStop'], ls='--', color='r') + ax.axvspan(params['pulseStart'], params['pulseStop'], + alpha=0.2, color='r', label='peak') + ax.plot(np.arange(xmin, xmax), raw_trace[xmin:xmax], + label='raw trace') + ax.set_xlabel('digitizer samples') + ax.legend(fontsize=8) + fig.tight_layout() + return ax def get_tim_peaks(run, key=None, merge_with=None, diff --git a/src/toolbox_scs/detectors/xgm.py b/src/toolbox_scs/detectors/xgm.py index 309f4a9f7020a7e4440dbf0df22cef6664570187..cde1950450494f29cd4bd4dd0e7d1acd2a10ec8c 100644 --- a/src/toolbox_scs/detectors/xgm.py +++ b/src/toolbox_scs/detectors/xgm.py @@ -1,8 +1,8 @@ """ XGM related sub-routines Copyright (2019) SCS Team. - - (contributions preferrably comply with pep8 code structure + + (contributions preferrably comply with pep8 code structure guidelines.) """ @@ -15,7 +15,7 @@ from scipy.signal import find_peaks import extra_data as ed from ..constants import mnemonics as _mnemonics -from ..misc.bunch_pattern_external import is_sase_1, is_sase_3, is_ppl +from ..misc.bunch_pattern_external import is_sase_1, is_sase_3 log = logging.getLogger(__name__) @@ -56,7 +56,7 @@ def load_xgm(run, xgm_mnemonic='SCS_SA3'): def get_xgm(run, key=None, merge_with=None, keepAllSase=False, - indices=slice(0,None)): + indices=slice(0, None)): """ Load and/or computes XGM data. Sources can be loaded on the fly via the key argument, or processed from an existing data set @@ -84,14 +84,14 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False, ------- xarray Dataset with pulse-resolved XGM variables aligned, merged with Dataset *merge_with* if provided. - + Example ------- >>> import extra_data as ed >>> import toolbox_scs.detectors as tbdet >>> run = ed.open_run(2212, 213) >>> xgm = tbdet.get_xgm(run) - + """ # check if bunch pattern table exists if bool(merge_with) and 'bunchPatternTable' in merge_with: @@ -102,42 +102,121 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False, log.debug('Loaded bpt from extra_data run.') else: bpt = None + if key is None and merge_with is None: - key = 'SCS_SA3' + keys = ['SCS_SA3'] if key is None: keys = [] else: keys = key if isinstance(key, list) else [key] + + # create a list of arrays to sort from the provided keys and merge_with + dataList = [] + for k in keys: + xgm = run.get_array(*_mnemonics[k].values()) + dataList.append(xgm.rename(k)) if merge_with is None: - vals = [] - for key in keys: - xgm = run.get_array(*_mnemonics[key].values()) - vals.append(xgm) - aligned_vals = xr.align(*vals, join='inner') - ds = xr.Dataset(dict(zip(keys, aligned_vals))) - return cleanXGMdata(ds, bpt=bpt, keepAllSase=keepAllSase, - indices=indices) - mw_xgm_keys = [k for k in merge_with - if 'XGMbunchId' in merge_with[k].dims] - if len(mw_xgm_keys) == 0 and len(keys) == 0: - keys = ['SCS_SA3'] - keys = [key for key in keys if key not in merge_with] - if len(keys) == 0: - return cleanXGMdata(merge_with, bpt=bpt, - keepAllSase=keepAllSase, indices=indices) - vals = [] - for key in keys: - xgm = run.get_array(*_mnemonics[key].values()) - vals.append(xgm) - aligned_vals = xr.align(*vals, join='inner') - ds = xr.Dataset(dict(zip(keys, aligned_vals))) - ds = ds.merge(merge_with, join='inner') - return cleanXGMdata(ds, bpt=bpt, keepAllSase=keepAllSase, - indices=indices) - - -def cleanXGMdata(data, bpt=None, keepAllSase=False, - indices=slice(0,None)): + mw_ds = xr.Dataset() + else: + mw_xgm_keys = [k for k in merge_with + if 'XGMbunchId' in merge_with[k].dims] + dataList += [merge_with[k] for k in mw_xgm_keys] + mw_ds = merge_with.drop(mw_xgm_keys) + + # assign pulse ID to each array in dataList and merge + ds = mw_ds + for d in dataList: + if bpt is not None: + arr = align_xgm_array(d, bpt) + else: + arr = d.where(d != 1., drop=True).sel(XGMbunchId=indices) + ds = ds.merge(arr, join='inner') + return ds + + +def align_xgm_array(xgm_arr, bpt): + """ + Assigns pulse ID coordinates to a pulse-resolved XGM array, according to + the bunch pattern table. If the arrays contains both SASE 1 and SASE 3 + data, it is split in two arrays. + + Parameters + ---------- + xgm_arr: xarray DataArray + array containing pulse-resolved XGM data, with dims ['trainId', + 'XGMbunchId'] + bpt: xarray DataArray + bunch pattern table + + Returns + ------- + xarray Dataset with pulse ID coordinates. For SASE 1 data, the coordinates + name is sa1_pId, for SASE 3 data, the coordinates name is sa3_pId. + """ + key = xgm_arr.name + compute_sa1 = False + compute_sa3 = False + # get the relevant masks for SASE 1 and/or SASE3 + if "SA1" in key or "SA3" in key: + if "SA1" in key: + mask = is_sase_1(bpt.sel(trainId=xgm_arr.trainId)) + compute_sa1 = True + else: + mask = is_sase_3(bpt.sel(trainId=xgm_arr.trainId)) + compute_sa3 = True + tid = mask.where(mask.sum(dim='pulse_slot') > 0, drop=True).trainId + mask = mask.sel(trainId=tid) + mask_sa1 = mask.rename({'pulse_slot': 'sa1_pId'}) + mask_sa3 = mask.rename({'pulse_slot': 'sa3_pId'}) + if "XGM" in key: + compute_sa1 = True + compute_sa3 = True + mask_sa1 = is_sase_1(bpt.sel(trainId=xgm_arr.trainId)) + mask_sa3 = is_sase_3(bpt.sel(trainId=xgm_arr.trainId)) + mask = xr.ufuncs.logical_or(mask_sa1, mask_sa3) + tid = mask.where(mask.sum(dim='pulse_slot') > 0, + drop=True).trainId + mask_sa1 = mask_sa1.sel(trainId=tid).rename({'pulse_slot': 'sa1_pId'}) + mask_sa3 = mask_sa3.sel(trainId=tid).rename({'pulse_slot': 'sa3_pId'}) + mask = mask.sel(trainId=tid) + + npulses_max = mask.sum(dim='pulse_slot').max().values + xgm_arr = xgm_arr.sel(trainId=tid).isel( + XGMbunchId=slice(0, npulses_max)) + # pad the xgm array to match the bpt dims, flatten and + # reorder xgm array to match the indices of the mask + xgm_flat = np.hstack((xgm_arr.fillna(1.), + np.ones((xgm_arr.sizes['trainId'], + 2700-npulses_max)))).flatten() + xgm_flat_arg = np.argwhere(xgm_flat != 1.0) + mask_flat = mask.values.flatten() + mask_flat_arg = np.argwhere(mask_flat) + if(xgm_flat_arg.shape != mask_flat_arg.shape): + log.warning(f'{key}: XGM data and bunch pattern do not match.') + new_xgm_flat = np.ones(xgm_flat.shape) + new_xgm_flat[mask_flat_arg] = xgm_flat[xgm_flat_arg] + new_xgm = new_xgm_flat.reshape((xgm_arr.sizes['trainId'], 2700)) + + # create a dataset with new_xgm array masked by SASE 1 or SASE 3 + xgm_dict = {} + if compute_sa1: + sa1_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa1_pId'], + coords={'trainId': xgm_arr.trainId, + 'sa1_pId': np.arange(2700)}, + name=key.replace('XGM', 'SA1')) + xgm_dict[sa1_xgm.name] = sa1_xgm.where(mask_sa1, drop=True) + if compute_sa3: + sa3_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa3_pId'], + coords={'trainId': xgm_arr.trainId, + 'sa3_pId': np.arange(2700)}, + name=key.replace('XGM', 'SA3')) + xgm_dict[sa3_xgm.name] = sa3_xgm.where(mask_sa3, drop=True) + ds = xr.Dataset(xgm_dict) + return ds + + +def cleanXGMdata(data, bpt=None, keepAllSase=False, + indices=slice(0, None)): """ Cleans the XGM data arrays obtained from load() function. The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0