Skip to content
Snippets Groups Projects
Commit ad999302 authored by Laurent Mercadier's avatar Laurent Mercadier
Browse files

Simplify digitizer functions and pulse ID coordinates assignment

parent 5b6da8a7
No related branches found
No related tags found
1 merge request!100Simplify digitizer functions and pulse ID coordinates assignment for XGM, digitizers
......@@ -3,6 +3,8 @@ from .xgm import (
autoFindFastAdcPeaks)
from .tim import (
load_TIM,)
from .digitizers import(
get_peaks, get_tim_peaks, get_fast_adc_peaks, find_integ_params)
from .dssc_data import (
save_xarray, load_xarray, get_data_formatted, save_attributes_h5)
from .dssc_misc import (
......@@ -20,6 +22,10 @@ __all__ = (
"cleanXGMdata",
"load_TIM",
"matchXgmTimPulseId",
"get_peaks",
"get_tim_peaks",
"get_fast_adc_peaks",
"find_integ_params",
"load_dssc_info",
"create_dssc_bins",
"calc_xgm_frame_indices",
......@@ -60,6 +66,7 @@ clean_ns = [
'FastCCD',
'tim',
'xgm',
'digitizers'
]
......
""" Digitizers related sub-routines
Copyright (2021) SCS Team.
(contributions preferrably comply with pep8 code structure
guidelines.)
"""
import logging
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import extra_data as ed
from ..constants import mnemonics as _mnemonics
from ..misc.bunch_pattern_external import is_sase_1, is_sase_3, is_ppl
log = logging.getLogger(__name__)
def peaks_from_raw_trace(traces,
intstart,
intstop,
bkgstart,
bkgstop,
period,
npulses,
extra_dim='pulseId'
):
"""
Computes peaks from raw digitizer traces by trapezoidal integration.
Parameters
----------
traces: xarray DataArray or numpy array containing raw traces. If
numpy array is provided, the second dimension is that of the samples.
intstart: int or list or 1D-numpy array
trace index of integration start. If 1d array, each value is the start
of one peak. The period and npulses parameters are ignored.
intstop: int
trace index of integration stop
bkgstart: int
trace index of background start
bkgstop: int
trace index of background stop
period: int
number of samples between two peaks
npulses: int
number of pulses
extra_dim: str
Name given to the dimension along the peaks
Returns
-------
xarray DataArray
"""
assert len(traces.shape)==2
if type(traces) is xr.DataArray:
ntid = traces.sizes['trainId']
coords = traces.coords
traces = traces.values
if traces.shape[0] != ntid:
traces = traces.T
else:
coords = None
if hasattr(intstart, '__len__'):
intstart = np.array(intstart)
pulses = intstart - intstart[0]
intstart = intstart[0]
else:
pulses = range(0, npulses*period, period)
results = xr.DataArray(np.empty((traces.shape[0], len(pulses))),
coords=coords,
dims=['trainId', extra_dim])
for i,p in enumerate(pulses):
a = intstart + p
b = intstop + p
bkga = bkgstart + p
bkgb = bkgstop + p
if b > traces.shape[1]:
break
bg = np.outer(np.median(traces[:,bkga:bkgb], axis=1),
np.ones(b-a))
results[:,i] = np.trapz(traces[:,a:b] - bg, axis=1)
return results
def get_peaks(proposal,
runNB,
data=None,
source=None,
key=None,
digitizer='ADQ412',
useRaw=True,
autoFind=True,
integParams=None,
bunchPattern='sase3',
extra_dim=None,
indices=None,
display=False
):
"""
Extract peaks from digitizer data.
Parameters
----------
proposal: int or str
proposal number
runNB: int or str
run number
data: xarray DataArray or str
array containing the raw traces or peak-integrated values from the
digitizer. If str, must be one of the ToolBox mnemonics. If None,
the data is loaded via the source and key arguments.
source: str
Name of digitizer source, e.g. 'SCS_UTC1_ADQ/ADC/1:network'. Only
required if data is a DataArray or None.
key: str
Key for digitizer data, e.g. 'digitizers.channel_1_A.raw.samples'.
Only required if data a DataArray or is None.
digitizer: string
name of digitizer, e.g. 'FastADC' or 'ADQ412'. Used to determine
the sampling rate when useRaw is True.
useRaw: bool
If True, extract peaks from raw traces. If False, uses the APD (or
peaks) data from the digitizer.
autoFind: bool
If True, finds integration parameters by inspecting the average raw
trace. Only valid if useRaw is True.
integParams: dict
dictionnary containing the integration parameters for raw trace
integration: 'intstart', 'intstop', 'bkgstart', 'bkgstop', 'period',
'npulses'. Not used if autoFind is True. All keys are required when
bunch pattern is missing.
bunchPattern: string or dict
match the peaks to the bunch pattern: 'sase3', 'scs_ppl'.
Alternatively, a dict with source, key and pattern can be provided,
e.g. {'source':'SCS_RR_UTC/TSYS/TIMESERVER',
'key':'bunchPatternTable.value', 'pattern':'sase3'}
extra_dim: str
Name given to the dimension along the peaks.
indices: array, slice
indices from the peak-integrated data to retrieve. Only required
when bunch pattern is missing and useRaw is False.
display: bool
displays info if True
Returns
-------
xarray.DataArray
"""
if data is None and (source is None or key is None):
raise ValueError('At least data or source + key arguments ' +
'are required.')
# load data
run = ed.open_run(proposal, runNB)
if data is None:
arr = run.get_array(source, key)
elif type(data) is str:
arr = run.get_array(*_mnemonics[data].values())
source = _mnemonics[data]['source']
key = _mnemonics[data]['key']
else:
arr = data
dim = [d for d in arr.dims if d is not 'trainId'][0]
#check if bunch pattern is provided
bpt = None
if type(bunchPattern) is dict:
bpt = run.get_array(bunchPattern['source'],
bunchPattern['key']).rename({'dim_0':'pulse_slot'})
pattern = bunchPattern['pattern']
elif _mnemonics['bunchPatternTable']['source'] in run.all_sources:
bpt = run.get_array(*_mnemonics['bunchPatternTable'].values())
pattern = bunchPattern
if bpt is not None:
#load mask and extract pulse Id:
if pattern is 'sase3':
mask = is_sase_3(bpt)
extra_dim = 'sa3_pId'
elif pattern is 'sase1':
mask = is_sase_1(bpt)
extra_dim = 'sa1_pId'
elif pattern is 'scs_ppl':
mask = is_ppl(bpt)
extra_dim = 'ol_pId'
else:
extra_dim = 'pulseId'
valid_tid = mask.where(mask.sum(dim='pulse_slot')>0,
drop=True).trainId
mask_on = mask.sel(trainId=valid_tid)
if (mask_on == mask_on[0]).all().values == False:
if display:
print('Pattern changed during the run!')
pid = np.unique(np.where(mask_on)[1])
npulses = len(pid)
extra_coords = pid
mask_final = mask_on.where(mask_on, drop=True).fillna(False)
mask_final = mask_final.astype(bool).rename({'pulse_slot':extra_dim})
if display:
print(f'Bunch pattern: {npulses} pulses for {pattern}.')
if extra_dim is None:
extra_dim = 'pulseId'
# 1. Use peak-integrated data from digitizer
if useRaw is False:
#1.1 No bunch pattern provided
if bpt is None:
if indices is None:
raise TypeError(f'indices argument must be provided '+
'when bunch pattern info is missing.')
return arr.isel({dim:indices}).rename({dim:'pulseId'})
#1.2 Bunch pattern is provided
if npulses == 1:
return arr.sel({dim:0}).expand_dims(extra_dim).T.assign_coords(
{extra_dim:extra_coords})
# verify period used by APD and match it to pid from bunchPattern
if digitizer is 'FastADC':
adc_source = source.split(':')[0]
enable_key = (source.split(':')[1].split('.')[0]
+ '.enablePeakComputation.value')
if run.get_array(adc_source, enable_key)[0] is False:
raise ValueError('The digitizer did not record '+
'peak-integrated data.')
period_key = (source.split(':')[1].split('.')[0] +
'.pulsePeriod.value')
period = run.get_array(adc_source, period_key)[0].values/24
if digitizer is 'ADQ412':
board_source = source.split(':')[0]
board = key[19]
channel = key[21]
channel_to_number = {'A':0, 'B':1, 'C':2, 'D':3}
channel = channel_to_number[channel]
in_del_key = (f'board{board}.apd.channel_{channel}'+
'.initialDelay.value')
initialDelay = run.get_array(board_source, in_del_key)[0].values
up_lim_key = (f'board{board}.apd.channel_{channel}'+
'.upperLimit.value')
upperLimit = run.get_array(board_source, up_lim_key)[0].values
period = (upperLimit - initialDelay)/440
stride = (pid[1] - pid[0]) / period
if period < 1:
print('Warning: the pulse period in digitizer was smaller '+
'than the separation of two pulses at 4.5 MHz.')
stride = 1
if stride < 1:
raise ValueError('Pulse period in digitizer was too large '+
'compared to actual pulse separation. Some pulses '+
'were not recorded.')
stride = int(stride)
if display:
print('period', period, 'stride', stride)
if npulses*stride > arr.sizes[dim]:
raise ValueError('The number of pulses recorded by digitizer '+
f'that correspond to actual {pattern} pulses '+
'is too small.')
peaks = arr.isel({dim:slice(0,npulses*stride,stride)}).rename(
{dim:extra_dim})
peaks = peaks.where(mask_final, drop=True)
return peaks.assign_coords({extra_dim:extra_coords})
# 2. Use raw data from digitizer
#minimum samples between two pulses, according to digitizer type
min_distance = 1
if digitizer is 'FastADC':
min_distance = 24
if digitizer is 'ADQ412':
min_distance = 440
if autoFind:
integParams = find_integ_params(arr.mean(dim='trainId'),
min_distance=min_distance)
if display:
print('auto find peaks result:', integParams)
# 2.1. No bunch pattern provided
if bpt is None:
required_keys=['intstart', 'intstop', 'bkgstart',
'bkgstop', 'period', 'npulses']
if integParams is None or not all(name in integParams
for name in required_keys):
raise TypeError('All keys of integParams argument '+
f'{required_keys} are required when '+
'bunch pattern info is missing.')
print(f'Retrieving {integParams["npulses"]} pulses.')
if extra_dim is None:
extra_dim = 'pulseId'
return peaks_from_raw_trace(arr, integParams['intstart'],
integParams['intstop'],
integParams['bkgstart'],
integParams['bkgstop'],
integParams['period'],
integParams['npulses'],
extra_dim=extra_dim)
# 2.2 Bunch pattern is provided
sample_id = (pid-pid[0])*min_distance
#override auto find parameters
if isinstance(integParams['intstart'], (int, np.integer)):
integParams['intstart'] = sample_id + integParams['intstart']
integParams['period'] = integParams['npulses'] = None
aligned_arr, _ = xr.align(arr, valid_tid)
peaks = peaks_from_raw_trace(aligned_arr, integParams['intstart'],
integParams['intstop'],
integParams['bkgstart'],
integParams['bkgstop'],
integParams['period'],
integParams['npulses'],
extra_dim=extra_dim)
peaks = peaks.where(mask_final, drop=True)
return peaks.assign_coords({extra_dim:extra_coords})
def find_integ_params(trace, min_distance=1, height=None, width=1):
"""
find integration parameters necessary for peak integration of a raw
digitizer trace.
Parameters
----------
trace: numpy array or xarray DataArray
The digitier raw trace used to find peaks
min_distance: int
minimum number of samples between two peaks
height: int
minimum threshold for peak determination
width: int
minimum width of peak
Returns
-------
dict with keys 'intstart', 'intstop', 'bkgstart', 'bkgstop', 'period',
'npulses' and values in number of samples.
"""
if type(trace) is xr.DataArray:
trace = trace.values
bl = np.median(trace)
trace_no_bl = trace - bl
if np.max(trace_no_bl) < np.abs(np.min(trace_no_bl)):
posNeg = 'negative'
trace_no_bl *= -1
trace = bl + trace_no_bl
noise = trace[:100]
noise_ptp = np.max(noise) - np.min(noise)
if height is None:
height = trace_no_bl.max()/20
centers, peaks = find_peaks(trace_no_bl, distance=min_distance,
height=height, width=width)
npulses = len(centers)
if npulses==0:
raise ValueError('Could not automatically find peaks.')
elif npulses==1:
period = 0
else:
period = np.median(np.diff(centers)).astype(int)
intstart = np.round(peaks['left_ips'][0]
- 0.5*np.mean(peaks['widths'])).astype(int)
intstop = np.round(peaks['right_ips'][0]
+ 0.5*np.mean(peaks['widths'])).astype(int)
bkgstop = intstart - int(0.5*np.mean(peaks['widths']))
bkgstart = bkgstop - 100
result = {'intstart':intstart, 'intstop':intstop,
'bkgstart':bkgstart, 'bkgstop':bkgstop,
'period':period, 'npulses':npulses}
return result
def get_tim_peaks(data, bunchPattern='sase3',
integParams=None, keepAllSase=False,
display=False):
"""
Automatically computes TIM peaks from sources in data. Sources
can be raw traces (e.g. "MCP2raw") or peak-integrated data
(e.g. "MCP2apd"). The bunch pattern table is used to assign
the pulse id coordinates.
Parameters
----------
data: xarray Dataset containing TIM data
bunchPattern: str
'sase1' or 'sase3' or 'scs_ppl', bunch pattern
used to extract peaks.
integParams: dict
dictionnary for raw trace integration, e.g.
{'intstart':100, 'intstop':200, 'bkgstart':50,
'bkgstop':99, 'period':24, 'npulses':500}.
If None, integration parameters are computed
automatically.
keepAllSase: bool
Only relevant in case of sase-dedicated trains. If
True, the trains for SASE 1 are kept, else they are
dropped.
display: bool
If True, displays information
Returns
-------
xarray Dataset with all TIM variables substituted by
the peak caclulated values (e.g. "MCP2raw" becomes
"MCP2peaks").
"""
proposal = data.attrs['runFolder'].split('/')[-3]
runNB = data.attrs['runFolder'].split('/')[-1]
peakDict = {}
toRemove = []
autoFind=True
if integParams is not None:
autoFind = False
for c in range(1,5):
key = f'MCP{c}raw'
useRaw = True
if key not in data:
key = f'MCP{c}apd'
useRaw = False
if key not in data:
continue
if display:
print(f'Retrieving TIM peaks from {key}...')
mnemonic = _mnemonics[key]
peaks = get_peaks(proposal, runNB, data[key],
source=mnemonic['source'],
key=mnemonic['key'],
digitizer='ADQ412',
useRaw=useRaw,
autoFind=autoFind,
integParams=integParams,
bunchPattern=bunchPattern,
display=display)
peakDict[f'MCP{c}peaks'] = peaks
toRemove.append(key)
ds = data.drop(toRemove)
join = 'outer' if keepAllSase else 'inner'
ds = ds.merge(peakDict, join=join)
return ds
def get_fast_adc_peaks(data, bunchPattern='scs_ppl',
integParams=None, keepAllSase=False,
display=False):
"""
Automatically computes Fast ADC peaks from sources in data.
Sources can be raw traces (e.g. "FastADC2raw") or peak-
integrated data (e.g. "FastADC2peaks"). The bunch pattern
table is used to assign the pulse id coordinates.
Parameters
----------
data: xarray Dataset containing TIM data
bunchPattern: str
'sase1' or 'sase3' or 'scs_ppl', bunch pattern
used to extract peaks.
integParams: dict
dictionnary for raw trace integration, e.g.
{'intstart':100, 'intstop':200, 'bkgstart':50,
'bkgstop':99, 'period':24, 'npulses':500}.
If None, integration parameters are computed
automatically.
keepAllSase: bool
Only relevant in case of sase-dedicated trains. If
True, the trains for SASE 1 are kept, else they are
dropped.
display: bool
If True, displays information
Returns
-------
xarray Dataset with all Fast ADC variables substituted by
the peak caclulated values (e.g. "FastADC2raw" becomes
"FastADC2peaks").
"""
proposal = data.attrs['runFolder'].split('/')[-3]
runNB = data.attrs['runFolder'].split('/')[-1]
peakDict = {}
toRemove = []
autoFind=True
if integParams is not None:
autoFind = False
for c in range(1,10):
key = f'FastADC{c}raw'
useRaw = True
if key not in data:
key = f'FastADC{c}peaks'
useRaw = False
if key not in data:
continue
if display:
print(f'Retrieving Fast ADC peaks from {key}...')
mnemonic = _mnemonics[key]
peaks = get_peaks(proposal, runNB, data[key],
source=mnemonic['source'],
key=mnemonic['key'],
digitizer='FastADC',
useRaw=useRaw,
autoFind=autoFind,
integParams=integParams,
bunchPattern=bunchPattern,
display=display)
peakDict[f'FastADC{c}peaks'] = peaks
toRemove.append(key)
ds = data.drop(toRemove)
join = 'outer' if keepAllSase else 'inner'
ds = ds.merge(peakDict, join=join)
return ds
......@@ -15,6 +15,7 @@ from scipy.signal import find_peaks
import extra_data as ed
from ..constants import mnemonics as _mnemonics
from ..misc.bunch_pattern_external import is_sase_1, is_sase_3, is_ppl
log = logging.getLogger(__name__)
......@@ -54,7 +55,7 @@ def load_xgm(run, xgm_mnemonic='SCS_SA3'):
return xgm
def cleanXGMdata(data, npulses=None, sase3First=True):
def cleanXGMdata(data, npulses=None, sase3First=True, keepAllSase=False):
''' Cleans the XGM data arrays obtained from load() function.
The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0
when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses.
......@@ -62,93 +63,140 @@ def cleanXGMdata(data, npulses=None, sase3First=True):
the function selectSASEinXGM is used to extract sase-resolved pulses.
Inputs:
data: xarray Dataset containing XGM TD arrays.
npulses: number of pulses, needed if pulse pattern not available.
npulses: number of SASE 3 pulses, needed if pulse pattern not available.
sase3First: bool, needed if pulse pattern not available.
Output:
xarray Dataset containing sase- and pulse-resolved XGM data, with
dimension names 'sa1_pId' and 'sa3_pId'
'''
dropList = []
mergeList = []
proposal = data.attrs['runFolder'].split('/')[-3]
runNB = data.attrs['runFolder'].split('/')[-1]
run = ed.open_run(proposal, runNB)
load_sa1 = True
if 'sase3' in data:
if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0,
drop=True) == 0):
print('Dedicated trains, skip loading SASE 1 data.')
#check bunch pattern table
if _mnemonics['bunchPatternTable']['source'] in run.all_sources:
bpt = run.get_array(*_mnemonics['bunchPatternTable'].values())
mask_sa1 = is_sase_1(bpt).sel(trainId=data.trainId)
mask_sa3 = is_sase_3(bpt).sel(trainId=data.trainId)
valid_tid_sa1 = mask_sa1.where(mask_sa1.sum(dim='pulse_slot')>0,
drop=True).trainId
valid_tid_sa3 = mask_sa3.where(mask_sa3.sum(dim='pulse_slot')>0,
drop=True).trainId
if valid_tid_sa1.size==0:
load_sa1 = False
npulses_sa1 = None
if np.intersect1d(valid_tid_sa1, valid_tid_sa3,
assume_unique=True).size==0:
load_sa1 = keepAllSase
if keepAllSase:
print('Dedicated trains, loading both '+
'SASE1 and SASE 3.')
else:
print('Dedicated trains, only loading SASE 3.')
isBunchPattern = True
else:
print('Missing bunch pattern info!')
if npulses is None:
raise TypeError('npulses argument is required when bunch pattern ' +
'info is missing.')
#pulse-resolved signals from XGMs
keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1",
"XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma",
"SCS_XGM", "SCS_SA3", "SCS_SA1",
"SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
bpt = None
isBunchPattern = False
dropList = []
mergeList = []
keys_all_sase = ["XTD10_XGM", "XTD10_XGM_sigma",
"SCS_XGM", "SCS_XGM_sigma"]
keys_all_sase = list(set(keys_all_sase) & set(data.keys()))
keys_dedicated = ["XTD10_SA3", "XTD10_SA1",
"XTD10_SA3_sigma", "XTD10_SA1_sigma",
"SCS_SA3", "SCS_SA1",
"SCS_SA3_sigma", "SCS_SA1_sigma"]
keys_dedicated = list(set(keys_dedicated) & set(data.keys()))
for whichXgm in ['SCS', 'XTD10']:
load_sa1 = True
if (f"{whichXgm}_SA3" not in data and f"{whichXgm}_XGM" in data):
#no SASE-resolved arrays available
if not 'sase3' in data:
npulses_xgm = data[f'{whichXgm}_XGM'].where(data[f'{whichXgm}_XGM'], drop=True).shape[1]
npulses_sa1 = npulses_xgm - npulses
if npulses_sa1==0:
load_sa1 = False
if npulses_sa1 < 0:
raise ValueError(f'npulses = {npulses} is larger than the total number'
+f' of pulses per train = {npulses_xgm}')
sa3 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase3', npulses=npulses,
sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3")
mergeList.append(sa3)
if f"{whichXgm}_XGM_sigma" in data:
sa3_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase3', npulses=npulses,
sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'}).rename(f"{whichXgm}_SA3_sigma")
mergeList.append(sa3_sigma)
dropList.append(f'{whichXgm}_XGM_sigma')
if load_sa1:
sa1 = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM', sase='sase1',
npulses=npulses_sa1, sase3First=sase3First).rename(
{'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1")
mergeList.append(sa1)
if f"{whichXgm}_XGM_sigma" in data:
sa1_sigma = selectSASEinXGM(data, xgm=f'{whichXgm}_XGM_sigma', sase='sase1', npulses=npulses_sa1,
sase3First=sase3First).rename({'XGMbunchId':'sa1_pId'}).rename(f"{whichXgm}_SA1_sigma")
mergeList.append(sa1_sigma)
dropList.append(f'{whichXgm}_XGM')
keys.remove(f'{whichXgm}_XGM')
for key in keys:
if key not in data:
continue
#split non-dedicated sase data in SASE 1 and SASE 3 arrays
npulses_sa1 = None
for key in keys_all_sase:
if isBunchPattern is False:
if npulses is None:
raise TypeError('npulses argument is required '
'when retrieving non-dedicated XGM arrays '+
'without bunch pattern info. Consider using '+
'SASE 1 or SASE 3 dedicated data arrays.')
npulses_xgm = data[key].where(data[key]!=1.,
drop=True).sizes['XGMbunchId']
npulses_sa1 = npulses_xgm - npulses
if npulses_sa1 < 0:
raise ValueError('Required number of pulses is too large. '+
'The XGM array contains only {npulses_xgm} '+
'pulses but npulses={npulses}.')
if npulses_sa1 == 0:
load_sa1 = False
if load_sa1:
sa1 = selectSASEinXGM(data, xgm=key, sase='sase1',
bpt=bpt,npulses=npulses_sa1,
sase3First=sase3First).rename(
{'XGMbunchId':'sa1_pId'})
sa1 = sa1.rename(key.replace('XGM', 'SA1'))
mergeList.append(sa1)
sa3 = selectSASEinXGM(data, xgm=key, sase='sase3',
bpt=bpt, npulses=npulses,
sase3First=sase3First).rename(
{'XGMbunchId':'sa3_pId'})
sa3 = sa3.rename(key.replace('XGM', 'SA3'))
mergeList.append(sa3)
dropList.append(key)
#dedicated sase data
for key in keys_dedicated:
if "sa3" in key.lower():
sase = 'sa3_'
elif "sa1" in key.lower():
else:
sase = 'sa1_'
if not load_sa1:
if load_sa1 is False:
dropList.append(key)
continue
else:
dropList.append(key)
continue
res = data[key].where(data[key] != 1.0, drop=True).rename(
{'XGMbunchId':'{}pId'.format(sase)}).rename(key)
res = res.assign_coords(
{f'{sase}pId':np.arange(res[f'{sase}pId'].shape[0])})
{'XGMbunchId':'{}pId'.format(sase)}).rename(key)
dropList.append(key)
mergeList.append(res)
mergeList.append(data.drop(dropList))
subset = xr.merge(mergeList, join='inner')
ds = xr.merge(mergeList, join='outer')
#assign pulse Id coordinates if bunch pattern is provided
if isBunchPattern:
for sase in ['sa1_pId', 'sa3_pId']:
if sase is 'sa1_pId':
mask = mask_sa1
sase_keys = [key for key in ds if "SA1" in key]
else:
mask = mask_sa3
sase_keys = [key for key in ds if "SA3" in key]
mask_flat = mask.values.flatten()
mask_flat_arg = np.argwhere(mask_flat)
for key in sase_keys:
xgm_flat = np.hstack((ds[key].fillna(1.),
np.ones((ds[key].sizes['trainId'],
2700-ds[key].sizes[sase])))).flatten()
xgm_flat_arg = np.argwhere(xgm_flat!=1.0)
if(xgm_flat_arg.shape != mask_flat_arg.shape):
print('The XGM data does not match the bunch pattern! '+
'Cannot assign pulse id coordinates.')
break
new_xgm_flat = np.ones(xgm_flat.shape)
new_xgm_flat[mask_flat_arg] = xgm_flat[xgm_flat_arg]
new_xgm = new_xgm_flat.reshape((ds[key].sizes['trainId'], 2700))
new_xgm = xr.DataArray(new_xgm, dims=['trainId', sase],
coords={'trainId':ds[key].trainId,
sase:np.arange(2700)},
name=ds[key])
new_xgm = new_xgm.where(new_xgm!=1, drop=True)
ds[key] = new_xgm
#copy original dataset attributes
for k in data.attrs.keys():
subset.attrs[k] = data.attrs[k]
return subset
ds.attrs[k] = data.attrs[k]
return ds
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None):
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM',
bpt = None, sase3First=True, npulses=None):
''' Given an array containing both SASE1 and SASE3 data, extracts SASE1-
or SASE3-only XGM data. The function tracks the changes of bunch patterns
in sase 1 and sase 3 and applies a mask to the XGM array to extract the
......@@ -158,6 +206,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
data: xarray Dataset containing xgm data
sase: key of sase to select: {'sase1', 'sase3'}
xgm: key of xgm to select: {'XTD10_XGM[_sigma]', 'SCS_XGM[_sigma]'}
bpt: xarray DataArray or numpy array: bunch pattern table
sase3First: bool, optional. Used in case no bunch pattern was recorded
npulses: int, optional. Required in case no bunch pattern was recorded.
......@@ -168,7 +217,7 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
are filled with NaNs.
'''
#1. case where bunch pattern is missing:
if sase not in data:
if bpt is None:
print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4])
+'SASE {} pulses come first.'.format('3' if sase3First else '1'))
#in older version of DAQ, non-data numbers were filled with 0.0.
......@@ -182,54 +231,63 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=
'This is not supported yet.')
else:
start=xgmData.shape[1]-npulses
return xgmData[:,start:start+npulses].assign_coords(XGMbunchId=np.arange(npulses))
ds = xgmData[:,start:start+npulses]
return ds.assign_coords(XGMbunchId=np.arange(npulses))
#2. case where bunch pattern is provided
#2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses
xgm_arr = data[xgm].where(data[xgm] != 1., drop=True)
sa3 = data['sase3'].where(data['sase3'] > 1, drop=True)
sa3_val=np.unique(sa3)
sa3_val = sa3_val[~np.isnan(sa3_val)]
sa1 = data['sase1'].where(data['sase1'] > 1, drop=True)
sa1_val=np.unique(sa1)
sa1_val = sa1_val[~np.isnan(sa1_val)]
sa_all = xr.concat([sa1, sa3], dim='bunchId').rename('sa_all')
sa_all = xr.DataArray(np.sort(sa_all)[:,:xgm_arr['XGMbunchId'].shape[0]],
dims=['trainId', 'bunchId'],
coords={'trainId':data.trainId},
name='sase_all')
if sase=='sase3':
idxListSase = np.unique(sa3)
newName = xgm.split('_')[0] + '_SA3'
mask_sa1 = is_sase_1(bpt)
mask_sa3 = is_sase_3(bpt)
mask_all = xr.ufuncs.logical_or(mask_sa1, mask_sa3)
valid_tid = mask_all.where(mask_all.sum(dim='pulse_slot')>0,
drop=True).trainId
valid_tid = np.intersect1d(valid_tid, data.trainId)
npulses_max = mask_all.sum(dim='pulse_slot').max().values
mask_all = mask_all.sel(trainId=valid_tid)
xgm_arr = data[xgm].sel(trainId=valid_tid).isel(
XGMbunchId=slice(0,npulses_max))
pid_arr = xr.DataArray(np.outer(np.ones(valid_tid.shape),
np.arange(2700)).astype(int),
dims=['trainId', 'pulse_slot'],
coords={'trainId':valid_tid}
)
if sase is 'sase1':
mask = mask_sa1
mask_alt = mask_sa3
name='SA1'
else:
idxListSase = np.unique(sa1)
newName = xgm.split('_')[0] + '_SA1'
#2.2 track the changes of pulse patterns and the indices at which they occured (invAll)
idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
#2.3 define a mask, loop over the different patterns and update the mask for each pattern
mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool),
dims=['trainId', 'XGMbunchId'],
coords={'trainId':data.trainId,
'XGMbunchId':sa_all['bunchId'].values},
name='mask')
big_sase = []
for i,idxXGM in enumerate(idxListAll):
mask.values = np.zeros(mask.shape, dtype=bool)
idxXGM = np.isin(idxXGM, idxListSase)
idxTid = invAll==i
mask[idxTid, idxXGM] = True
sa_arr = xgm_arr.where(mask, drop=True)
if sa_arr.trainId.size > 0:
sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size))
big_sase.append(sa_arr)
if len(big_sase) > 0:
da_sase = xr.concat(big_sase, dim='trainId').rename(newName)
else:
da_sase = xr.DataArray([], dims=['trainId'], name=newName)
return da_sase
mask = mask_sa3
mask_alt = mask_sa1
name='SA3'
pid = pid_arr.where(mask, drop=True).fillna(2700)
pid_alt = pid_arr.where(mask_alt, drop=True).fillna(2700)
pid_all = pid_arr.where(mask_all, drop=True).fillna(2700)
#list of various pulse patterns accross the run
patterns = np.unique(pid_all, axis=0)
subsets = []
#for each pattern, select trains of sase, check if
#alt-sase is present, select pulses of sase and
#append the sub-array into list
for pattern in patterns:
tid = pid.where((pid_all==pattern).all(axis=1),
drop=True).trainId
if tid.size == 0:
continue
p = pid.sel(trainId=tid[0]).values
tid_alt = pid_alt.where((pid_all==pattern).all(axis=1),
drop=True).trainId
if tid_alt.size > 0:
p_alt = pid_alt.sel(trainId=tid[0]).values
p_all = np.sort(np.concatenate((p,p_alt)))
pulses = np.atleast_1d(np.argwhere(
np.in1d(p_all, p)).squeeze())
else:
pulses = slice(np.sum(p<2700))
arr = xgm_arr.sel(trainId=tid).isel(XGMbunchId=pulses)
subsets.append(arr)
#finally, concatenate the sub-arrays with join='outer'
da = xr.concat(subsets, dim='trainId', join='outer')
da = da.rename(da.name.replace('XGM', name))
return da
def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
''' Calculate the relative contribution of SASE 1 or SASE 3 pulses
......@@ -292,7 +350,7 @@ def calibrateXGMs(data, allPulses=False, plot=False, display=False):
print(f'Using fast data averages (slowTrain) for {whichXgm}')
slowTrainData.append(data[f'{whichXgm}_slowTrain'])
else:
mnemo = _mnemonics[f'{whichXgm}_slowTrain']
mnemo = tb.mnemonics[f'{whichXgm}_slowTrain']
if mnemo['key'] in data.attrs['run'].keys_for_source(mnemo['source']):
if display:
print(f'Using fast data averages (slowTrain) for {whichXgm}')
......@@ -451,7 +509,8 @@ def calibrateXGMsFromAllPulses(data, plot=False):
# TIM
def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None, npulses=None):
def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None,
npulses=None):
''' Computes peak integration from raw MCP traces.
Inputs:
......@@ -487,7 +546,8 @@ def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None, n
t_offset = 440 * step
else:
t_offset = 1
results = xr.DataArray(np.zeros((data.trainId.shape[0], npulses)), coords=data[keyraw].coords,
results = xr.DataArray(np.zeros((data.trainId.shape[0], npulses)),
coords=data[keyraw].coords,
dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
for i in range(npulses):
a = intstart + t_offset*i
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment