Skip to content
Snippets Groups Projects

Improved BOZ flat field

Merged Loïc Le Guyader requested to merge boz_flat_field into master
1 file
+ 46
57
Compare changes
  • Side-by-side
  • Inline
+ 786
460
@@ -14,8 +14,14 @@ import dask.array as da
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm
from matplotlib.patches import Polygon
from extra_data import open_run
from extra_geom import DSSC_1MGeometry
from toolbox_scs.routines.XAS import xas
__all__ = [
'load_dssc_module',
@@ -50,17 +56,27 @@ class parameters():
self.darkrun = darkrun
self.run = run
self.module = module
self.pixel_pos = _get_pixel_pos(self.module)
self.gain = gain
self.mask_idx = None
self.mean_th = (None, None)
self.std_th = (None, None)
self.rois = None
self.rois_th = None
self.flat_field = None
self.flat_field_prod_th = (5.0, np.PINF)
self.flat_field_ratio_th = (np.NINF, 1.2)
self.plane_guess_fit = None
self.use_hex = False
self.force_mirror = True
self.ff_alpha = None
self.ff_max_iter = None
self.Fnl = None
self.alpha = None
self.nl_alpha = None
self.sat_level = None
self.max_iter = None
self.nl_max_iter = None
# temporary data
self.arr_dark = None
@@ -76,8 +92,8 @@ class parameters():
self.module, drop_intra_darks=True, persist=True)
# make sure to rechunk the arrays
self.arr = self.arr.rechunk((100, -1, -1, -1))
self.arr_dark = self.arr_dark.rechunk((100, -1, -1, -1))
self.arr = self.arr.rechunk(('auto', -1, -1, -1))
self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1))
def set_mask(self, arr):
"""Set mask of bad pixels.
@@ -105,12 +121,17 @@ class parameters():
"""Get the list of bad pixel indices."""
return self.mask_idx
def set_flat_field(self, plane):
def set_flat_field(self, plane,
prod_th=None, ratio_th=None):
"""Set the flat field plane definition."""
if type(plane) is not list:
self.flat_field = plane.tolist()
else:
self.flat_field = plane
if prod_th is not None:
self.flat_field_prod_th = prod_th
if ratio_th is not None:
self.flat_field_ratio_th = ratio_th
def get_flat_field(self):
"""Get the flat field plane definition."""
@@ -155,11 +176,18 @@ class parameters():
v['rois_th'] = self.rois_th
v['flat_field'] = self.flat_field
v['flat_field_prod_th'] = self.flat_field_prod_th
v['flat_field_ratio_th'] = self.flat_field_ratio_th
v['plane_guess_fit'] = self.plane_guess_fit
v['use_hex'] = self.use_hex
v['force_mirror'] = self.force_mirror
v['ff_alpha'] = self.ff_alpha
v['ff_max_iter'] = self.ff_max_iter
v['Fnl'] = self.Fnl
v['alpha'] = self.alpha
v['nl_alpha'] = self.nl_alpha
v['sat_level'] = self.sat_level
v['max_iter'] = self.max_iter
v['nl_max_iter'] = self.nl_max_iter
fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json'
@@ -186,12 +214,17 @@ class parameters():
c.rois = v['rois']
c.rois_th = v['rois_th']
c.set_flat_field(v['flat_field'])
c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th'])
c.plane_guess_fit = v['plane_guess_fit']
c.use_hex = v['use_hex']
c.force_mirror = v['force_mirror']
c.ff_alpha = v['ff_alpha']
c.ff_max_iter = v['ff_max_iter']
c.set_Fnl(v['Fnl'])
c.alpha = v['alpha']
c.nl_alpha = v['nl_alpha']
c.sat_level = v['sat_level']
c.max_iter = v['max_iter']
c.nl_max_iter = v['nl_max_iter']
return c
@@ -208,131 +241,298 @@ class parameters():
f += f'rois threshold: {self.rois_th}\n'
f += f'rois: {self.rois}\n'
f += f'flat field: {self.flat_field}\n'
f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
f += f'plane guess fit: {self.plane_guess_fit}\n'
f += f'use hexagons: {self.use_hex}\n'
f += f'enforce mirror symmetry: {self.force_mirror}\n'
f += f'ff alpha: {self.ff_alpha}, max. iter.: {self.ff_max_iter}\n'
if self.Fnl is not None:
f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
f += f'alpha:{self.alpha}, sat. level:{self.sat_level}, '
f += f' max. iter.:{self.max_iter}'
f += f'nl alpha:{self.nl_alpha}, sat. level:{self.sat_level}, '
f += f' nl max. iter.:{self.nl_max_iter}'
else:
f += 'Fnl: None'
return f
# Hexagonal pixels related function
def _plane_flat_field(p, roi):
"""Compute the p plane over the given roi.
def _get_pixel_pos(module):
"""Compute the pixel position on hexagonal lattice of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
Given the plane parameters p, compute the plane over the roi
size.
# keeping only module 15 pixel X,Y position
return g.get_pixel_positions()[module][:, :, :2]
Parameters
----------
p: a vector of a, b, c, d plane parameter with the
plane given by ax+ by + cz + d = 0
roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
def get_roi_pixel_pos(roi, params):
"""Compute fake or real pixel position of an roi from roi center.
Returns
Inputs:
-------
the plane field given by p evaluated on the roi
extend.
roi: dictionnary
params: parameters
TODO
----
the hexagonal lattice is currently ignored.
Returns:
--------
X, Y: 1-d array of pixel position.
"""
a, b, c, d = p
if params.use_hex:
# DSSC pixel position on hexagonal lattice
X = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 0]
Y = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 1]
else:
nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
# center of ROI is put to 0,0
X -= np.mean(X)
Y -= np.mean(Y)
nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
return X, Y
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
Z = -(a*X + b*Y + d)/c
def _get_pixel_corners(module):
"""Compute the pixel corners of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
return Z
# corners are in z,y,x oder so we rop z, flip x & y
corners = g.to_distortion_array(allow_negative_xy=True)
corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1]
return corners
def _get_pixel_hexagons(module):
"""Compute DSSC pixel hexagons for plotting.
Parameters:
-----------
module: either int, for the module number or a 2-d array of corners to
get hexagons from
Returns:
--------
a 1-d list of hexagons where corners position are in mm
"""
hexes = []
if type(module) is int:
corners = _get_pixel_corners(module)
else:
corners = module
for y in range(corners.shape[0]):
for x in range(corners.shape[1]):
c = 1e3*corners[y, x, :, :] # convert to mm
hexes.append(Polygon(c))
def compute_flat_field_correction(rois, p, plot=False):
"""Compute the plane field correction on beam rois.
return hexes
def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
"""Add a colobar on a new axes so it match the plot size.
Inputs
------
rois: dictionnary of beam rois['n', '0', 'p']
p: plane vector
plot: boolean, True by default, diagnostic plot
im: image plotted
ax: axes on which the image was plotted
loc: string, default 'right', location of the colorbar
size: string, default '5%', proportion of the colobar with respect to the
plotted image
pad: float, default 0.05, pad width between plot and colorbar
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes(loc, size=size, pad=pad)
cbar = fig.colorbar(im, cax=cax)
return cbar
# dark related functions
def bad_pixel_map(params):
"""Compute the bad pixels map.
Inputs
------
params: parameters
Returns
-------
numpy 2D array of the flat field correction evaluated over one DSSC ladder
(2 sensors)
bad pixel map
"""
flat_field = np.ones((128, 512))
assert params.arr_dark is not None, "Data not loaded"
# compute mean and std
dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
dark_std = params.arr_dark.std(axis=(0, 1)).compute()
r = 'n'
flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
_plane_flat_field(p, rois[r])
r = 'p'
flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
np.fliplr(_plane_flat_field(p, rois[r]))
mask = np.ones_like(dark_mean)
if params.mean_th[0] is not None:
mask *= dark_mean >= params.mean_th[0]
if params.mean_th[1] is not None:
mask *= dark_mean <= params.mean_th[1]
if params.std_th[0] is not None:
mask *= dark_std >= params.std_th[0]
if params.std_th[1] is not None:
mask *= dark_std >= params.std_th[1]
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
print(f'# bad pixel: {int(128*512-mask.sum())}')
return flat_field
return mask.astype(bool)
def nl_domain(N, low, high):
"""Create the input domain where the non-linear correction defined.
def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
"""Inspect dark run data and plot diagnostic.
Inputs
------
N: integer, number of control points or intervals
low: input values below or equal to low will not be corrected
high: input values higher or equal to high will not be corrected
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mean_th: tuple of threshold (low, high), default (None, None), to compute
a mask of good pixels for which the mean dark value lie inside this
range
std_th: tuple of threshold (low, high), default (None, None), to compute a
mask of bad pixels for which the dark std value lie inside this
range
Returns
-------
array of 2**9 integer values with N segments
fig: matplotlib figure
"""
x = np.arange(2**9)
vx = x.copy()
eps = 1e-5
vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
vx[x <= low] = 0
vx[x >= high] = 0
# compute mean and std
dark_mean = arr.mean(axis=(0, 1)).compute()
dark_std = arr.std(axis=(0, 1)).compute()
return vx
fig = plt.figure(figsize=(7, 2.7))
gs = fig.add_gridspec(2, 4)
ax1 = fig.add_subplot(gs[0, 1:])
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax11 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 1:])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax22 = fig.add_subplot(gs[1, 0])
def nl_lut(domain, dy):
"""Compute the non-linear correction.
vmin = np.percentile(dark_mean.flatten(), 2)
vmax = np.percentile(dark_mean.flatten(), 98)
im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
ax1.invert_yaxis()
ax1.set_aspect('equal')
cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
cbar1.ax.set_ylabel('dark mean')
ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
range=(vmin/2, vmax*2))
if mean_th[0] is not None:
ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
if mean_th[1] is not None:
ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
ax11.set_yscale('log')
vmin = np.percentile(dark_std.flatten(), 2)
vmax = np.percentile(dark_std.flatten(), 98)
im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
ax2.invert_yaxis()
ax2.set_aspect('equal')
cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
cbar2.ax.set_ylabel('dark std')
ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
if std_th[0] is not None:
ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
if std_th[1] is not None:
ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
ax22.set_yscale('log')
return fig
# histogram related functions
def histogram_module(arr, mask=None):
"""Compute a histogram of the 9 bits raw pixel values over a module.
Inputs
------
domain: input domain where dy is defined. For zero no correction is
defined. For non-zero value x, dy[x] is applied.
dy: a vector of deviation from linearity on control point homogeneously
dispersed over 9 bits.
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mask: optional bad pixel mask
Returns
-------
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
histogram
"""
x = np.arange(2**9)
ndy = np.insert(dy, 0, 0) # add zero to dy
if mask is not None:
w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
arr.shape[1], axis=1), arr.shape[0], axis=0)
w = w.rechunk(arr.chunks)
return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
else:
return da.bincount(arr.ravel(), minlength=512).compute()
f = x + ndy[domain]
return f
def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
"""Compute and plot a histogram of the 9 bits raw pixel values.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
mask: optional bad pixel mask
extra_lines: boolean, default False, plot extra lines at period values
Returns
-------
(h, hd): histogram of arr, arr_dark
figure
"""
from matplotlib.ticker import MultipleLocator
f = plt.figure(figsize=(6, 3))
ax = plt.gca()
h = histogram_module(arr, mask=mask)
Sum_h = np.sum(h)
ax.plot(np.arange(2**9), h/Sum_h, marker='o',
ms=3, markerfacecolor='none', lw=1)
if arr_dark is not None:
hd = histogram_module(arr_dark, mask=mask)
Sum_hd = np.sum(hd)
ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
else:
hd = None
if extra_lines:
for k in range(50, 271):
if not (k - 2) % 8:
ax.axvline(k, c='k', alpha=0.5, ls='--')
if not (k - 3) % 16:
ax.axvline(k, c='g', alpha=0.3, ls='--')
if not (k - 7) % 32:
ax.axvline(k, c='r', alpha=0.3, ls='--')
ax.axvline(271, c='C1', alpha=0.5, ls='--')
ax.set_xlim([0, 2**9-1])
ax.set_yscale('log')
ax.xaxis.set_minor_locator(MultipleLocator(10))
ax.set_xlabel('DSSC pixel value')
ax.set_ylabel('count frequency')
return (h, hd), f
# rois related function
def find_rois(data_mean, threshold):
"""Find rois from 3 beams configuration.
@@ -446,7 +646,8 @@ def inspect_rois(data_mean, rois, threshold=None, allrois=False):
fig = plt.figure(figsize=(5, 3))
grid = plt.GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(2, 1),
# left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.05, hspace=0.05)
wspace=0.05, hspace=0.05,
figure=fig)
main_ax = fig.add_subplot(grid[0, 1])
y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax)
x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax)
@@ -497,311 +698,162 @@ def inspect_rois(data_mean, rois, threshold=None, allrois=False):
return fig
def histogram_module(arr, mask=None):
"""Compute a histogram of the 9 bits raw pixel values over a module.
# Flat field related functions
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mask: optional bad pixel mask
def _plane_flat_field(p, roi, params):
"""Compute the p plane over the given roi.
Given the plane parameters p, compute the plane over the roi
size.
Parameters
----------
p: a vector of a, b, c, d plane parameter with the
plane given by ax+ by + cz + d = 0
roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
params: parameters
Returns
-------
histogram
the plane field given by p evaluated on the roi
extend.
"""
if mask is not None:
w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
arr.shape[1], axis=1), arr.shape[0], axis=0)
w = w.rechunk((100, -1, -1, -1))
return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
else:
return da.bincount(arr.ravel(), minlength=512).compute()
def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
"""Compute and plot a histogram of the 9 bits raw pixel values.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
mask: optional bad pixel mask
extra_lines: boolean, default False, plot extra lines at period values
Returns
-------
(h, hd): histogram of arr, arr_dark
figure
"""
from matplotlib.ticker import MultipleLocator
f = plt.figure(figsize=(6, 3))
ax = plt.gca()
h = histogram_module(arr, mask=mask)
Sum_h = np.sum(h)
ax.plot(np.arange(2**9), h/Sum_h, marker='o',
ms=3, markerfacecolor='none', lw=1)
if arr_dark is not None:
hd = histogram_module(arr_dark, mask=mask)
Sum_hd = np.sum(hd)
ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
else:
hd = None
if extra_lines:
for k in range(50, 271):
if not (k - 2) % 8:
ax.axvline(k, c='k', alpha=0.5, ls='--')
if not (k - 3) % 16:
ax.axvline(k, c='g', alpha=0.3, ls='--')
if not (k - 7) % 32:
ax.axvline(k, c='r', alpha=0.3, ls='--')
ax.axvline(271, c='C1', alpha=0.5, ls='--')
ax.set_xlim([0, 2**9-1])
ax.set_yscale('log')
ax.xaxis.set_minor_locator(MultipleLocator(10))
ax.set_xlabel('DSSC pixel value')
ax.set_ylabel('count frequency')
return (h, hd), f
def load_dssc_module(proposalNB, runNB, moduleNB=15,
subset=slice(None), drop_intra_darks=True, persist=False):
"""Load single module dssc data as dask array.
Inputs
------
proposalNB: proposal number
runNB: run number
moduleNB: default 15, module number
subset: default slice(None), subset of trains to load
drop_intra_darks: boolean, default True, remove intra darks from the data
persist: default False, load all data persistently in memory
Returns
-------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
"""
run = open_run(proposal=proposalNB, run=runNB)
# DSSC
source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
key = 'image.data'
arr = run[source, key][subset].dask_array()
# fix 256 value becoming spuriously 0 instead
arr[arr == 0] = 256
ppt = run[source, key][subset].data_counts()
# ignore train with no pulses, can happen in burst mode acquisition
ppt = ppt[ppt > 0]
tid = ppt.index.to_numpy()
ppt = np.unique(ppt)
assert ppt.shape[0] == 1, "number of pulses changed during the run"
ppt = ppt[0]
# reshape in trainId, pulseId, 2d-image
arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
a, b, c, d = p
# drop intra darks
if drop_intra_darks:
arr = arr[:, ::2, :, :]
X, Y = get_roi_pixel_pos(roi, params)
# load data in memory
if persist:
arr = arr.persist()
Z = -(a*X + b*Y + d)/c
return arr, tid
return Z
def average_module(arr, dark=None, ret='mean',
mask=None, sat_roi=None, sat_level=300, F_INL=None):
"""Compute the average or std over a module.
def compute_flat_field_correction(rois, params, plot=False):
"""Compute the plane field correction on beam rois.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
dark: default None, dark to be substracted
ret: string, either 'mean' to compute the mean or 'std' to compute the
standard deviation
mask: default None, mask of bad pixels to ignore
sat_roi: roi over which to check for pixel with values larger than
sat_level to drop the image from the average or std
sat_level: int, minimum pixel value for a pixel to be considered saturated
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
rois: dictionnary of beam rois['n', '0', 'p']
params: parameters
plot: boolean, True by default, diagnostic plot
Returns
-------
average or standard deviation image
numpy 2D array of the flat field correction evaluated over one DSSC ladder
(2 sensors)
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x])
else:
narr = arr
if mask is not None:
narr = narr*mask
if sat_roi is not None:
temp = (da.logical_not(da.any(
narr[:, :, sat_roi['yl']:sat_roi['yh'],
sat_roi['xl']:sat_roi['xh']] >= sat_level,
axis=[2, 3], keepdims=True)))
not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)
flat_field = np.ones((128, 512))
if dark is not None:
narr = narr - dark
plane = params.get_flat_field()
force_mirror = params.force_mirror
if ret == 'mean':
if sat_roi is not None:
return da.average(narr, axis=0, weights=not_sat)
else:
return narr.mean(axis=0)
elif ret == 'std':
return narr.std(axis=0)
r = rois['n']
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[:4], r, params)
r = rois['p']
if force_mirror:
a, b, c, d = plane[:4]
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field([-a, b, c, d], r, params)
else:
raise ValueError(f'ret={ret} not supported')
def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
"""Add a colobar on a new axes so it match the plot size.
Inputs
------
im: image plotted
ax: axes on which the image was plotted
loc: string, default 'right', location of the colorbar
size: string, default '5%', proportion of the colobar with respect to the
plotted image
pad: float, default 0.05, pad width between plot and colorbar
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[4:], r, params)
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes(loc, size=size, pad=pad)
cbar = fig.colorbar(im, cax=cax)
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
return cbar
return flat_field
def bad_pixel_map(params):
"""Compute the bad pixels map.
def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
params: parameters
avg: module average image with no saturated shots for the flat field
determination
rois: dictionnary or ROIs
prod_th, ratio_th: tuple of floats for low and high threshold on
product and ratio
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
bad pixel map
fig: matplotlib figure plotted
domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order
"""
assert params.arr_dark is not None, "Data not loaded"
# compute mean and std
dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
dark_std = params.arr_dark.std(axis=(0, 1)).compute()
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
mask = np.ones_like(dark_mean)
if params.mean_th[0] is not None:
mask *= dark_mean >= params.mean_th[0]
if params.mean_th[1] is not None:
mask *= dark_mean <= params.mean_th[1]
if params.std_th[0] is not None:
mask *= dark_std >= params.std_th[0]
if params.std_th[1] is not None:
mask *= dark_std >= params.std_th[1]
fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9))
print(f'# bad pixel: {int(128*512-mask.sum())}')
img_rois = {}
centers = {}
return mask.astype(bool)
for k, r in enumerate(['n', '0', 'p']):
roi = rois[r]
centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
(roi['xl'] + roi['xh'])//2])
d = '0'
roi = rois[d]
for k, r in enumerate(['n', '0', 'p']):
img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
roi['yl']:roi['yh'], roi['xl']:roi['xh']]
im = axs[0, k].imshow(img_rois[r],
vmin=vmin,
vmax=vmax)
def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
"""Inspect dark run data and plot diagnostic.
n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th)
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mean_th: tuple of threshold (low, high), default (None, None), to compute
a mask of good pixels for which the mean dark value lie inside this
range
std_th: tuple of threshold (low, high), default (None, None), to compute a
mask of bad pixels for which the dark std value lie inside this
range
prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
for k, r in enumerate(['n', '0', 'p']):
v = img_rois[r]*img_rois['0']
if prod_vmin is None:
prod_vmin = np.percentile(v, .5)
prod_vmax = np.percentile(v, 20) # we look for low intensity region
im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2))
Returns
-------
fig: matplotlib figure
"""
# compute mean and std
dark_mean = arr.mean(axis=(0, 1)).compute()
dark_std = arr.std(axis=(0, 1)).compute()
v = img_rois[r]/img_rois['0']
if ratio_vmin is None:
ratio_vmin = np.percentile(v, 5)
ratio_vmax = np.percentile(v, 99.8)
im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
fig = plt.figure(figsize=(7, 2.7))
gs = fig.add_gridspec(2, 4)
ax1 = fig.add_subplot(gs[0, 1:])
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax11 = fig.add_subplot(gs[0, 0])
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
ax2 = fig.add_subplot(gs[1, 1:])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax22 = fig.add_subplot(gs[1, 0])
cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
cbar.ax.set_xlabel('product')
vmin = np.percentile(dark_mean.flatten(), 2)
vmax = np.percentile(dark_mean.flatten(), 98)
im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
ax1.invert_yaxis()
ax1.set_aspect('equal')
cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
cbar1.ax.set_ylabel('dark mean')
cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal")
cbar.ax.set_xlabel('ratio')
ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
range=(vmin/2, vmax*2))
if mean_th[0] is not None:
ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
if mean_th[1] is not None:
ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
ax11.set_yscale('log')
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
vmin = np.percentile(dark_std.flatten(), 2)
vmax = np.percentile(dark_std.flatten(), 98)
im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
ax2.invert_yaxis()
ax2.set_aspect('equal')
cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
cbar2.ax.set_ylabel('dark std')
domain = (n_m, p_m)
ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
if std_th[0] is not None:
ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
if std_th[1] is not None:
ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
ax22.set_yscale('log')
return fig
return fig, domain
def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat field
determination
rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
as the reference beam in the middle
rois: dictionnary of rois
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
@@ -837,6 +889,10 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
v = img_rois[r]/img_rois['0']
im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')
n_m, p_m = domain
axs[1, 0].contour(n_m)
axs[1, 2].contour(p_m)
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
@@ -848,7 +904,7 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
return fig
def plane_fitting_domain(avg, rois):
def plane_fitting_domain(avg, rois, prod_th, ratio_th):
"""Extract beams roi, compute their ratio and the domain.
Inputs
@@ -857,6 +913,10 @@ def plane_fitting_domain(avg, rois):
determination
rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
as the reference beam in the middle
prod_th: float tuple, low and hight threshold level to determine the plane
fitting domain on the product image of the orders
ratio_th: float tuple, low and high threshold level to determine the plane
fitting domain on the ratio image of the orders
Returns
-------
@@ -869,27 +929,34 @@ def plane_fitting_domain(avg, rois):
"""
centers = {}
for k, r in enumerate(['n', '0', 'p']):
centers[r] = np.array([(rois[r]['yl'] + rois[r]['yh'])//2,
(rois[r]['xl'] + rois[r]['xh'])//2])
for k in ['n', '0', 'p']:
r = rois[k]
centers[k] = np.array([(r['yl'] + r['yh'])//2,
(r['xl'] + r['xh'])//2])
k = 'n'
num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
r['yl']:r['yh'], r['xl']:r['xh']]
n = num/denom
n_m = ((num*denom) > 5) * (num/denom < 1.2)
prod = num*denom
n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(n > ratio_th[0]) * (n < ratio_th[1]))
n_m[~np.isfinite(n)] = 0
n[~np.isfinite(n)] = 0
k = 'p'
num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
r['yl']:r['yh'], r['xl']:r['xh']]
p = num/denom
p_m = ((num*denom) > 5) * (num/denom < 1.2)
prod = num*denom
p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(p > ratio_th[0]) * (p < ratio_th[1]))
p_m[~np.isfinite(p)] = 0
p[~np.isfinite(p)] = 0
@@ -916,150 +983,195 @@ def plane_fitting(params):
sat_level=params.sat_level).compute()
data_mean = data.mean(axis=0) # mean over pulseId
n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois)
n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois,
params.flat_field_prod_th, params.flat_field_ratio_th)
def _crit(x):
"""Fitting criteria for the plane field normalization.
Inputs
------
x: vector [a, b, c, d] defining the plane as
x: 2 vector [a, b, c, d] concatenated defining the plane as
a*x + b*y + c*z + d = 0
"""
a, b, c, d = x
num = a**2 + b**2 + c**2
a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x
nY, nX = n.shape
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
d0_2 = np.sum(n_m*(a*X + b*Y + c*n + d)**2)/num
num_n = a_n**2 + b_n**2 + c_n**2
nY, nX = p.shape
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
d2_2 = np.sum(np.fliplr(p_m)*(a*X + b*Y + c*np.fliplr(p) + d)**2)/num
roi = params.rois['n']
X, Y = get_roi_pixel_pos(roi, params)
d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n
return d2_2 + d0_2
num_p = a_p**2 + b_p**2 + c_p**2
p_guess_fit = [-0.2, -0.1, 1, -0.54]
roi = params.rois['p']
X, Y = get_roi_pixel_pos(roi, params)
if params.force_mirror:
d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n
else:
d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p
return 1e3*(d2_2 + d0_2)
if params.plane_guess_fit is None:
if params.use_hex:
p_guess_fit = [-20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ]
else:
p_guess_fit = [-0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54]
else:
p_guess_fit = params.plane_guess_fit
res = minimize(_crit, p_guess_fit)
return res
def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
flat_field=None, F_INL=None):
"""Process one module and extract roi intensity.
def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
mask, sat_level=511):
"""Criteria for the ff_refine_fit.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
dark: pulse resolved dark image to remove
rois: dictionnary of rois
mask: default None, mask of ignored pixels
p: ff plane
params: parameters
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
sat_level: integer, default 511, at which level pixel begin to saturate
flat_field: default None, flat field correction
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
dataset of extracted pulse and train resolved roi intensities.
sum of standard deviation on binned 0th order intensity
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x])
else:
narr = arr
params.set_flat_field(p)
ff = compute_flat_field_correction(rois, params)
data = process(np.arange(2**9), arr_dark, arr, tid, rois, mask, ff,
sat_level)
# apply mask
if mask is not None:
narr = narr*mask
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0', fluorescence=True)
rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0', fluorescence=True)
rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0', fluorescence=True)
# crop rois
r = {}
rd = {}
for n in rois.keys():
r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
err_sigma = (np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA'])
+ np.nansum(rd['sigmaA']))
err_mean = ((1.0 - np.nanmean(rn['muA']))**2 +
(1.0 - np.nanmean(rp['muA']))**2 +
(1.0 - np.nanmean(rd['muA']))**2)
# find saturated shots
r_sat = {}
for n in rois.keys():
r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
return 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
# TODO: flat field should not be applied on intra darks
# # change flat field dimension to match data
# if flat_field is not None:
# temp = np.ones_like(dark)
# temp[::2, :, :] = flat_field[:, :]
# flat_field = temp
# compute dark corrected ROI values
v = {}
for n in rois.keys():
def ff_refine_fit(params):
"""Refine the flat field fit by minimizing data spread.
r[n] = r[n] - rd[n]
Inputs
------
params: parameters
if flat_field is not None:
# TODO: flat field should not be applied on intra darks
# ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
# rois[n]['xl']:rois[n]['xh']]
ff = flat_field[rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
r[n] = r[n]/ff
Returns
-------
res: scipy minimize result. res.x is the optimized parameters
v[n] = r[n].sum(axis=(2, 3))
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
res = xr.Dataset()
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
dims = ['trainId', 'pulseId']
r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
for n in rois.keys():
res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
coords=r_coords, dims=dims)
res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)
p0 = params.get_flat_field()
for n in rois.keys():
roi = rois[n]
res[n + '_area'] = xr.DataArray(np.array([
(roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
fixed_p = (params.ff_alpha, params, params.arr_dark, params.arr,
params.tid, fitrois, params.get_mask(), params.sat_level)
return res
def fit_callback(x):
if not hasattr(fit_callback, "counter"):
fit_callback.counter = 0 # it doesn't exist yet, so initialize it
fit_callback.start = time.monotonic()
fit_callback.res = []
now = time.monotonic()
time_delta = datetime.timedelta(seconds=now-fit_callback.start)
fit_callback.counter += 1
def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
"""Process dark and run data with corrections.
temp = list(fixed_p)
Jalpha = ff_refine_crit(x, *temp)
temp[0] = 0
J0 = ff_refine_crit(x, *temp)
temp[0] = 1
J1 = ff_refine_crit(x, *temp)
fit_callback.res.append([J0, Jalpha, J1])
print(f'{fit_callback.counter-1}: {time_delta} '
f'({J0}, {Jalpha}, {J1}), {x}')
return False
fit_callback(p0)
res = minimize(ff_refine_crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.ff_max_iter},
callback=fit_callback)
return res, fit_callback.res
# non-linearity related functions
def nl_domain(N, low, high):
"""Create the input domain where the non-linear correction defined.
Inputs
------
Fmodel: correction lookup table
arr_dark: dark data
arr: data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask of good pixels
flat_field: zone plate flat field correction
sat_level: integer, default 511, at which level pixel begin to saturate
N: integer, number of control points or intervals
low: input values below or equal to low will not be corrected
high: input values higher or equal to high will not be corrected
Returns
-------
roi extracted intensities
array of 2**9 integer values with N segments
"""
x = np.arange(2**9)
vx = x.copy()
eps = 1e-5
vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
vx[x <= low] = 0
vx[x >= high] = 0
return vx
def nl_lut(domain, dy):
"""Compute the non-linear correction.
Inputs
------
domain: input domain where dy is defined. For zero no correction is
defined. For non-zero value x, dy[x] is applied.
dy: a vector of deviation from linearity on control point homogeneously
dispersed over 9 bits.
Returns
-------
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
"""
# dark process
res = average_module(arr_dark, F_INL=Fmodel)
dark = res.compute()
x = np.arange(2**9)
ndy = np.insert(dy, 0, 0) # add zero to dy
# data process
proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
flat_field=flat_field, F_INL=Fmodel)
data = proc.compute()
f = x + ndy[domain]
return data
return f
def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
@@ -1118,7 +1230,7 @@ def nl_fit(params, domain):
-------
res: scipy minimize result. res.x is the optimized parameters
firres: iteration index arrays of criteria results for
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
@@ -1135,9 +1247,9 @@ def nl_fit(params, domain):
p0 = np.array([0]*N)
# flat flat_field
ff = compute_flat_field_correction(params.rois, params.get_flat_field())
ff = compute_flat_field_correction(params.rois, params)
fixed_p = (domain, params.alpha, params.arr_dark, params.arr, params.tid,
fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level)
def fit_callback(x):
@@ -1164,7 +1276,7 @@ def nl_fit(params, domain):
fit_callback(p0)
res = minimize(nl_crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.max_iter},
options={'disp': True, 'maxiter': params.nl_max_iter},
callback=fit_callback)
return res, fit_callback.res
@@ -1261,6 +1373,29 @@ def snr(sig, ref, methods=None, verbose=False):
return res
def inspect_Fnl(Fnl):
"""Plot the correction function Fnl.
Inputs
------
Fnl: non linear correction function lookup table
Returns
-------
matplotlib figure
"""
x = np.arange(2**9)
f = plt.figure(figsize=(6, 4))
plt.plot(x, Fnl - x)
# plt.axvline(40, c='k', ls='--')
# plt.axvline(280, c='k', ls='--')
plt.xlabel('input value')
plt.ylabel('output correction F(x)-x')
plt.xlim([0, 511])
return f
def inspect_correction(params, gain=None):
"""Criteria for the non linear correction.
@@ -1286,8 +1421,8 @@ def inspect_correction(params, gain=None):
# flat flat_field
plane_ff = params.get_flat_field()
if plane_ff is None:
plane_ff = [0.0, 0.0, 1.0, -1.0]
ff = compute_flat_field_correction(params.rois, plane_ff)
plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
ff = compute_flat_field_correction(params.rois, params)
# non linearities
Fnl = params.get_Fnl()
@@ -1314,8 +1449,6 @@ def inspect_correction(params, gain=None):
# nbins = np.linspace(0.01, 1.0, 100)
from matplotlib.colors import LogNorm
photon_scale = None
for k, d in enumerate([data, data_ff, data_ff_nl]):
@@ -1383,28 +1516,221 @@ def inspect_correction(params, gain=None):
return f
def inspect_Fnl(Fnl):
"""Plot the correction function Fnl.
# data processing related functions
def load_dssc_module(proposalNB, runNB, moduleNB=15,
subset=slice(None), drop_intra_darks=True, persist=False):
"""Load single module dssc data as dask array.
Inputs
------
Fnl: non linear correction function lookup table
proposalNB: proposal number
runNB: run number
moduleNB: default 15, module number
subset: default slice(None), subset of trains to load
drop_intra_darks: boolean, default True, remove intra darks from the data
persist: default False, load all data persistently in memory
Returns
-------
matplotlib figure
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
"""
x = np.arange(2**9)
f = plt.figure(figsize=(6, 4))
run = open_run(proposal=proposalNB, run=runNB)
plt.plot(x, Fnl - x)
# plt.axvline(40, c='k', ls='--')
# plt.axvline(280, c='k', ls='--')
plt.xlabel('input value')
plt.ylabel('output correction F(x)-x')
plt.xlim([0, 511])
# DSSC
source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
key = 'image.data'
return f
arr = run[source, key][subset].dask_array()
# fix 256 value becoming spuriously 0 instead
arr[arr == 0] = 256
ppt = run[source, key][subset].data_counts()
# ignore train with no pulses, can happen in burst mode acquisition
ppt = ppt[ppt > 0]
tid = ppt.index.to_numpy()
ppt = np.unique(ppt)
assert ppt.shape[0] == 1, "number of pulses changed during the run"
ppt = ppt[0]
# reshape in trainId, pulseId, 2d-image
arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
# drop intra darks
if drop_intra_darks:
arr = arr[:, ::2, :, :]
# load data in memory
if persist:
arr = arr.persist()
return arr, tid
def average_module(arr, dark=None, ret='mean',
mask=None, sat_roi=None, sat_level=300, F_INL=None):
"""Compute the average or std over a module.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
dark: default None, dark to be substracted
ret: string, either 'mean' to compute the mean or 'std' to compute the
standard deviation
mask: default None, mask of bad pixels to ignore
sat_roi: roi over which to check for pixel with values larger than
sat_level to drop the image from the average or std
sat_level: int, minimum pixel value for a pixel to be considered saturated
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
average or standard deviation image
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x])
else:
narr = arr
if mask is not None:
narr = narr*mask
if sat_roi is not None:
temp = (da.logical_not(da.any(
narr[:, :, sat_roi['yl']:sat_roi['yh'],
sat_roi['xl']:sat_roi['xh']] >= sat_level,
axis=[2, 3], keepdims=True)))
not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)
if dark is not None:
narr = narr - dark
if ret == 'mean':
if sat_roi is not None:
return da.average(narr, axis=0, weights=not_sat)
else:
return narr.mean(axis=0)
elif ret == 'std':
return narr.std(axis=0)
else:
raise ValueError(f'ret={ret} not supported')
def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
flat_field=None, F_INL=None):
"""Process one module and extract roi intensity.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
dark: pulse resolved dark image to remove
rois: dictionnary of rois
mask: default None, mask of ignored pixels
sat_level: integer, default 511, at which level pixel begin to saturate
flat_field: default None, flat field correction
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
dataset of extracted pulse and train resolved roi intensities.
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x])
else:
narr = arr
# apply mask
if mask is not None:
narr = narr*mask
# crop rois
r = {}
rd = {}
for n in rois.keys():
r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
# find saturated shots
r_sat = {}
for n in rois.keys():
r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
# TODO: flat field should not be applied on intra darks
# # change flat field dimension to match data
# if flat_field is not None:
# temp = np.ones_like(dark)
# temp[::2, :, :] = flat_field[:, :]
# flat_field = temp
# compute dark corrected ROI values
v = {}
for n in rois.keys():
r[n] = r[n] - rd[n]
if flat_field is not None:
# TODO: flat field should not be applied on intra darks
# ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
# rois[n]['xl']:rois[n]['xh']]
ff = flat_field[rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
r[n] = r[n]/ff
v[n] = r[n].sum(axis=(2, 3))
res = xr.Dataset()
dims = ['trainId', 'pulseId']
r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
for n in rois.keys():
res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
coords=r_coords, dims=dims)
res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)
for n in rois.keys():
roi = rois[n]
res[n + '_area'] = xr.DataArray(np.array([
(roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
return res
def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
"""Process dark and run data with corrections.
Inputs
------
Fmodel: correction lookup table
arr_dark: dark data
arr: data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask of good pixels
flat_field: zone plate flat field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
roi extracted intensities
"""
# dark process
res = average_module(arr_dark, F_INL=Fmodel)
dark = res.compute()
# data process
proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
flat_field=flat_field, F_INL=Fmodel)
data = proc.compute()
return data
def inspect_saturation(data, gain, Nbins=200):
"""Plot roi integrated histogram of the data with saturation
Loading