diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py index 5d917bb65ff7a5aee52ce0e67f9eadb6a76fc462..30e963ece4fa1bfe3c0fe33f9c22c0f9036250de 100644 --- a/src/toolbox_scs/detectors/digitizers.py +++ b/src/toolbox_scs/detectors/digitizers.py @@ -164,7 +164,7 @@ def get_peaks(run, indices=None, ): """ - Extract peaks from digitizer data. + Extract peaks from one source (channel) of a digitizer. Parameters ---------- @@ -530,12 +530,13 @@ def get_peak_params(run, key, raw_trace=None, plot=False): return params -def digitizer_peaks_from_keys(run, key, digitizer='ADQ412', - bunchPattern='sase3', integParams=None, - keepAllSase=False): +def get_peaks_from_list(run, dataList, digitizer='ADQ412', + bunchPattern='sase3', integParams=None, + keepAllSase=False): """ Automatically loads and computes digitizer peaks from ToolBox - mnemonics. Sources can be raw traces (e.g. "MCP3raw") or peak- + mnemonics and/or data arrays (raw traces or peak-integrated values). + Sources can be raw traces (e.g. "MCP3raw") or peak- integrated data (e.g. "MCP2apd"). The bunch pattern table is used to assign the pulse id coordinates. @@ -545,7 +546,7 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412', DataCollection containing the digitizer data. If provided, data must be a str or list of str. If None, data must be an xarray Dataset containing actual data. - key: str or list of str + key: list of str and/or DataArray A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw". bunchPattern: str 'sase1' or 'sase3' or 'scs_ppl', bunch pattern @@ -566,16 +567,21 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412', the peak caclulated values (e.g. "MCP2raw" becomes "MCP2peaks"). """ - keys = key if isinstance(key, list) else [key] autoFind = True if integParams is not None: autoFind = False - peakDict = {} - for key in keys: + keys = [] + vals = [] + for d in dataList: + if isinstance(d, str): + key = d + log.debug(f'Loading peaks from {key}...') + if isinstance(d, xr.DataArray): + key = d.name + log.debug(f'Retrieving peaks from {key} array...') useRaw = True if 'raw' in key else False - log.debug(f'Retrieving peaks from {key}...') m = _mnemonics[key] - peaks = get_peaks(run, key, + peaks = get_peaks(run, d, source=m['source'], key=m['key'], digitizer=digitizer, @@ -583,11 +589,12 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412', autoFind=autoFind, integParams=integParams, bunchPattern=bunchPattern) - newKey = key.replace('raw', 'peaks').replace('apd', 'peaks') - peakDict[newKey] = peaks - ds = xr.Dataset() + key = key.replace('raw', 'peaks').replace('apd', 'peaks') + keys.append(key) + vals.append(peaks) join = 'outer' if keepAllSase else 'inner' - ds = ds.merge(peakDict, join=join) + aligned_vals = xr.align(*vals, join=join) + ds = xr.Dataset(dict(zip(keys, aligned_vals))) return ds @@ -734,13 +741,15 @@ def get_digitizer_peaks(run, key, digitizer, merge_with, dig_keys += [f'MCP{c}peaks' for c in range(1, 5)] if merge_with is None: if key is None: - key = default_key + keys = [default_key] + else: + keys = key if isinstance(key, list) else [key] log.info(f'Loading peaks from {key} data.') - return digitizer_peaks_from_keys(run, key, - digitizer=digitizer, - bunchPattern=bunchPattern, - integParams=integParams, - keepAllSase=keepAllSase) + return get_peaks_from_list(run, keys, + digitizer=digitizer, + bunchPattern=bunchPattern, + integParams=integParams, + keepAllSase=keepAllSase) mw_keys = [k for k in merge_with if k in dig_keys] mw_to_process = [k for k in mw_keys if len(merge_with[k].coords) != 2] mw_channels = [re.split(r'(\d+)', k)[1] for k in mw_keys] @@ -748,7 +757,6 @@ def get_digitizer_peaks(run, key, digitizer, merge_with, key = default_key if key is None: keys = [] - ds_dig = xr.Dataset() else: keys = key if isinstance(key, list) else [key] duplicate = [k for k in keys if @@ -757,31 +765,15 @@ def get_digitizer_peaks(run, key, digitizer, merge_with, log.info(f'{duplicate} keys already in merge_with dataset. ' + 'Skipping.') keys = [k for k in keys if k not in duplicate] - ds_dig = digitizer_peaks_from_keys(run, keys, - digitizer=digitizer, - bunchPattern=bunchPattern, - integParams=integParams, - keepAllSase=keepAllSase) log.info(f'Loading peaks from {keys+mw_to_process}.') - autoFind = True - if integParams is not None: - autoFind = False - peakDict = {} - toRemove = [] - for key in mw_to_process: - useRaw = True if 'raw' in key else False - log.debug(f'Retrieving peaks from {key}...') - m = _mnemonics[key] - peaks = get_peaks(run, merge_with[key], source=m['source'], - key=m['key'], digitizer=digitizer, - useRaw=useRaw, autoFind=autoFind, - integParams=integParams, - bunchPattern=bunchPattern) - newKey = key.replace('raw', 'peaks').replace('apd', 'peaks') - peakDict[newKey] = peaks - toRemove.append(key) - ds = merge_with.drop(toRemove) + dataList = keys + dataList += [merge_with[k] for k in mw_to_process] + ds_dig = get_peaks_from_list(run, dataList, + digitizer=digitizer, + bunchPattern=bunchPattern, + integParams=integParams, + keepAllSase=keepAllSase) + ds = merge_with.drop(mw_to_process) join = 'outer' if keepAllSase else 'inner' - ds = ds.merge(peakDict, join=join) - ds = ds_dig.merge(ds, join=join) + ds = ds.merge(ds_dig, join=join) return ds diff --git a/src/toolbox_scs/detectors/xgm.py b/src/toolbox_scs/detectors/xgm.py index 32f94f44f736db2d830469b9c174d3fb4ce71c87..309f4a9f7020a7e4440dbf0df22cef6664570187 100644 --- a/src/toolbox_scs/detectors/xgm.py +++ b/src/toolbox_scs/detectors/xgm.py @@ -94,8 +94,12 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False, """ # check if bunch pattern table exists - if _mnemonics['bunchPatternTable']['source'] in run.all_sources: + if bool(merge_with) and 'bunchPatternTable' in merge_with: + bpt = merge_with['bunchPatternTable'] + log.debug('Using bpt from merge_with dataset.') + elif _mnemonics['bunchPatternTable']['source'] in run.all_sources: bpt = run.get_array(*_mnemonics['bunchPatternTable'].values()) + log.debug('Loaded bpt from extra_data run.') else: bpt = None if key is None and merge_with is None: @@ -307,12 +311,11 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', return xgmData#.assign_coords(XGMbunchId=np.arange(npulses)) #2. case where bunch pattern is provided - mask_sa1 = is_sase_1(bpt) - mask_sa3 = is_sase_3(bpt) + mask_sa1 = is_sase_1(bpt).sel(trainId=data.trainId) + mask_sa3 = is_sase_3(bpt).sel(trainId=data.trainId) mask_all = xr.ufuncs.logical_or(mask_sa1, mask_sa3) tid = mask_all.where(mask_all.sum(dim='pulse_slot')>0, drop=True).trainId - tid = np.intersect1d(tid, data.trainId) mask_sa1 = mask_sa1.sel(trainId=tid) mask_sa3 = mask_sa3.sel(trainId=tid) mask_all = mask_all.sel(trainId=tid)