From 0229c1b9aa0cf887153e999f1629f30bc7fb9235 Mon Sep 17 00:00:00 2001 From: Laurent Mercadier <laurent.mercadier@xfel.eu> Date: Mon, 12 Dec 2022 11:34:29 +0100 Subject: [PATCH] Move load_bpt and get_unique_sase_pId to bunchpattern.py, add npulses_has_changed() and get_sase_pId() functions --- src/toolbox_scs/load.py | 61 +++------- src/toolbox_scs/misc/bunch_pattern.py | 157 +++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 45 deletions(-) diff --git a/src/toolbox_scs/load.py b/src/toolbox_scs/load.py index a5bec21..0ce3c1c 100644 --- a/src/toolbox_scs/load.py +++ b/src/toolbox_scs/load.py @@ -20,6 +20,8 @@ from .constants import mnemonics as _mnemonics from .mnemonics_machinery import mnemonics_for_run from .util.exceptions import ToolBoxValueError import toolbox_scs.detectors as tbdet +from .misc.bunch_pattern import (npulses_has_changed, + get_unique_sase_pId, load_bpt) __all__ = [ 'concatenateRuns', @@ -55,7 +57,7 @@ def load(proposalNB=None, runNB=None, laser_bp=None, ): """ - Load a run and extract the data. Output is an xarray with aligned + Load a run and extract the data. Output is an xarray with aligned trainIds. Parameters @@ -97,9 +99,9 @@ def load(proposalNB=None, runNB=None, 'FastADC3peaks') and aligns the pulse Id according to the fadc_bp bunch pattern. extract_fadc2: bool - If True, extracts the peaks from FastADC2 variables (e.g. 'FastADC2_5raw', - 'FastADC2_3peaks') and aligns the pulse Id according to the fadc2_bp bunch - pattern. + If True, extracts the peaks from FastADC2 variables (e.g. + 'FastADC2_5raw', 'FastADC2_3peaks') and aligns the pulse Id according + to the fadc2_bp bunch pattern. extract_xgm: bool If True, extracts the values from XGM variables (e.g. 'SCS_SA3', 'XTD10_XGM') and aligns the pulse Id with the sase1 / sase3 bunch @@ -153,8 +155,10 @@ def load(proposalNB=None, runNB=None, data_arrays = [] run_mnemonics = mnemonics_for_run(run) # load pulse pattern info only if number of sase 3 pulses changed - sase3, sase3_changed = get_sase_pId(run) - if sase3_changed: + sase3 = None + if npulses_has_changed(run, run_mnemonics=run_mnemonics) is False: + sase3 = get_unique_sase_pId(run, run_mnemonics=run_mnemonics) + else: log.warning('Number of pulses changed during the run. ' 'Loading bunch pattern table.') bpt = load_bpt(run, run_mnemonics=run_mnemonics) @@ -220,7 +224,7 @@ def load(proposalNB=None, runNB=None, data = xr.merge(data_arrays, join='inner') data.attrs['runFolder'] = runFolder - + # backward compatibility with old-defined variables: if extract_tim is not None: extract_adq412 = extract_tim @@ -230,13 +234,14 @@ def load(proposalNB=None, runNB=None, adq412_bp = tim_bp if laser_bp is not None: fadc_bp = laser_bp - + adq412 = [k for k in run_mnemonics if 'MCP' in k and k in data] if extract_adq412 and len(adq412) > 0: - data = tbdet.get_digitizer_peaks(run, mnemonics=adq412, merge_with=data, - bunchPattern=adq412_bp) + data = tbdet.get_digitizer_peaks(run, mnemonics=adq412, + merge_with=data, + bunchPattern=adq412_bp) - fadc = [k for k in run_mnemonics if ('FastADC' in k) + fadc = [k for k in run_mnemonics if ('FastADC' in k) and ('FastADC2_' not in k) and (k in data)] if extract_fadc and len(fadc) > 0: data = tbdet.get_digitizer_peaks(run, mnemonics=fadc, merge_with=data, @@ -251,7 +256,8 @@ def load(proposalNB=None, runNB=None, 'SCS_SA1', 'SCS_SA1_sigma', 'SCS_SA3', 'SCS_SA3_sigma'] xgm = [k for k in xgm if k in data] if extract_xgm and len(xgm) > 0: - data = tbdet.get_xgm(run, mnemonics=xgm, merge_with=data) + data = tbdet.get_xgm(run, mnemonics=xgm, merge_with=data, + sase3=sase3) bam = [k for k in run_mnemonics if 'BAM' in k and k in data] if extract_bam and len(bam) > 0: @@ -492,34 +498,3 @@ def concatenateRuns(runs): for k in orderedRuns[0].attrs.keys(): result.attrs[k] = [run.attrs[k] for run in orderedRuns] return result - - -def load_bpt(run, merge_with=None, run_mnemonics=None): - if run_mnemonics is None: - run_mnemonics = mnemonics_for_run(run) - - for key in ['bunchPatternTable', 'bunchPatternTable_SA3']: - if bool(merge_with) and key in merge_with: - log.debug(f'Using {key} from merge_with dataset.') - return merge_with[key] - if key in run_mnemonics: - bpt = run.get_array(*run_mnemonics[key].values(), - name='bunchPatternTable') - log.debug(f'Loaded {key} from DataCollection.') - return bpt - log.debug('Could not find bunch pattern table.') - return None - - -def get_sase_pId(run, sase='sase3'): - mnemonics = mnemonics_for_run(run) - if sase not in mnemonics: - # bunch pattern not recorded - return [], True - npulse_sase = np.unique(get_array(run, 'npulses_' + sase)) - if len(npulse_sase) == 1: - return np.unique(load_run_values(run)[sase])[1:], False - # number of pulses changed during the run - return np.unique(get_array(run, sase))[1:], True - - diff --git a/src/toolbox_scs/misc/bunch_pattern.py b/src/toolbox_scs/misc/bunch_pattern.py index 90deb34..34ba18c 100644 --- a/src/toolbox_scs/misc/bunch_pattern.py +++ b/src/toolbox_scs/misc/bunch_pattern.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- """ Toolbox for SCS. - Various utilities function to quickly process data measured at the SCS instruments. + Various utilities function to quickly process data + measured at the SCS instruments. Copyright (2019) SCS Team. """ import os +import logging import numpy as np import xarray as xr @@ -15,13 +17,164 @@ from extra_data import RunDirectory # import and hide variable, such that it does not alter namespace. from ..constants import mnemonics as _mnemonics_bp +from ..mnemonics_machinery import mnemonics_for_run +from .bunch_pattern_external import is_pulse_at __all__ = [ 'extractBunchPattern', + 'get_sase_pId', + 'npulses_has_changed', 'pulsePatternInfo', - 'repRate' + 'repRate', ] +log = logging.getLogger(__name__) + + +def npulses_has_changed(run, sase='sase3', run_mnemonics=None): + """ + Checks if the number of pulses has changed during the run for + a specific location `sase` (='sase1', 'sase3', 'scs_ppl' or 'laser') + If the source is not found, returns True. + + Parameters + ---------- + run: extra_data.DataCollection + DataCollection containing the data. + sase: str + The location where to check: {'sase1', 'sase3', 'scs_ppl'} + run_mnemonics: dict + the mnemonics for the run (see `menonics_for_run`) + + Returns + ------- + ret: bool + True if the number of pulses has changed or the source was not + found, False if the number of pulses did not change. + """ + if run_mnemonics is None: + run_mnemonics = mnemonics_for_run(run) + if sase == 'scs_ppl': + sase = 'laser' + if sase not in run_mnemonics: + return True + npulses = run.get_array(*run_mnemonics['npulses_'+sase].values()) + if len(np.unique(npulses)) == 1: + return False + return True + + +def get_unique_sase_pId(run, sase='sase3', run_mnemonics=None): + """ + Assuming that the number of pulses did not change during the run, + returns the pulse Ids as the run value of the sase mnemonic. + + Parameters + ---------- + run: extra_data.DataCollection + DataCollection containing the data. + sase: str + The location where to check: {'sase1', 'sase3', 'scs_ppl'} + run_mnemonics: dict + the mnemonics for the run (see `menonics_for_run`) + + Returns + ------- + pulseIds: np.array + the pulse ids at the specified location. Returns None if the + mnemonic is not in the run. + """ + if run_mnemonics is None: + run_mnemonics = mnemonics_for_run(run) + if sase == 'scs_ppl': + sase = 'laser' + if sase not in run_mnemonics: + # bunch pattern not recorded + return None + npulses = run.get_run_value(run_mnemonics['npulses_'+sase]['source'], + run_mnemonics['npulses_'+sase]['key']) + pulseIds = run.get_run_value(run_mnemonics[sase]['source'], + run_mnemonics[sase]['key'])[:npulses] + return pulseIds + + +def get_sase_pId(run, sase='sase3', run_mnemonics=None, + bpt=None, merge_with=None): + """ + Returns the pulse Ids of the specified `sase` during a run. + If the number of pulses has changed during the run, it loads the + bunch pattern table and extract all pulse Ids used + Parameters + ---------- + run: extra_data.DataCollection + DataCollection containing the data. + sase: str + The location where to check: {'sase1', 'sase3', 'scs_ppl'} + run_mnemonics: dict + the mnemonics for the run (see `menonics_for_run`) + bpt: 2D-array + The bunch pattern table. Used only if the number of pulses + has changed. If None, it is loaded on the fly. + merge_with: xarray.Dataset + dataset that may contain the bunch pattern table to use in + case the number of pulses has changed. If merge_with does + not contain the bunch pattern table, it is loaded and added + as a variable 'bunchPatternTable' to merge_with. + + Returns + ------- + pulseIds: np.array + the pulse ids at the specified location. Returns None if the + mnemonic is not in the run. + """ + if npulses_has_changed(run, sase, run_mnemonics) is False: + return get_unique_sase_pId(run, sase, run_mnemonics) + if bpt is None: + bpt = load_bpt(run, merge_with, run_mnemonics) + if bpt is not None: + mask = is_pulse_at(bpt, sase) + return np.unique(np.nonzero(mask.values)[1]) + return None + + +def load_bpt(run, merge_with=None, run_mnemonics=None): + """ + Load the bunch pattern table. It returns the one contained in + merge_with if possible. Or, it adds it to merge_with once it is + loaded. + + Parameters + ---------- + run: extra_data.DataCollection + DataCollection containing the data. + merge_with: xarray.Dataset + dataset that may contain the bunch pattern table or to which + add the bunch pattern table once loaded. + run_mnemonics: dict + the mnemonics for the run (see `menonics_for_run`) + + Returns + ------- + bpt: xarray.Dataset + the bunch pattern table as specified by the mnemonics + 'bunchPatternTable' + """ + if run_mnemonics is None: + run_mnemonics = mnemonics_for_run(run) + for key in ['bunchPatternTable', 'bunchPatternTable_SA3']: + if merge_with is not None and key in merge_with: + log.debug(f'Using {key} from merge_with dataset.') + return merge_with[key] + if key in run_mnemonics: + bpt = run.get_array(*run_mnemonics[key].values(), + name='bunchPatternTable') + log.debug(f'Loaded {key} from DataCollection.') + if merge_with is not None: + merge_with = merge_with.merge(bpt, join='inner') + return bpt + log.debug('Could not find bunch pattern table.') + return None + def extractBunchPattern(bp_table=None, key='sase3', runDir=None): ''' generate the bunch pattern and number of pulses of a source directly from the -- GitLab