diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py index dad9404edc9430a0f71e71fba42a5d85ad31fd37..b5389bab81febfce7fe1818fd94757be2e180fad 100644 --- a/src/toolbox_scs/detectors/digitizers.py +++ b/src/toolbox_scs/detectors/digitizers.py @@ -7,6 +7,7 @@ """ import logging +import os import numpy as np import xarray as xr @@ -18,6 +19,9 @@ from ..misc.bunch_pattern_external import is_pulse_at from ..util.exceptions import ToolBoxValueError from ..mnemonics_machinery import (mnemonics_to_process, mnemonics_for_run) +from extra_data import open_run +from extra_data.read_machinery import find_proposal +from extra.components import XrayPulses, OpticalLaserPulses __all__ = [ 'check_peak_params', @@ -26,7 +30,10 @@ __all__ = [ 'get_peaks', 'get_tim_peaks', 'digitizer_signal_description', - 'get_dig_avg_trace' + 'get_dig_avg_trace', + 'extract_digitizer_peaks', + 'load_processed_peaks', + 'check_processed_peak_params' ] log = logging.getLogger(__name__) @@ -268,21 +275,20 @@ def get_peaks(run, min_distance = 24 if digitizer == 'ADQ412': min_distance = 440 - params = integParams.copy() if autoFind: stride = int(np.max([1, np.floor(arr.sizes['trainId']/200)])) trace = arr.isel(trainId=slice(0, None, stride)).mean(dim='trainId') try: - params = find_integ_params(trace) + integParams = find_integ_params(trace) except ValueError as err: log.warning(f'{err}, trying with averaged trace over all trains.') trace = arr.mean(dim='trainId') - params = find_integ_params(trace) - log.debug(f'Auto find peaks result: {params}') + integParams = find_integ_params(trace) + log.debug(f'Auto find peaks result: {integParams}') required_keys = ['pulseStart', 'pulseStop', 'baseStart', 'baseStop', 'period', 'npulses'] - if params is None or not all(name in params + if integParams is None or not all(name in integParams for name in required_keys): raise TypeError('All keys of integParams argument ' f'{required_keys} are required when ' @@ -291,10 +297,16 @@ def get_peaks(run, # 2.1. No bunch pattern provided if bpt is None: log.info('Missing bunch pattern info.') - log.debug(f'Retrieving {params["npulses"]} pulses.') + log.debug(f'Retrieving {integParams["npulses"]} pulses.') if extra_dim is None: extra_dim = 'pulseId' - return peaks_from_raw_trace(arr, **params, extra_dim=extra_dim) + return peaks_from_raw_trace(arr, integParams['pulseStart'], + integParams['pulseStop'], + integParams['baseStart'], + integParams['baseStop'], + integParams['period'], + integParams['npulses'], + extra_dim=extra_dim) # 2.2 Bunch pattern is provided # load mask and extract pulse Id: @@ -322,7 +334,7 @@ def get_peaks(run, period_bpt = 0 else: period_bpt = np.min(np.diff(pid)) - if autoFind and period_bpt*min_distance != params['period']: + if autoFind and period_bpt*min_distance != integParams['period']: log.warning('The period from the bunch pattern is different than ' 'that found by the peak-finding algorithm. Either ' 'the algorithm failed or the bunch pattern source ' @@ -330,9 +342,15 @@ def get_peaks(run, # create array of sample indices for peak integration sample_id = (pid-pid[0])*min_distance # override auto find parameters - if isinstance(params['pulseStart'], (int, np.integer)): - params['pulseStart'] = params['pulseStart'] + sample_id - peaks = peaks_from_raw_trace(valid_arr, **params, extra_dim=extra_dim) + if isinstance(integParams['pulseStart'], (int, np.integer)): + integParams['pulseStart'] = integParams['pulseStart'] + sample_id + peaks = peaks_from_raw_trace(valid_arr, integParams['pulseStart'], + integParams['pulseStop'], + integParams['baseStart'], + integParams['baseStop'], + integParams['period'], + integParams['npulses'], + extra_dim) if pattern_changed: peaks = peaks.where(mask_on, drop=True) return peaks.assign_coords({extra_dim: pid}) @@ -719,7 +737,7 @@ def check_peak_params(run, mnemonic, raw_trace=None, ntrains=200, params=None, return params -def plotPeakIntegrationWindow(raw_trace, params, bp_params, show_all=False): +def plotPeakIntegrationWindow(raw_trace, params, bp_params=None, show_all=False): if show_all: fig, ax = plt.subplots(figsize=(6, 3), constrained_layout=True) n = params['npulses'] @@ -1234,3 +1252,218 @@ def timFactorFromTable(voltage, photonEnergy, mcp=1): poly = np.poly1d(tim_calibration_table[photonEnergy][mcp-1]) f = -np.exp(poly(voltage)) return f + +############################################################################################# +############################################################################################# +############################################################################################# + +def extract_digitizer_peaks(proposal, runNB, mnemonic, bunchPattern=None, + integParams=None, autoFind=True, save=True, + subdir='usr/processed_runs'): + if integParams is None and autoFind is False: + log.warning('integParams not provided and autoFind is False. ' + 'Cannot compute peak integration parameters.') + return xr.DataArray() + + run = open_run(proposal, runNB) + run_mnemonics = mnemonics_for_run(run) + mnemo = run_mnemonics.get(mnemonic) + if mnemo is None: + log.warning('Mnemonic not found. Skipping.') + return xr.DataArray() + source, key = mnemo['source'], mnemo['key'] + extra_dim = {'sase3': 'sa3_pId', 'scs_ppl': 'ol_pId'}.get(bunchPattern) + if extra_dim is None: + extra_dim = 'pulseId' + digitizer = digitizer_type(run, source) + if digitizer == 'FastADC': + pulse_period = 24 + if digitizer == 'ADQ412': + pulse_period = 440 + + pattern = None + regular = True + try: + if bunchPattern == 'sase3': + pattern = XrayPulses(run) + if bunchPattern == 'scs_ppl': + pattern = OpticalLaserPulses(run) + except Exception as e: + print(e) + bunchPattern = None + + if integParams is not None: + required_keys = ['pulseStart', 'pulseStop', 'baseStart', + 'baseStop', 'period', 'npulses'] + if not all(name in integParams for name in required_keys): + raise TypeError('All keys of integParams argument ' + f'{required_keys} are required.') + params = integParams.copy() + autoFind = False + if pattern is not None: + # use period and npulses from pattern + pulse_ids = pattern.peek_pulse_ids(labelled=False) + npulses_from_bp = len(pulse_ids) + period_from_bp = 0 + if npulses_from_bp > 1: + period_from_bp = min(np.diff(pulse_ids)) * pulse_period + if (npulses_from_bp != params['npulses'] or + period_from_bp != params['period']): + print(f'Integration parameters (npulses={params["npulses"]}, ' + f'period={params["period"]}) do not match ' + f'the bunch pattern (npulses={npulses_from_bp}, ' + f'period={period_from_bp}). Using bunch pattern parameters.') + params['npulses'] = npulses_from_bp + params['period'] = period_from_bp + else: + period = params['period'] + npulses = params['npulses'] + pulse_ids = np.arange(params['npulses'], dtype=np.uint64) + + elif pattern is not None: + if pattern.is_constant_pattern() is False: + print('The number of pulses changed during the run.') + pulse_ids = np.unique(pattern.pulse_ids(labelled=False, copy=False)) + regular = False + else: + pulse_ids = pattern.peek_pulse_ids(labelled=False) + npulses = len(pulse_ids) + period = 0 + if npulses > 1: + periods = np.diff(pulse_ids) + if len(np.unique(periods)) > 1: + regular = False + period = min(periods) + npulses = int((max(pulse_ids) - min(pulse_ids)) / period) + 1 + period *= pulse_period + else: + pulse_ids = npulses = period = None + + # generate average trace + traces = run[source, key].xarray(name=mnemonic.replace('raw', 'avg'), + extra_dims=mnemo['dim']) + trace = traces.mean('trainId') + # find peak integration parameters + if autoFind == True: + #params = find_integration_params(trace, period, npulses) + params = find_integ_params(trace) + if (period is not None and params['period'] != period + or npulses is not None and params['npulses'] != npulses): + log.warning(f'Bunch pattern (npulses={npulses}, period={period}) and ' + f'found integration parameters (npulses={params["npulses"]}, ' + f'period={params["period"]}) do not match. Using bunch ' + 'pattern parameters.') + params['period'] = period + params['npulses'] = min(len(trace) // period, npulses) + print(params['npulses']) + if pulse_ids is None: + pulse_ids = np.arange(params['npulses'], dtype=np.uint64) + + if params is None: + print('Could not find peak integration parameters.') + return xr.DataArray() + + # extract peaks + data = peaks_from_raw_trace(traces, **params, extra_dim=extra_dim) + data = data.rename(mnemonic.replace('raw', 'peaks')) + data = data.assign_coords({extra_dim: pulse_ids}) + + if regular is False: + period = int(period / pulse_period) + mask = pattern.pulse_mask(labelled=False) + mask = xr.DataArray(mask, dims=['trainId', extra_dim], + coords={'trainId': run[source, key].train_ids, + extra_dim: np.arange(mask.shape[1])}) + mask = mask.sel({extra_dim: slice(pulse_ids[0], + pulse_ids[0] + npulses*period,period)}) + data = data.where(mask, drop=True) + + data.attrs['params_keys'] = list(params.keys()) + data.attrs[f'params_{data.name}'] = list(params.values()) + if save: + save_peaks(proposal, runNB, data, trace, subdir) + return data + +def save_peaks(proposal, runNB, peaks, avg_trace, subdir): + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, subdir + f'/r{runNB:04d}/') + os.makedirs(path, exist_ok=True) + fname = path + f'r{runNB:04d}-digitizers-data.h5' + ds_peaks = peaks.to_dataset(promote_attrs=True) + + if os.path.isfile(fname): + ds = xr.load_dataset(fname) + ds = ds.drop_vars([peaks.name, avg_trace.name], errors='ignore') + for dim in ds.dims: + if all(dim not in ds[v].dims for v in ds): + ds = ds.drop_dims(dim) + dim_name = 'sampleId' + if 'sampleId' in ds.dims and ds.sizes['sampleId'] != len(avg_trace): + dim_name = 'sampleId2' + avg_trace = avg_trace.rename({avg_trace.dims[0]: dim_name}) + if f'params_{peaks.name}' in ds.attrs: + del ds.attrs[f'params_{peaks.name}'] + ds = xr.merge([ds, ds_peaks, avg_trace], + combine_attrs='drop_conflicts', join='inner') + else: + ds = ds_peaks.merge(avg_trace.rename({avg_trace.dims[0]: 'sampleId'})) + ds.to_netcdf(fname, format='NETCDF4') + print(f'saved data into {fname}.') + +def load_processed_peaks(proposal, runNB, mnemonic=None, + data='usr/processed_runs', merge_with=None): + if mnemonic is None: + return load_all_processed_peaks(proposal, runNB, data, merge_with) + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, data + f'/r{runNB:04d}/') + fname = path + f'r{runNB:04d}-digitizers-data.h5' + if os.path.isfile(fname): + ds = xr.load_dataset(fname) + if mnemonic not in ds: + print(f'Mnemonic {mnemonic} not found in {fname}') + return {} + da = ds[mnemonic] + da.attrs['params_keys'] = ds.attrs['params_keys'] + da.attrs[f'params_{mnemonic}'] = ds.attrs[f'params_{mnemonic}'] + if merge_with is not None: + return merge_with.merge(da.to_dataset(promote_attrs=True), + combine_attrs='drop_conflicts', join='inner') + else: + return da + else: + print(f'Mnemonic {mnemonic} not found in {fname}') + return merge_with + +def load_all_processed_peaks(proposal, runNB, data='usr/processed_runs', + merge_with=None): + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, data + f'/r{runNB:04d}/') + fname = path + f'r{runNB:04d}-digitizers-data.h5' + if os.path.isfile(fname): + if merge_with is not None: + return merge_with.merge(xr.load_dataset(fname), + combine_attrs='drop_conflicts', join='inner') + return xr.load_dataset(fname) + else: + print(f'{fname} not found.') + return merge_with + +def check_processed_peak_params(proposal, runNB, mnemonic, data='usr/processed_runs', + plot=True, show_all=False): + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, data + f'/r{runNB:04d}/') + fname = path + f'r{runNB:04d}-digitizers-data.h5' + if os.path.isfile(fname): + ds = xr.load_dataset(fname) + if mnemonic.replace('raw', 'peaks') not in ds: + print(f'Mnemonic {mnemonic} not found in {fname}') + return {} + da = ds[mnemonic] + params = dict(zip(ds.attrs['params_keys'], ds.attrs[f'params_{mnemonic}'])) + if plot: + plotPeakIntegrationWindow(ds[mnemonic.replace('peaks', 'avg')], + params, show_all=show_all) + return params + else: + print(f'{fname} not found.') + return {} \ No newline at end of file