Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • SCS/ToolBox
  • kluyvert/ToolBox
2 results
Show changes
import unittest
import logging
import os
import sys
import argparse
import toolbox_scs as tb
from toolbox_scs.util.exceptions import *
import extra_data as ed
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
suites = {"packaging": (
"test_constant",
),
"load": (
"test_load",
"test_openrun",
"test_openrunpath",
"test_loadbinnedarray",
)
}
class TestToolbox(unittest.TestCase):
@classmethod
def setUpClass(cls):
log_root.info("Start global setup")
cls._mnentry = 'SCS_RR_UTC/TSYS/TIMESERVER'
cls._ed_run = ed.open_run(2212, 235)
log_root.info("Finished global setup, start tests")
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_constant(self):
cls = self.__class__
self.assertEqual(tb.mnemonics['bunchPatternTable']['source'],cls._mnentry)
def test_load(self):
fields = ["SCS_XGM"]
# normal behavior
run_tb = None
proposalNB = 2511
runNB = 176
run_tb, data = tb.load(proposalNB, runNB, fields)
self.assertEqual(data['bunchPatternTable'].values[0, 0], 2113321)
# exception raised
run_tb = None
proposalNB = 2511
runNB = 1766
with self.assertRaises(ToolBoxPathError) as cm:
run_tb, data = tb.load(proposalNB, runNB, fields)
tb_exception = cm.exception
constr_path = f'/gpfs/exfel/exp/SCS/202001/p002511/raw/r{runNB}'
exp_msg = f"Invalid path: {constr_path}. " + \
"The constructed path does not exist."
self.assertEqual(tb_exception.message, exp_msg)
def test_openrun(self):
run, _ = tb.load(2212, 235)
src = 'SCS_DET_DSSC1M-1/DET/0CH0:xtdf'
self.assertTrue(src in run.all_sources)
def test_openrunpath(self):
run = tb.run_by_path(
"/gpfs/exfel/exp/SCS/201901/p002212/raw/r0235")
src = 'SCS_DET_DSSC1M-1/DET/0CH0:xtdf'
self.assertTrue(src in run.all_sources)
def test_loadbinnedarray(self):
cls = self.__class__
# Normal use
mnemonic = 'PP800_PhaseShifter'
data = tb.get_array(cls._ed_run, mnemonic, 0.5)
self.assertTrue = (data)
# unknown mnemonic
mnemonic = 'blabla'
with self.assertRaises(ToolBoxValueError) as cm:
scan_variable = tb.get_array(cls._ed_run, mnemonic, 0.5)
excp = cm.exception
self.assertEqual(excp.value, mnemonic)
def list_suites():
print("\nPossible test suites:\n" + "-" * 79)
for key in suites:
print(key)
print("-" * 79 + "\n")
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestToolbox(test))
return suite
def main(*cliargs):
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
main(*args.run_suites)
import unittest
import logging
import os
import sys
import argparse
from toolbox_scs.util.data_access import (
find_run_dir,
)
from toolbox_scs.util.exceptions import ToolBoxPathError
suites = {"ed-extensions": (
"test_rundir1",
"test_rundir2",
"test_rundir3",
)
}
def list_suites():
print("""\nPossible test suites:\n-------------------------""")
for key in suites:
print(key)
print("-------------------------\n")
class TestDataAccess(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_rundir1(self):
Proposal = 2212
Run = 235
Dir = find_run_dir(Proposal, Run)
self.assertEqual(Dir,
"/gpfs/exfel/exp/SCS/201901/p002212/raw/r0235")
def test_rundir2(self):
Proposal = 23678
Run = 235
with self.assertRaises(Exception) as cm:
find_run_dir(Proposal, Run)
exp = cm.exception
self.assertEqual(str(exp), "Couldn't find proposal dir for 'p023678'")
def test_rundir3(self):
Proposal = 2212
Run = 800
with self.assertRaises(ToolBoxPathError) as cm:
find_run_dir(Proposal, Run)
exp_msg = cm.exception.message
print(exp_msg)
path = f'/gpfs/exfel/exp/SCS/201901/p00{Proposal}/raw/r0{Run}'
err_msg = f"Invalid path: {path}. " \
"The constructed path does not exist."
self.assertEqual(exp_msg, err_msg)
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestDataAccess(test))
return suite
def main(*cliargs):
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
main(*args.run_suites)
class ToolBoxError(Exception):
"""
Parent Toolbox exception. (to be defined)
"""
pass
class ToolBoxPathError(ToolBoxError):
def __init__(self, message = "", path = ""):
self.path = path
self.message = f'Invalid path: {path}. ' + message
class ToolBoxTypeError(ToolBoxError):
def __init__(self, msg = "", dtype = ''):
self.dtype = dtype
self.message = "Unknown data type: " + dtype + " \n" + msg
class ToolBoxValueError(ToolBoxError):
def __init__(self, msg = "", val = None):
self.value = val
self.message = msg + " unknown value: " + str(val)
class ToolBoxFileError(ToolBoxError):
def __init__(self, msg = "", val = ''):
self.value = val
self.message = f"file: {val}, {msg}"
\ No newline at end of file
import os
def get_version():
release_tag = os.popen('git describe --tags').read()
return release_tag.strip("\n")
\ No newline at end of file
# -*- coding: utf-8 -*-
""" Toolbox for SCS.
Various utilities function to quickly process data measured at the SCS instruments.
Copyright (2019) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from scipy.signal import find_peaks
# XGM
def cleanXGMdata(data, npulses=None, sase3First=True):
''' Cleans the XGM data arrays obtained from load() function.
The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0
when there is no pulse. This function sorts the SASE 1 and SASE 3 pulses.
For DAQ runs after April 2019, sase-resolved arrays can be used. For older runs,
the function selectSASEinXGM can be used to extract sase-resolved pulses.
Inputs:
data: xarray Dataset containing XGM TD arrays.
npulses: number of pulses, needed if pulse pattern not available.
sase3First: bool, needed if pulse pattern not available.
Output:
xarray Dataset containing sase- and pulse-resolved XGM data, with
dimension names 'sa1_pId' and 'sa3_pId'
'''
dropList = []
mergeList = []
dedicated = False
if 'sase3' in data:
if np.all(data['npulses_sase1'].where(data['npulses_sase3'] !=0,
drop=True) == 0):
dedicated = True
print('Dedicated trains, skip loading SASE 1 data.')
#pulse-resolved signals from XGMs
keys = ["XTD10_XGM", "XTD10_SA3", "XTD10_SA1",
"XTD10_XGM_sigma", "XTD10_SA3_sigma", "XTD10_SA1_sigma",
"SCS_XGM", "SCS_SA3", "SCS_SA1",
"SCS_XGM_sigma", "SCS_SA3_sigma", "SCS_SA1_sigma"]
if ("SCS_SA3" not in data and "SCS_XGM" in data):
#no SASE-resolved arrays available
sa3 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase3', npulses=npulses,
sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'})
mergeList.append(sa3)
if not dedicated:
sa1 = selectSASEinXGM(data, xgm='SCS_XGM', sase='sase1',
npulses=npulses, sase3First=sase3First).rename(
{'XGMbunchId':'sa1_pId'})
mergeList.append(sa1)
dropList.append('SCS_XGM')
keys.remove('SCS_XGM')
if ("XTD10_SA3" not in data and "XTD10_XGM" in data):
#no SASE-resolved arrays available
sa3 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase3', npulses=npulses,
sase3First=sase3First).rename({'XGMbunchId':'sa3_pId'})
mergeList.append(sa3)
if not dedicated:
sa1 = selectSASEinXGM(data, xgm='XTD10_XGM', sase='sase1',
npulses=npulses, sase3First=sase3First).rename(
{'XGMbunchId':'sa1_pId'})
mergeList.append(sa1)
dropList.append('XTD10_XGM')
keys.remove('XTD10_XGM')
for key in keys:
if key not in data:
continue
if "sa3" in key.lower():
sase = 'sa3_'
elif "sa1" in key.lower():
sase = 'sa1_'
if dedicated:
dropList.append(key)
continue
else:
dropList.append(key)
continue
res = data[key].where(data[key] != 1.0, drop=True).rename(
{'XGMbunchId':'{}pId'.format(sase)}).rename(key)
res = res.assign_coords(
{f'{sase}pId':np.arange(res[f'{sase}pId'].shape[0])})
dropList.append(key)
mergeList.append(res)
mergeList.append(data.drop(dropList))
subset = xr.merge(mergeList, join='inner')
for k in data.attrs.keys():
subset.attrs[k] = data.attrs[k]
return subset
def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM', sase3First=True, npulses=None):
''' Given an array containing both SASE1 and SASE3 data, extracts SASE1-
or SASE3-only XGM data. The function tracks the changes of bunch patterns
in sase 1 and sase 3 and applies a mask to the XGM array to extract the
relevant pulses. This way, all complicated patterns are accounted for.
Inputs:
data: xarray Dataset containing xgm data
sase: key of sase to select: {'sase1', 'sase3'}
xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'}
sase3First: bool, optional. Used in case no bunch pattern was recorded
npulses: int, optional. Required in case no bunch pattern was recorded.
Output:
DataArray that has all trainIds that contain a lasing
train in sase, with dimension equal to the maximum number of pulses of
that sase in the run. The missing values, in case of change of number of pulses,
are filled with NaNs.
'''
#1. case where bunch pattern is missing:
if sase not in data:
print('Missing bunch pattern info!')
if npulses is None:
raise TypeError('npulses argument is required when bunch pattern ' +
'info is missing.')
print('Retrieving {} SASE {} pulses assuming that '.format(npulses, sase[4])
+'SASE {} pulses come first.'.format('3' if sase3First else '1'))
#in older version of DAQ, non-data numbers were filled with 0.0.
xgmData = data[xgm].where(data[xgm]!=0.0, drop=True)
xgmData = xgmData.fillna(0.0).where(xgmData!=1.0, drop=True)
if (sase3First and sase=='sase3') or (not sase3First and sase=='sase1'):
return xgmData[:,:npulses]
else:
if xr.ufuncs.isnan(xgmData).any():
raise Exception('The number of pulses changed during the run. '
'This is not supported yet.')
else:
start=xgmData.shape[1]-npulses
return xgmData[:,start:start+npulses]
#2. case where bunch pattern is provided
#2.1 Merge sase1 and sase3 bunch patterns to get indices of all pulses
xgm_arr = data[xgm].where(data[xgm] != 1., drop=True)
sa3 = data['sase3'].where(data['sase3'] > 1, drop=True)
sa3_val=np.unique(sa3)
sa3_val = sa3_val[~np.isnan(sa3_val)]
sa1 = data['sase1'].where(data['sase1'] > 1, drop=True)
sa1_val=np.unique(sa1)
sa1_val = sa1_val[~np.isnan(sa1_val)]
sa_all = xr.concat([sa1, sa3], dim='bunchId').rename('sa_all')
sa_all = xr.DataArray(np.sort(sa_all)[:,:xgm_arr['XGMbunchId'].shape[0]],
dims=['trainId', 'bunchId'],
coords={'trainId':data.trainId},
name='sase_all')
if sase=='sase3':
idxListSase = np.unique(sa3)
newName = xgm.split('_')[0] + '_SA3'
else:
idxListSase = np.unique(sa1)
newName = xgm.split('_')[0] + '_SA1'
#2.2 track the changes of pulse patterns and the indices at which they occured (invAll)
idxListAll, invAll = np.unique(sa_all.fillna(-1), axis=0, return_inverse=True)
#2.3 define a mask, loop over the different patterns and update the mask for each pattern
mask = xr.DataArray(np.zeros((data.dims['trainId'], sa_all['bunchId'].shape[0]), dtype=bool),
dims=['trainId', 'XGMbunchId'],
coords={'trainId':data.trainId,
'XGMbunchId':sa_all['bunchId'].values},
name='mask')
big_sase = []
for i,idxXGM in enumerate(idxListAll):
mask.values = np.zeros(mask.shape, dtype=bool)
idxXGM = np.isin(idxXGM, idxListSase)
idxTid = invAll==i
mask[idxTid, idxXGM] = True
sa_arr = xgm_arr.where(mask, drop=True)
if sa_arr.trainId.size > 0:
sa_arr = sa_arr.assign_coords(XGMbunchId=np.arange(sa_arr.XGMbunchId.size))
big_sase.append(sa_arr)
if len(big_sase) > 0:
da_sase = xr.concat(big_sase, dim='trainId').rename(newName)
else:
da_sase = xr.DataArray([], dims=['trainId'], name=newName)
return da_sase
def saseContribution(data, sase='sase1', xgm='XTD10_XGM'):
''' Calculate the relative contribution of SASE 1 or SASE 3 pulses
for each train in the run. Supports fresh bunch, dedicated trains
and pulse on demand modes.
Inputs:
data: xarray Dataset containing xgm data
sase: key of sase for which the contribution is computed: {'sase1', 'sase3'}
xgm: key of xgm to select: {'XTD10_XGM', 'SCS_XGM'}
Output:
1D DataArray equal to sum(sase)/sum(sase1+sase3)
'''
xgm_sa1 = selectSASEinXGM(data, 'sase1', xgm=xgm)
xgm_sa3 = selectSASEinXGM(data, 'sase3', xgm=xgm)
#Fill missing train ids with 0
r = xr.align(*[xgm_sa1, xgm_sa3], join='outer', exclude=['XGMbunchId'])
xgm_sa1 = r[0].fillna(0)
xgm_sa3 = r[1].fillna(0)
contrib = xgm_sa1.sum(axis=1)/(xgm_sa1.sum(axis=1) + xgm_sa3.sum(axis=1))
if sase=='sase1':
return contrib
else:
return 1 - contrib
def calibrateXGMs(data, rollingWindow=200, plot=False):
''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM
(read in intensityTD property) to the respective slow ion signal
(photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
If the sase-resolved signal (introduced in May 2019) are recorded, the
calibration is defined as the mean ratio between the photocurrent and
the low-pass slowTrain signal. Otherwise, calibrateXGMsFromAllPulses()
is called.
Inputs:
data: xarray Dataset
rollingWindow: length of running average to calculate E_fast_avg
plot: boolean, plot the calibration output
Output:
factors: numpy ndarray of shape 1 x 2 containing
[XTD10 calibration factor, SCS calibration factor]
'''
XTD10_factor = np.nan
SCS_factor = np.nan
if "XTD10_slowTrain" in data or "SCS_slowTrain" in data:
if "XTD10_slowTrain" in data:
XTD10_factor = np.mean(data.XTD10_photonFlux/data.XTD10_slowTrain)
else:
print('no XTD10 XGM data. Skipping calibration for XTD10 XGM')
if "SCS_slowTrain" in data:
#XTD10_SA3_contrib = data.XTD10_slowTrain_SA3 * data.npulses_sase3 / (
# data.XTD10_slowTrain * (data.npulses_sase3+data.npulses_sase1))
#SCS_SA3_SLOW = data.SCS_photonFlux*(data.npulses_sase3+
# data.npulses_sase1)*XTD10_SA3_contrib/data.npulses_sase3
#SCS_factor = np.mean(SCS_SA3_SLOW/data.SCS_slowTrain_SA3)
SCS_factor = np.mean(data.SCS_photonFlux/data.SCS_slowTrain)
else:
print('no SCS XGM data. Skipping calibration for SCS XGM')
#TODO: plot the results of calibration
return np.array([XTD10_factor, SCS_factor])
else:
return calibrateXGMsFromAllPulses(data, rollingWindow, plot)
def calibrateXGMsFromAllPulses(data, rollingWindow=200, plot=False):
''' Calibrate the fast (pulse-resolved) signals of the XTD10 and SCS XGM
(read in intensityTD property) to the respective slow ion signal
(photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
One has to take into account the possible signal created by SASE1 pulses. In the
tunnel, this signal is usually large enough to be read by the XGM and the relative
contribution C of SASE3 pulses to the overall signal is computed.
In the tunnel, the calibration F is defined as:
F = E_slow / E_fast_avg, where
E_fast_avg is the rolling average (with window rollingWindow) of the fast signal.
In SCS XGM, the signal from SASE1 is usually in the noise, so we calculate the
average over the pulse-resolved signal of SASE3 pulses only and calibrate it to the
slow signal modulated by the SASE3 contribution:
F = (N1+N3) * E_avg * C/(N3 * E_fast_avg_sase3), where N1 and N3 are the number
of pulses in SASE1 and SASE3, E_fast_avg_sase3 is the rolling average (with window
rollingWindow) of the SASE3-only fast signal.
Inputs:
data: xarray Dataset
rollingWindow: length of running average to calculate E_fast_avg
plot: boolean, plot the calibration output
Output:
factors: numpy ndarray of shape 1 x 2 containing
[XTD10 calibration factor, SCS calibration factor]
'''
XTD10_factor = np.nan
SCS_factor = np.nan
noSCS = noXTD10 = False
if 'SCS_XGM' not in data:
print('no SCS XGM data. Skipping calibration for SCS XGM')
noSCS = True
if 'XTD10_XGM' not in data:
print('no XTD10 XGM data. Skipping calibration for XTD10 XGM')
noXTD10 = True
if noSCS and noXTD10:
return np.array([XTD10_factor, SCS_factor])
if not noSCS and noXTD10:
print('XTD10 data is needed to calibrate SCS XGM.')
return np.array([XTD10_factor, SCS_factor])
start = 0
stop = None
npulses = data['npulses_sase3']
ntrains = npulses.shape[0]
# First, in case of change in number of pulses, locate a region where
# the number of pulses is maximum.
if not np.all(npulses == npulses[0]):
print('Warning: Number of pulses per train changed during the run!')
start = np.argmax(npulses.values)
stop = ntrains + np.argmax(npulses.values[::-1]) - 1
if stop - start < rollingWindow:
print('not enough consecutive data points with the largest number of pulses per train')
start += rollingWindow
stop = np.min((ntrains, stop+rollingWindow))
# Calculate SASE3 slow data
sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM')
SA3_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3']
SA1_SLOW = data['XTD10_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*(1-sa3contrib)/data['npulses_sase1']
# Calibrate XTD10 XGM with all signal from SASE1 and SASE3
if not noXTD10:
xgm_avg = selectSASEinXGM(data, 'sase3', 'XTD10_XGM').mean(axis=1)
rolling_sa3_xgm = xgm_avg.rolling(trainId=rollingWindow).mean()
ratio = SA3_SLOW/rolling_sa3_xgm
XTD10_factor = ratio[start:stop].mean().values
print('calibration factor XTD10 XGM: %f'%XTD10_factor)
# Calibrate SCS XGM with SASE3-only contribution
if not noSCS:
SCS_SLOW = data['SCS_photonFlux']*(data['npulses_sase3']+data['npulses_sase1'])*sa3contrib/data['npulses_sase3']
scs_sase3_fast = selectSASEinXGM(data, 'sase3', 'SCS_XGM').mean(axis=1)
meanFast = scs_sase3_fast.rolling(trainId=rollingWindow).mean()
ratio = SCS_SLOW/meanFast
SCS_factor = ratio[start:stop].median().values
print('calibration factor SCS XGM: %f'%SCS_factor)
if plot:
if noSCS ^ noXTD10:
plt.figure(figsize=(8,4))
else:
plt.figure(figsize=(8,8))
plt.subplot(211)
plt.title('E[uJ] = %.2f x IntensityTD' %(XTD10_factor))
plt.plot(SA3_SLOW, label='SA3 slow', color='C1')
plt.plot(rolling_sa3_xgm*XTD10_factor,
label='SA3 fast signal rolling avg', color='C4')
plt.plot(xgm_avg*XTD10_factor, label='SA3 fast signal train avg', alpha=0.2, color='C4')
plt.ylabel('Energy [uJ]')
plt.xlabel('train in run')
plt.legend(loc='upper left', fontsize=10)
plt.twinx()
plt.plot(SA1_SLOW, label='SA1 slow', alpha=0.2, color='C2')
plt.ylabel('SA1 slow signal [uJ]')
plt.legend(loc='lower right', fontsize=10)
plt.subplot(212)
plt.title('E[uJ] = %.2g x HAMP' %SCS_factor)
plt.plot(SCS_SLOW, label='SCS slow', color='C1')
plt.plot(meanFast*SCS_factor, label='SCS HAMP rolling avg', color='C2')
plt.ylabel('Energy [uJ]')
plt.xlabel('train in run')
plt.plot(scs_sase3_fast*SCS_factor, label='SCS HAMP train avg', alpha=0.2, color='C2')
plt.legend(loc='upper left', fontsize=10)
plt.tight_layout()
return np.array([XTD10_factor, SCS_factor])
# TIM
def mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=1, t_offset=None, npulses=None):
''' Computes peak integration from raw MCP traces.
Inputs:
data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
mcp: MCP channel number
t_offset: index separation between two pulses. Needed if bunch
pattern info is not available. If None, checks the pulse
pattern and determine the t_offset assuming mininum pulse
separation of 220 ns and digitizer resolution of 2 GHz.
npulses: number of pulses. If None, takes the maximum number of
pulses according to the bunch pattern (field 'npulses_sase3')
Output:
results: DataArray with dims trainId x max(sase3 pulses)
'''
keyraw = 'MCP{}raw'.format(mcp)
if keyraw not in data:
raise ValueError("Source not found: {}!".format(keyraw))
if npulses is None:
npulses = int(data['npulses_sase3'].max().values)
if t_offset is None:
sa3 = data['sase3'].where(data['sase3']>1)
if npulses > 1:
#Calculate the number of pulses between two lasing pulses (step)
step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
step = int(step[1] - step[0])
#multiply by elementary samples length (220 ns @ 2 GHz = 440)
t_offset = 440 * step
else:
t_offset = 1
results = xr.DataArray(np.zeros((data.trainId.shape[0], npulses)), coords=data[keyraw].coords,
dims=['trainId', 'MCP{}fromRaw'.format(mcp)])
for i in range(npulses):
a = intstart + t_offset*i
b = intstop + t_offset*i
bkga = bkgstart + t_offset*i
bkgb = bkgstop + t_offset*i
if b > data.dims['samplesId']:
break
bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a))
results[:,i] = np.trapz(data[keyraw][:,a:b] - bg, axis=1)
return results
def getTIMapd(data, mcp=1, use_apd=True, intstart=None, intstop=None,
bkgstart=None, bkgstop=None, t_offset=None, npulses=None,
stride=1):
''' Extract peak-integrated data from TIM where pulses are from SASE3 only.
If use_apd is False it calculates integration from raw traces.
The missing values, in case of change of number of pulses, are filled
with NaNs. If no bunch pattern info is available, the function assumes
that SASE 3 comes first and that the number of pulses is fixed in both
SASE 1 and 3.
Inputs:
data: xarray Dataset containing MCP raw traces (e.g. 'MCP1raw')
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
t_offset: number of ADC samples between two pulses
mcp: MCP channel number
npulses: int, optional. Number of pulses to compute. Required if
no bunch pattern info is available.
stride: int, optional. Used to select pulses in the APD array if
no bunch pattern info is available.
Output:
tim: DataArray of shape trainId only for SASE3 pulses x N
with N=max(number of pulses per train)
'''
#1. case where no bunch pattern is available:
if 'sase3' not in data:
print('Missing bunch pattern info!\n')
if npulses is None:
raise TypeError('npulses argument is required when bunch pattern ' +
'info is missing.')
print('Retrieving {} SASE 3 pulses assuming that '.format(npulses) +
'SASE 3 pulses come first.')
if use_apd:
tim = data[f'MCP{mcp}apd'][:,:npulses:stride]
else:
tim = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp,
t_offset=t_offset, npulses=npulses)
return tim
#2. If bunch pattern available, define a mask that corresponds to the SASE 3 pulses
sa3 = data['sase3'].where(data['sase3']>1, drop=True)
sa3 -= sa3[0,0]
#2.1 case where apd is used:
if use_apd:
pulseId = 'apdId'
pulseIdDim = data.dims['apdId']
initialDelay = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.initialDelay.value')[0].values
upperLimit = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.upperLimit.value')[0].values
#440 = samples between two pulses @4.5 MHz with ADQ412 digitizer:
period = int((upperLimit - initialDelay)/440)
#display some warnings if apd parameters do not match pulse pattern:
period_from_bunch_pattern = int(np.nanmin(np.diff(sa3)))
if period > period_from_bunch_pattern:
print(f'Warning: apd parameter was set to record 1 pulse out of {period} @ 4.5 MHz ' +
f'but XFEL delivered 1 pulse out of {period_from_bunch_pattern}.')
maxPulses = data['npulses_sase3'].max().values
if period*pulseIdDim < period_from_bunch_pattern*(maxPulses-1):
print(f'Warning: Number of pulses and/or rep. rate in apd parameters were set ' +
f'too low ({pulseIdDim})to record the {maxPulses} SASE 3 pulses')
peaks = data[f'MCP{mcp}apd']
#2.2 case where integration is performed on raw trace:
else:
pulseId = f'MCP{mcp}fromRaw'
pulseIdDim = int(np.max(sa3).values) + 1
period = int(np.nanmin(np.diff(sa3)))
peaks = mcpPeaks(data, intstart, intstop, bkgstart, bkgstop, mcp=mcp, t_offset=period*440,
npulses=pulseIdDim)
sa3 = sa3/period
#2.3 track the changes of pulse patterns and the indices at which they occured (invAll)
idxList, inv = np.unique(sa3, axis=0, return_inverse=True)
mask = xr.DataArray(np.zeros((data.dims['trainId'], pulseIdDim), dtype=bool),
dims=['trainId', pulseId],
coords={'trainId':data.trainId,
pulseId:np.arange(pulseIdDim)})
mask = mask.sel(trainId=sa3.trainId)
for i,idxApd in enumerate(idxList):
idxApd = idxApd[idxApd>=0].astype(int)
idxTid = inv==i
mask[idxTid, idxApd] = True
peaks = peaks.where(mask, drop=True)
peaks = peaks.assign_coords({pulseId:np.arange(peaks[pulseId].shape[0])})
return peaks
def calibrateTIM(data, rollingWindow=200, mcp=1, plot=False, use_apd=True, intstart=None,
intstop=None, bkgstart=None, bkgstop=None, t_offset=None, npulses_apd=None):
''' Calibrate TIM signal (Peak-integrated signal) to the slow ion signal of SCS_XGM
(photocurrent read by Keithley, channel 'pulseEnergy.photonFlux.value').
The aim is to find F so that E_tim_peak[uJ] = F x TIM_peak. For this, we want to
match the SASE3-only average TIM pulse peak per train (TIM_avg) to the slow XGM
signal E_slow.
Since E_slow is the average energy per pulse over all SASE1 and SASE3
pulses (N1 and N3), we first extract the relative contribution C of the SASE3 pulses
by looking at the pulse-resolved signals of the SA3_XGM in the tunnel.
There, the signal of SASE1 is usually strong enough to be above noise level.
Let TIM_avg be the average of the TIM pulses (SASE3 only).
The calibration factor is then defined as: F = E_slow * C * (N1+N3) / ( N3 * TIM_avg ).
If N3 changes during the run, we locate the indices for which N3 is maximum and define
a window where to apply calibration (indices start/stop).
Warning: the calibration does not include the transmission by the KB mirrors!
Inputs:
data: xarray Dataset
rollingWindow: length of running average to calculate TIM_avg
mcp: MCP channel
plot: boolean. If True, plot calibration results.
use_apd: boolean. If False, the TIM pulse peaks are extract from raw traces using
getTIMapd
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
t_offset: index separation between two pulses
npulses_apd: number of pulses
Output:
F: float, TIM calibration factor.
'''
start = 0
stop = None
npulses = data['npulses_sase3']
ntrains = npulses.shape[0]
if not np.all(npulses == npulses[0]):
start = np.argmax(npulses.values)
stop = ntrains + np.argmax(npulses.values[::-1]) - 1
if stop - start < rollingWindow:
print('not enough consecutive data points with the largest number of pulses per train')
start += rollingWindow
stop = np.min((ntrains, stop+rollingWindow))
filteredTIM = getTIMapd(data, mcp, use_apd, intstart, intstop, bkgstart, bkgstop, t_offset, npulses_apd)
sa3contrib = saseContribution(data, 'sase3', 'XTD10_XGM')
avgFast = filteredTIM.mean(axis=1).rolling(trainId=rollingWindow).mean()
ratio = ((data['npulses_sase3']+data['npulses_sase1']) *
data['SCS_photonFlux'] * sa3contrib) / (avgFast*data['npulses_sase3'])
F = float(ratio[start:stop].median().values)
if plot:
fig = plt.figure(figsize=(8,5))
ax = plt.subplot(211)
ax.set_title('E[uJ] = {:2e} x TIM (MCP{})'.format(F, mcp))
ax.plot(data['SCS_photonflux'], label='SCS XGM slow (all SASE)', color='C0')
slow_avg_sase3 = data['SCS_photonflux']*(data['npulses_sase1']
+data['npulses_sase3'])*sa3contrib/data['npulses_sase3']
ax.plot(slow_avg_sase3, label='SCS XGM slow (SASE3 only)', color='C1')
ax.plot(avgFast*F, label='Calibrated TIM rolling avg', color='C2')
ax.legend(loc='upper left', fontsize=8)
ax.set_ylabel('Energy [$\mu$J]', size=10)
ax.plot(filteredTIM.mean(axis=1)*F, label='Calibrated TIM train avg', alpha=0.2, color='C9')
ax.legend(loc='best', fontsize=8, ncol=2)
plt.xlabel('train in run')
ax = plt.subplot(234)
xgm_fast = selectSASEinXGM(data)
ax.scatter(filteredTIM, xgm_fast, s=5, alpha=0.1, rasterized=True)
fit, cov = np.polyfit(filteredTIM.values.flatten(),xgm_fast.values.flatten(),1, cov=True)
y=np.poly1d(fit)
x=np.linspace(filteredTIM.min(), filteredTIM.max(), 10)
ax.plot(x, y(x), lw=2, color='r')
ax.set_ylabel('Raw HAMP [$\mu$J]', size=10)
ax.set_xlabel('TIM (MCP{}) signal'.format(mcp), size=10)
ax.annotate(s='y(x) = F x + A\n'+
'F = %.3e\n$\Delta$F/F = %.2e\n'%(fit[0],np.abs(np.sqrt(cov[0,0])/fit[0]))+
'A = %.3e'%fit[1],
xy=(0.5,0.6), xycoords='axes fraction', fontsize=10, color='r')
print('TIM calibration factor: %e'%(F))
ax = plt.subplot(235)
ax.hist(filteredTIM.values.flatten()*F, bins=50, rwidth=0.8)
ax.set_ylabel('number of pulses', size=10)
ax.set_xlabel('Pulse energy MCP{} [uJ]'.format(mcp), size=10)
ax.set_yscale('log')
ax = plt.subplot(236)
if not use_apd:
pulseStart = intstart
pulseStop = intstop
else:
pulseStart = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStart.value')[0].values
pulseStop = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1', 'board1.apd.channel_0.pulseStop.value')[0].values
if 'MCP{}raw'.format(mcp) not in data:
tid, data = data.attrs['run'].train_from_index(0)
trace = data['SCS_UTC1_ADQ/ADC/1:network']['digitizers.channel_1_D.raw.samples']
print('no raw data for MCP{}. Loading trace from MCP1'.format(mcp))
label_trace='MCP1 Voltage [V]'
else:
trace = data['MCP{}raw'.format(mcp)][0]
label_trace='MCP{} Voltage [V]'.format(mcp)
ax.plot(trace[:pulseStop+25], 'o-', ms=2, label='trace')
ax.axvspan(pulseStart, pulseStop, color='C2', alpha=0.2, label='APD region')
ax.axvline(pulseStart, color='gray', ls='--')
ax.axvline(pulseStop, color='gray', ls='--')
ax.set_xlim(pulseStart - 25, pulseStop + 25)
ax.set_ylabel(label_trace, size=10)
ax.set_xlabel('sample #', size=10)
ax.legend(fontsize=8)
plt.tight_layout()
return F
''' TIM calibration table
Dict with key= photon energy and value= array of polynomial coefficients for each MCP (1,2,3).
The polynomials correspond to a fit of the logarithm of the calibration factor as a function
of MCP voltage. If P is a polynomial and V the MCP voltage, the calibration factor (in microjoule
per APD signal) is given by -exp(P(V)).
This table was generated from the calibration of March 2019, proposal 900074, semester 201930,
runs 69 - 111 (Ni edge): https://in.xfel.eu/elog/SCS+Beamline/2323
runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334
runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349
'''
tim_calibration_table = {
705.5: np.array([
[-6.85344690e-12, 5.00931986e-08, -1.27206912e-04, 1.15596821e-01, -3.15215367e+01],
[ 1.25613942e-11, -5.41566381e-08, 8.28161004e-05, -7.27230153e-02, 3.10984925e+01],
[ 1.14094964e-12, 7.72658935e-09, -4.27504907e-05, 4.07253378e-02, -7.00773062e+00]]),
779: np.array([
[ 4.57610777e-12, -2.33282497e-08, 4.65978738e-05, -6.43305156e-02, 3.73958623e+01],
[ 2.96325102e-11, -1.61393276e-07, 3.32600044e-04, -3.28468195e-01, 1.28328844e+02],
[ 1.14521506e-11, -5.81980336e-08, 1.12518434e-04, -1.19072484e-01, 5.37601559e+01]]),
851: np.array([
[ 3.15774215e-11, -1.71452934e-07, 3.50316512e-04, -3.40098861e-01, 1.31064501e+02],
[5.36341958e-11, -2.92533156e-07, 6.00574534e-04, -5.71083140e-01, 2.10547161e+02],
[ 3.69445588e-11, -1.97731342e-07, 3.98203522e-04, -3.78338599e-01, 1.41894119e+02]])
}
def timFactorFromTable(voltage, photonEnergy, mcp=1):
''' Returns an energy calibration factor for TIM integrated peak signal (APD)
according to calibration from March 2019, proposal 900074, semester 201930,
runs 69 - 111 (Ni edge): https://in.xfel.eu/elog/SCS+Beamline/2323
runs 113 - 153 (Co edge): https://in.xfel.eu/elog/SCS+Beamline/2334
runs 163 - 208 (Fe edge): https://in.xfel.eu/elog/SCS+Beamline/2349
Uses the tim_calibration_table declared above.
Inputs:
voltage: MCP voltage in volts.
photonEnergy: FEL photon energy in eV. Calibration factor is linearly
interpolated between the known values from the calibration table.
mcp: MCP channel (1, 2, or 3).
Output:
f: calibration factor in microjoule per APD signal
'''
energies = np.sort([key for key in tim_calibration_table])
if photonEnergy not in energies:
if photonEnergy > energies.max():
photonEnergy = energies.max()
elif photonEnergy < energies.min():
photonEnergy = energies.min()
else:
idx = np.searchsorted(energies, photonEnergy) - 1
polyA = np.poly1d(tim_calibration_table[energies[idx]][mcp-1])
polyB = np.poly1d(tim_calibration_table[energies[idx+1]][mcp-1])
fA = -np.exp(polyA(voltage))
fB = -np.exp(polyB(voltage))
f = fA + (fB-fA)/(energies[idx+1]-energies[idx])*(photonEnergy - energies[idx])
return f
poly = np.poly1d(tim_calibration_table[photonEnergy][mcp-1])
f = -np.exp(poly(voltage))
return f
def checkTimApdWindow(data, mcp=1, use_apd=True, intstart=None, intstop=None):
''' Plot the first and last pulses in MCP trace together with
the window of integration to check if the pulse integration
is properly calculated. If the number of pulses changed during
the run, it selects a train where the number of pulses was
maximum.
Inputs:
data: xarray Dataset
mcp: MCP channel (1, 2, 3 or 4)
use_apd: if True, gets the APD parameters from the digitizer
device. If False, uses intstart and intstop as boundaries
and uses the bunch pattern to determine the separation
between two pulses.
intstart: trace index of integration start of the first pulse
intstop: trace index of integration stop of the first pulse
Output:
Plot
'''
mcpToChannel={1:'D', 2:'C', 3:'B', 4:'A'}
apdChannels={1:3, 2:2, 3:1, 4:0}
npulses_max = data['npulses_sase3'].max().values
tid = data['npulses_sase3'].where(data['npulses_sase3'] == npulses_max,
drop=True).trainId.values
if 'MCP{}raw'.format(mcp) not in data:
print('no raw data for MCP{}. Loading average trace from MCP{}'.format(mcp, mcp))
trace = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1:network',
'digitizers.channel_1_{}.raw.samples'.format(mcpToChannel[mcp])
).sel({'trainId':tid}).mean(dim='trainId')
else:
trace = data['MCP{}raw'.format(mcp)].sel({'trainId':tid}).mean(dim='trainId')
if use_apd:
pulseStart = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1',
'board1.apd.channel_{}.pulseStart.value'.format(apdChannels[mcp]))[0].values
pulseStop = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1',
'board1.apd.channel_{}.pulseStop.value'.format(apdChannels[mcp]))[0].values
initialDelay = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1',
'board1.apd.channel_{}.initialDelay.value'.format(apdChannels[mcp]))[0].values
upperLimit = data.attrs['run'].get_array(
'SCS_UTC1_ADQ/ADC/1',
'board1.apd.channel_{}.upperLimit.value'.format(apdChannels[mcp]))[0].values
else:
pulseStart = intstart
pulseStop = intstop
if npulses_max > 1:
sa3 = data['sase3'].where(data['sase3']>1)
step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
step = int(step[1] - step[0])
nsamples = 440 * step
else:
nsamples = 0
fig, ax = plt.subplots(figsize=(5,3))
ax.plot(trace[:pulseStop+25], color='C1', label='first pulse')
ax.axvspan(pulseStart, pulseStop, color='k', alpha=0.1, label='APD region')
ax.axvline(pulseStart, color='gray', ls='--')
ax.axvline(pulseStop, color='gray', ls='--')
ax.set_xlim(pulseStart-25, pulseStop+25)
ax.locator_params(axis='x', nbins=4)
ax.set_ylabel('MCP{} Voltage [V]'.format(mcp))
ax.set_xlabel('First pulse sample #')
if npulses_max > 1:
pulseStart = pulseStart + nsamples*(npulses_max-1)
pulseStop = pulseStop + nsamples*(npulses_max-1)
ax2 = ax.twiny()
ax2.plot(range(pulseStart-25,pulseStop+25), trace[pulseStart-25:pulseStop+25],
color='C4', label='last pulse')
ax2.locator_params(axis='x', nbins=4)
ax2.set_xlabel('Last pulse sample #')
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc=0)
else:
ax.legend(loc='lower left')
plt.tight_layout()
def matchXgmTimPulseId(data, use_apd=True, intstart=None, intstop=None,
bkgstart=None, bkgstop=None, t_offset=None,
npulses=None, sase3First=True, stride=1):
''' Function to match XGM pulse Id with TIM pulse Id.
Inputs:
data: xarray Dataset containing XGM and TIM data
use_apd: bool. If True, uses the digitizer APD ('MCP[1,2,3,4]apd').
If False, peak integration is performed from raw traces.
All following parameters are needed in this case.
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
t_offset: index separation between two pulses
npulses: number of pulses to compute. Required if no bunch
pattern info is available
sase3First: bool, needed if bunch pattern is missing.
stride: int, used to select pulses in the TIM APD array if
no bunch pattern info is available.
Output:
xr DataSet containing XGM and TIM signals with the share d
dimension 'sa3_pId'. Raw traces, raw XGM and raw APD are dropped.
'''
dropList = []
mergeList = []
ndata = cleanXGMdata(data, npulses, sase3First)
for mcp in range(1,5):
if 'MCP{}apd'.format(mcp) in data or 'MCP{}raw'.format(mcp) in data:
MCPapd = getTIMapd(data, mcp=mcp, use_apd=use_apd, intstart=intstart,
intstop=intstop,bkgstart=bkgstart, bkgstop=bkgstop,
t_offset=t_offset, npulses=npulses,
stride=stride).rename('MCP{}apd'.format(mcp))
if use_apd:
MCPapd = MCPapd.rename({'apdId':'sa3_pId'})
else:
MCPapd = MCPapd.rename({'MCP{}fromRaw'.format(mcp):'sa3_pId'})
mergeList.append(MCPapd)
if 'MCP{}raw'.format(mcp) in ndata:
dropList.append('MCP{}raw'.format(mcp))
if 'MCP{}apd'.format(mcp) in data:
dropList.append('MCP{}apd'.format(mcp))
mergeList.append(ndata.drop(dropList))
subset = xr.merge(mergeList, join='inner')
for k in ndata.attrs.keys():
subset.attrs[k] = ndata.attrs[k]
return subset
# Fast ADC
def fastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop, period=None, npulses=None):
''' Computes peak integration from raw FastADC traces.
Inputs:
data: xarray Dataset containing FastADC raw traces (e.g. 'FastADC1raw')
channel: FastADC channel number
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
period: number of samples between two pulses. Needed if bunch
pattern info is not available. If None, checks the pulse
pattern and determine the period assuming a resolution of
9.23 ns per sample which leads to 24 samples between
two bunches @ 4.5 MHz.
npulses: number of pulses. If None, takes the maximum number of
pulses according to the bunch patter (field 'npulses_sase3')
Output:
results: DataArray with dims trainId x max(sase3 pulses)
'''
keyraw = 'FastADC{}raw'.format(channel)
if keyraw not in data:
raise ValueError("Source not found: {}!".format(keyraw))
if npulses is None:
npulses = int(data['npulses_sase3'].max().values)
if period is None:
sa3 = data['sase3'].where(data['sase3']>1)
if npulses > 1:
#Calculate the number of pulses between two lasing pulses (step)
step = sa3.where(data['npulses_sase3']>1, drop=True)[0,:2].values
step = int(step[1] - step[0])
#multiply by elementary pulse length (221.5 ns / 9.23 ns = 24 samples)
period = 24 * step
else:
period = 1
results = xr.DataArray(np.empty((data.trainId.shape[0], npulses)), coords=data[keyraw].coords,
dims=['trainId', 'peakId'.format(channel)])
for i in range(npulses):
a = intstart + period*i
b = intstop + period*i
bkga = bkgstart + period*i
bkgb = bkgstop + period*i
bg = np.outer(np.median(data[keyraw][:,bkga:bkgb], axis=1), np.ones(b-a))
integ = np.trapz(data[keyraw][:,a:b] - bg, axis=1)
results[:,i] = integ
return results
def autoFindFastAdcPeaks(data, channel=5, threshold=35000, display=False, plot=False):
''' Automatically finds positive peaks in channel of Fast ADC trace, assuming
a minimum absolute height of 'threshold' counts and a minimum width of 4
samples. The find_peaks function and determination of the peak region and
baseline subtraction is optimized for typical photodiode signals of the
SCS instrument (ILH, FFT reflectometer, FFT diag stage).
Inputs:
data: xarray Dataset containing Fast ADC traces
key: data key of the array of traces
threshold: minimum height of the peaks
display: bool, displays info on the pulses found
plot: plots regions of integration of the first pulse in the trace
Output:
peaks: DataArray of the integrated peaks
'''
key = f'FastADC{channel}raw'
if key not in data:
raise ValueError(f'{key} not found in data set')
trace = data[key].where(data['npulses_sase3']>0, drop=True).isel(trainId=0).values
centers, peaks = find_peaks(trace, height=threshold, width=(4, None))
c = centers[0]
w = np.average(peaks['widths']).astype(int)
period = np.median(np.diff(centers)).astype(int)
npulses = centers.shape[0]
intstart = int(c - w/4) + 1
intstop = int(c + w/4) + 1
bkgstop = int(peaks['left_ips'][0])-5
bkgstart = bkgstop - 10
if display:
print(f'Found {npulses} pulses, avg. width={w}, period={period} samples, ' +
f'rep. rate={1e6/(9.230769*period):.3f} kHz')
fAdcPeaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop,
bkgstart=bkgstart, bkgstop=bkgstop, period=period, npulses=npulses)
if plot:
plt.figure()
plt.plot(trace, 'o-', ms=3)
for i in range(npulses):
plt.axvline(intstart+i*period, ls='--', color='g')
plt.axvline(intstop+i*period, ls='--', color='r')
plt.axvline(bkgstart+i*period, ls='--', color='lightgrey')
plt.axvline(bkgstop+i*period, ls='--', color='grey')
plt.title(f'Fast ADC {channel} trace')
plt.xlim(bkgstart-10, intstop + 50)
return fAdcPeaks
def mergeFastAdcPeaks(data, channel, intstart, intstop, bkgstart, bkgstop,
period=None, npulses=None, dim='lasPulseId'):
''' Calculates the peaks from Fast ADC raw traces with fastAdcPeaks()
and merges the results in Dataset.
Inputs:
data: xr Dataset with 'FastADC[channel]raw' traces
channel: Fast ADC channel
intstart: trace index of integration start
intstop: trace index of integration stop
bkgstart: trace index of background start
bkgstop: trace index of background stop
period: Number of ADC samples between two pulses. Needed
if bunch pattern info is not available. If None, checks the
pulse pattern and determine the period assuming a resolution
of 9.23 ns per sample = 24 samples between two pulses @ 4.5 MHz.
npulses: number of pulses. If None, takes the maximum number of
pulses according to the bunch patter (field 'npulses_sase3')
dim: name of the xr dataset dimension along the peaks
'''
peaks = fastAdcPeaks(data, channel=channel, intstart=intstart, intstop=intstop,
bkgstart=bkgstart, bkgstop=bkgstop, period=period,
npulses=npulses)
key = 'FastADC{}peaks'.format(channel)
if key in data:
s = data.drop(key)
else:
s = data
peaks = peaks.rename(key).rename({'peakId':dim})
subset = xr.merge([s, peaks], join='inner')
for k in data.attrs.keys():
subset.attrs[k] = data.attrs[k]
return subset