From 28f65b7567bab3ed004ed4120cee8d480a3e6a11 Mon Sep 17 00:00:00 2001 From: Laurent Mercadier <laurent.mercadier@xfel.eu> Date: Tue, 13 Apr 2021 22:53:40 +0200 Subject: [PATCH] Added BAM detectors function, added is_pulse_at in bunch_pattern_external --- src/toolbox_scs/constants.py | 14 +-- src/toolbox_scs/detectors/__init__.py | 3 + src/toolbox_scs/detectors/bam_detectors.py | 86 +++++++++++++++++++ src/toolbox_scs/detectors/digitizers.py | 18 +--- src/toolbox_scs/detectors/load_detectors.py | 11 ++- src/toolbox_scs/misc/__init__.py | 2 +- src/toolbox_scs/misc/bunch_pattern.py | 48 ----------- .../misc/bunch_pattern_external.py | 35 +++++++- 8 files changed, 142 insertions(+), 75 deletions(-) create mode 100644 src/toolbox_scs/detectors/bam_detectors.py diff --git a/src/toolbox_scs/constants.py b/src/toolbox_scs/constants.py index 2b1789d..72eee76 100644 --- a/src/toolbox_scs/constants.py +++ b/src/toolbox_scs/constants.py @@ -30,15 +30,15 @@ mnemonics = { 'dim': ['pulse_slot']}, # Bunch Arrival Monitors - "BAM5": {'source': 'SCS_ILH_LAS/DOOCS/BAM_414_B2:output', - 'key': 'data.lowChargeArrivalTime', - 'dim': ['BAMbunchId']}, - "BAM6": {'source': 'SCS_ILH_LAS/DOOCS/BAM_1932M_TL:output', - 'key': 'data.lowChargeArrivalTime', - 'dim': ['BAMbunchId']}, - "BAM7": {'source': 'SCS_ILH_LAS/DOOCS/BAM_1932S_TL:output', + "BAM414": {'source': 'SCS_ILH_LAS/DOOCS/BAM_414_B2:output', 'key': 'data.lowChargeArrivalTime', 'dim': ['BAMbunchId']}, + "BAM1932M": {'source': 'SCS_ILH_LAS/DOOCS/BAM_1932M_TL:output', + 'key': 'data.lowChargeArrivalTime', + 'dim': ['BAMbunchId']}, + "BAM1932S": {'source': 'SCS_ILH_LAS/DOOCS/BAM_1932S_TL:output', + 'key': 'data.lowChargeArrivalTime', + 'dim': ['BAMbunchId']}, # SA3 "nrj": {'source': 'SA3_XTD10_MONO/MDL/PHOTON_ENERGY', diff --git a/src/toolbox_scs/detectors/__init__.py b/src/toolbox_scs/detectors/__init__.py index c5a95de..716b46b 100644 --- a/src/toolbox_scs/detectors/__init__.py +++ b/src/toolbox_scs/detectors/__init__.py @@ -5,6 +5,7 @@ from .xgm import ( from .digitizers import ( get_peaks, get_tim_peaks, get_laser_peaks, get_digitizer_peaks, check_peak_params) +from .bam_detectors import get_bam from .dssc_data import ( save_xarray, load_xarray, get_data_formatted, save_attributes_h5) from .dssc_misc import ( @@ -27,6 +28,7 @@ __all__ = ( "get_laser_peaks", "get_digitizer_peaks", "check_peak_params", + "get_bam", "save_xarray", "load_xarray", "get_data_formatted", @@ -68,6 +70,7 @@ clean_ns = [ 'xgm', 'digitizers', 'load_detectors' + 'bam_detectors' ] diff --git a/src/toolbox_scs/detectors/bam_detectors.py b/src/toolbox_scs/detectors/bam_detectors.py new file mode 100644 index 0000000..36fef4a --- /dev/null +++ b/src/toolbox_scs/detectors/bam_detectors.py @@ -0,0 +1,86 @@ +""" Beam Arrival Monitor related sub-routines + + Copyright (2021) SCS Team. + + (contributions preferrably comply with pep8 code structure + guidelines.) +""" + +import logging + +import numpy as np +import xarray as xr +import matplotlib.pyplot as plt +from scipy.signal import find_peaks +import re + +from ..constants import mnemonics as _mnemonics +from ..misc.bunch_pattern_external import is_pulse_at +from ..util.exceptions import ToolBoxValueError + + +log = logging.getLogger(__name__) + + +def get_bam(run, key=None, merge_with=None, bunchPattern='sase3'): + """ + blablabla + """ + + # check if bunch pattern table exists + if bool(merge_with) and 'bunchPatternTable' in merge_with: + bpt = merge_with['bunchPatternTable'] + log.debug('Using bpt from merge_with dataset.') + elif _mnemonics['bunchPatternTable']['source'] in run.all_sources: + bpt = run.get_array(*_mnemonics['bunchPatternTable'].values()) + log.debug('Loaded bpt from extra_data run.') + else: + bpt = None + if key is None: + keys = ['BAM1932M'] if merge_with is None else [] + else: + keys = key if isinstance(key, list) else [key] + if bool(merge_with): + bam_keys = [k for k in _mnemonics if 'BAM' in k] + mw_to_process = [k for k in merge_with if k in bam_keys and + len(merge_with[k].coords) == 1] + keys = [k for k in keys if k not in merge_with] + else: + mw_to_process = [] + + ds = xr.Dataset() + roi = np.s_[:5400:2] + if len(keys) > 0: + vals = [] + for k in keys: + val = run.get_array(*_mnemonics[k].values(), roi=roi) + val[_mnemonics[k]['dim'][0]] = np.arange(2700) + vals.append(val) + aligned_vals = xr.align(*vals, join='inner') + ds = dict(zip(keys, aligned_vals)) + ds = xr.Dataset(ds) + if len(mw_to_process) > 0: + vals = [] + for k in mw_to_process: + val = merge_with[k].isel({_mnemonics[k]['dim'][0]: roi}) + val[_mnemonics[k]['dim'][0]] = np.arange(2700) + vals.append(val) + aligned_vals = xr.align(*vals, join='inner') + ds_mw = dict(zip(mw_to_process, aligned_vals)) + ds_mw = xr.Dataset(ds_mw) + ds = ds.merge(ds_mw) + + if bpt is not None and len(ds.variables) > 0: + dim_names = {'sase3': 'sa3_pId', 'sase1': 'sa1_pId', + 'scs_ppl': 'ol_pId'} + mask = is_pulse_at(bpt, bunchPattern) + mask = mask.rename({'pulse_slot': dim_names[bunchPattern]}) + ds = ds.rename({_mnemonics['BAM1932M']['dim'][0]: dim_names[bunchPattern]}) + ds = ds.where(mask, drop=True) + + if bool(merge_with): + result = merge_with.drop(mw_to_process) + result = result.merge(ds, join='inner') + return result + + return ds diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py index 642ace1..946f53d 100644 --- a/src/toolbox_scs/detectors/digitizers.py +++ b/src/toolbox_scs/detectors/digitizers.py @@ -15,7 +15,7 @@ from scipy.signal import find_peaks import re from ..constants import mnemonics as _mnemonics -from ..misc.bunch_pattern_external import is_sase_1, is_sase_3, is_ppl +from ..misc.bunch_pattern_external import is_pulse_at from ..util.exceptions import ToolBoxValueError @@ -127,12 +127,7 @@ def peaks_from_apd(array, params, digitizer, bpt, bunchPattern): npulses_apd = params['npulses'] dim_names = {'sase3': 'sa3_pId', 'sase1': 'sa1_pId', 'scs_ppl': 'ol_pId'} pulse_dim = dim_names[bunchPattern] - if bunchPattern == 'sase3': - mask = is_sase_3(bpt).rename({'pulse_slot': pulse_dim}) - if bunchPattern == 'sase1': - mask = is_sase_1(bpt).rename({'pulse_slot': pulse_dim}) - if bunchPattern == 'scs_ppl': - mask = is_ppl(bpt).rename({'pulse_slot': pulse_dim}) + mask = is_pulse_at(bpt, bunchPattern).rename({'pulse_slot': pulse_dim}) mask = mask.sel(trainId=array.trainId) mask = mask.assign_coords({pulse_dim: np.arange(2700)}) pid = np.sort(np.unique(np.where(mask)[1])) @@ -299,12 +294,7 @@ def get_peaks(run, dim_names = {'sase3': 'sa3_pId', 'sase1': 'sa1_pId', 'scs_ppl': 'ol_pId'} extra_dim = dim_names[pattern] valid_tid = np.intersect1d(arr.trainId, bpt.trainId, assume_unique=True) - if pattern == 'sase3': - mask = is_sase_3(bpt.sel(trainId=valid_tid)) - if pattern == 'sase1': - mask = is_sase_1(bpt.sel(trainId=valid_tid)) - if pattern == 'scs_ppl': - mask = is_ppl(bpt.sel(trainId=valid_tid)) + mask = is_pulse_at(bpt.sel(trainId=valid_tid), pattern) mask = mask.rename({'pulse_slot': extra_dim}) mask = mask.assign_coords({extra_dim: np.arange(2700)}) mask_on = mask.where(mask, drop=True).fillna(False).astype(bool) @@ -648,7 +638,7 @@ def check_peak_params(run, key, raw_trace=None, ntrains=200, params=None, sel = run.select_trains(np.s_[:ntrains]) bp_params = {} bpt = sel.get_array(*_mnemonics['bunchPatternTable'].values()) - mask = is_sase_3(bpt) if bunchPattern == 'sase3' else is_ppl(bpt) + mask = is_pulse_at(bpt, bunchPattern) pid = np.sort(np.unique(np.where(mask)[1])) bp_params['npulses'] = len(pid) if bp_params['npulses'] == 1: diff --git a/src/toolbox_scs/detectors/load_detectors.py b/src/toolbox_scs/detectors/load_detectors.py index 6da74b9..2ad3ed9 100644 --- a/src/toolbox_scs/detectors/load_detectors.py +++ b/src/toolbox_scs/detectors/load_detectors.py @@ -9,6 +9,7 @@ from ..constants import mnemonics as _mnemonics from .digitizers import get_tim_peaks, get_laser_peaks from .xgm import get_xgm +from .bam_detectors import get_bam def get_all_detectors(run, data, tim_bp='sase3', laser_bp='scs_ppl'): @@ -32,10 +33,16 @@ def get_all_detectors(run, data, tim_bp='sase3', laser_bp='scs_ppl'): adc peaks. See get_fast_adc_peaks for details. """ tim = [k for k in _mnemonics if 'MCP' in k and k in data] + ds = get_tim_peaks(run, key=tim, merge_with=data, bunchPattern=tim_bp) + laser = [k for k in _mnemonics if 'FastADC' in k and k in data] + ds = get_laser_peaks(run, key=laser, merge_with=ds, bunchPattern=laser_bp) + xgm = [k for k in _mnemonics if ('_SA3' in k or '_SA1' in k or '_XGM' in k) and k in data] - ds = get_tim_peaks(run, key=tim, merge_with=data, bunchPattern=tim_bp) - ds = get_laser_peaks(run, key=laser, merge_with=ds, bunchPattern=laser_bp) ds = get_xgm(run, key=xgm, merge_with=ds) + + bam = [k for k in _mnemonics if 'BAM' in k and k in data] + ds = get_bam(run, key=bam, merge_with=ds) + return ds diff --git a/src/toolbox_scs/misc/__init__.py b/src/toolbox_scs/misc/__init__.py index 9afb2d9..5def83e 100644 --- a/src/toolbox_scs/misc/__init__.py +++ b/src/toolbox_scs/misc/__init__.py @@ -1,5 +1,5 @@ from .bunch_pattern import (extractBunchPattern, pulsePatternInfo, - repRate, sortBAMdata, + repRate, ) from .bunch_pattern_external import is_sase_3, is_sase_1, is_ppl from .laser_utils import positionToDelay, degToRelPower diff --git a/src/toolbox_scs/misc/bunch_pattern.py b/src/toolbox_scs/misc/bunch_pattern.py index 2edf854..3f5939b 100644 --- a/src/toolbox_scs/misc/bunch_pattern.py +++ b/src/toolbox_scs/misc/bunch_pattern.py @@ -212,51 +212,3 @@ def repRate(data=None, runNB=None, proposalNB=None, key='sase3'): return 0 f = 1/((a[0,1] - a[0,0])*12e-3/54.1666667) return f - -def sortBAMdata(data, key='scs_ppl', sa3Offset=0): - ''' Extracts beam arrival monitor data from the raw arrays 'BAM6', 'BAM7', etc... - according to the bunchPatternTable. The BAM arrays contain 7220 values, which - corresponds to FLASH busrt length of 800 us @ 9 MHz. The bunchPatternTable - only has 2700 values, corresponding to XFEL 600 us burst length @ 4.5 MHz. - Hence, the BAM arrays are truncated to 5400 with a stride of 2 and matched - to the bunchPatternTable. If key is one of the sase, the given dimension name - of the bam arrays is 'sa[sase number]_pId', to match other data (XGM, TIM...). - If key is 'scs_ppl', the dimension is named 'ol_pId' - Inputs: - data: xarray Dataset containing BAM arrays - key: str, ['sase1', 'sase2', 'sase3', 'scs_ppl'] - sa3Offset: int, used if key=='scs_ppl'. Offset in number of pulse_id - between the first OL and FEL pulses. An offset of 40 means that - the first laser pulse comes 40 pulse_id later than the FEL on a - grid of 4.5 MHz. Negative values shift the laser pulse before - the FEL one. - Output: - ndata: xarray Dataset with same keys as input data (but new bam arrays) - ''' - a, b, mask = extractBunchPattern(key=key, runDir=data.attrs['run']) - if key == 'scs_ppl': - a3, b3, mask3 = extractBunchPattern(key='sase3', runDir=data.attrs['run']) - firstSa3_pId = a3.where(b3>0, drop=True)[0,0].values.astype(int) - mask = mask.roll(pulse_slot=firstSa3_pId+sa3Offset) - mask = mask.rename({'pulse_slot':'BAMbunchId'}) - ndata = data - dropList = [] - mergeList = [] - for k in data: - if 'BAM' in k: - dropList.append(k) - bam = data[k].isel(BAMbunchId=slice(0,5400,2)) - bam = bam.where(mask, drop=True) - if 'sase' in key: - name = f'sa{key[4]}_pId' - elif key=='scs_ppl': - name = 'ol_pId' - else: - name = 'bam_pId' - bam = bam.rename({'BAMbunchId':name}) - mergeList.append(bam) - mergeList.append(data.drop(dropList)) - ndata = xr.merge(mergeList, join='inner') - for k in data.attrs.keys(): - ndata.attrs[k] = data.attrs[k] - return ndata diff --git a/src/toolbox_scs/misc/bunch_pattern_external.py b/src/toolbox_scs/misc/bunch_pattern_external.py index 0a5ca3b..993b584 100644 --- a/src/toolbox_scs/misc/bunch_pattern_external.py +++ b/src/toolbox_scs/misc/bunch_pattern_external.py @@ -19,7 +19,7 @@ def _convert_data(bpt_dec): bpt_conv = bpt_dec if type(bpt_dec).__module__ == 'xarray.core.dataarray': - bpt_conv = bpt_dec.where(bpt_dec.values == True, other = 0) + bpt_conv = bpt_dec.where(bpt_dec.values == True, other=0) elif type(bpt_dec).__module__ == 'numpy': bpt_conv = bpt_dec.astype(int) else: @@ -30,13 +30,42 @@ def _convert_data(bpt_dec): return bpt_conv +def is_pulse_at(bpt, loc): + """ + Check for prescence of a pulse at the location provided. + + Parameters + ---------- + bpt : numpy array, xarray DataArray + The bunch pattern data. + loc : str + The location where to check: {'sase1', 'sase3', 'scs_ppl'} + + Returns + ------- + boolean : numpy array, xarray DataArray + true if a pulse is present at *loc*. + """ + if loc == 'sase3': + bpt_dec = ebp.is_sase(bpt, 3) + elif loc == 'sase1': + bpt_dec = ebp.is_sase(bpt, 1) + elif loc == 'scs_ppl': + bpt_dec = ebp.is_laser(bpt, laser=PPL_SCS) + else: + raise ValueError(f'loc argument is {loc}, expected "sase1", ' + + '"sase3" or "scs_ppl"') + + return _convert_data(bpt_dec) + + def is_sase_3(bpt): """ Check for prescence of a SASE3 pulse. Parameters ---------- - data : numpy array, xarray DataArray + bpt : numpy array, xarray DataArray The bunch pattern data. Returns @@ -54,7 +83,7 @@ def is_sase_1(bpt): Parameters ---------- - data : numpy array, xarray DataArray + bpt : numpy array, xarray DataArray The bunch pattern data. Returns -- GitLab