Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • SCS/ToolBox
  • kluyvert/ToolBox
2 results
Show changes
Showing
with 4762 additions and 0 deletions
from .bunch_pattern import *
from .bunch_pattern_external import *
from .laser_utils import *
from .undulator import *
__all__ = (
bunch_pattern.__all__
+ bunch_pattern_external.__all__
+ laser_utils.__all__
+ undulator.__all__
)
# -*- coding: utf-8 -*-
""" Toolbox for SCS.
Various utilities function to quickly process data
measured at the SCS instruments.
Copyright (2019) SCS Team.
"""
import os
import logging
import numpy as np
import xarray as xr
from extra_data.read_machinery import find_proposal
from extra_data import RunDirectory
# import and hide variable, such that it does not alter namespace.
from ..constants import mnemonics as _mnemonics_bp
from ..mnemonics_machinery import mnemonics_for_run
from .bunch_pattern_external import is_pulse_at
__all__ = [
'extractBunchPattern',
'get_sase_pId',
'npulses_has_changed',
'pulsePatternInfo',
'repRate',
]
log = logging.getLogger(__name__)
def npulses_has_changed(run, loc='sase3', run_mnemonics=None):
"""
Checks if the number of pulses has changed during the run for
a specific location `loc` (='sase1', 'sase3', 'scs_ppl' or 'laser')
If the source is not found in the run, returns True.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the data.
loc: str
The location where to check: {'sase1', 'sase3', 'scs_ppl'}
run_mnemonics: dict
the mnemonics for the run (see `menonics_for_run`)
Returns
-------
ret: bool
True if the number of pulses has changed or the source was not
found, False if the number of pulses did not change.
"""
sase_list = ['sase1', 'sase3', 'laser', 'scs_ppl']
if loc not in sase_list:
raise ValueError(f"Unknow sase location '{loc}'. Expected one in"
f"{sase_list}")
if run_mnemonics is None:
run_mnemonics = mnemonics_for_run(run)
if loc == 'scs_ppl':
loc = 'laser'
if loc not in run_mnemonics:
return True
if run_mnemonics[loc]['key'] not in run[run_mnemonics[loc]['source']].keys():
log.info(f'Mnemonic {loc} not found in run.')
return True
npulses = run.get_array(*run_mnemonics['npulses_'+loc].values())
if len(np.unique(npulses)) == 1:
return False
return True
def get_unique_sase_pId(run, loc='sase3', run_mnemonics=None):
"""
Assuming that the number of pulses did not change during the run,
returns the pulse Ids as the run value of the sase mnemonic.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the data.
loc: str
The location where to check: {'sase1', 'sase3', 'scs_ppl'}
run_mnemonics: dict
the mnemonics for the run (see `menonics_for_run`)
Returns
-------
pulseIds: np.array
the pulse ids at the specified location. Returns None if the
mnemonic is not in the run.
"""
if run_mnemonics is None:
run_mnemonics = mnemonics_for_run(run)
if loc == 'scs_ppl':
loc = 'laser'
if loc not in run_mnemonics:
# bunch pattern not recorded
return None
npulses = run.get_run_value(run_mnemonics['npulses_'+loc]['source'],
run_mnemonics['npulses_'+loc]['key'])
pulseIds = run.get_run_value(run_mnemonics[loc]['source'],
run_mnemonics[loc]['key'])[:npulses]
return pulseIds
def get_sase_pId(run, loc='sase3', run_mnemonics=None,
bpt=None, merge_with=None):
"""
Returns the pulse Ids of the specified `loc` during a run.
If the number of pulses has changed during the run, it loads the
bunch pattern table and extract all pulse Ids used.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the data.
loc: str
The location where to check: {'sase1', 'sase3', 'scs_ppl'}
run_mnemonics: dict
the mnemonics for the run (see `menonics_for_run`)
bpt: 2D-array
The bunch pattern table. Used only if the number of pulses
has changed. If None, it is loaded on the fly.
merge_with: xarray.Dataset
dataset that may contain the bunch pattern table to use in
case the number of pulses has changed. If merge_with does
not contain the bunch pattern table, it is loaded and added
as a variable 'bunchPatternTable' to merge_with.
Returns
-------
pulseIds: np.array
the pulse ids at the specified location. Returns None if the
mnemonic is not in the run.
"""
if npulses_has_changed(run, loc, run_mnemonics) is False:
return get_unique_sase_pId(run, loc, run_mnemonics)
if bpt is None:
bpt = load_bpt(run, merge_with, run_mnemonics)
if bpt is not None:
mask = is_pulse_at(bpt, loc)
return np.unique(np.nonzero(mask.values)[1])
return None
def load_bpt(run, merge_with=None, run_mnemonics=None):
"""
Load the bunch pattern table. It returns the one contained in
merge_with if possible. Or, it adds it to merge_with once it is
loaded.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the data.
merge_with: xarray.Dataset
dataset that may contain the bunch pattern table or to which
add the bunch pattern table once loaded.
run_mnemonics: dict
the mnemonics for the run (see `menonics_for_run`)
Returns
-------
bpt: xarray.Dataset
the bunch pattern table as specified by the mnemonics
'bunchPatternTable'
"""
if run_mnemonics is None:
run_mnemonics = mnemonics_for_run(run)
for key in ['bunchPatternTable', 'bunchPatternTable_SA3']:
if merge_with is not None and key in merge_with:
log.debug(f'Using {key} from merge_with dataset.')
return merge_with[key]
if key in run_mnemonics:
bpt = run.get_array(*run_mnemonics[key].values(),
name='bunchPatternTable')
log.debug(f'Loaded {key} from DataCollection.')
if merge_with is not None:
merge_with.update(merge_with.merge(bpt, join='inner'))
return bpt
log.debug('Could not find bunch pattern table.')
return None
def extractBunchPattern(bp_table=None, key='sase3', runDir=None):
''' generate the bunch pattern and number of pulses of a source directly from the
bunch pattern table and not using the MDL device BUNCH_DECODER. This is
inspired by the euxfel_bunch_pattern package,
https://git.xfel.eu/gitlab/karaboDevices/euxfel_bunch_pattern
Inputs:
bp_table: DataArray corresponding to the mnemonics "bunchPatternTable".
If None, the bunch pattern table is loaded using runDir.
key: str, ['sase1', 'sase2', 'sase3', 'scs_ppl']
runDir: extra-data DataCollection. Required only if bp_table is None.
Outputs:
bunchPattern: DataArray containing indices of the sase/laser pulses for
each train
npulses: DataArray containing the number of pulses for each train
matched: 2-D DataArray mask (trainId x 2700), True where 'key' has pulses
'''
keys=['sase1', 'sase2', 'sase3', 'scs_ppl']
if key not in keys:
raise ValueError(f'Invalid key "{key}", possible values are {keys}')
if bp_table is None:
if runDir is None:
raise ValueError('bp_table and runDir cannot both be None')
bp_mnemo = _mnemonics_bp['bunchPatternTable']
if bp_mnemo['source'] not in runDir.all_sources:
raise ValueError('Source {} not found in run'.format(
bp_mnemo['source']))
else:
bp_table = runDir.get_array(bp_mnemo['source'],bp_mnemo['key'],
extra_dims=bp_mnemo['dim'])
# define relevant masks, see euxfel_bunch_pattern package for details
DESTINATION_MASK = 0xf << 18
DESTINATION_T4D = 4 << 18 # SASE1/3 dump
DESTINATION_T5D = 2 << 18 # SASE2 dump
PHOTON_LINE_DEFLECTION = 1 << 27 # Soft kick (e.g. SA3)
LASER_SEED6 = 1 << 13
if 'sase' in key:
sase = int(key[4])
destination = DESTINATION_T5D if (sase == 2) else DESTINATION_T4D
matched = (bp_table & DESTINATION_MASK) == destination
if sase == 1:
# Pulses to SASE 1 when soft kick is off
matched &= (bp_table & PHOTON_LINE_DEFLECTION) == 0
elif sase == 3:
# Pulses to SASE 3 when soft kick is on
matched &= (bp_table & PHOTON_LINE_DEFLECTION) != 0
elif key=='scs_ppl':
matched = (bp_table & LASER_SEED6) != 0
# create table of indices where bunch pattern and mask match
nz = np.nonzero(matched.values)
dim_pId = matched.shape[1]
bunchPattern = np.ones(matched.shape, dtype=np.uint64)*dim_pId
bunchPattern[nz] = nz[1]
bunchPattern = np.sort(bunchPattern)
npulses = np.count_nonzero(bunchPattern<dim_pId, axis=1)
bunchPattern[bunchPattern == dim_pId] = 0
bunchPattern = xr.DataArray(bunchPattern[:,:1000], dims=['trainId', 'bunchId'],
coords={'trainId':matched.trainId},
name=key)
npulses = xr.DataArray(npulses, dims=['trainId'],
coords={'trainId':matched.trainId},
name=f'npulses_{key}')
return bunchPattern, npulses, matched
def pulsePatternInfo(data, plot=False):
''' display general information on the pulse patterns operated by SASE1 and SASE3.
This is useful to track changes of number of pulses or mode of operation of
SASE1 and SASE3. It also determines which SASE comes first in the train and
the minimum separation between the two SASE sub-trains.
Inputs:
data: xarray Dataset containing pulse pattern info from the bunch decoder MDL:
{'sase1, sase3', 'npulses_sase1', 'npulses_sase3'}
plot: bool enabling/disabling the plotting of the pulse patterns
Outputs:
print of pulse pattern info. If plot==True, plot of the pulse pattern.
'''
#Which SASE comes first?
npulses_sa3 = data['npulses_sase3']
npulses_sa1 = data['npulses_sase1']
dedicated = False
if np.all(npulses_sa1.where(npulses_sa3 !=0, drop=True) == 0):
dedicated = True
print('No SASE 1 pulses during SASE 3 operation')
if np.all(npulses_sa3.where(npulses_sa1 !=0, drop=True) == 0):
dedicated = True
print('No SASE 3 pulses during SASE 1 operation')
if dedicated==False:
pulseIdmin_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).min().values
pulseIdmax_sa1 = data['sase1'].where(npulses_sa1 != 0).where(data['sase1']>1).max().values
pulseIdmin_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).min().values
pulseIdmax_sa3 = data['sase3'].where(npulses_sa3 != 0).where(data['sase3']>1).max().values
#print(pulseIdmin_sa1, pulseIdmax_sa1, pulseIdmin_sa3, pulseIdmax_sa3)
if pulseIdmin_sa1 > pulseIdmax_sa3:
t = 0.220*(pulseIdmin_sa1 - pulseIdmax_sa3 + 1)
print('SASE 3 pulses come before SASE 1 pulses (minimum separation %.1f µs)'%t)
elif pulseIdmin_sa3 > pulseIdmax_sa1:
t = 0.220*(pulseIdmin_sa3 - pulseIdmax_sa1 + 1)
print('SASE 1 pulses come before SASE 3 pulses (minimum separation %.1f µs)'%t)
else:
print('Interleaved mode')
#What is the pulse pattern of each SASE?
for key in['sase3', 'sase1']:
print('\n*** %s pulse pattern: ***'%key.upper())
npulses = data['npulses_%s'%key]
sase = data[key]
if not np.all(npulses == npulses[0]):
print('Warning: number of pulses per train changed during the run!')
#take the derivative along the trainId to track changes in pulse number:
diff = npulses.diff(dim='trainId')
#only keep trainIds where a change occured:
diff = diff.where(diff !=0, drop=True)
#get a list of indices where a change occured:
idx_change = np.argwhere(np.isin(npulses.trainId.values,
diff.trainId.values, assume_unique=True))[:,0]
#add index 0 to get the initial pulse number per train:
idx_change = np.insert(idx_change, 0, 0)
print('npulses\tindex From\tindex To\ttrainId From\ttrainId To\trep. rate [kHz]')
for i,idx in enumerate(idx_change):
n = npulses[idx]
idxFrom = idx
trainIdFrom = npulses.trainId[idx]
if i < len(idx_change)-1:
idxTo = idx_change[i+1]-1
else:
idxTo = npulses.shape[0]-1
trainIdTo = npulses.trainId[idxTo]
if n <= 1:
print('%i\t%i\t\t%i\t\t%i\t%i'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo))
else:
f = 1/((sase[idxFrom,1] - sase[idxFrom,0])*222e-6)
print('%i\t%i\t\t%i\t\t%i\t%i\t%.0f'%(n, idxFrom, idxTo, trainIdFrom, trainIdTo, f))
print('\n')
if plot:
plt.figure(figsize=(6,3))
plt.plot(data['npulses_sase3'].trainId, data['npulses_sase3'], 'o-',
ms=3, label='SASE 3')
plt.xlabel('trainId')
plt.ylabel('pulses per train')
plt.plot(data['npulses_sase1'].trainId, data['npulses_sase1'], '^-',
ms=3, color='C2', label='SASE 1')
plt.legend()
def repRate(data=None, runNB=None, proposalNB=None, key='sase3'):
''' Calculates the pulse repetition rate (in kHz) in sase
according to the bunch pattern and assuming a grid of
4.5 MHz.
Inputs:
-------
data: xarray Dataset containing pulse pattern, needed if runNB is none
runNB: int or str, run number. Needed if data is None
proposal: int or str, proposal where to find the run. Needed if data is None
key: str in [sase1, sase2, sase3, scs_ppl], source for which the
repetition rate is calculated
Output:
-------
f: repetition rate in kHz
'''
if runNB is None and data is None:
raise ValueError('Please provide either the runNB + proposal or the data argument.')
if runNB is not None and proposalNB is None:
raise ValueError('Proposal is missing.')
if runNB is not None:
if isinstance(runNB, int):
runNB = 'r{:04d}'.format(runNB)
if isinstance(proposalNB,int):
proposalNB = 'p{:06d}'.format(proposalNB)
runFolder = os.path.join(find_proposal(proposalNB), 'raw', runNB)
runDir = RunDirectory(runFolder)
bp_mnemo = _mnemonics_bp['bunchPatternTable']
if bp_mnemo['source'] not in runDir.all_sources:
raise ValueError('Source {} not found in run'.format(
bp_mnemo['source']))
else:
bp_table = runDir.get_array(bp_mnemo['source'],bp_mnemo['key'],
extra_dims=bp_mnemo['dim'])
a, b, mask = extractBunchPattern(bp_table, key=key)
else:
if key not in ['sase1', 'sase3']:
a, b, mask = extractBunchPattern(key=key, runDir=data.attrs['run'])
else:
a = data[key]
b = data[f'npulses_{key}']
a = a.where(b > 1, drop = True).values
if len(a)==0:
print('Not enough pulses to extract repetition rate')
return 0
f = 1/((a[0,1] - a[0,0])*12e-3/54.1666667)
return f
"""
A collection of wrappers around the the euxfel_bunch_pattern pkg
The euxfel_bunch_pattern package provides generic methods to extract
information from the bunch pattern tables. To ease its use from within
the toolbox some of its methods are wrapped. Like this they show up in
the users namespace in a self-explanatory way.
"""
import logging
import euxfel_bunch_pattern as ebp
__all__ = [
'is_sase_3',
'is_sase_1',
'is_ppl',
'is_pulse_at',
]
PPL_SCS = ebp.LASER_SEED6
log = logging.getLogger(__name__)
def _convert_data(bpt_dec):
bpt_conv = bpt_dec
if type(bpt_dec).__module__ == 'xarray.core.dataarray':
bpt_conv = bpt_dec.where(bpt_dec.values == True, other=0)
elif type(bpt_dec).__module__ == 'numpy':
bpt_conv = bpt_dec.astype(int)
else:
dtype = type(bpt_dec).__module__
log.warning(f"Could not convert data type {dtype}."
"Return raw euxfel_bp table.")
return bpt_conv
def is_pulse_at(bpt, loc):
"""
Check for prescence of a pulse at the location provided.
Parameters
----------
bpt : numpy array, xarray DataArray
The bunch pattern data.
loc : str
The location where to check: {'sase1', 'sase3', 'scs_ppl'}
Returns
-------
boolean : numpy array, xarray DataArray
true if a pulse is present at *loc*.
"""
if loc == 'sase3':
bpt_dec = ebp.is_sase(bpt, 3)
elif loc == 'sase1':
bpt_dec = ebp.is_sase(bpt, 1)
elif loc == 'scs_ppl':
bpt_dec = ebp.is_laser(bpt, laser=PPL_SCS)
else:
raise ValueError(f'loc argument is {loc}, expected "sase1", ' +
'"sase3" or "scs_ppl"')
return _convert_data(bpt_dec)
def is_sase_3(bpt):
"""
Check for prescence of a SASE3 pulse.
Parameters
----------
bpt : numpy array, xarray DataArray
The bunch pattern data.
Returns
-------
boolean : numpy array, xarray DataArray
true if SASE3 pulse is present.
"""
bpt_dec = ebp.is_sase(bpt, 3)
return _convert_data(bpt_dec)
def is_sase_1(bpt):
"""
Check for prescence of a SASE1 pulse.
Parameters
----------
bpt : numpy array, xarray DataArray
The bunch pattern data.
Returns
-------
boolean : numpy array, xarray DataArray
true if SASE1 pulse is present.
"""
bpt_dec = ebp.is_sase(bpt, 1)
return _convert_data(bpt_dec)
def is_ppl(bpt):
"""
Check for prescence of pp-laser pulse.
Parameters
----------
bpt : numpy array, xarray DataArray
The bunch pattern data.
Returns
-------
boolean : numpy array, xarray DataArray
true if pp-laser pulse is present.
"""
bpt_dec = ebp.is_laser(bpt, laser=PPL_SCS)
return _convert_data(bpt_dec)
__all__ = [
'degToRelPower',
'positionToDelay',
'delayToPosition',
'fluenceCalibration',
'align_ol_to_fel_pId'
]
import numpy as np
import matplotlib.pyplot as plt
def positionToDelay(pos, origin=0, invert=True, reflections=1):
''' converts a motor position in mm into optical delay in picosecond
Inputs:
pos: array-like delay stage motor position
origin: motor position of time zero in mm
invert: bool, inverts the sign of delay if True
reflections: number of bounces in the delay stage
Output:
delay in picosecond
'''
c_ = 299792458 * 1e-9 # speed of light in mm/ps
x = -1 if invert else 1
return 2*reflections*(pos-origin)*x/c_
def delayToPosition(delay, origin=0, invert=True, reflections=1):
''' converts an optical delay in picosecond into a motor position in mm
Inputs:
delay: array-like delay in ps
origin: motor position of time zero in mm
invert: bool, inverts the sign of delay if True
reflections: number of bounces in the delay stage
Output:
delay in picosecond
'''
c_ = 299792458 * 1e-9 # speed of light in mm/ps
x = -1 if invert else 1
return origin + 0.5 * x * delay * c_ / reflections
def degToRelPower(x, theta0=0):
''' converts a half-wave plate position in degrees into relative power
between 0 and 1.
Inputs:
x: array-like positions of half-wave plate, in degrees
theta0: position for which relative power is zero
Output:
array-like relative power
'''
return np.sin(2*(x-theta0)*np.pi/180)**2
def fluenceCalibration(hwp, power_mW, npulses, w0x, w0y=None,
train_rep_rate=10, fit_order=1,
plot=True, xlabel='HWP [%]'):
"""
Given a measurement of relative powers or half wave plate angles
and averaged powers in mW, this routine calculates the corresponding
fluence and fits a polynomial to the data.
Parameters
----------
hwp: array-like (N)
angle or relative power from the half wave plate
power_mW: array-like (N)
measured power in mW by powermeter
npulses: int
number of pulses per train during power measurement
w0x: float
radius at 1/e^2 in x-axis in meter
w0y: float, optional
radius at 1/e^2 in y-axis in meter. If None, w0y=w0x is assumed.
train_rep_rate: float
repetition rate of the FEL, by default equals to 10 Hz.
fit_order: int
order of the polynomial fit
plot: bool
Plot the results if True
xlabel: str
xlabel for the plot
Output
------
F: ndarray (N)
fluence in mJ/cm^2
fit_F: ndarray
coefficients of the fluence polynomial fit
E: ndarray (N)
pulse energy in microJ
fit_E: ndarray
coefficients of the fluence polynomial fit
"""
power = np.array(power_mW)
hwp = np.array(hwp)
E = power/(train_rep_rate*npulses)*1e-3 # pulse energy in J
if w0y is None:
w0y = w0x
F = 2*E/(np.pi*w0x*w0y) # fluence in J/m^2
fit_E = np.polyfit(hwp, E*1e6, fit_order)
fit_F = np.polyfit(hwp, F*1e-1, fit_order)
x = np.linspace(hwp.min(), hwp.max(), 100)
if plot:
fig, ax = plt.subplots(figsize=(6, 4))
ax.set_title(f'w0x = {w0x*1e6:.0f} $\mu$m, w0y = {w0y*1e6:.0f} $\mu$m')
ax.plot(hwp, F*1e-1, 'o', label='data')
fit_label = 'F = '
for i in range(len(fit_F)-1, 1, -1):
fit_label += f'{fit_F[i]:.2g}x$^{i}$ + '
if i % 2 == 0:
fit_label += '\n'
fit_label += f'{fit_F[-2]:.2g}x + {fit_F[-1]:.2g}'
ax.plot(x, np.poly1d(fit_F)(x), label=fit_label)
ax.set_ylabel('Fluence [mJ/cm$^2$]')
ax.set_xlabel(xlabel)
ax.legend()
ax.grid()
def eTf(x):
return 1e-7*2*x/(np.pi*w0x*w0y)
def fTe(x):
return 1e7*x*np.pi*w0x*w0y/2
ax2 = ax.secondary_yaxis('right', functions=(fTe, eTf))
ax2.set_ylabel(r'Pulse energy [$\mu$J]')
return F*1e-1, fit_F, E*1e6, fit_E
def align_ol_to_fel_pId(ds, ol_dim='ol_pId', fel_dim='sa3_pId',
offset=0, fill_value=np.nan):
'''
Aligns the optical laser (OL) pulse Ids to the FEL pulse Ids.
The new OL coordinates are calculated as ds[ol_dim] +
ds[fel_dim][0] + offset. The ol_dim is then removed, and if the number
of OL and FEL pulses are different, the missing values are replaced by
fill_value (NaN by default).
Parameters
----------
ds: xarray.Dataset
Dataset containing both OL and FEL dimensions
ol_dim: str
name of the OL dimension
fel_dim: str
name of the FEL dimension
offset: int
offset added to the OL pulse Ids.
fill_value: (scalar or dict-like, optional)
Value to use for newly missing values. If a dict-like, maps variable
names to fill values. Use a data array’s name to refer to its values.
Output
------
ds: xarray.Dataset
The newly aligned dataset
'''
fel_vars = [v for v in ds if fel_dim in ds[v].dims]
ol_vars = [v for v in ds if ol_dim in ds[v].dims] + [ol_dim]
if len(set.intersection(set(fel_vars), set(ol_vars))) > 0:
raise ValueError('Variables share ol and fel dimensions: no alignment'
' possible.')
ds_fel = ds.drop(ol_vars)
ds_ol = ds[ol_vars]
ds_ol = ds_ol.assign_coords({ol_dim: ds[ol_dim] +
ds[fel_dim][0].values + offset})
ds_ol = ds_ol.rename({ol_dim: fel_dim})
ds = ds_fel.merge(ds_ol, join='outer', fill_value=fill_value)
return ds
__all__ = [
'get_undulator_config',
]
import numpy as np
import xarray as xr
from toolbox_scs.load import load_run_values
import matplotlib.pyplot as plt
def get_undulator_cells(run, park_pos=62.0):
rvalues = run.get_run_values('SA3_XTD4_UND/DOOCS/UNDULATOR_CELLS')
rvalues = {k: rvalues[k] for k in rvalues.keys()
if '.gapApplied.value' in k or '.kApplied.value' in k}
cells = list(range(2, 13)) + list(range(14, 24))
keys = np.unique([k.replace('.gapApplied.value',
'').replace('.kApplied.value',
'') for k in rvalues])
assert len(keys) == len(cells)
result = [[], [], []]
names = ['gap', 'K', 'cell_name']
for i, k in enumerate(keys):
result[0].append(rvalues[k + '.gapApplied.value'])
result[1].append(rvalues[k + '.kApplied.value'])
result[2].append(k)
result = xr.merge([xr.DataArray(result[i], dims='cell',
coords={'cell': cells},
name=names[i]) for i in range(3)])
result['closed'] = result['gap'] < park_pos
return result
def plot_undulator_config(ds, park_pos):
fig, ax = plt.subplots(figsize=(6, 3))
ax.bar(ds.cell, ds.gap - park_pos-1, bottom=park_pos+1, alpha=0.5)
for c in ds.cell:
ax.text(c-.25, park_pos/2, f"K={ds.sel(cell=c).K.values:.4f}",
rotation='vertical')
ax.set_ylim(0, park_pos+1)
ax.invert_yaxis()
ax.set_xlabel('CELL #')
ax.set_ylabel('gap size')
def get_undulator_config(run, park_pos=62.0, plot=True):
'''
Extract the undulator cells configuration from a given run.
The gap size and K factor as well as the magnetic chicane delay and photon
energy of colors 1, 2 and 3 are compiled into an xarray Dataset.
Note:
This function looks at run control values, it does not reflect any change
of values during the run. Do not use to extract configuration when scanning
the undulator.
Parameters
----------
run: EXtra-Data DataCollection
The run containing the undulator information
park_pos: float, optional
The parked position of a cell (i.e. when fully opened)
plot: bool, optional
If True, plot the undulator cells configuration
Returns
-------
cells: xarray Dataset
The resulting dataset of the undulator configuration
'''
ds = get_undulator_cells(run, park_pos)
rvalues = load_run_values(run)
attrs = {}
if 'UND' in rvalues:
attrs = {f'color_{i+1}_keV': rvalues[k] for
i, k in enumerate(['UND', 'UND2', 'UND3'])}
if 'MAG_CHICANE_DELAY' in rvalues:
attrs['MAG_CHICANE_DELAY'] = rvalues['MAG_CHICANE_DELAY']
for k in attrs:
ds.attrs[k] = attrs[k]
if plot:
plot_undulator_config(ds, park_pos)
return ds
""" Handling ToolBox mnemonics sub-routines
Copyright (2021) SCS Team.
(contributions preferrably comply with pep8 code structure
guidelines.)
"""
import logging
from .constants import mnemonics as _mnemonics
from extra_data import open_run
from copy import deepcopy
__all__ = [
'mnemonics_for_run'
]
log = logging.getLogger(__name__)
def mnemonics_for_run(prop_or_run, runNB=None):
"""
Returns the availble ToolBox mnemonics for a give extra_data
DataCollection, or a given proposal + run number.
Parameters
----------
prop_or_run: extra_data DataCollection or int
The run (DataCollection) to check for mnemonics.
Alternatively, the proposal number (int), for which the runNB
is also required.
runNB: int
The run number. Only used if the first argument is the proposal
number.
Returns
-------
mnemonics: dict
The dictionnary of mnemonics that are available in the run.
Example
-------
>>> import toolbox_scs as tb
>>> tb.mnemonics_for_run(2212, 213)
"""
run = prop_or_run
if runNB is not None:
run = open_run(prop_or_run, runNB)
result = {}
for m in _mnemonics:
version = mnemo_version_index(run, m)
if version != -1:
result[m] = _mnemonics[m][version]
return result
def mnemo_version_index(run, mnemonic):
"""
Given a mnemonic and a DataCollection, checks the existence of the
mnemonic versions within the tuple value of the mnemonic and returns the
valid version index.
Parameters
----------
run: extra_data DataCollection
The run that contains the data source corresponding to the mnemonic
menmonic: str
A ToolBox mnemonic
Returns
-------
index: int
The index of the tuple. If no valid version is found, returns -1.
"""
if len(_mnemonics[mnemonic]) == 1:
if _mnemonics[mnemonic][0]['source'] in run.all_sources:
return 0
return -1
for i, v in enumerate(_mnemonics[mnemonic]):
if (v['source'] in run.all_sources and
v['key'] in run.keys_for_source(v['source'])):
log.debug(f'Found version {i} for mnemonic "{mnemonic}": {v}')
return i
return -1
def mnemonics_to_process(mnemo_list, merge_with, detector, func=None):
"""
Finds the list of mnemonics, within mnemo_list and merge_with, that
correspond to arrays that are not yet loaded and/or processed by a
detector function. Removes the mnemonics of the already processed
arrays from the list.
Parameters
----------
mnemo_list: str or list of str
ToolBox mnemonics of pulse-resolved detector arrays
merge_with: xarray Dataset
Dataset that may contain non-processed arrays
detector: str
One in {'ADQ412', 'FastADC', 'FastADC2', 'XGM', 'BAM', 'PES'}
func: function
function that takes one argument, an unprocessed mnemonic string,
and converts it into a processed one, i.e. from 'MCP2apd' to
'MCP2peaks'. If None, the function returns the input mnemonic.
Returns
-------
mnemonics: list of str
the mnemonics to process
"""
if func is None:
def func(x):
return x
det_list = ['ADQ412', 'FastADC', 'FastADC2', 'XGM', 'BAM', 'PES']
if detector not in det_list:
raise ValueError(f"Detector not supported. Expecting one in {det_list}")
if detector == 'BAM':
det_mnemos = [m for m in _mnemonics if 'BAM' in m]
default_mnemo = 'BAM1932M'
default_processed = 'BAM1932M'
if detector == 'XGM':
det_mnemos = ['XTD10_XGM', 'XTD10_XGM_sigma', 'XTD10_SA3',
'XTD10_SA3_sigma', 'XTD10_SA1', 'XTD10_SA1_sigma',
'SCS_XGM', 'SCS_XGM_sigma', 'SCS_SA1', 'SCS_SA1_sigma',
'SCS_SA3', 'SCS_SA3_sigma']
default_mnemo = 'SCS_SA3'
default_processed = 'SCS_SA3'
if detector == 'ADQ412':
det_mnemos = [m for m in _mnemonics if 'MCP' in m and
'XTD10_' not in m]
default_mnemo = 'MCP2apd'
default_processed = 'MCP2peaks'
if detector == 'FastADC':
det_mnemos = [m for m in _mnemonics if 'FastADC' in m\
and 'FastADC2_' not in m]
default_mnemo = 'FastADC5raw'
default_processed = 'FastADC5peaks'
if detector == 'FastADC2':
det_mnemos = [m for m in _mnemonics if 'FastADC2_' in m]
default_mnemo = 'FastADC2_5raw'
default_processed = 'FastADC2_5peaks'
if detector == 'PES':
det_mnemos = [m for m in _mnemonics if 'PES' in m and 'raw' in m]
default_mnemo = 'PES_W_raw'
default_processed = 'PES_W_tof'
dig_dims = list(set([_mnemonics[m][0]['dim'][0] for m in det_mnemos]))
processed_mnemos = list(set([func(m) for m in det_mnemos]))
# create a list of mnemonics to process from the provided mnemonics and
# merge_with Dataset
mw_mnemos = []
mw_processed = []
if bool(merge_with):
mw_mnemos = [m for m in merge_with if m in det_mnemos and
any(dim in merge_with[m].dims for dim in dig_dims)]
mw_processed = [m for m in merge_with if m in processed_mnemos and
any(dim in merge_with[m].dims for dim in dig_dims)
is False]
if mnemo_list is None:
mnemonics = []
if len(mw_mnemos) == 0 and default_processed not in mw_processed:
mnemonics = [default_mnemo]
else:
mnemonics = [mnemo_list] if isinstance(mnemo_list, str) else mnemo_list
mnemonics = list(set(mnemonics + mw_mnemos))
for m in mnemonics[:]:
if func(m) in mw_processed:
mnemonics.remove(m)
return mnemonics
""" Toolbox for SCS.
Various utilities function to quickly process data measured
at the SCS instrument.
Copyright (2019-) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import re
from toolbox_scs.misc.laser_utils import positionToDelay as pTd
from toolbox_scs.routines.XAS import xas
__all__ = [
'reflectivity'
]
def prepare_reflectivity_ds(ds, Iokey, Irkey, alternateTrains,
pumpOnEven, pumpedOnly):
"""
Sorts the dataset according to the bunch pattern:
Alternating pumped/unpumped pulses, alternating pumped/unpumped
trains, or pumped only.
"""
assert ds[Iokey].dims == ds[Irkey].dims, \
f"{Iokey} and {Irkey} do not share dimensions."
if alternateTrains:
p = 0 if pumpOnEven else 1
pumped_tid = ds['trainId'].where(ds.trainId % 2 == p, drop=True)
unpumped_tid = ds['trainId'].where(ds.trainId % 2 == int(not p),
drop=True)
max_size = min(pumped_tid.size, unpumped_tid.size)
pumped = ds.sel(trainId=pumped_tid[:max_size])
unpumped = ds.sel(trainId=unpumped_tid[:max_size]
).assign_coords(trainId=pumped.trainId)
for v in [Iokey, Irkey]:
pumped[v+'_unpumped'] = unpumped[v].rename(v+'_unpumped')
ds = pumped
elif pumpedOnly is False:
# check that number of pulses is even with pumped / unpumped pattern
dim_name = [dim for dim in ds[Iokey].dims if dim != 'trainId'][0]
if ds[dim_name].size % 2 == 1:
ds = ds.isel({dim_name: slice(0, -1)})
print('The dataset contains an odd number of pulses '
'per train. Ignoring the last pulse.')
pumped = ds.isel({dim_name: slice(0, None, 2)})
unpumped = ds.isel({dim_name: slice(1, None, 2)}).assign_coords(
{dim_name: pumped[dim_name]})
for v in [Iokey, Irkey]:
pumped[v+'_unpumped'] = unpumped[v].rename(v+'_unpumped')
ds = pumped
return ds
def reflectivity(data, Iokey='FastADC5peaks', Irkey='FastADC3peaks',
delaykey='PP800_DelayLine', binWidth=0.05,
positionToDelay=True, origin=None, invert=False,
pumpedOnly=False, alternateTrains=False, pumpOnEven=True,
Ioweights=False, plot=True, plotErrors=True, units='mm'
):
"""
Computes the reflectivity R = 100*(Ir/Io[pumped] / Ir/Io[unpumped] - 1)
as a function of delay. Delay can be a motor position in mm or an
optical delay in ps, with possibility to convert from position to delay.
The default scheme is alternating pulses pumped/unpumped/... in each
train, also possible are alternating trains and pumped only.
If fitting is enabled, attempts a double exponential (default) or step
function fit.
Parameters
----------
data: xarray Dataset
Dataset containing the Io, Ir and delay data
Iokey: str
Name of the Io variable
Irkey: str
Name of the Ir variable
delaykey: str
Name of the delay variable (motor position in mm or
optical delay in ps)
binWidth: float
width of bin in units of delay variable
positionToDelay: bool
If True, adds a time axis converted from position axis according
to origin and invert parameters. Ignored if origin is None.
origin: float
Position of time overlap, shown as a vertical line.
Used if positionToDelay is True to convert position to time axis.
invert: bool
Used if positionToDelay is True to convert position to time axis.
pumpedOnly: bool
Assumes that all trains and pulses are pumped. In this case,
Delta R is defined as Ir/Io.
alternateTrains: bool
If True, assumes that trains alternate between pumped and
unpumped data.
pumpOnEven: bool
Only used if alternateTrains=True. If True, even trains are pumped,
if False, odd trains are pumped.
Ioweights: bool
If True, computes the ratio of the means instead of the mean of
the ratios Irkey/Iokey. Useful when dealing with large intensity
variations.
plot: bool
If True, plots the results.
plotErrors: bool
If True, plots the 95% confidence interval.
Output
------
xarray Dataset containing the binned Delta R, standard deviation,
standard error, counts and delays, and the fitting results if full
is True.
"""
# select relevant variables from dataset
variables = [Iokey, Irkey, delaykey]
ds = data[variables]
# prepare dataset according to pulse pattern
ds = prepare_reflectivity_ds(ds, Iokey, Irkey, alternateTrains,
pumpOnEven, pumpedOnly)
if (len(ds[delaykey].dims) > 1) and (ds[delaykey].dims !=
ds[Iokey].dims):
raise ValueError("Dimensions mismatch: delay variable has dims "
f"{ds[delaykey].dims} but (It, Io) variables have "
f"dims {ds[Iokey].dims}.")
bin_delays = binWidth * np.round(ds[delaykey] / binWidth)
ds[delaykey+'_binned'] = bin_delays
counts = xr.ones_like(ds[Iokey]).groupby(bin_delays).sum(...)
if Ioweights is False:
ds['deltaR'] = ds[Irkey]/ds[Iokey]
if pumpedOnly is False:
ds['deltaR'] = 100*(ds['deltaR'] /
(ds[Irkey+'_unpumped']/ds[Iokey+'_unpumped']) - 1)
groupBy = ds.groupby(bin_delays)
binned = groupBy.mean(...)
std = groupBy.std(...)
binned['deltaR_std'] = std['deltaR']
binned['deltaR_stderr'] = std['deltaR'] / np.sqrt(counts)
binned['counts'] = counts.astype(int)
else:
xas_pumped = xas(ds, Iokey=Iokey, Itkey=Irkey, nrjkey=delaykey,
fluorescence=True, bins=binWidth)
if pumpedOnly:
deltaR = xas_pumped['muA']
stddev = xas_pumped['sigmaA']
else:
xas_unpumped = xas(ds, Iokey=Iokey+'_unpumped',
Itkey=Irkey+'_unpumped', nrjkey=delaykey,
fluorescence=True, bins=binWidth)
deltaR = 100*(xas_pumped['muA'] / xas_unpumped['muA'])
stddev = np.abs(deltaR) * np.sqrt(
(xas_pumped['sigmaA']/xas_pumped['muA'])**2 +
(xas_unpumped['sigmaA']/xas_unpumped['muA'])**2)
deltaR -= 100
deltaR = xr.DataArray(deltaR, dims=delaykey, name='deltaR',
coords={delaykey: xas_pumped['nrj']})
stddev = xr.DataArray(stddev, dims=delaykey, name='deltaR_std',
coords={delaykey: xas_pumped['nrj']})
stderr = xr.DataArray(stddev / np.sqrt(xas_pumped['counts']),
dims=delaykey, name='deltaR_stderr',
coords={delaykey: xas_pumped['nrj']})
counts = xr.DataArray(xas_pumped['counts'], dims=delaykey,
name='counts',
coords={delaykey: xas_pumped['nrj']})
binned = xr.merge([deltaR, stddev, stderr, counts])
# copy attributes
for key, val in data.attrs.items():
binned.attrs[key] = val
binned = binned.rename({delaykey: 'delay'})
if plot:
plot_reflectivity(binned, delaykey, positionToDelay,
origin, invert, plotErrors, units)
return binned
def plot_reflectivity(data, delaykey, positionToDelay, origin,
invert, plotErrors, units):
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
ax.plot(data['delay'], data['deltaR'], 'o-', color='C0')
xlabel = delaykey + f' [{units}]'
if plotErrors:
ax.fill_between(data['delay'],
data['deltaR'] - 1.96*data['deltaR_stderr'],
data['deltaR'] + 1.96*data['deltaR_stderr'],
color='C0', alpha=0.2)
ax2 = ax.twinx()
ax2.bar(data['delay'], data['counts'],
width=0.80*(data['delay'][1]-data['delay'][0]),
color='C1', alpha=0.2)
ax2.set_ylabel('counts', color='C1', fontsize=13)
ax2.set_ylim(0, data['counts'].max()*3)
if origin is not None:
ax.axvline(origin, color='grey', ls='--')
if positionToDelay:
ax3 = ax.twiny()
xmin, xmax = ax.get_xlim()
ax3.set_xlim(pTd(xmin, origin, invert),
pTd(xmax, origin, invert),)
ax3.set_xlabel('delay [ps]', fontsize=13)
try:
proposalNB = int(re.findall(r'p(\d{6})',
data.attrs['proposal'])[0])
runNB = data.attrs['runNB']
ax.set_title(f'run {runNB} p{proposalNB}', fontsize=14)
except Exception:
if 'plot_title' in data.attrs:
ax.set_title(data.attrs['plot_title'])
ax.set_xlabel(xlabel, fontsize=13)
ax.set_ylabel(r'$\Delta R$ [%]', color='C0', fontsize=13)
ax.grid()
return fig, ax
# -*- coding: utf-8 -*-
""" Toolbox for XAS experiments.
Based on the LCLS LO59 experiment libraries.
Time-resolved XAS and XMCD with uncertainties.
Copyright (2019-) SCS Team
Copyright (2017-2019) Loïc Le Guyader <loic.le.guyader@xfel.eu>
"""
from toolbox_scs.base.knife_edge import arrays_to1d
import numpy as np
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import re
__all__ = [
'xas',
'xasxmcd',
]
def absorption(T, Io, fluorescence=False):
""" Compute the absorption A = -ln(T/Io) (or A = T/Io
for fluorescence)
Inputs:
T: 1-D transmission value array of length N
Io: 1-D Io monitor value array of length N
fluorescence: boolean, if False, compute A as
negative log, if True, compute A as ratio
Output:
a structured array with:
muA: absorption mean
sigmaA: absorption standard deviation
weights: sum of Io values
muT: transmission mean
sigmaT: transmission standard deviation
muIo: Io mean
sigmaIo: Io standard deviation
p: correlation coefficient between T and Io
counts: length of T
"""
T = np.array(T)
Io = np.array(Io)
counts = len(T)
assert counts == len(Io), "T and Io must have the same length"
# remove not number from the data
good = np.logical_and(np.isfinite(T), np.isfinite(Io))
T = T[good]
Io = Io[good]
# return type of the structured array
fdtype = [('muA', 'f8'), ('sigmaA', 'f8'), ('weights', 'f8'),
('muT', 'f8'), ('sigmaT', 'f8'), ('muIo', 'f8'),
('sigmaIo', 'f8'), ('p', 'f8'), ('counts', 'i8')]
if counts == 0:
return np.array([(np.NaN, np.NaN, np.NaN, np.NaN, np.NaN, np.NaN,
np.NaN, np.NaN, 0)], dtype=fdtype)
muT = np.mean(T)
sigmaT = np.std(T)
muIo = np.mean(Io)
sigmaIo = np.std(Io)
weights = np.sum(Io)
p = np.corrcoef(T, Io)[0, 1]
# weighted average of T/Io with Io as weights
muA = muT / muIo
# Derivation of standard deviation
# 1. using biased weighted sample variance:
# sigmaA = np.sqrt(np.average((T/Io - muA)**2, weights=Io))
# 2. using unbiased weighted sample variance (reliablility weights):
V2 = np.sum(Io**2)
sigmaA = np.sqrt(np.sum(Io*(T/Io - muA)**2) / (weights - V2/weights))
# 3. using error propagation for correlated data:
# sigmaA = np.abs(muA)*(np.sqrt((sigmaT/muT)**2 +
# (sigmaIo/muIo)**2 - 2*p*sigmaIo*sigmaT/(muIo*muT)))
if not fluorescence:
sigmaA = sigmaA / np.abs(muA)
muA = -np.log(muA)
return np.array([(muA, sigmaA, weights, muT, sigmaT, muIo, sigmaIo,
p, counts)], dtype=fdtype)
def binning(x, data, func, bins=100, bin_length=None):
""" General purpose 1-dimension data binning
Inputs:
x: input vector of len N
data: structured array of len N
func: a function handle that takes data from a bin an return
a structured array of statistics computed in that bin
bins: array of bin-edges or number of desired bins
bin_length: if not None, the bin width covering the whole range
Outputs:
bins: the bins edges
res: a structured array of binned data
"""
if bin_length is not None:
bin_start = np.amin(x)
bin_end = np.amax(x)
bins = np.arange(bin_start, bin_end+bin_length, bin_length)
elif np.size(bins) == 1:
bin_start = np.amin(x)
bin_end = np.amax(x)
bins = np.linspace(bin_start, bin_end, bins)
bin_centers = (bins[1:]+bins[:-1])/2
nb_bins = len(bin_centers)
bin_idx = np.digitize(x, bins)
dummy = func([])
res = np.empty((nb_bins), dtype=dummy.dtype)
for k in range(nb_bins):
res[k] = func(data[k+1 == bin_idx])
return bins, res
def xas(nrun, bins=None, Iokey='SCS_SA3', Itkey='MCP3peaks', nrjkey='nrj',
Iooffset=0, plot=False, fluorescence=False):
""" Compute the XAS spectra from a xarray nrun.
Inputs:
nrun: xarray of SCS data
bins: an array of bin-edges or an integer number of
desired bins or a float for the desired bin width.
Iokey: string for the Io fields, typically 'SCS_XGM'
Itkey: string for the It fields, typically 'MCP3apd'
nrjkey: string for the nrj fields, typically 'nrj'
Iooffset: offset to apply on Io
plot: boolean, displays a XAS spectrum if True
fluorescence: boolean, if True, absorption is the ratio,
if False, absorption is negative log
Outputs:
a dictionnary containing:
nrj: the bin centers
muA: the absorption
sigmaA: standard deviation on the absorption
sterrA: standard error on the absorption
muIo: the mean of the Io
counts: the number of events in each bin
"""
Io = nrun[Iokey].values.flatten() + Iooffset
nrj, It = arrays_to1d(nrun[nrjkey].values, nrun[Itkey].values)
names_list = ['nrj', 'Io', 'It']
rundata = np.vstack((nrj, Io, It))
rundata = np.rec.fromarrays(rundata, names=names_list)
def whichIo(data):
""" Select which fields to use as I0 and which to use as I1
"""
if len(data) == 0:
return absorption([], [], fluorescence)
else:
Io_sign = np.sign(np.nanmean(data['Io']))
It_sign = np.sign(np.nanmean(data['It']))
return absorption(It_sign*data['It'], Io_sign*data['Io'],
fluorescence)
if bins is None:
num_bins = 80
energy_limits = [np.nanmin(nrj), np.nanmax(nrj)]
bins = np.linspace(energy_limits[0], energy_limits[1], num_bins+1)
elif type(bins) == int:
energy_limits = [np.nanmin(nrj), np.nanmax(nrj)]
bins = np.linspace(energy_limits[0], energy_limits[1], bins+1)
elif type(bins) == float:
energy_limits = [np.nanmin(nrj), np.nanmax(nrj)]
bins = np.arange(energy_limits[0], energy_limits[1], bins)
dummy, nosample = binning(rundata['nrj'], rundata, whichIo, bins)
muA = nosample['muA']
sterrA = nosample['sigmaA'] / np.sqrt(nosample['counts'])
bins_c = 0.5*(bins[1:] + bins[:-1])
if plot:
f = plt.figure(figsize=(6.5, 6))
gs = gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax1 = plt.subplot(gs[0])
ax1.plot(bins_c, muA, color='C1', label=r'$\sigma$')
if fluorescence:
ax1.set_ylabel('XAS (fluorescence)')
else:
ax1.set_ylabel('XAS (-log)')
ax1.set_xlabel('Energy (eV)')
ax1.legend()
ax1_twin = ax1.twinx()
ax1_twin.bar(bins_c, nosample['muIo'],
width=0.80*(bins_c[1]-bins_c[0]), color='C1', alpha=0.2)
ax1_twin.set_ylabel('Io')
try:
proposalNB = int(re.findall(r'p(\d{6})',
nrun.attrs['proposal'])[0])
runNB = nrun.attrs['runNB']
ax1.set_title(f'run {runNB} p{proposalNB}')
except:
if 'plot_title' in nrun.attrs:
f.suptitle(nrun.attrs['plot_title'])
ax2 = plt.subplot(gs[1])
ax2.bar(bins_c, nosample['counts'], width=0.80*(bins_c[1]-bins_c[0]),
color='C0', alpha=0.2)
ax2.set_xlabel('Energy (eV)')
ax2.set_ylabel('counts')
return {'nrj': bins_c, 'muA': muA, 'sterrA': sterrA,
'sigmaA': nosample['sigmaA'], 'muIo': nosample['muIo'],
'counts': nosample['counts']}
def xasxmcd(dataP, dataN):
""" Compute XAS and XMCD from data with both magnetic field direction
Inputs:
dataP: structured array for positive field
dataN: structured array for negative field
Outputs:
xas: structured array for the sum
xmcd: structured array for the difference
"""
assert len(dataP) == len(dataN), "binned datasets must be of same lengths"
assert not np.any(dataP['nrj'] - dataN['nrj']), "Energy points for " \
"dataP and dataN should be the same"
muXAS = dataP['muA'] + dataN['muA']
muXMCD = dataP['muA'] - dataN['muA']
# standard error is the same for XAS and XMCD
sigma = np.sqrt(dataP['sterrA']**2 + dataN['sterrA']**2)
res = np.empty(len(muXAS), dtype=[('nrj', 'f8'), ('muXAS', 'f8'),
('sigmaXAS', 'f8'), ('muXMCD', 'f8'),
('sigmaXMCD', 'f8')])
res['nrj'] = dataP['nrj']
res['muXAS'] = muXAS
res['muXMCD'] = muXMCD
res['sigmaXAS'] = sigma
res['sigmaXMCD'] = sigma
return res
from .XAS import *
from .boz import *
from .Reflectivity import *
# Module name is the same as a child function, we use alias to avoid conflict
import toolbox_scs.routines.knife_edge as knife_edge_module
from .knife_edge import *
__all__ = (
knife_edge_module.__all__
+ XAS.__all__
+ boz.__all__
+ Reflectivity.__all__
)
"""
Beam splitting Off-axis Zone plate analysis routines.
Copyright (2021, 2022, 2023, 2024) SCS Team.
"""
import time
import datetime
import json
import warnings
import numpy as np
import xarray as xr
import dask.array as da
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm
from matplotlib.patches import Polygon
from extra_data import open_run
from extra_geom import DSSC_1MGeometry
from toolbox_scs.routines.XAS import xas
try:
import cupy as cp
_can_use_gpu = True
except ModuleNotFoundError:
_can_use_gpu = False
print('Cupy is not installed in this environment, no access to the GPU')
except ImportError:
_can_use_gpu = False
print('Not currently running on a GPU node')
__all__ = [
'parameters',
'get_roi_pixel_pos',
'bad_pixel_map',
'inspect_dark',
'histogram_module',
'inspect_histogram',
'find_rois',
'find_rois_from_params',
'inspect_rois',
'compute_flat_field_correction',
'inspect_flat_field_domain',
'inspect_plane_fitting',
'plane_fitting_domain',
'plane_fitting',
'ff_refine_crit',
'ff_refine_fit',
'nl_domain',
'nl_lut',
'nl_crit',
'nl_crit_sk',
'nl_fit',
'inspect_nl_fit',
'snr',
'inspect_Fnl',
'inspect_correction',
'inspect_correction_sk',
'load_dssc_module',
'average_module',
'process_module',
'process',
'inspect_saturation'
]
class parameters():
"""Parameters contains all input parameters for the BOZ corrections.
This is used in beam splitting off-axis zone plate spectrocopy analysis as
well as the during the determination of correction parameters themselves to
ensure they can be reproduced.
Inputs
------
proposal: int, proposal number
darkrun: int, run number for the dark run
run: int, run number for the data run
module: int, DSSC module number
gain: float, number of ph per bin
drop_intra_darks: drop every second DSSC frame
"""
def __init__(self, proposal, darkrun, run, module, gain,
drop_intra_darks=True):
self.proposal = proposal
self.darkrun = darkrun
self.run = run
self.module = module
self.pixel_pos = _get_pixel_pos(self.module)
self.gain = gain
self.drop_intra_darks = drop_intra_darks
self.mask = None
self.mask_idx = None
self.mean_th = (None, None)
self.std_th = (None, None)
self.rois = None
self.rois_th = None
self.ff_type = 'plane'
self.flat_field = None
self.flat_field_prod_th = (5.0, np.PINF)
self.flat_field_ratio_th = (np.NINF, 1.2)
self.plane_guess_fit = None
self.use_hex = False
self.force_mirror = True
self.ff_alpha = None
self.ff_max_iter = None
self._using_gpu = False
self.Fnl = None
self.nl_alpha = None
self.sat_level = None
self.nl_max_iter = None
# temporary data
self.arr_dark = None
self.tid_dark = None
self.arr = None
self.tid = None
def dask_load_persistently(self, dark_data_size_Gb=None,
data_size_Gb=None):
"""Load dask data array in memory.
Inputs
------
dark_data_size_Gb: float, optional size of dark to load in memory,
in Gb
data_size_Gb: float, optional size of data to load in memory, in Gb
"""
self.arr_dark, self.tid_dark = load_dssc_module(self.proposal,
self.darkrun, self.module, drop_intra_darks=self.drop_intra_darks,
persist=True, data_size_Gb=dark_data_size_Gb)
self.arr, self.tid = load_dssc_module(self.proposal, self.run,
self.module, drop_intra_darks=self.drop_intra_darks,
persist=True, data_size_Gb=data_size_Gb)
# make sure to rechunk the arrays
self.arr = self.arr.rechunk(('auto', -1, -1, -1))
self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1))
def use_gpu(self):
assert _can_use_gpu, 'Failed to import cupy'
gpu_mem_gb = cp.cuda.Device().mem_info[1] / 2**30
if gpu_mem_gb < 30:
print(f'Warning: GPU memory ({gpu_mem_gb}GB) may be insufficient')
if self._using_gpu:
return
assert (
self.arr is not None and
self.arr_dark is not None
), "Must load data before switching to GPU"
if self.mask is not None:
self.mask = cp.array(self.mask)
# moving full data to GPU
limit = 2**30
self.arr = da.array(
cp.array(self.arr.compute())
).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
self.arr_dark = da.array(
cp.array(self.arr_dark.compute())
).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
self._using_gpu = True
def set_mask(self, arr):
"""Set mask of bad pixels.
Inputs
------
arr: either a boolean array of a DSSC module image or a list of bad
pixel indices
"""
if type(arr) is not list:
self.mask_idx = np.argwhere(arr == False).tolist()
self.mask = arr
else:
self.mask_idx = arr
mask = np.ones((128, 512), dtype=bool)
for k in self.mask_idx:
mask[k[0], k[1]] = False
self.mask = mask
if self._using_gpu:
self.mask = cp.array(self.mask)
def get_mask(self):
"""Get the boolean array bad pixel of a DSSC module."""
return self.mask
def get_mask_idx(self):
"""Get the list of bad pixel indices."""
return self.mask_idx
def flat_field_guess(self, guess=None):
"""Set the flat-field guess parameter for the fit and returns it.
Inputs
------
guess: a list of 8 floats, the 4 first to define the plane
ax+by+cz+d=0 for 'n' beam and the 4 last for the 'p' beam
in case mirror symmetry is disbaled
"""
if guess is not None:
self.plane_guess_fit = guess
return self.plane_guess_fit
if self.plane_guess_fit is None:
if self.use_hex:
self.plane_guess_fit = [
-20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ]
else:
self.plane_guess_fit = [
-0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54]
return self.plane_guess_fit
def set_flat_field(self, ff_params, ff_type='plane',
prod_th=None, ratio_th=None):
"""Set the flat-field plane definition.
Inputs
------
ff_params: list of parameters
ff_type: string identifying the type of flat field normalization,
default is 'plane'.
"""
self.ff_type = ff_type
if type(ff_params) is not list:
self.flat_field = ff_params.tolist()
else:
self.flat_field = ff_params
if prod_th is not None:
self.flat_field_prod_th = prod_th
if ratio_th is not None:
self.flat_field_ratio_th = ratio_th
def get_flat_field(self):
"""Get the flat-field plane definition."""
if self.flat_field is None:
return None
else:
return np.array(self.flat_field)
def set_Fnl(self, Fnl):
"""Set the non-linear correction function."""
if isinstance(Fnl, list):
self.Fnl = Fnl
else:
self.Fnl = Fnl.tolist()
def get_Fnl(self):
"""Get the non-linear correction function."""
if self.Fnl is None:
return None
else:
if self._using_gpu:
return cp.array(self.Fnl)
else:
return np.array(self.Fnl)
def save(self, path='./'):
"""Save the parameters as a JSON file.
Inputs
------
path: str, where to save the file, default to './'
"""
v = {}
v['proposal'] = self.proposal
v['darkrun'] = self.darkrun
v['run'] = self.run
v['module'] = self.module
v['gain'] = self.gain
v['drop_intra_darks'] = self.drop_intra_darks
v['mask'] = self.mask_idx
v['mean_th'] = self.mean_th
v['std_th'] = self.std_th
v['rois'] = self.rois
v['rois_th'] = self.rois_th
v['ff_type'] = self.ff_type
v['flat_field'] = self.flat_field
v['flat_field_prod_th'] = self.flat_field_prod_th
v['flat_field_ratio_th'] = self.flat_field_ratio_th
v['plane_guess_fit'] = self.plane_guess_fit
v['use_hex'] = self.use_hex
v['force_mirror'] = self.force_mirror
v['ff_alpha'] = self.ff_alpha
v['ff_max_iter'] = self.ff_max_iter
v['Fnl'] = self.Fnl
v['nl_alpha'] = self.nl_alpha
v['sat_level'] = self.sat_level
v['nl_max_iter'] = self.nl_max_iter
fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json'
with open(path + fname, 'w') as f:
json.dump(v, f)
print(path + fname)
@classmethod
def load(cls, fname):
"""Load parameters from a JSON file.
Inputs
------
fname: string, name a the JSON file to load
"""
with open(fname, 'r') as f:
v = json.load(f)
c = cls(v['proposal'], v['darkrun'], v['run'], v['module'], v['gain'],
v['drop_intra_darks'])
c.mean_th = v['mean_th']
c.std_th = v['std_th']
c.set_mask(v['mask'])
c.rois = v['rois']
c.rois_th = v['rois_th']
if 'ff_type' not in v:
v['ff_type'] = 'plane'
c.set_flat_field(v['flat_field'], v['ff_type'],
v['flat_field_prod_th'], v['flat_field_ratio_th'])
c.plane_guess_fit = v['plane_guess_fit']
c.use_hex = v['use_hex']
c.force_mirror = v['force_mirror']
c.ff_alpha = v['ff_alpha']
c.ff_max_iter = v['ff_max_iter']
c.set_Fnl(v['Fnl'])
c.nl_alpha = v['nl_alpha']
c.sat_level = v['sat_level']
c.nl_max_iter = v['nl_max_iter']
return c
def __str__(self):
f = f'proposal:{self.proposal} darkrun:{self.darkrun} run:{self.run}'
f += f' module:{self.module} gain:{self.gain} ph/bin\n'
f += f'drop intra darks:{self.drop_intra_darks}\n'
if self.mask_idx is not None:
f += f'mean threshold:{self.mean_th} std threshold:{self.std_th}\n'
f += f'mask:(#{len(self.mask_idx)}) {self.mask_idx}\n'
else:
f += 'mask:None\n'
f += f'rois threshold: {self.rois_th}\n'
f += f'rois: {self.rois}\n'
f += f'flat-field type: {self.ff_type}\n'
f += f'flat-field p: {self.flat_field} '
f += f'prod:{self.flat_field_prod_th} '
f += f'ratio:{self.flat_field_ratio_th}\n'
f += f'plane guess fit: {self.plane_guess_fit}\n'
f += f'use hexagons: {self.use_hex}\n'
f += f'enforce mirror symmetry: {self.force_mirror}\n'
f += f'ff alpha: {self.ff_alpha}, max. iter.: {self.ff_max_iter}\n'
if self.Fnl is not None:
f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
f += f'nl alpha:{self.nl_alpha}, sat. level:{self.sat_level}, '
f += f' nl max. iter.:{self.nl_max_iter}'
else:
f += 'Fnl: None'
return f
def ensure_on_host(arr):
# load data back from GPU - if it was on GPU
if hasattr(arr, "__cuda_array_interface__"): # avoid importing CuPy
return arr.get()
elif isinstance(arr, (da.Array,)):
return arr.map_blocks(ensure_on_host)
return arr
# Hexagonal pixels related function
def _get_pixel_pos(module):
"""Compute the pixel position on hexagonal lattice of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
# keeping only module 15 pixel X,Y position
return g.get_pixel_positions()[module][:, :, :2]
def get_roi_pixel_pos(roi, params):
"""Compute fake or real pixel position of an roi from roi center.
Inputs:
-------
roi: dictionnary
params: parameters
Returns:
--------
X, Y: 1-d array of pixel position.
"""
if params.use_hex:
# DSSC pixel position on hexagonal lattice
X = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 0]
Y = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 1]
else:
nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
# center of ROI is put to 0,0
X -= np.mean(X)
Y -= np.mean(Y)
return X, Y
def _get_pixel_corners(module):
"""Compute the pixel corners of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
# corners are in z,y,x oder so we rop z, flip x & y
corners = g.to_distortion_array(allow_negative_xy=True)
corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1]
return corners
def _get_pixel_hexagons(module):
"""Compute DSSC pixel hexagons for plotting.
Parameters:
-----------
module: either int, for the module number or a 2-d array of corners to
get hexagons from
Returns:
--------
a 1-d list of hexagons where corners position are in mm
"""
hexes = []
if type(module) is int:
corners = _get_pixel_corners(module)
else:
corners = module
for y in range(corners.shape[0]):
for x in range(corners.shape[1]):
c = 1e3*corners[y, x, :, :] # convert to mm
hexes.append(Polygon(c))
return hexes
def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
"""Add a colobar on a new axes so it match the plot size.
Inputs
------
im: image plotted
ax: axes on which the image was plotted
loc: string, default 'right', location of the colorbar
size: string, default '5%', proportion of the colobar with respect to the
plotted image
pad: float, default 0.05, pad width between plot and colorbar
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes(loc, size=size, pad=pad)
cbar = fig.colorbar(im, cax=cax)
return cbar
# dark related functions
def bad_pixel_map(params):
"""Compute the bad pixels map.
Inputs
------
params: parameters
Returns
-------
bad pixel map
"""
assert params.arr_dark is not None, "Data not loaded"
# compute mean and std
dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
dark_std = params.arr_dark.std(axis=(0, 1)).compute()
mask = np.ones_like(dark_mean)
if params.mean_th[0] is not None:
mask *= dark_mean >= params.mean_th[0]
if params.mean_th[1] is not None:
mask *= dark_mean <= params.mean_th[1]
if params.std_th[0] is not None:
mask *= dark_std >= params.std_th[0]
if params.std_th[1] is not None:
mask *= dark_std >= params.std_th[1]
print(f'# bad pixel: {int(128*512-mask.sum())}')
return mask.astype(bool)
def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
"""Inspect dark run data and plot diagnostic.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mean_th: tuple of threshold (low, high), default (None, None), to compute
a mask of good pixels for which the mean dark value lie inside this
range
std_th: tuple of threshold (low, high), default (None, None), to compute a
mask of bad pixels for which the dark std value lie inside this
range
Returns
-------
fig: matplotlib figure
"""
# compute mean and std
dark_mean = ensure_on_host(arr.mean(axis=(0, 1)).compute())
dark_std = ensure_on_host(arr.std(axis=(0, 1)).compute())
fig = plt.figure(figsize=(7, 2.7))
gs = fig.add_gridspec(2, 4)
ax1 = fig.add_subplot(gs[0, 1:])
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax11 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 1:])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax22 = fig.add_subplot(gs[1, 0])
vmin = np.percentile(dark_mean.flatten(), 2)
vmax = np.percentile(dark_mean.flatten(), 98)
im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
ax1.invert_yaxis()
ax1.set_aspect('equal')
cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
cbar1.ax.set_ylabel('dark mean')
ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
range=(vmin/2, vmax*2))
if mean_th[0] is not None:
ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
if mean_th[1] is not None:
ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
ax11.set_yscale('log')
vmin = np.percentile(dark_std.flatten(), 2)
vmax = np.percentile(dark_std.flatten(), 98)
im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
ax2.invert_yaxis()
ax2.set_aspect('equal')
cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
cbar2.ax.set_ylabel('dark std')
ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
if std_th[0] is not None:
ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
if std_th[1] is not None:
ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
ax22.set_yscale('log')
return fig
# histogram related functions
def histogram_module(arr, mask=None):
"""Compute a histogram of the 9 bits raw pixel values over a module.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mask: optional bad pixel mask
Returns
-------
histogram
"""
if mask is not None:
w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
arr.shape[1], axis=1), arr.shape[0], axis=0)
w = w.rechunk(arr.chunks)
return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
else:
return da.bincount(arr.ravel(), minlength=512).compute()
def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
"""Compute and plot a histogram of the 9 bits raw pixel values.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
mask: optional bad pixel mask
extra_lines: boolean, default False, plot extra lines at period values
Returns
-------
(h, hd): histogram of arr, arr_dark
figure
"""
from matplotlib.ticker import MultipleLocator
f = plt.figure(figsize=(6, 3))
ax = plt.gca()
h = ensure_on_host(histogram_module(arr, mask=mask))
Sum_h = np.sum(h)
ax.plot(np.arange(2**9), h/Sum_h, marker='o',
ms=3, markerfacecolor='none', lw=1)
if arr_dark is not None:
hd = ensure_on_host(histogram_module(arr_dark, mask=mask))
Sum_hd = np.sum(hd)
ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
else:
hd = None
if extra_lines:
for k in range(50, 271):
if not (k - 2) % 8:
ax.axvline(k, c='k', alpha=0.5, ls='--')
if not (k - 3) % 16:
ax.axvline(k, c='g', alpha=0.3, ls='--')
if not (k - 7) % 32:
ax.axvline(k, c='r', alpha=0.3, ls='--')
ax.axvline(271, c='C1', alpha=0.5, ls='--')
ax.set_xlim([0, 2**9-1])
ax.set_yscale('log')
ax.xaxis.set_minor_locator(MultipleLocator(10))
ax.set_xlabel('DSSC pixel value')
ax.set_ylabel('count frequency')
return (h, hd), f
# rois related function
def find_rois(data_mean, threshold, extended=False):
"""Find rois from 3 beams configuration.
Inputs
------
data_mean: dark corrected average image
threshold: threshold value to find beams
extended: boolean, True to define additional ASICS based rois
Returns
-------
rois: dictionnary of rois
"""
# compute vertical and horizontal projection
pX = data_mean.mean(axis=0)
pX = pX[:256] # half the ladder since there is a gap in the middle
pY = data_mean.mean(axis=1)
pX = pX/np.max(pX)
pY = pY/np.max(pY)
# along X
lowX = int(np.argmax(pX > threshold) - 1) # 1st occurrence returned
highX = int(pX.shape[0] -
np.argmax(pX[::-1] > threshold)) # last occ. returned
midX = int(0.5*(lowX+highX))
leftX2 = int(np.argmax(pX[lowX+5:midX-5] < threshold)) + lowX + 5
midX2 = int(np.argmax(pX[midX+5:highX-5] < threshold)) + midX + 5
midX1 = int(midX - 5 - np.argmax(pX[midX-5:lowX+5:-1] < threshold))
rightX1 = int(highX - 5 - np.argmax(pX[highX-5:midX+5:-1] < threshold))
# along Y
lowY = int(np.argmax(pY > threshold) - 1) # 1st occurrence returned
highY = int(pY.shape[0]
- np.argmax(pY[::-1] > threshold)) # last occ. returned
# define rois
rois = {}
# beam roi
rois['n'] = {'xl': lowX, 'xh': leftX2, 'yl': lowY, 'yh': highY}
rois['0'] = {'xl': midX1, 'xh': midX2, 'yl': lowY, 'yh': highY}
rois['p'] = {'xl': rightX1, 'xh': highX, 'yl': lowY, 'yh': highY}
# saturation roi
rois['sat'] = {'xl': lowX, 'xh': highX, 'yl': lowY, 'yh': highY}
if extended:
# baseline correction rois
for k in [0, 1, 2, 3]:
rois[f'b{k}'] = {'xl': k*64, 'xh': (k+1)*64, 'yl': 0, 'yh': lowY}
for k in [8, 9, 10, 11]:
rois[f'b{k}'] = {'xl': (k-8)*64, 'xh': (k+1-8)*64,
'yl': highY, 'yh': 128}
# ASICs splitted beam roi
rois['0X'] = {'xl': lowX, 'xh': 1*64, 'yl': lowY, 'yh': 64}
rois['1X1'] = {'xl': 64, 'xh': leftX, 'yl': lowY, 'yh': 64}
rois['1X2'] = {'xl': leftX, 'xh': 2*64, 'yl': lowY, 'yh': 64}
rois['2X1'] = {'xl': 2*64, 'xh': rightX, 'yl': lowY, 'yh': 64}
rois['2X2'] = {'xl': rightX, 'xh': 3*64, 'yl': lowY, 'yh': 64}
rois['3X'] = {'xl': 3*64, 'xh': highX, 'yl': lowY, 'yh': 64}
rois['8X'] = {'xl': lowX, 'xh': 1*64, 'yl': 64, 'yh': highY}
rois['9X1'] = {'xl': 64, 'xh': leftX, 'yl': 64, 'yh': highY}
rois['9X2'] = {'xl': leftX, 'xh': 2*64, 'yl': 64, 'yh': highY}
rois['10X1'] = {'xl': 2*64, 'xh': rightX, 'yl': 64, 'yh': highY}
rois['10X2'] = {'xl': rightX, 'xh': 3*64, 'yl': 64, 'yh': highY}
rois['11X'] = {'xl': 3*64, 'xh': highX, 'yl': 64, 'yh': highY}
return rois
def find_rois_from_params(params):
"""Find rois from 3 beams configuration.
Inputs
------
params: parameters
Returns
-------
rois: dictionnary of rois
"""
assert params.arr_dark is not None, "Data not loaded"
dark = average_module(params.arr_dark).compute()
assert params.arr is not None, "Data not loaded"
data = average_module(params.arr, dark=dark).compute()
data_mean = data.mean(axis=0) # mean over pulseId
threshold = params.rois_th
return find_rois(data_mean, threshold)
def inspect_rois(data_mean, rois, threshold=None, allrois=False):
"""Find rois from 3 beams configuration from mean module image.
Inputs
------
data_mean: mean module image
threshold: float, default None, threshold value used to detect beams
boundaries
allrois: boolean, default False, plot all rois defined in rois or only the
main ones (['n', '0', 'p'])
Returns
-------
matplotlib figure
"""
# compute vertical and horizontal projection
pX = data_mean.mean(axis=0)
pX = pX[:256] # half the ladder since there is a gap in the middle
pY = data_mean.mean(axis=1)
pX = pX/np.max(pX)
pY = pY/np.max(pY)
# Set up the axes with gridspec
fig = plt.figure(figsize=(5, 3))
grid = plt.GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(2, 1),
# left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.05, hspace=0.05,
figure=fig)
main_ax = fig.add_subplot(grid[0, 1])
y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax)
x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax)
# scatter points on the main axes
Xs = np.arange(len(pX))
Ys = np.arange(len(pY))
main_ax.pcolormesh(Xs, Ys, np.flipud(data_mean[:, :256]),
cmap='Greys_r',
vmin=0,
vmax=np.percentile(data_mean[:, :256], 99))
main_ax.set_aspect('equal')
from matplotlib.patches import Rectangle
roi = rois['n']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='b'))
roi = rois['0']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='g'))
roi = rois['p']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='r'))
x.plot(Xs, pX)
x.invert_yaxis()
if threshold is not None:
x.axhline(threshold, c='k', alpha=.5, ls='--')
x.axvline(rois['n']['xl'], c='b', alpha=.3)
x.axvline(rois['n']['xh'], c='b', alpha=.3)
x.axvline(rois['0']['xl'], c='g', alpha=.3)
x.axvline(rois['0']['xh'], c='g', alpha=.3)
x.axvline(rois['p']['xl'], c='r', alpha=.3)
x.axvline(rois['p']['xh'], c='r', alpha=.3)
y.plot(pY, np.arange(len(pY)-1, -1, -1))
y.invert_xaxis()
if threshold is not None:
y.axvline(threshold, c='k', alpha=.5, ls='--')
y.axhline(127-rois['p']['yl'], c='r', alpha=.5)
y.axhline(127-rois['p']['yh'], c='r', alpha=.5)
return fig
# Flat-field related functions
def _plane_flat_field(p, roi, params):
"""Compute the p plane over the given roi.
Given the plane parameters p, compute the plane over the roi
size.
Parameters
----------
p: a vector of a, b, c, d plane parameter with the
plane given by ax+ by + cz + d = 0
roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
params: parameters
Returns
-------
the plane field given by p evaluated on the roi
extend.
"""
a, b, c, d = p
X, Y = get_roi_pixel_pos(roi, params)
Z = -(a*X + b*Y + d)/c
return Z
def compute_flat_field_correction(rois, params, plot=False):
if params.ff_type == 'plane':
return compute_plane_flat_field_correction(rois, params, plot)
elif params.ff_type == 'polyline':
return compute_polyline_flat_field_correction(rois, params, plot)
else:
raise ValueError(f'Uknown flat field type {params.ff_type}')
def compute_plane_flat_field_correction(rois, params, plot=False):
"""Compute the plane-field correction on beam rois.
Inputs
------
rois: dictionnary of beam rois['n', '0', 'p']
params: parameters
plot: boolean, True by default, diagnostic plot
Returns
-------
numpy 2D array of the flat-field correction evaluated over one DSSC ladder
(2 sensors)
"""
flat_field = np.ones((128, 512))
plane = params.get_flat_field()
force_mirror = params.force_mirror
r = rois['n']
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[:4], r, params)
r = rois['p']
if force_mirror:
a, b, c, d = plane[:4]
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field([-a, b, c, d], r, params)
else:
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[4:], r, params)
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(
np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
return flat_field
def initialize_polyline_ff_correction(avg, rois, params, plot=False):
"""Initialize the polyline flat field correction.
Inputs
------
avg: 2D array, average module image
rois: dictionnary of ROIs.
plot: boolean, plot initialized polyline versus data projection
Returns
-------
fig: handle to figure or None
"""
refn = avg[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
refp = avg[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
mid = avg[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
mref = 0.5*(refn + refp)
inv_signal = mref/mid # normalization
H_projection = inv_signal[:, :].mean(axis=0)
x = np.arange(0, len(H_projection))
H_z = np.polyfit(x, H_projection, 6)
H_p = np.poly1d(H_z)
V_projection = (inv_signal/H_p(x))[:, :].mean(axis=1)
y = np.arange(0, len(V_projection))
V_z = np.polyfit(y, V_projection, 6)
if plot:
fig, axs = plt.subplots(2, 1, figsize=(4,6))
axs[0].plot(x, H_projection, label='data (n+p)/2x0')
axs[0].plot(x, H_p(x), label='poly')
axs[0].legend()
axs[0].set_xlabel('x (px)')
axs[0].set_ylabel('H projection')
axs[1].plot(y, V_projection, label='data (n+p)/2x0')
V_p = np.poly1d(V_z)
axs[1].plot(y, V_p(y), label='poly')
axs[1].legend()
axs[1].set_xlabel('y (px)')
axs[1].set_ylabel('V projection')
else:
fig = None
# scaling on polynom coefficients for better fitting
ff = np.array([H_z/np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]),
V_z/np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])])
params.set_flat_field(ff.flatten())
params.ff_type = 'polyline'
return fig
def compute_polyline_flat_field_correction(rois, params, plot=False):
"""Compute the 1D polyline field correction on beam rois.
Inputs
------
rois: dictionnary of beam rois['n', '0', 'p']
params: parameters
plot: boolean, True by default, diagnostic plot
Returns
-------
numpy 2D array of the flat-field correction evaluated over one DSSC ladder
(2 sensors)
"""
flat_field = np.ones((128, 512))
z = np.array(params.get_flat_field()).reshape((2, -1))
H_z = z[0, :]
V_z = z[1, :]
coeffs = np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0])
H_p = np.poly1d(H_z*coeffs)
coeffs = np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])
V_p = np.poly1d(V_z*coeffs)
n = rois['n']
p = rois['p']
wn = n['xh']-n['xl']
wp = p['xh']-p['xl']
assert wn == wp, (\
f"For polyline flat field normalization, both 'n' and 'p' ROIs "
f"must have the same width {wn} and {wp}px"
)
x = np.arange(wn)
wn = n['yh']-n['yl']
y = np.arange(wn)
norm = V_p(y)[:, np.newaxis]*H_p(x)
n_int = flat_field[n['yl']:n['yh'], n['xl']:n['xh']]
flat_field[n['yl']:n['yh'], n['xl']:n['xh']] = \
norm*n_int
p_int = flat_field[p['yl']:p['yh'], p['xl']:p['xh']]
flat_field[p['yl']:p['yh'], p['xl']:p['xh']] = \
norm*p_int # not the mirror
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(
np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
return flat_field
def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None,
vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary or ROIs
prod_th, ratio_th: tuple of floats for low and high threshold on
product and ratio
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9))
img_rois = {}
centers = {}
for k, r in enumerate(['n', '0', 'p']):
roi = rois[r]
centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
(roi['xl'] + roi['xh'])//2])
d = '0'
roi = rois[d]
for k, r in enumerate(['n', '0', 'p']):
img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
roi['yl']:roi['yh'], roi['xl']:roi['xh']]
im = axs[0, k].imshow(img_rois[r],
vmin=vmin,
vmax=vmax)
n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th)
prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
for k, r in enumerate(['n', '0', 'p']):
v = img_rois[r]*img_rois['0']
if prod_vmin is None:
prod_vmin = np.percentile(v, .5)
prod_vmax = np.percentile(v, 20) # we look for low intensity region
im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2))
v = img_rois[r]/img_rois['0']
if ratio_vmin is None:
ratio_vmin = np.percentile(v, 5)
ratio_vmax = np.percentile(v, 99.8)
im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax,
cmap='RdBu_r')
axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
cbar.ax.set_xlabel('product')
cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal")
cbar.ax.set_xlabel('ratio')
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
domain = (n_m, p_m)
return fig, domain
def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None):
warnings.warn("This method is depreciated, use inspect_ff_fitting instead")
return inspect_ff_fitting(avg, rois, domain, vmin, vmax)
def inspect_ff_fitting(avg, rois, domain=None, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary of rois
domain: list of domain mask for the -1st and +1st order
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6))
img_rois = {}
centers = {}
for k, r in enumerate(['n', '0', 'p']):
roi = rois[r]
centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
(roi['xl'] + roi['xh'])//2])
d = '0'
roi = rois[d]
for k, r in enumerate(['n', '0', 'p']):
img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
roi['yl']:roi['yh'], roi['xl']:roi['xh']]
im = axs[0, k].imshow(img_rois[r],
vmin=vmin,
vmax=vmax)
for k, r in enumerate(['n', '0', 'p']):
v = img_rois[r]/img_rois['0']
im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')
if domain is not None:
n_m, p_m = domain
axs[1, 0].contour(n_m)
axs[1, 2].contour(p_m)
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
cbar.ax.set_xlabel('ratio')
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
return fig
def inspect_ff_fitting_sk(avg, rois, ff, domain=None, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary of rois
ff: 2D array, flat field normalization
domain: list of domain mask for the -1st and +1st order
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
refn = avg[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
refp = avg[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
mid = avg[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
mref = 0.5*(refn + refp)
ffn = ff[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
ffp = ff[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
ffmid = ff[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
np_norm = 0.5*(ffn+ffp)
mid_norm = ffmid
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True,
figsize=(8, 4))
im = axs[0, 0].imshow(mref)
axs[0, 0].set_title('(n+p)/2')
fig.colorbar(im, ax=axs[0, 0])
im = axs[1, 0].imshow(mid)
axs[1, 0].set_title('0')
fig.colorbar(im, ax=axs[1, 0])
im = axs[2, 0].imshow(mid/mref-1, cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 0].set_title('2x0/(n+p) - 1')
fig.colorbar(im, ax=axs[2, 0])
im = axs[0, 1].imshow(np_norm)
axs[0, 1].set_title('norm: (n+p)/2')
fig.colorbar(im, ax=axs[0, 1])
im = axs[1, 1].imshow(mid_norm)
axs[1, 1].set_title('norm: 0')
fig.colorbar(im, ax=axs[1, 1])
im = axs[2, 1].imshow(mid_norm/np_norm-1, cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 1].set_title('norm: 2x0/(n+p) - 1')
fig.colorbar(im, ax=axs[2, 1])
im = axs[0, 2].imshow(mref/np_norm)
axs[0, 2].set_title('(n+p)/2 /norm')
fig.colorbar(im, ax=axs[0, 2])
im = axs[1, 2].imshow(mid/mid_norm)
axs[1, 2].set_title('0 /norm')
fig.colorbar(im, ax=axs[1, 2])
im = axs[2, 2].imshow((mid/mid_norm)/(mref/np_norm)-1,
cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 2].set_title('2x0/(n+p) - 1 /norm')
fig.colorbar(im, ax=axs[2, 2])
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
return fig
def plane_fitting_domain(avg, rois, prod_th, ratio_th):
"""Extract beams roi, compute their ratio and the domain.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
as the reference beam in the middle
prod_th: float tuple, low and hight threshold level to determine the plane
fitting domain on the product image of the orders
ratio_th: float tuple, low and high threshold level to determine the plane
fitting domain on the ratio image of the orders
Returns
-------
n: img ratio 'n'/'0'
n_m: mask where the the product 'n'*'0' is higher than 5 indicting that the
img ratio 'n'/'0' is defined
p: img ratio 'p'/'0'
p_m: mask where the the product 'p'*'0' is higher than 5 indicting that the
img ratio 'p'/'0' is defined
"""
centers = {}
for k in ['n', '0', 'p']:
r = rois[k]
centers[k] = np.array([(r['yl'] + r['yh'])//2,
(r['xl'] + r['xh'])//2])
k = 'n'
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
r['yl']:r['yh'], r['xl']:r['xh']]
n = num/denom
prod = num*denom
n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(n > ratio_th[0]) * (n < ratio_th[1]))
n_m[~np.isfinite(n)] = 0
n[~np.isfinite(n)] = 0
k = 'p'
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
r['yl']:r['yh'], r['xl']:r['xh']]
p = num/denom
prod = num*denom
p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(p > ratio_th[0]) * (p < ratio_th[1]))
p_m[~np.isfinite(p)] = 0
p[~np.isfinite(p)] = 0
return n, n_m, p, p_m
def plane_fitting(params):
"""Fit the plane flat-field normalization.
Inputs
------
params: parameters
Returns
-------
res: the minimization result. The fitted vector res.x = [a, b, c, d]
defines the plane as a*x + b*y + c*z + d = 0
"""
assert params.arr_dark is not None, "Data not loaded"
dark = average_module(params.arr_dark).compute()
assert params.arr is not None, "Data not loaded"
data = average_module(params.arr, dark=dark,
ret='mean', mask=params.mask, sat_roi=params.rois['sat'],
sat_level=params.sat_level).compute()
data_mean = data.mean(axis=0) # mean over pulseId
n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois,
params.flat_field_prod_th, params.flat_field_ratio_th)
def _crit(x):
"""Fitting criteria for the plane field normalization.
Inputs
------
x: 2 vector [a, b, c, d] concatenated defining the plane as
a*x + b*y + c*z + d = 0
"""
a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x
num_n = a_n**2 + b_n**2 + c_n**2
roi = params.rois['n']
X, Y = get_roi_pixel_pos(roi, params)
d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n
num_p = a_p**2 + b_p**2 + c_p**2
roi = params.rois['p']
X, Y = get_roi_pixel_pos(roi, params)
if params.force_mirror:
d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n
else:
d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p
return 1e3*(d2_2 + d0_2)
res = minimize(_crit, params.flat_field_guess())
return res
def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
mask, sat_level=511):
"""Criteria for the ff_refine_fit.
Inputs
------
p: ff plane
params: parameters
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
sum of standard deviation on binned 0th order intensity
"""
params.set_flat_field(p)
ff = compute_flat_field_correction(rois, params)
if np.any(ff < 0.0):
bad = 1e6
else:
bad = 0.0
data = process(None, arr_dark, arr, tid, rois, mask, ff,
sat_level, params._using_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0', fluorescence=True)
rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0', fluorescence=True)
rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0', fluorescence=True)
err_sigma = (np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA'])
+ np.nansum(rd['sigmaA']))
err_mean = ((1.0 - np.nanmean(rn['muA']))**2 +
(1.0 - np.nanmean(rp['muA']))**2 +
(1.0 - np.nanmean(rd['muA']))**2)
return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
def ff_refine_crit_sk(p, alpha, params, arr_dark, arr, tid, rois,
mask, sat_level=511):
"""Criteria for the ff_refine_fit, combining 'n' and 'p' as reference.
Inputs
------
p: ff plane
params: parameters
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
sum of standard deviation on binned 0th order intensity
"""
params.set_flat_field(p, params.ff_type)
ff = compute_flat_field_correction(rois, params)
if np.any(ff < 0.0):
bad = 1e6
else:
bad = 0.0
data = process(None, arr_dark, arr, tid, rois, mask, ff,
sat_level, params._using_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
r = xas(d, 40, Iokey='np_mean_sk', Itkey='0', nrjkey='0', fluorescence=True)
err_sigma = np.nansum(r['sigmaA'])
err_mean = (1.0 - np.nanmean(r['muA']))**2
return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
def ff_refine_fit(params, crit=ff_refine_crit):
"""Refine the flat-field fit by minimizing data spread.
Inputs
------
params: parameters
Returns
-------
res: scipy minimize result. res.x is the optimized parameters
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
p0 = params.get_flat_field()
if p0 is None: # flat field was not yet fitted
p0 = params.flat_field_guess()
fixed_p = (params.ff_alpha, params, params.arr_dark, params.arr,
params.tid, fitrois, params.get_mask(), params.sat_level)
def fit_callback(x):
if not hasattr(fit_callback, "counter"):
fit_callback.counter = 0 # it doesn't exist yet, so initialize it
fit_callback.start = time.monotonic()
fit_callback.res = []
now = time.monotonic()
time_delta = datetime.timedelta(seconds=now-fit_callback.start)
fit_callback.counter += 1
temp = list(fixed_p)
Jalpha = crit(x, *temp)
temp[0] = 0
J0 = crit(x, *temp)
temp[0] = 1
J1 = crit(x, *temp)
fit_callback.res.append([J0, Jalpha, J1])
print(f'{fit_callback.counter-1}: {time_delta} '
f'(reg. term: {J0}, {Jalpha}, err. term: {J1}), {x}')
return False
fit_callback(p0)
res = minimize(crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.ff_max_iter},
callback=fit_callback)
return res, fit_callback.res
# non-linearity related functions
def nl_domain(N, low, high):
"""Create the input domain where the non-linear correction defined.
Inputs
------
N: integer, number of control points or intervals
low: input values below or equal to low will not be corrected
high: input values higher or equal to high will not be corrected
Returns
-------
array of 2**9 integer values with N segments
"""
x = np.arange(2**9)
vx = x.copy()
eps = 1e-5
vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
vx[x <= low] = 0
vx[x >= high] = 0
return vx
def nl_lut(domain, dy):
"""Compute the non-linear correction.
Inputs
------
domain: input domain where dy is defined. For zero no correction is
defined. For non-zero value x, dy[x] is applied.
dy: a vector of deviation from linearity on control point homogeneously
dispersed over 9 bits.
Returns
-------
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
"""
x = np.arange(2**9)
ndy = np.insert(dy, 0, 0) # add zero to dy
f = x + ndy[domain]
return f
def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
sat_level=511, use_gpu=False):
"""Criteria for the non linear correction.
Inputs
------
p: vector of dy non linear correction
domain: domain over which the non linear correction is defined
alpha: float, coefficient scaling the cost of the correction function
in the criterion
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
(1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
error squared from a transmission of 1.0 and err2 is the sum of the square
of the deviation from the ideal detector response.
"""
Fmodel = nl_lut(domain, p)
data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
arr, tid, rois, mask, flat_field, sat_level, use_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err_1 = 1e8*v_1['weighted']['s']**2
v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err_2 = 1e8*v_2['weighted']['s']**2
err_a = np.sum((Fmodel-np.arange(2**9))**2)
return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
sat_level=511, use_gpu=False):
"""Non linear correction criteria, combining 'n' and 'p' as reference.
Inputs
------
p: vector of dy non linear correction
domain: domain over which the non linear correction is defined
alpha: float, coefficient scaling the cost of the correction function
in the criterion
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
(1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
error squared from a transmission of 1.0 and err2 is the sum of the square
of the deviation from the ideal detector response.
"""
Fmodel = nl_lut(domain, p)
data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
arr, tid, rois, mask, flat_field, sat_level, use_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
v = snr(d['np_mean_sk'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err = 1e8*v['weighted']['s']**2
err_a = np.sum((Fmodel-np.arange(2**9))**2)
return (1.0 - alpha)*err + alpha*err_a
def nl_fit(params, domain, ff=None, crit=None):
"""Fit non linearities correction function.
Inputs
------
params: parameters
domain: array of index
ff: array, flat field correction
crit: function, criteria function
Returns
-------
res: scipy minimize result. res.x is the optimized parameters
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# p0
N = np.unique(domain).shape[0] - 1
p0 = np.array([0]*N)
# flat flat_field
if ff is None:
ff = compute_flat_field_correction(params.rois, params)
if crit is None:
crit = nl_crit
fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr,
params.tid, fitrois, params.get_mask(), ff, params.sat_level,
params._using_gpu)
def fit_callback(x):
if not hasattr(fit_callback, "counter"):
fit_callback.counter = 0 # it doesn't exist yet, so initialize it
fit_callback.start = time.monotonic()
fit_callback.res = []
now = time.monotonic()
time_delta = datetime.timedelta(seconds=now-fit_callback.start)
fit_callback.counter += 1
temp = list(fixed_p)
Jalpha = crit(x, *temp)
temp[1] = 0
J0 = crit(x, *temp)
temp[1] = 1
J1 = crit(x, *temp)
fit_callback.res.append([J0, Jalpha, J1])
print(f'{fit_callback.counter-1}: {time_delta} '
f'({J0}, {Jalpha}, {J1}), {x}')
return False
fit_callback(p0)
res = minimize(crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.nl_max_iter},
callback=fit_callback)
return res, fit_callback.res
def inspect_nl_fit(res_fit):
"""Plot the progress of the fit.
Inputs
------
res_fit:
Returns
-------
matplotlib figure
"""
r = np.array(res_fit)
f = plt.figure(figsize=(6, 4))
ax = f.gca()
ax2 = plt.twinx()
ax.plot(1.0/np.sqrt(1e-8*r[:, 0]), c='C0')
ax2.plot(r[:, 2], c='C1', ls='-.')
ax.set_xlabel('# iteration')
ax.set_ylabel('SNR', color='C0')
ax2.set_ylabel('correction cost', color='C1')
ax.set_yscale('log')
ax2.set_yscale('log')
return f
def snr(sig, ref, methods=None, verbose=False):
""" Compute mean, std and SNR from transmitted and I0 signals.
Inputs
------
sig: 1D signal samples
ref: 1D reference samples
methods: None by default or list of strings to select which methods to use.
Possible values are 'direct', 'weighted', 'diff'. In case of None, all
methods will be calculated.
verbose: booleand, if True prints calculated values
Returns
-------
dictionnary of [methods][value] where value is 'mu' for mean and 's' for
standard deviation.
"""
if methods is None:
methods = ['direct', 'weighted', 'diff']
w = ref
x = sig/ref
mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref)
w = w[mask]
sig = sig[mask]
ref = ref[mask]
x = x[mask]
res = {}
# direct mean and std
if 'direct' in methods:
mu = np.mean(x)
s = np.std(x)
if verbose:
print(f'mu: {mu}, s: {s}, snr: {mu/s}')
res['direct'] = {'mu': mu, 's':s}
# weighted mean and std
if 'weighted' in methods:
wmu = np.sum(sig)/np.sum(ref)
v1 = np.sum(w)
v2 = np.sum(w**2)
ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1))
if verbose:
print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}')
res['weighted'] = {'mu': wmu, 's':ws}
# noise from diff
if 'diff' in methods:
dmu = np.mean(x)
ds = np.std(np.diff(x))/np.sqrt(2)
if verbose:
print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}')
res['diff'] = {'mu': dmu, 's':ds}
return res
def inspect_Fnl(Fnl):
"""Plot the correction function Fnl.
Inputs
------
Fnl: non linear correction function lookup table
Returns
-------
matplotlib figure
"""
x = np.arange(2**9)
f = plt.figure(figsize=(6, 4))
plt.plot(x, Fnl - x)
# plt.axvline(40, c='k', ls='--')
# plt.axvline(280, c='k', ls='--')
plt.xlabel('input value')
plt.ylabel('output correction F(x)-x')
plt.xlim([0, 511])
return f
def inspect_correction(params, gain=None):
"""Comparison plot of the different corrections.
Inputs
------
params: parameters
gain: float, default None, DSSC gain in ph/bin
Returns
-------
matplotlib figure
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# flat flat_field
plane_ff = params.get_flat_field()
if plane_ff is None:
plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
ff = compute_flat_field_correction(params.rois, params)
# non linearities
Fnl = params.get_Fnl()
if Fnl is None:
Fnl = np.arange(2**9)
xp = np if not params._using_gpu else cp
# compute all levels of correction
data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
params._using_gpu)
data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
# for conversion to nb of photons
if gain is None:
g = 1
else:
g = gain
scale = 1e-6
f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True)
# nbins = np.linspace(0.01, 1.0, 100)
photon_scale = None
for k, d in enumerate([data, data_ff, data_ff_nl]):
for l, (n, r) in enumerate([('n', '0'), ('p', '0'), ('n', 'p')]):
if photon_scale is None:
lower = 0
upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
photon_scale = np.linspace(lower, upper, 150)
good_d = d.where(d['sat_sat'] == False, drop=True)
sat_d = d.where(d['sat_sat'], drop=True)
snr_v = snr(good_d[n].values.flatten(),
good_d[r].values.flatten(), verbose=True)
m = snr_v['direct']['mu']
h, xedges, yedges, img = axs[l, k].hist2d(
g*scale*good_d[r].values.flatten(),
good_d[n].values.flatten()/good_d[r].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Blues',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
h, xedges, yedges, img2 = axs[l, k].hist2d(
g*scale*sat_d[r].values.flatten(),
sat_d[n].values.flatten()/sat_d[r].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Reds',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
v = snr_v['direct']['mu']/snr_v['direct']['s']
axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}',
transform = axs[l, k].transAxes)
v = snr_v['weighted']['mu']/snr_v['weighted']['s']
axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
transform = axs[l, k].transAxes)
#axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
#axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
axs[l, k].set_ylim([0.95*m, 1.05*m])
for k in range(3):
#for l in range(3):
# axs[l, k].set_ylim([0.95, 1.05])
if gain:
axs[2, k].set_xlabel('photons (10$^6$)')
else:
axs[2, k].set_xlabel('ADU (10$^6$)')
f.colorbar(img, ax=axs, label='events')
axs[0, 0].set_title('raw')
axs[0, 1].set_title('flat-field')
axs[0, 2].set_title('non-linear')
axs[0, 0].set_ylabel(r'-1$^\mathrm{st}$/0$^\mathrm{th}$ order')
axs[1, 0].set_ylabel(r'1$^\mathrm{st}$/0$^\mathrm{th}$ order')
axs[2, 0].set_ylabel(r'-1$^\mathrm{st}$/1$^\mathrm{th}$ order')
return f
def inspect_correction_sk(params, ff, gain=None):
"""Comparison plot of the different corrections, combining 'n' and 'p'.
Inputs
------
params: parameters
gain: float, default None, DSSC gain in ph/bin
Returns
-------
matplotlib figure
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# flat flat_field
#plane_ff = params.get_flat_field()
#if plane_ff is None:
# plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
#ff = compute_flat_field_correction(params.rois, params)
# non linearities
Fnl = params.get_Fnl()
if Fnl is None:
Fnl = np.arange(2**9)
xp = np if not params._using_gpu else cp
# compute all levels of correction
data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
params._using_gpu)
data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
# for conversion to nb of photons
if gain is None:
g = 1
else:
g = gain
scale = 1e-6
f, axs = plt.subplots(1, 3, figsize=(8, 2.5), sharex=True)
# nbins = np.linspace(0.01, 1.0, 100)
photon_scale = None
for k, d in enumerate([data, data_ff, data_ff_nl]):
if photon_scale is None:
lower = 0
upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
photon_scale = np.linspace(lower, upper, 150)
good_d = d.where(d['sat_sat'] == False, drop=True)
sat_d = d.where(d['sat_sat'], drop=True)
snr_v = snr(good_d['np_mean_sk'].values.flatten(),
good_d['0'].values.flatten(), verbose=True)
m = snr_v['direct']['mu']
h, xedges, yedges, img = axs[k].hist2d(
g*scale*good_d['0'].values.flatten(),
good_d['np_mean_sk'].values.flatten()/good_d['0'].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Blues',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
h, xedges, yedges, img2 = axs[k].hist2d(
g*scale*sat_d['0'].values.flatten(),
sat_d['np_mean_sk'].values.flatten()/sat_d['0'].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Reds',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
v = snr_v['direct']['mu']/snr_v['direct']['s']
axs[k].text(0.4, 0.15, f'SNR: {v:.0f}',
transform = axs[k].transAxes)
v = snr_v['weighted']['mu']/snr_v['weighted']['s']
axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
transform = axs[k].transAxes)
# axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
# axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
axs[k].set_ylim([0.95*m, 1.05*m])
for k in range(3):
#for l in range(3):
# axs[l, k].set_ylim([0.95, 1.05])
if gain:
axs[k].set_xlabel('photons (10$^6$)')
else:
axs[k].set_xlabel('ADU (10$^6$)')
f.colorbar(img, ax=axs, label='events')
axs[0].set_title('raw')
axs[1].set_title('flat-field')
axs[2].set_title('non-linear')
axs[0].set_ylabel(r'np_mean/0')
return f
# data processing related functions
def load_dssc_module(proposalNB, runNB, moduleNB=15,
subset=slice(None), drop_intra_darks=True,
persist=False, data_size_Gb=None):
"""Load single module dssc data as dask array.
Inputs
------
proposalNB: proposal number
runNB: run number
moduleNB: default 15, module number
subset: default slice(None), subset of trains to load
drop_intra_darks: boolean, default True, remove intra darks from the data
persist: default False, load all data persistently in memory
data_size_Gb: float, if persist is True, can optionaly restrict
the amount of data loaded for dark data and run data in Gb
Returns
-------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
"""
run = open_run(proposal=proposalNB, run=runNB)
# DSSC
source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
key = 'image.data'
arr = run[source, key][subset].dask_array()
# fix 256 value becoming spuriously 0 instead
arr[arr == 0] = 256
ppt = run[source, key][subset].data_counts()
# ignore train with no pulses, can happen in burst mode acquisition
ppt = ppt[ppt > 0]
tid = ppt.index.values
ppt = np.unique(ppt)
assert ppt.shape[0] == 1, "number of pulses changed during the run"
ppt = ppt[0]
# reshape in trainId, pulseId, 2d-image
arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
# drop intra darks
if drop_intra_darks:
arr = arr[:, ::2, :, :]
# load data in memory
if persist:
if data_size_Gb is not None:
# keep only xGb of data
N = int(data_size_Gb*1024**3/(arr.shape[1]*128*512*2))
SLICE = slice(0, N)
arr = arr[SLICE]
tid = tid[SLICE]
arr = arr.persist()
return arr, tid
def average_module(arr, dark=None, ret='mean',
mask=None, sat_roi=None, sat_level=300, F_INL=None):
"""Compute the average or std over a module.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
dark: default None, dark to be substracted
ret: string, either 'mean' to compute the mean or 'std' to compute the
standard deviation
mask: default None, mask of bad pixels to ignore
sat_roi: roi over which to check for pixel with values larger than
sat_level to drop the image from the average or std
sat_level: int, minimum pixel value for a pixel to be considered saturated
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
average or standard deviation image
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
else:
narr = arr
if mask is not None:
narr = narr*mask
if sat_roi is not None:
not_sat = da.repeat(
da.repeat(
da.all(
narr[
:,
:,
sat_roi["yl"] : sat_roi["yh"],
sat_roi["xl"] : sat_roi["xh"],
]
< sat_level,
axis=[2, 3],
keepdims=True,
),
128,
axis=2,
),
512,
axis=3,
)
if dark is not None:
narr = narr - dark
if ret == 'mean':
if sat_roi is not None:
return da.average(narr, axis=0, weights=not_sat)
else:
return narr.mean(axis=0)
elif ret == 'std':
return narr.std(axis=0)
else:
raise ValueError(f'ret={ret} not supported')
def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
flat_field=None, F_INL=None, use_gpu=False):
"""Process one module and extract roi intensity.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
dark: pulse resolved dark image to remove
rois: dictionnary of rois
mask: default None, mask of ignored pixels
sat_level: integer, default 511, at which level pixel begin to saturate
flat_field: default None, flat-field correction
F_INL: default None, non-linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
dataset of extracted pulse and train resolved roi intensities.
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
else:
narr = arr
# apply mask
if mask is not None:
narr = narr*mask
# crop rois
r = {}
rd = {}
for n in rois.keys():
r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
# find saturated shots
r_sat = {}
for n in rois.keys():
r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
# TODO: flat-field should not be applied on intra darks
# # change flat-field dimension to match data
# if flat_field is not None:
# temp = np.ones_like(dark)
# temp[::2, :, :] = flat_field[:, :]
# flat_field = temp
if use_gpu and flat_field is not None:
flat_field = cp.asarray(flat_field)
# compute dark corrected ROI values
v = {}
r_ff = {}
ff = {}
for n in rois.keys():
r[n] = r[n] - rd[n]
if flat_field is not None:
# TODO: flat-field should not be applied on intra darks
# ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
# rois[n]['xl']:rois[n]['xh']]
ff[n] = flat_field[rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
r_ff[n] = r[n]/ff[n]
else:
ff[n] = 1.0
r_ff[n] = r[n]
v[n] = r_ff[n].sum(axis=(2, 3))
# np_mean roi where we normalize the sum of flat_field
np_mean = (r['n'] + r['p'])/(ff['n'] + ff['p'])
v['np_mean_sk'] = np_mean.sum(axis=(2,3))
res = xr.Dataset()
dims = ['trainId', 'pulseId']
r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
for n in rois.keys():
res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims)
res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]),
coords=r_coords, dims=dims)
res['np_mean_sk'] = xr.DataArray(ensure_on_host(v['np_mean_sk']),
coords=r_coords, dims=dims)
res['np_mean_sk_sat'] = res['n_sat'] + res['p_sat']
for n in rois.keys():
roi = rois[n]
res[n + '_area'] = xr.DataArray(np.array([
(roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
res['np_mean_area'] = res['n_area'] + res['p_area']
return res
def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511,
use_gpu=False):
"""Process dark and run data with corrections.
Inputs
------
Fmodel: correction lookup table
arr_dark: dark data
arr: data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask of good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
roi extracted intensities
"""
# dark process
dark = average_module(arr_dark, F_INL=Fmodel).compute()
# data process
return process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
flat_field=flat_field, F_INL=Fmodel, use_gpu=use_gpu).compute()
def inspect_saturation(data, gain, Nbins=200):
"""Plot roi integrated histogram of the data with saturation
Inputs
------
data: xarray of roi integrated DSSC data
gain: nominal DSSC gain in ph/bin
Nbins: number of bins for the histogram, by default 200
Returns
-------
f: handle to the matplotlib figure
h: xarray of the histogram data
"""
d = data.where(data['sat_sat'] == False, drop=True)
s = data.where(data['sat_sat'] == True, drop=True)
# percentage of saturated shots
N_nonsat = d['n'].count()
N_all = data.dims['trainId'] * data.dims['pulseId']
sat_percent = ((N_all - N_nonsat)/N_all).values*100.0
# find the bin ranges
sum_v = {}
low = 0
high = 0
scale = 1e-6
for k in ['n', '0', 'p']:
v = data[k].values.ravel()*scale*gain
sum_v[k] = np.nansum(v)
v_low, v_high = np.nanmin(v), np.nanmax(v)
if v_low < low:
low = v_low
if v_high > high:
high = v_high
# compute bins edges, center and width
bins = np.linspace(low, high, Nbins+1)
bins_c = 0.5*(bins[:-1] + bins[1:])
w = bins[1] - bins[0]
fig, ax = plt.subplots(figsize=(6,4))
h = {}
for kk, k in enumerate(['n', '0', 'p']):
v_d = d[k].values.ravel()*scale*gain
v_s = s[k].values.ravel()*scale*gain
h[k+'_nosat'], bin_e = np.histogram(v_d, bins)
h[k+'_sat'], bin_e = np.histogram(v_s, bins)
# compute density normalization on all data
norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat']))
ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm,
h[k+'_nosat']/norm, facecolor=f"C{kk}",
edgecolor='none', alpha=0.2)
ax.plot(bins_c, h[k+'_nosat']/norm, label=k,
c=f'C{kk}', alpha=0.4)
ax.text(0.6, 0.9, f"saturation: {sat_percent:.2f}%",
color='r', alpha=0.5, transform=plt.gca().transAxes)
ax.legend()
ax.set_xlabel(r'10$^6$ ph')
ax.set_ylabel('density')
# save data as xarray dataset
dv = {}
for k in h.keys():
dv[k] = {"dims": "N", "data": h[k]}
ds = {
"coords": {"N": {"dims": "N", "data": bins_c,
"attrs": {"units": f"{scale:g} ph"}}},
"attrs": {"saturation (%)": sat_percent},
"dims": "N",
"data_vars": dv}
return fig, xr.Dataset.from_dict(ds)
""" Toolbox for SCS.
Various utilities function to quickly process data measured
at the SCS instrument.
Copyright (2019-) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import re
from toolbox_scs.base.knife_edge import knife_edge_base, erfc, arrays_to1d
__all__ = [
'knife_edge'
]
def knife_edge(ds, axisKey='scannerX', signalKey='FastADC4peaks',
axisRange=None, p0=None, full=False, plot=False,
display=False):
"""
Calculates the beam radius at 1/e^2 from a knife-edge scan by
fitting with erfc function:
f(x, x0, w0, a, b) = a*erfc(np.sqrt(2)*(x-x0)/w0) + b
with w0 the beam radius at 1/e^2 and x0 the beam center.
Parameters
----------
ds: xarray Dataset
dataset containing the detector signal and the motor position.
axisKey: str
key of the axis against which the knife-edge is performed.
signalKey: str
key of the detector signal.
axisRange: list of floats
edges of the scanning axis between which to apply the fit.
p0: list of floats, numpy 1D array
initial parameters used for the fit: x0, w0, a, b. If None, a beam
radius of 100 micrometers is assumed.
full: bool
If False, returns the beam radius and standard error.
If True, returns the popt, pcov list of parameters and covariance
matrix from scipy.optimize.curve_fit.
plot: bool
If True, plots the data and the result of the fit. Default is False.
display: bool
If True, displays info on the fit. True when plot is True, default is
False.
Returns
-------
If full is False, tuple with beam radius at 1/e^2 in mm and standard
error from the fit in mm. If full is True, returns parameters and
covariance matrix from scipy.optimize.curve_fit function.
"""
popt, pcov = knife_edge_base(ds[axisKey].values, ds[signalKey].values,
axisRange=axisRange, p0=p0)
if plot:
positions, intensities = arrays_to1d(ds[axisKey].values,
ds[signalKey].values)
title = ''
if ds.attrs.get('proposal') and ds.attrs.get('runNB'):
proposalNB = int(re.findall(r'p(\d{6})',
ds.attrs['proposal'])[0])
runNB = ds.attrs['runNB']
title = f'run {runNB} p{proposalNB}'
plot_knife_edge(positions, intensities, popt, pcov[1, 1]**0.5,
title, axisKey, signalKey)
display = True
if display:
funcStr = 'a*erfc(np.sqrt(2)*(x-x0)/w0) + b'
print('fitting function:', funcStr)
print('w0 = (%.1f +/- %.1f) um' % (np.abs(popt[1])*1e3,
pcov[1, 1]**0.5*1e3))
print('x0 = (%.3f +/- %.3f) mm' % (popt[0], pcov[0, 0]**0.5))
print('a = %e +/- %e ' % (popt[2], pcov[2, 2]**0.5))
print('b = %e +/- %e ' % (popt[3], pcov[3, 3]**0.5))
if full:
return popt, pcov
else:
return np.abs(popt[1]), pcov[1, 1]**0.5
def plot_knife_edge(positions, intensities, fit_params, rel_err, title,
axisKey, signalKey):
plt.figure(figsize=(7, 4))
plt.scatter(positions, intensities, color='C1',
label='measured', s=2, alpha=0.1)
xfit = np.linspace(positions.min(), positions.max(), 1000)
yfit = erfc(xfit, *fit_params)
plt.plot(xfit, yfit, color='C4',
label=r'fit $\rightarrow$ $w_0=$(%.1f $\pm$ %.1f) $\mu$m' % (
np.abs(fit_params[1])*1e3, rel_err*1e3))
leg = plt.legend()
for lh in leg.legendHandles:
lh.set_alpha(1)
plt.ylabel(signalKey)
plt.xlabel(axisKey + ' position [mm]')
plt.title(title)
============
Test modules
============
Comments
========
The code below is intended to be executed from the command line in the directory:
toolbox\_scs/test.
The test suites directly import the toolbox\_scs/ package. The idea is that problems related to packaging come up immediately (changing folder structure, breaking relative dependencies between the subpackages ....).
Requirements to run the code are:
* loaded exfel-python environment
* local installation of toolbox\_scs using pip (pip install --user .)
*Comment*: During development, use the -e flag when installing via pip, such that changes become effective immediately.
Usage
=====
* **Help message**
.. code:: bash
python test_top_level --help
.. parsed-literal::
usage: test_top_level.py [-h] [--list-suites] [--run-suites S [S ...]]
optional arguments:
-h, --help show this help message and exit
--list-suites list possible test suites
--run-suites S [S ...]
a list of valid test suites
* **List available test suites**
.. code:: bash
python test_top_level --list-suites
.. parsed-literal::
Possible test suites:
-------------------------
packaging
load
-------------------------
* **Run selected test suites**
.. code:: bash
python3 test_top_level --run-suites packaging
.. parsed-literal::
test_constant (__main__.TestToolbox) ... INFO:extra_data.read_machinery:Found proposal dir '/gpfs/exfel/exp/SCS/201901/p002212' in 0.055 s
DEBUG:extra_data.run_files_map:Loaded cached files map from /gpfs/exfel/exp/SCS/201901/p002212/scratch/.karabo_data_maps/raw_r0235.json
DEBUG:extra_data.run_files_map:Loaded cached files map in 0.29 s
DEBUG:extra_data.reader:Opened run with 313 files in 0.22 s
ok
----------------------------------------------------------------------
Ran 1 test in 0.616s
OK
#!/bin/bash
python test_top_level.py --run-suite packaging load
python test_misc.py --run-suite bunch-pattern-decoding
python test_utils.py --run-suite ed-extensions
python test_dssc_cls.py --run-suite no-processing
\ No newline at end of file
import unittest
import logging
import os
import argparse
import shutil
from time import strftime
import numpy as np
import xarray as xr
import extra_data as ed
import toolbox_scs as tb
import toolbox_scs.detectors as tbdet
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
suites = {"no-processing": (
"test_create",
"test_use_xgm_tim",
),
"processing": (
"test_processing_quick",
#"test_normalization_all",
)
}
_temp_dirs = ['tmp']
def setup_tmp_dir():
for d in _temp_dirs:
if not os.path.isdir(d):
os.mkdir(d)
def cleanup_tmp_dir():
for d in _temp_dirs:
shutil.rmtree(d, ignore_errors=True)
log_root.info(f'remove {d}')
class TestDSSC(unittest.TestCase):
@classmethod
def setUpClass(cls):
log_root.info("Start global setup")
setup_tmp_dir()
log_root.info("Finished global setup, start tests")
@classmethod
def tearDownClass(cls):
log_root.info("Clean up test environment....")
cleanup_tmp_dir()
def test_create(self):
proposal_nb = 2212
run_nb = 235
run = tb.load_run(proposal_nb, run_nb, include='*DA*')
run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
bins_trainId = tb.get_array(run,
'PP800_PhaseShifter',
0.04)
bins_pulse = ['pumped', 'unpumped'] * 10
binner1 = tbdet.create_dssc_bins("trainId",
run_info['trainIds'],
bins_trainId.values)
binner2 = tbdet.create_dssc_bins("pulse",
np.linspace(0,19,20, dtype=int),
bins_pulse)
binners = {'trainId': binner1, 'pulse': binner2}
params = {'binners': binners}
# normal
run235 = tbdet.DSSCBinner(proposal_nb, run_nb)
del(run235)
run235 = tbdet.DSSCBinner(2212, 235, dssc_coords_stride=1)
run235.add_binner('trainId', binner1)
run235.add_binner('pulse', binner2)
xgm_threshold=(300.0, 8000.0)
run235.create_pulsemask('xgm', xgm_threshold)
self.assertIsNotNone(run235.get_xgm_binned())
self.assertEqual(run235.binners['trainId'].values[0],
np.float32(7585.52))
# expected fails
with self.assertRaises(FileNotFoundError) as cm:
run235 = tbdet.DSSCBinner(2212, 2354)
err_msg = "[Errno 2] No such file or directory: " \
"'/gpfs/exfel/exp/SCS/201901/p002212/raw/r2354'"
self.assertEqual(str(cm.exception), err_msg)
def test_use_xgm_tim(self):
proposal_nb = 2599
run_nb = 103
run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
fpt = run_info['frames_per_train']
n_trains = run_info['number_of_trains']
trainIds = run_info['trainIds']
buckets_train = ['chunk1']*n_trains
buckets_pulse = ['image_unpumped', 'dark',
'image_pumped', 'dark']
binner1 = tbdet.create_dssc_bins("trainId",trainIds,buckets_train)
binner2 = tbdet.create_dssc_bins("pulse",
np.linspace(0,fpt-1,fpt, dtype=int),
buckets_pulse)
binners = {'trainId': binner1, 'pulse': binner2}
dssc_frame_coords = np.linspace(0,2,2, dtype=np.uint64)
bin_obj = tbdet.DSSCBinner(proposal_nb, run_nb,
binners=binners,
dssc_coords_stride=dssc_frame_coords)
bin_obj.load_xgm()
bin_obj.load_tim()
xgm_binned = bin_obj.get_xgm_binned()
tim_binned = bin_obj.get_tim_binned()
self.assertIsNotNone(xgm_binned)
self.assertIsNotNone(tim_binned)
def test_processing_quick(self):
proposal_nb = 2530
module_list=[2]
run_nb = 49
run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
fpt = run_info['frames_per_train']
n_trains = run_info['number_of_trains']
trainIds = run_info['trainIds']
buckets_train = np.zeros(n_trains)
buckets_pulse = ['image', 'dark'] * 99 + ['image_last']
binner1 = tbdet.create_dssc_bins("trainId",
trainIds,
buckets_train)
binner2 = tbdet.create_dssc_bins("pulse",
np.linspace(0,fpt-1,fpt, dtype=int),
buckets_pulse)
binners = {'trainId': binner1, 'pulse': binner2}
bin_obj = tbdet.DSSCBinner(proposal_nb, run_nb, binners=binners)
bin_obj.process_data(
modules=module_list, filepath='./tmp/', chunksize=248)
filename = f'./tmp/run_{run_nb}_module{module_list[0]}.h5'
self.assertTrue(os.path.isfile(filename))
run_formatted = tbdet.DSSCFormatter('./tmp/')
run_formatted.combine_files()
attrs = {'run_type':'useful description',
'comment':'blabla',
'run_number':run_nb}
run_formatted.add_attributes(attrs)
run_formatted.save_formatted_data(
f'./tmp/run_{run_nb}_formatted.h5')
data = tbdet.load_xarray(f'./tmp/run_{run_nb}_formatted.h5')
self.assertIsNotNone(data)
def test_normalization_all(self):
proposal_nb = 2530
module_list=[2]
# dark
run_nb = 49
run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
fpt = run_info['frames_per_train']
n_trains = run_info['number_of_trains']
trainIds = run_info['trainIds']
buckets_train = np.zeros(n_trains)
binner1 = tbdet.create_dssc_bins("trainId",
trainIds,
buckets_train)
binner2 = tbdet.create_dssc_bins("pulse",
np.linspace(0,fpt-1,fpt, dtype=int),
np.linspace(0,fpt-1,fpt, dtype=int))
binners = {'trainId': binner1, 'pulse': binner2}
bin_obj = tbdet.DSSCBinner(proposal_nb, run_nb, binners=binners)
bin_obj.process_data(
modules=module_list, filepath='./tmp/', chunksize=248)
filename = f'./tmp/run_{run_nb}_module{module_list[0]}.h5'
self.assertTrue(os.path.isfile(filename))
run_formatted = tbdet.DSSCFormatter('./tmp/')
run_formatted.combine_files()
attrs = {'run_type':'useful description',
'comment':'blabla',
'run_number':run_nb}
run_formatted.add_attributes(attrs)
run_formatted.save_formatted_data(
f'./tmp/run_{run_nb}_formatted.h5')
# main run
run_nb = 50
run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
fpt = run_info['frames_per_train']
n_trains = run_info['number_of_trains']
trainIds = run_info['trainIds']
buckets_train = np.zeros(n_trains)
buckets_pulse = ['image', 'dark'] * 99 + ['image_last']
binner1 = tbdet.create_dssc_bins("trainId",
trainIds,
buckets_train)
binner2 = tbdet.create_dssc_bins("pulse",
np.linspace(0,fpt-1,fpt, dtype=int),
buckets_pulse)
binners = {'trainId': binner1, 'pulse': binner2}
bin_obj = tbdet.DSSCBinner(proposal_nb, run_nb, binners=binners)
dark = tbdet.load_xarray('./tmp/run_49_formatted.h5')
bin_params = {'modules':module_list,
'chunksize':248,
'filepath':'./tmp/',
'xgm_normalization':True,
'normevery':2,
'dark_image':dark['data']
}
bin_obj.process_data(**bin_params)
filename = f'./tmp/run_{run_nb}_module{module_list[0]}.h5'
self.assertTrue(os.path.isfile(filename))
def list_suites():
print("\nPossible test suites:\n" + "-" * 79)
for key in suites:
print(key)
print("-" * 79 + "\n")
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestDSSC(test))
return suite
def main(*cliargs):
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
main(*args.run_suites)
import unittest
import numpy as np
from numpy.testing import assert_array_equal
import xarray as xa
from .. import hRIXS
class TestHRIXS(unittest.TestCase):
def test_integration(self):
data = xa.Dataset()
img = np.arange(100 * 200 * 2)
img.shape = (2, 100, 200)
data['hRIXS_det'] = (('trainId', 'x', 'y'), img)
h = hRIXS()
h.CURVE_A = 0.1
h.CURVE_B = 0.01
h.Y_RANGE = slice(30, 170)
r = h.integrate(data)
self.assertIs(r, data)
self.assertEqual(data['spectrum'][1, 50].values[()],
28517.704705882363)
self.assertEqual(data['spectrum'][1, 50].coords['energy'], 90)
h.dark_image = xa.DataArray(np.ones((100, 200)), dims=('x', 'y'))
h.USE_DARK = True
h.integrate(data)
self.assertEqual(data['spectrum'][1, 50].values[()],
28516.704705882363)
def test_centroid(self):
data = xa.Dataset()
img = np.array([
[[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 1, 1, 0, 0, 0,],
[0, 0, 1, 1, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],],
[[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 1, 1, 2, 0, 0,],
[0, 0, 1, 7, 2, 0, 0,],
[0, 0, 1, 1, 2, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],],
])
data['hRIXS_det'] = (('trainId', 'x', 'y'), img)
h = hRIXS()
h.Y_RANGE = slice(0, 7)
h.THRESHOLD = 0.5
h.BINS = 10
data = h.centroid(data)
assert_array_equal(data['spectrum'], [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
])
h.CURVE_A = 0.1
h.CURVE_B = 0.01
r = h.centroid(data)
self.assertIs(r, data)
assert_array_equal(data['spectrum'], [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
])
def test_getparam(self):
# this is just a smoke test
h = hRIXS()
d = h.get_params()
self.assertEqual(d['bins'], 100)
if __name__ == "__main__":
unittest.main()
import unittest
import logging
import os
import sys
import argparse
import toolbox_scs as tb
import toolbox_scs.misc as tbm
from toolbox_scs.util.exceptions import ToolBoxPathError
# -----------------------------------------------------------------------------
# global test settings
# -----------------------------------------------------------------------------
proposalNB = 2511
runNB = 176
# -----------------------------------------------------------------------------
suites = {"bunch-pattern-decoding": (
"test_isppl",
"test_issase1",
"test_issase3",
"test_extractBunchPattern",
"test_pulsePatternInfo",
)
}
class TestDataAccess(unittest.TestCase):
@classmethod
def setUpClass(cls):
run = tb.load_run(proposalNB, runNB)
mnemonic = tb.mnemonics["bunchPatternTable"]
cls._bpt = run.get_array(*mnemonic.values())
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_isppl(self):
cls = self.__class__
bpt_decoded = tbm.is_ppl(cls._bpt)
self.assertEqual(bpt_decoded.values[0][0],1)
def test_issase1(self):
cls = self.__class__
bpt_decoded = tbm.is_sase_3(cls._bpt)
self.assertEqual(bpt_decoded.values[0][0],0)
def test_issase3(self):
cls = self.__class__
bpt_decoded = tbm.is_sase_3(cls._bpt)
self.assertEqual(bpt_decoded.values[0][0],0)
def test_extractBunchPattern(self):
cls = self.__class__
bpt_decoded = tbm.extractBunchPattern(cls._bpt,
'scs_ppl')
self.assertIsNotNone(bpt_decoded)
self.assertEqual(bpt_decoded[0].values[0,1],80)
def test_pulsePatternInfo(self):
pass
def list_suites():
print("\nPossible test suites:\n" + "-" * 79)
for key in suites:
print(key)
print("-" * 79 + "\n")
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestDataAccess(test))
return suite
def start_tests(*cliargs):
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
start_tests(*args.run_suites)
import unittest
import logging
import os
import sys
import argparse
import toolbox_scs as tb
from toolbox_scs.util.exceptions import *
import extra_data as ed
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
suites = {"packaging": (
"test_constant",
),
"load": (
"test_load",
"test_openrun",
"test_openrunpath",
"test_loadbinnedarray",
)
}
class TestToolbox(unittest.TestCase):
@classmethod
def setUpClass(cls):
log_root.info("Start global setup")
cls._mnentry = 'SCS_RR_UTC/TSYS/TIMESERVER'
cls._ed_run = ed.open_run(2212, 235)
log_root.info("Finished global setup, start tests")
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_constant(self):
cls = self.__class__
self.assertEqual(tb.mnemonics['bunchPatternTable']['source'],cls._mnentry)
def test_load(self):
fields = ["SCS_XGM"]
# normal behavior
run_tb = None
proposalNB = 2511
runNB = 176
run_tb, data = tb.load(proposalNB, runNB, fields)
self.assertEqual(data['bunchPatternTable'].values[0, 0], 2113321)
# exception raised
run_tb = None
proposalNB = 2511
runNB = 1766
with self.assertRaises(ToolBoxPathError) as cm:
run_tb, data = tb.load(proposalNB, runNB, fields)
tb_exception = cm.exception
constr_path = f'/gpfs/exfel/exp/SCS/202001/p002511/raw/r{runNB}'
exp_msg = f"Invalid path: {constr_path}. " + \
"The constructed path does not exist."
self.assertEqual(tb_exception.message, exp_msg)
def test_openrun(self):
run, _ = tb.load(2212, 235)
src = 'SCS_DET_DSSC1M-1/DET/0CH0:xtdf'
self.assertTrue(src in run.all_sources)
def test_openrunpath(self):
run = tb.run_by_path(
"/gpfs/exfel/exp/SCS/201901/p002212/raw/r0235")
src = 'SCS_DET_DSSC1M-1/DET/0CH0:xtdf'
self.assertTrue(src in run.all_sources)
def test_loadbinnedarray(self):
cls = self.__class__
# Normal use
mnemonic = 'PP800_PhaseShifter'
data = tb.get_array(cls._ed_run, mnemonic, 0.5)
self.assertTrue = (data)
# unknown mnemonic
mnemonic = 'blabla'
with self.assertRaises(ToolBoxValueError) as cm:
scan_variable = tb.get_array(cls._ed_run, mnemonic, 0.5)
excp = cm.exception
self.assertEqual(excp.value, mnemonic)
def list_suites():
print("\nPossible test suites:\n" + "-" * 79)
for key in suites:
print(key)
print("-" * 79 + "\n")
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestToolbox(test))
return suite
def main(*cliargs):
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
main(*args.run_suites)
import unittest
import logging
import os
import sys
import argparse
from toolbox_scs.util.data_access import (
find_run_dir,
)
from toolbox_scs.util.exceptions import ToolBoxPathError
suites = {"ed-extensions": (
"test_rundir1",
"test_rundir2",
"test_rundir3",
)
}
def list_suites():
print("""\nPossible test suites:\n-------------------------""")
for key in suites:
print(key)
print("-------------------------\n")
class TestDataAccess(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def setUp(self):
pass
def tearDown(self):
pass
def test_rundir1(self):
Proposal = 2212
Run = 235
Dir = find_run_dir(Proposal, Run)
self.assertEqual(Dir,
"/gpfs/exfel/exp/SCS/201901/p002212/raw/r0235")
def test_rundir2(self):
Proposal = 23678
Run = 235
with self.assertRaises(Exception) as cm:
find_run_dir(Proposal, Run)
exp = cm.exception
self.assertEqual(str(exp), "Couldn't find proposal dir for 'p023678'")
def test_rundir3(self):
Proposal = 2212
Run = 800
with self.assertRaises(ToolBoxPathError) as cm:
find_run_dir(Proposal, Run)
exp_msg = cm.exception.message
print(exp_msg)
path = f'/gpfs/exfel/exp/SCS/201901/p00{Proposal}/raw/r0{Run}'
err_msg = f"Invalid path: {path}. " \
"The constructed path does not exist."
self.assertEqual(exp_msg, err_msg)
def suite(*tests):
suite = unittest.TestSuite()
for test in tests:
suite.addTest(TestDataAccess(test))
return suite
def main(*cliargs):
logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)
try:
for test_suite in cliargs:
if test_suite in suites:
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite(*suites[test_suite]))
else:
log_root.warning(
"Unknown suite: '{}'".format(test_suite))
pass
except Exception as err:
log_root.error("Unecpected error: {}".format(err),
exc_info=True)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--list-suites',
action='store_true',
help='list possible test suites')
parser.add_argument('--run-suites', metavar='S',
nargs='+', action='store',
help='a list of valid test suites')
args = parser.parse_args()
if args.list_suites:
list_suites()
if args.run_suites:
main(*args.run_suites)