From 27c9380a851334f3710be44012ea3ef50b5f0e45 Mon Sep 17 00:00:00 2001 From: Laurent Mercadier <laurent.mercadier@xfel.eu> Date: Sun, 6 Oct 2024 11:54:54 +0200 Subject: [PATCH] update pes_get_tof(), add saving and loading averge traces functions --- src/toolbox_scs/detectors/pes.py | 290 ++++++++++++++++--------------- 1 file changed, 150 insertions(+), 140 deletions(-) diff --git a/src/toolbox_scs/detectors/pes.py b/src/toolbox_scs/detectors/pes.py index be5ee30..33c1e30 100644 --- a/src/toolbox_scs/detectors/pes.py +++ b/src/toolbox_scs/detectors/pes.py @@ -11,29 +11,36 @@ import numpy as np import xarray as xr import extra_data as ed import re +import os +from pathlib import Path +from multiprocessing import Pool from extra.components import XrayPulses, AdqRawChannel from ..misc.bunch_pattern_external import is_sase_3 from ..mnemonics_machinery import (mnemonics_to_process, mnemonics_for_run) +from extra_data.read_machinery import find_proposal from ..constants import mnemonics as _mnemonics __all__ = [ 'get_pes_params', 'get_pes_tof', - 'extract_pes_spectra', + 'save_pes_avg_traces', + 'load_pes_avg_traces' ] log = logging.getLogger(__name__) -def extract_pes_spectra(proposal, runNB, mnemonic, - start=0, origin=None, width=None): +def get_pes_tof(proposal, runNB, mnemonic, start=0, origin=None, + width=None, subtract_baseline=False, + baseStart=None, baseWidth=40,merge_with=None): """ Extracts time-of-flight spectra from raw digitizer traces. The spectra - are aligned by pulse Id using the SASE 3 bunch pattern, and have time - coordinates in nanoseconds. + are aligned by pulse Id using the SASE 3 bunch pattern. If origin is + not None, a time coordinate in nanoseconds 'time_ns' is computed and + added to the DataArray. Parameters ---------- @@ -46,12 +53,24 @@ def extract_pes_spectra(proposal, runNB, mnemonic, start: int starting sample of the first spectrum in the raw trace. origin: int - sample of the raw trace that corresponds to time-of-flight origin. - Used to compute the 'time_ns' coordinates. - If None, computation of 'time_ns' is skipped. + sample of the spectrum that corresponds to time-of-flight origin. + This is relative to the start, so the position in the raw trace + corresponds to start + origin. Used to compute the 'time_ns' + coordinates. If None, computation of 'time_ns' is skipped. width: int number of samples per spectra. If None, the number of samples for 4.5 MHz repetition rate is used. + subtract_baseline: bool + If True, subtract baseline defined by baseStart and baseWidth to each + spectrum. + baseStart: int + starting sample of the baseline. + baseWidth: int + number of samples to average (starting from baseStart) for baseline + calculation. + merge_with: xarray Dataset + If provided, the resulting Dataset will be merged with this + one. Returns ------- @@ -63,7 +82,8 @@ def extract_pes_spectra(proposal, runNB, mnemonic, >>> import toolbox_scs as tb >>> import toolbox_scs.detectors as tbdet >>> proposal, runNB = 900447, 12 - >>> pes = tbdet.get_pes_tof(proposal, runNB, 'PES_2Araw') + >>> pes = tbdet.get_pes_tof(proposal, runNB, 'PES_2Araw', + >>> start=2557, origin=76) """ run = ed.open_run(proposal, runNB) all_mnemonics = mnemonics_for_run(run) @@ -106,6 +126,11 @@ def extract_pes_spectra(proposal, runNB, mnemonic, coords={'trainId': kd.train_id_coordinates(), 'sa3_pId': pulse_ids[:npulses_trace], 'sampleId': np.arange(period)}) + if subtract_baseline: + if baseStart is None: + baseStart = 0 + spectra = spectra - spectra.isel( + sampleId=slice(baseStart, baseStart + baseWidth)).mean('sampleId') if width is None: width = PULSE_PERIOD spectra = spectra.isel(sampleId=slice(0, width), drop=True) @@ -122,149 +147,27 @@ def extract_pes_spectra(proposal, runNB, mnemonic, - origin)/sample_rate * 1e9 spectra = spectra.assign_coords(time_ns=('sampleId', time_ns)) spectra.attrs['origin'] = origin + spectra = spectra.rename(mnemonic.replace('raw', 'spectrum')) + if merge_with is not None: + return merge_with.merge(spectra.to_dataset(promote_attrs=True), + join='inner') return spectra.rename(mnemonic.replace('raw', 'spectrum')) -def get_pes_tof(run, mnemonics=None, merge_with=None, - start=31390, width=300, origin=None, width_ns=None, - subtract_baseline=True, - baseStart=None, baseWidth=80, - sample_rate=2e9): - """ - Extracts time-of-flight spectra from raw digitizer traces. The - tracesvare either loaded via ToolBox mnemonics or those in the - optionally provided merge_with dataset. The spectra are aligned - by pulse Id using the SASE 3 bunch pattern, and have time coordinates - in nanoseconds. - - Parameters - ---------- - run: extra_data.DataCollection - DataCollection containing the digitizer data - mnemonics: str or list of str - mnemonics for PES, e.g. "PES_W_raw" or ["PES_W_raw", "PES_ENE_raw"]. - If None and no merge_with dataset is provided, defaults to "PES_W_raw". - merge_with: xarray Dataset - If provided, the resulting Dataset will be merged with this - one. The PES variables of merge_with (if any) will also be - computed and merged. - start: int - starting sample of the first spectrum in the raw trace. - width: int - number of samples per spectra. - origin: int - sample of the raw trace that corresponds to time-of-flight origin. - If None, origin is equal to start. - width_ns: float - time window for one spectrum. If None, the time window is defined by - width / sample rate. - subtract_baseline: bool - If True, subtract baseline defined by baseStart and baseWidth to each - spectrum. - baseStart: int - starting sample of the baseline. - baseWidth: int - number of samples to average (starting from baseStart) for baseline - calculation. - sample_rate: float - sample rate of the digitizer. - - Returns - ------- - pes: xarray Dataset - Dataset containing the PES time-of-flight spectra (e.g. "PES_W_tof"), - merged with optionally provided merg_with dataset. - - Example - ------- - >>> import toolbox_scs as tb - >>> import toolbox_scs.detectors as tbdet - >>> run, ds = tb.load(2927, 100, "PES_W_raw") - >>> pes = tbdet.get_pes_tof(run, merge_with=ds) - """ - def to_processed_name(name): - return name.replace('raw', 'tof') - to_process = mnemonics_to_process(mnemonics, merge_with, - 'PES', to_processed_name) - run_mnemonics = mnemonics_for_run(run) - # check if bunch pattern table exists - if bool(merge_with) and 'bunchPatternTable' in merge_with: - bpt = merge_with['bunchPatternTable'] - elif 'bunchPatternTable' in run_mnemonics: - bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values()) - elif 'bunchPatternTable_SA3' in run_mnemonics: - bpt = run.get_array(*run_mnemonics['bunchPatternTable_SA3'].values()) - else: - bpt = None - - mask = is_sase_3(bpt).assign_coords({'pulse_slot': np.arange(2700)}) - mask_on = mask.where(mask, drop=True) - npulses = mask.sum(dim='pulse_slot')[0].values - if npulses > 1: - period = mask_on['pulse_slot'].diff(dim='pulse_slot')[0].values - else: - period = 0 - if origin is None: - origin = start - if baseStart is None: - baseStart = start - if width_ns is not None: - width = int(sample_rate * width_ns * 1e-9) - time_ns = 1e9 * (np.arange(start, start + width) - origin) / sample_rate - - ds = xr.Dataset() - for m in to_process: - if bool(merge_with) and m in merge_with: - arr = merge_with[m] - else: - arr = run.get_array(*run_mnemonics[m].values(), name=m) - if arr.sizes['PESsampleId'] < npulses*period*440 + start + width: - log.warning('Not all pulses were recorded. The number of samples ' - f'on the digitizer {arr.sizes["PESsampleId"]} is not ' - f'enough to cover the {npulses} spectra. Missing ' - 'spectra will be filled with NaNs.') - spectra = [] - for p in range(npulses): - begin = p*period*440 + start - end = begin + width - if end > arr.sizes['PESsampleId']: - break - pes = arr.isel(PESsampleId=slice(begin, end)) - if subtract_baseline: - baseBegin = p*period*440 + baseStart - baseEnd = baseBegin + baseWidth - bl = arr.isel( - PESsampleId=slice(baseBegin, baseEnd)).mean(dim='PESsampleId') - pes = pes - bl - spectra.append(pes) - spectra = xr.concat(spectra, - dim='sa3_pId').rename(m.replace('raw', 'tof')) - ds = ds.merge(spectra) - if len(ds.variables) > 0: - ds = ds.assign_coords( - {'sa3_pId': mask_on['pulse_slot'][:ds.sizes['sa3_pId']].values}) - ds = ds.rename({'PESsampleId': 'time_ns'}) - ds = ds.assign_coords({'time_ns': time_ns}) - if bool(merge_with): - ds = merge_with.drop(to_process, - errors='ignore').merge(ds, join='left') - - return ds - - def get_pes_params(run, channel=None): """ Extract PES parameters for a given extra_data DataCollection. - Parameters are gas, binding energy, voltages of the MPOD. + Parameters are gas, binding energy, retardation voltages or all + voltages of the MPOD. Parameters ---------- run: extra_data.DataCollection DataCollection containing the digitizer data channel: str - Channel name, e.g. '2A'. If None, or if the channel is not - found in the data, the retardation voltage for all channels is - retrieved. + Channel name or PES mnemonic, e.g. '2A' or 'PES_1Craw'. + If None, or if the channel is not found in the data, + the retardation voltage for all channels is retrieved. Returns ------- params: dict @@ -289,6 +192,8 @@ def get_pes_params(run, channel=None): channel = [f'{a//4 + 1}{b}' for a, b in enumerate( ['A', 'B', 'C', 'D']*4)] else: + if 'raw' in channel: + channel = channel.split('raw')[0].split('_')[1] channel = [channel] for c in channel: rv = get_pes_rv(run, c, mpod_mapper) @@ -342,3 +247,108 @@ def get_pes_voltages(run, device='SA3_XTD10_PES/MDL/DAQ_MPOD'): if len(a.findall(k)) == 1: voltages[k.split('.')[0]] = da[device][k] return voltages + + +def calculate_average(run, name, mnemo): + return run[mnemo['source'], mnemo['key']].xarray( + name=name, extra_dims=mnemo['dim']).mean('trainId') + + +def save_pes_avg_traces(proposal, runNB, channels=None, + subdir='usr/processed_runs'): + ''' + Save average traces of PES into an h5 file. + + Parameters + ---------- + proposal:int + The proposal number. + runNB: int + The run number. + channels: str or list + The PES channels or mnemonics, e.g. '2A', ['2A', '3C'], + ['PES_1Araw', 'PES_4Draw', '3B'] + subdir: str + subdirectory. The data is stored in + <proposal path>/<subdir>/r{runNB:04d}/f'r{runNB:04d}-pes-data.h5' + + Output + ------ + xarray Dataset saved in a h5 file containing the PES average traces. + ''' + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, subdir + f'/r{runNB:04d}/') + Path(path).mkdir(parents=True, exist_ok=True) + fname = path + f'r{runNB:04d}-pes-data.h5' + if channels is None: + channels = [f'{a//4 + 1}{b}' for a, b in enumerate( + ['A', 'B', 'C', 'D']*4)] + if isinstance(channels, str): + channels = [channels] + run = ed.open_run(proposal, runNB, parallelize=False) + all_mnemos = mnemonics_for_run(run) + # use multiprocessing.Pool + args = [] + for c in channels: + m = c + if 'raw' not in c: + m = f'PES_{c}raw' + if m not in all_mnemos: + continue + args.append([run, m.replace('raw', 'avg'), all_mnemos[m]]) + if len(args) == 0: + log.warning('No pes average trace to save. Skipping') + return + with Pool(len(args)) as pool: + avg_traces = pool.starmap(calculate_average, args) + avg_traces = xr.merge(avg_traces) + ds = xr.Dataset() + if os.path.isfile(fname): + ds = xr.load_dataset(fname) + ds = ds.drop_vars(channels, errors='ignore') + ds = ds.merge(avg_traces) + ds.to_netcdf(fname, format='NETCDF4') + return + + +def load_pes_avg_traces(proposal, runNB, channels=None, + subdir='usr/processed_runs'): + ''' + Load existing PES average traces. + + Parameters + ---------- + proposal:int + The proposal number. + runNB: int + The run number. + channels: str or list + The PES channels or mnemonics, e.g. '2A', ['2A', '3C'], + ['PES_1Araw', 'PES_4Draw', '3B'] + subdir: str + subdirectory. The data is stored in + <proposal path>/<subdir>/r{runNB:04d}/f'r{runNB:04d}-pes-data.h5' + + Output + ------ + ds: xarray Dataset + dataset containing the PES average traces. + ''' + root = find_proposal(f'p{proposal:06d}') + path = os.path.join(root, subdir + f'/r{runNB:04d}/') + fname = path + f'r{runNB:04d}-pes-data.h5' + if channels is None: + channels = [f'PES_{a//4 + 1}{b}avg' for a, b in + enumerate(['A', 'B', 'C', 'D']*4)] + if isinstance(channels, str): + channels = [channels] + for i, c in enumerate(channels): + if 'PES_' not in c and 'avg' not in c: + channels[i] = f'PES_{c}avg' + if os.path.isfile(fname): + ds = xr.load_dataset(fname) + channels = [c for c in channels if c in ds] + ds = ds[channels] + return ds + else: + log.warning(f'{fname} is not a valid file.') -- GitLab