""" Beam splitting Off-axis Zone plate analysis routines. Copyright (2021) SCS Team. """ import time import datetime import json import numpy as np import xarray as xr import dask.array as da from scipy.optimize import minimize import matplotlib.pyplot as plt from matplotlib.colors import LogNorm from matplotlib import cm from extra_data import open_run from extra_geom import DSSC_1MGeometry __all__ = [ 'load_dssc_module', 'inspect_dark', 'average_module', 'plane_fitting', 'compute_flat_field_correction', 'nl_crit', 'inspect_correction', 'find_rois', 'nl_domain' ] class parameters(): """Parameters contains all input parameters for the BOZ corrections. This is used in beam splitting off-axis zone plate spectrocopy analysis as well as the during the determination of correction parameters themselves to ensure they can be reproduced. Inputs ------ proposal: int, proposal number darkrun: int, run number for the dark run run: int, run number for the data run module: int, DSSC module number gain: float, number of ph per bin """ def __init__(self, proposal, darkrun, run, module, gain): self.proposal = proposal self.darkrun = darkrun self.run = run self.module = module self.gain = gain self.mask_idx = None self.mean_th = (None, None) self.std_th = (None, None) self.rois = None self.rois_th = None self.flat_field = None self.flat_field_prod_th = (5.0, np.PINF) self.flat_field_ratio_th = (np.NINF, 1.2) self.plane_guess_fit = None self.Fnl = None self.alpha = None self.sat_level = None self.max_iter = None # temporary data self.arr_dark = None self.tid_dark = None self.arr = None self.tid = None def dask_load_persistently(self): """Load dask data array in memory.""" self.arr_dark, self.tid_dark = load_dssc_module(self.proposal, self.darkrun, self.module, drop_intra_darks=True, persist=True) self.arr, self.tid = load_dssc_module(self.proposal, self.run, self.module, drop_intra_darks=True, persist=True) # make sure to rechunk the arrays self.arr = self.arr.rechunk((100, -1, -1, -1)) self.arr_dark = self.arr_dark.rechunk((100, -1, -1, -1)) def set_mask(self, arr): """Set mask of bad pixels. Inputs ------ arr: either a boolean array of a DSSC module image or a list of bad pixel indices """ if type(arr) is not list: self.mask_idx = np.argwhere(arr == False).tolist() self.mask = arr else: self.mask_idx = arr mask = np.ones((128, 512), dtype=bool) for k in self.mask_idx: mask[k[0], k[1]] = False self.mask = mask def get_mask(self): """Get the boolean array bad pixel of a DSSC module.""" return self.mask def get_mask_idx(self): """Get the list of bad pixel indices.""" return self.mask_idx def set_flat_field(self, plane, prod_th=None, ratio_th=None): """Set the flat field plane definition.""" if type(plane) is not list: self.flat_field = plane.tolist() else: self.flat_field = plane if prod_th is not None: self.flat_field_prod_th = prod_th if ratio_th is not None: self.flat_field_ratio_th = ratio_th def get_flat_field(self): """Get the flat field plane definition.""" if self.flat_field is None: return None else: return np.array(self.flat_field) def set_Fnl(self, Fnl): """Set the non-linear correction function.""" if isinstance(Fnl, list): self.Fnl = Fnl else: self.Fnl = Fnl.tolist() def get_Fnl(self): """Get the non-linear correction function.""" if self.Fnl is None: return None else: return np.array(self.Fnl) def save(self, path='./'): """Save the parameters as a JSON file. Inputs ------ path: str, where to save the file, default to './' """ v = {} v['proposal'] = self.proposal v['darkrun'] = self.darkrun v['run'] = self.run v['module'] = self.module v['gain'] = self.gain v['mask'] = self.mask_idx v['mean_th'] = self.mean_th v['std_th'] = self.std_th v['rois'] = self.rois v['rois_th'] = self.rois_th v['flat_field'] = self.flat_field v['flat_field_prod_th'] = self.flat_field_prod_th v['flat_field_ratio_th'] = self.flat_field_ratio_th v['plane_guess_fit'] = self.plane_guess_fit v['Fnl'] = self.Fnl v['alpha'] = self.alpha v['sat_level'] = self.sat_level v['max_iter'] = self.max_iter fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json' with open(path + fname, 'w') as f: json.dump(v, f) print(path + fname) @classmethod def load(cls, fname): """Load parameters from a JSON file. Inputs ------ fname: string, name a the JSON file to load """ with open(fname, 'r') as f: v = json.load(f) c = cls(v['proposal'], v['darkrun'], v['run'], v['module'], v['gain']) c.mean_th = v['mean_th'] c.std_th = v['std_th'] c.set_mask(v['mask']) c.rois = v['rois'] c.rois_th = v['rois_th'] c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th']) c.plane_guess_fit = v['plane_guess_fit'] c.set_Fnl(v['Fnl']) c.alpha = v['alpha'] c.sat_level = v['sat_level'] c.max_iter = v['max_iter'] return c def __str__(self): f = f'proposal:{self.proposal} darkrun:{self.darkrun} run:{self.run}' f += f' module:{self.module} gain:{self.gain} ph/bin\n' if self.mask_idx is not None: f += f'mean threshold:{self.mean_th} std threshold:{self.std_th}\n' f += f'mask:(#{len(self.mask_idx)}) {self.mask_idx}\n' else: f += 'mask:None\n' f += f'rois threshold: {self.rois_th}\n' f += f'rois: {self.rois}\n' f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n' f += f'plane guess fit: {self.plane_guess_fit}\n' if self.Fnl is not None: f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n' f += f'alpha:{self.alpha}, sat. level:{self.sat_level}, ' f += f' max. iter.:{self.max_iter}' else: f += 'Fnl: None' return f def _get_pixel_pos(): """Compute the pixel position on hexagonal lattice of DSSC module 15""" # module pixel position dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)] g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos) # keeping only module 15 pixel X,Y position return g.get_pixel_positions()[15][:, :, :2] def _plane_flat_field(p, roi): """Compute the p plane over the given roi. Given the plane parameters p, compute the plane over the roi size. Parameters ---------- p: a vector of a, b, c, d plane parameter with the plane given by ax+ by + cz + d = 0 roi: a dictionnary roi['yh', 'yl', 'xh', 'xl'] Returns ------- the plane field given by p evaluated on the roi extend. """ a, b, c, d = p # DSSC pixel position on hexagonal lattice pixel_pos = _get_pixel_pos() pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] Z = -(a*pos[:, :, 0] + b*pos[:, :, 1] + d)/c return Z def compute_flat_field_correction(rois, p, plot=False): """Compute the plane field correction on beam rois. Inputs ------ rois: dictionnary of beam rois['n', '0', 'p'] p: plane vector plot: boolean, True by default, diagnostic plot Returns ------- numpy 2D array of the flat field correction evaluated over one DSSC ladder (2 sensors) """ flat_field = np.ones((128, 512)) r = 'n' flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \ _plane_flat_field(p[:4], rois[r]) r = 'p' flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \ _plane_flat_field(p[4:], rois[r]) if plot: f, ax = plt.subplots(1, 1, figsize=(6, 2)) img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r') f.colorbar(img, ax=[ax], label='amplitude') ax.set_xlabel('px') ax.set_ylabel('px') ax.set_aspect('equal') return flat_field def nl_domain(N, low, high): """Create the input domain where the non-linear correction defined. Inputs ------ N: integer, number of control points or intervals low: input values below or equal to low will not be corrected high: input values higher or equal to high will not be corrected Returns ------- array of 2**9 integer values with N segments """ x = np.arange(2**9) vx = x.copy() eps = 1e-5 vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1) vx[x <= low] = 0 vx[x >= high] = 0 return vx def nl_lut(domain, dy): """Compute the non-linear correction. Inputs ------ domain: input domain where dy is defined. For zero no correction is defined. For non-zero value x, dy[x] is applied. dy: a vector of deviation from linearity on control point homogeneously dispersed over 9 bits. Returns ------- F_INL: default None, non linear correction function given as a lookup table with 9 bits integer input """ x = np.arange(2**9) ndy = np.insert(dy, 0, 0) # add zero to dy f = x + ndy[domain] return f def find_rois(data_mean, threshold): """Find rois from 3 beams configuration. Inputs ------ data_mean: dark corrected average image threshold: threshold value to find beams Returns ------- rois: dictionnary of rois """ # compute vertical and horizontal projection pX = data_mean.mean(axis=0) pX = pX[:256] # half the ladder since there is a gap in the middle pY = data_mean.mean(axis=1) # along X lowX = int(np.argmax(pX[:64] > threshold)) # 1st occurrence returned highX = int(np.argmax(pX[192:] <= threshold) + 192) # 1st occ. returned leftX = int(np.argmin(pX[64:128]) + 64) rightX = int(np.argmin(pX[128:192]) + 128) # along Y lowY = int(np.argmax(pY[:64] > threshold)) # 1st occurrence returned highY = int(np.argmax(pY[64:] < threshold) + 64) # 1st occ. returned # define rois rois = {} # baseline correction rois for k in [0, 1, 2, 3]: rois[f'b{k}'] = {'xl': k*64, 'xh': (k+1)*64, 'yl': 0, 'yh': lowY} for k in [8, 9, 10, 11]: rois[f'b{k}'] = {'xl': (k-8)*64, 'xh': (k+1-8)*64, 'yl': highY, 'yh': 128} # beam roi rois['n'] = {'xl': lowX, 'xh': leftX, 'yl': lowY, 'yh': highY} rois['0'] = {'xl': leftX, 'xh': rightX, 'yl': lowY, 'yh': highY} rois['p'] = {'xl': rightX, 'xh': highX, 'yl': lowY, 'yh': highY} # saturation roi rois['sat'] = {'xl': lowX, 'xh': highX, 'yl': lowY, 'yh': highY} # ASICs splitted beam roi rois['0X'] = {'xl': lowX, 'xh': 1*64, 'yl': lowY, 'yh': 64} rois['1X1'] = {'xl': 64, 'xh': leftX, 'yl': lowY, 'yh': 64} rois['1X2'] = {'xl': leftX, 'xh': 2*64, 'yl': lowY, 'yh': 64} rois['2X1'] = {'xl': 2*64, 'xh': rightX, 'yl': lowY, 'yh': 64} rois['2X2'] = {'xl': rightX, 'xh': 3*64, 'yl': lowY, 'yh': 64} rois['3X'] = {'xl': 3*64, 'xh': highX, 'yl': lowY, 'yh': 64} rois['8X'] = {'xl': lowX, 'xh': 1*64, 'yl': 64, 'yh': highY} rois['9X1'] = {'xl': 64, 'xh': leftX, 'yl': 64, 'yh': highY} rois['9X2'] = {'xl': leftX, 'xh': 2*64, 'yl': 64, 'yh': highY} rois['10X1'] = {'xl': 2*64, 'xh': rightX, 'yl': 64, 'yh': highY} rois['10X2'] = {'xl': rightX, 'xh': 3*64, 'yl': 64, 'yh': highY} rois['11X'] = {'xl': 3*64, 'xh': highX, 'yl': 64, 'yh': highY} return rois def find_rois_from_params(params): """Find rois from 3 beams configuration. Inputs ------ params: parameters Returns ------- rois: dictionnary of rois """ assert params.arr_dark is not None, "Data not loaded" dark = average_module(params.arr_dark).compute() assert params.arr is not None, "Data not loaded" data = average_module(params.arr, dark=dark).compute() data_mean = data.mean(axis=0) # mean over pulseId threshold = params.rois_th return find_rois(data_mean, threshold) def inspect_rois(data_mean, rois, threshold=None, allrois=False): """Find rois from 3 beams configuration from mean module image. Inputs ------ data_mean: mean module image threshold: float, default None, threshold value used to detect beams boundaries allrois: boolean, default False, plot all rois defined in rois or only the main ones (['n', '0', 'p']) Returns ------- matplotlib figure """ # compute vertical and horizontal projection pX = data_mean.mean(axis=0) pX = pX[:256] # half the ladder since there is a gap in the middle pY = data_mean.mean(axis=1) # Set up the axes with gridspec fig = plt.figure(figsize=(5, 3)) grid = plt.GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(2, 1), # left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.05, hspace=0.05) main_ax = fig.add_subplot(grid[0, 1]) y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax) x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax) # scatter points on the main axes Xs = np.arange(len(pX)) Ys = np.arange(len(pY)) main_ax.pcolormesh(Xs, Ys, np.flipud(data_mean[:, :256]), cmap='Greys_r', vmin=0, vmax=np.percentile(data_mean[:, :256], 99)) main_ax.set_aspect('equal') from matplotlib.patches import Rectangle roi = rois['n'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='b')) roi = rois['0'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='g')) roi = rois['p'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='r')) x.plot(Xs, pX) x.invert_yaxis() if threshold is not None: x.axhline(threshold, c='k', alpha=.5) x.axvline(rois['n']['xl'], c='r', alpha=.5) x.axvline(rois['0']['xl'], c='r', alpha=.6) x.axvline(rois['p']['xl'], c='r', alpha=.7) x.axvline(rois['p']['xh'], c='r', alpha=.8) y.plot(pY, np.arange(len(pY)-1, -1, -1)) y.invert_xaxis() if threshold is not None: y.axvline(threshold, c='k', alpha=.5) y.axhline(127-rois['p']['yl'], c='r', alpha=.5) y.axhline(127-rois['p']['yh'], c='r', alpha=.6) return fig def histogram_module(arr, mask=None): """Compute a histogram of the 9 bits raw pixel values over a module. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) mask: optional bad pixel mask Returns ------- histogram """ if mask is not None: w = da.repeat(da.repeat(da.array(mask[None, None, :, :]), arr.shape[1], axis=1), arr.shape[0], axis=0) w = w.rechunk((100, -1, -1, -1)) return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute() else: return da.bincount(arr.ravel(), minlength=512).compute() def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False): """Compute and plot a histogram of the 9 bits raw pixel values. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y) mask: optional bad pixel mask extra_lines: boolean, default False, plot extra lines at period values Returns ------- (h, hd): histogram of arr, arr_dark figure """ from matplotlib.ticker import MultipleLocator f = plt.figure(figsize=(6, 3)) ax = plt.gca() h = histogram_module(arr, mask=mask) Sum_h = np.sum(h) ax.plot(np.arange(2**9), h/Sum_h, marker='o', ms=3, markerfacecolor='none', lw=1) if arr_dark is not None: hd = histogram_module(arr_dark, mask=mask) Sum_hd = np.sum(hd) ax.plot(np.arange(2**9), hd/Sum_hd, marker='o', ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5) else: hd = None if extra_lines: for k in range(50, 271): if not (k - 2) % 8: ax.axvline(k, c='k', alpha=0.5, ls='--') if not (k - 3) % 16: ax.axvline(k, c='g', alpha=0.3, ls='--') if not (k - 7) % 32: ax.axvline(k, c='r', alpha=0.3, ls='--') ax.axvline(271, c='C1', alpha=0.5, ls='--') ax.set_xlim([0, 2**9-1]) ax.set_yscale('log') ax.xaxis.set_minor_locator(MultipleLocator(10)) ax.set_xlabel('DSSC pixel value') ax.set_ylabel('count frequency') return (h, hd), f def load_dssc_module(proposalNB, runNB, moduleNB=15, subset=slice(None), drop_intra_darks=True, persist=False): """Load single module dssc data as dask array. Inputs ------ proposalNB: proposal number runNB: run number moduleNB: default 15, module number subset: default slice(None), subset of trains to load drop_intra_darks: boolean, default True, remove intra darks from the data persist: default False, load all data persistently in memory Returns ------- arr: dask array of reshaped dssc data (trainId, pulseId, x, y) tid: array of train id number """ run = open_run(proposal=proposalNB, run=runNB) # DSSC source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf' key = 'image.data' arr = run[source, key][subset].dask_array() # fix 256 value becoming spuriously 0 instead arr[arr == 0] = 256 ppt = run[source, key][subset].data_counts() # ignore train with no pulses, can happen in burst mode acquisition ppt = ppt[ppt > 0] tid = ppt.index.to_numpy() ppt = np.unique(ppt) assert ppt.shape[0] == 1, "number of pulses changed during the run" ppt = ppt[0] # reshape in trainId, pulseId, 2d-image arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3]) # drop intra darks if drop_intra_darks: arr = arr[:, ::2, :, :] # load data in memory if persist: arr = arr.persist() return arr, tid def average_module(arr, dark=None, ret='mean', mask=None, sat_roi=None, sat_level=300, F_INL=None): """Compute the average or std over a module. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) dark: default None, dark to be substracted ret: string, either 'mean' to compute the mean or 'std' to compute the standard deviation mask: default None, mask of bad pixels to ignore sat_roi: roi over which to check for pixel with values larger than sat_level to drop the image from the average or std sat_level: int, minimum pixel value for a pixel to be considered saturated F_INL: default None, non linear correction function given as a lookup table with 9 bits integer input Returns ------- average or standard deviation image """ # F_INL if F_INL is not None: narr = arr.map_blocks(lambda x: F_INL[x]) else: narr = arr if mask is not None: narr = narr*mask if sat_roi is not None: temp = (da.logical_not(da.any( narr[:, :, sat_roi['yl']:sat_roi['yh'], sat_roi['xl']:sat_roi['xh']] >= sat_level, axis=[2, 3], keepdims=True))) not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3) if dark is not None: narr = narr - dark if ret == 'mean': if sat_roi is not None: return da.average(narr, axis=0, weights=not_sat) else: return narr.mean(axis=0) elif ret == 'std': return narr.std(axis=0) else: raise ValueError(f'ret={ret} not supported') def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05): """Add a colobar on a new axes so it match the plot size. Inputs ------ im: image plotted ax: axes on which the image was plotted loc: string, default 'right', location of the colorbar size: string, default '5%', proportion of the colobar with respect to the plotted image pad: float, default 0.05, pad width between plot and colorbar """ from mpl_toolkits.axes_grid1 import make_axes_locatable fig = ax.figure divider = make_axes_locatable(ax) cax = divider.append_axes(loc, size=size, pad=pad) cbar = fig.colorbar(im, cax=cax) return cbar def bad_pixel_map(params): """Compute the bad pixels map. Inputs ------ params: parameters Returns ------- bad pixel map """ assert params.arr_dark is not None, "Data not loaded" # compute mean and std dark_mean = params.arr_dark.mean(axis=(0, 1)).compute() dark_std = params.arr_dark.std(axis=(0, 1)).compute() mask = np.ones_like(dark_mean) if params.mean_th[0] is not None: mask *= dark_mean >= params.mean_th[0] if params.mean_th[1] is not None: mask *= dark_mean <= params.mean_th[1] if params.std_th[0] is not None: mask *= dark_std >= params.std_th[0] if params.std_th[1] is not None: mask *= dark_std >= params.std_th[1] print(f'# bad pixel: {int(128*512-mask.sum())}') return mask.astype(bool) def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)): """Inspect dark run data and plot diagnostic. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) mean_th: tuple of threshold (low, high), default (None, None), to compute a mask of good pixels for which the mean dark value lie inside this range std_th: tuple of threshold (low, high), default (None, None), to compute a mask of bad pixels for which the dark std value lie inside this range Returns ------- fig: matplotlib figure """ # compute mean and std dark_mean = arr.mean(axis=(0, 1)).compute() dark_std = arr.std(axis=(0, 1)).compute() fig = plt.figure(figsize=(7, 2.7)) gs = fig.add_gridspec(2, 4) ax1 = fig.add_subplot(gs[0, 1:]) ax1.set_xticklabels([]) ax1.set_yticklabels([]) ax11 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 1:]) ax2.set_xticklabels([]) ax2.set_yticklabels([]) ax22 = fig.add_subplot(gs[1, 0]) vmin = np.percentile(dark_mean.flatten(), 2) vmax = np.percentile(dark_mean.flatten(), 98) im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax) ax1.invert_yaxis() ax1.set_aspect('equal') cbar1 = _add_colorbar(im1, ax=ax1, size='2%') cbar1.ax.set_ylabel('dark mean') ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1), range=(vmin/2, vmax*2)) if mean_th[0] is not None: ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--') if mean_th[1] is not None: ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--') ax11.set_yscale('log') vmin = np.percentile(dark_std.flatten(), 2) vmax = np.percentile(dark_std.flatten(), 98) im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax) ax2.invert_yaxis() ax2.set_aspect('equal') cbar2 = _add_colorbar(im2, ax=ax2, size='2%') cbar2.ax.set_ylabel('dark std') ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2)) if std_th[0] is not None: ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--') if std_th[1] is not None: ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--') ax22.set_yscale('log') return fig def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs ------ avg: module average image with no saturated shots for the flat field determination rois: dictionnary or ROIs prod_th, ratio_th: tuple of floats for low and high threshold on product and ratio vmin: imshow vmin level, default None will use 5 percentile value vmax: imshow vmax level, default None will use 99.8 percentile value Returns ------- fig: matplotlib figure plotted domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order """ if vmin is None: vmin = np.percentile(avg, 5) if vmax is None: vmax = np.percentile(avg, 99.8) fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9)) img_rois = {} centers = {} for k, r in enumerate(['n', '0', 'p']): roi = rois[r] centers[r] = np.array([(roi['yl'] + roi['yh'])//2, (roi['xl'] + roi['xh'])//2]) d = '0' roi = rois[d] for k, r in enumerate(['n', '0', 'p']): img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[ roi['yl']:roi['yh'], roi['xl']:roi['xh']] im = axs[0, k].imshow(img_rois[r], vmin=vmin, vmax=vmax) n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th) prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4 for k, r in enumerate(['n', '0', 'p']): v = img_rois[r]*img_rois['0'] if prod_vmin is None: prod_vmin = np.percentile(v, .5) prod_vmax = np.percentile(v, 20) # we look for low intensity region im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma') axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2)) v = img_rois[r]/img_rois['0'] if ratio_vmin is None: ratio_vmin = np.percentile(v, 5) ratio_vmax = np.percentile(v, 99.8) im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r') axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2)) cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") cbar.ax.set_xlabel('data mean') cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal") cbar.ax.set_xlabel('product') cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal") cbar.ax.set_xlabel('ratio') # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') domain = (n_m, p_m) return fig, domain def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs ------ avg: module average image with no saturated shots for the flat field determination rois: dictionnary of rois vmin: imshow vmin level, default None will use 5 percentile value vmax: imshow vmax level, default None will use 99.8 percentile value Returns ------- fig: matplotlib figure plotted """ if vmin is None: vmin = np.percentile(avg, 5) if vmax is None: vmax = np.percentile(avg, 99.8) fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6)) img_rois = {} centers = {} for k, r in enumerate(['n', '0', 'p']): roi = rois[r] centers[r] = np.array([(roi['yl'] + roi['yh'])//2, (roi['xl'] + roi['xh'])//2]) d = '0' roi = rois[d] for k, r in enumerate(['n', '0', 'p']): img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[ roi['yl']:roi['yh'], roi['xl']:roi['xh']] im = axs[0, k].imshow(img_rois[r], vmin=vmin, vmax=vmax) for k, r in enumerate(['n', '0', 'p']): v = img_rois[r]/img_rois['0'] im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r') n_m, p_m = domain axs[1, 0].contour(n_m) axs[1, 2].contour(p_m) cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") cbar.ax.set_xlabel('data mean') cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal") cbar.ax.set_xlabel('ratio') # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') return fig def plane_fitting_domain(avg, rois, prod_th, ratio_th): """Extract beams roi, compute their ratio and the domain. Inputs ------ avg: module average image with no saturated shots for the flat field determination rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0' as the reference beam in the middle prod_th: float tuple, low and hight threshold level to determine the plane fitting domain on the product image of the orders ratio_th: float tuple, low and high threshold level to determine the plane fitting domain on the ratio image of the orders Returns ------- n: img ratio 'n'/'0' n_m: mask where the the product 'n'*'0' is higher than 5 indicting that the img ratio 'n'/'0' is defined p: img ratio 'p'/'0' p_m: mask where the the product 'p'*'0' is higher than 5 indicting that the img ratio 'p'/'0' is defined """ centers = {} for k, r in enumerate(['n', '0', 'p']): centers[r] = np.array([(rois[r]['yl'] + rois[r]['yh'])//2, (rois[r]['xl'] + rois[r]['xh'])//2]) k = 'n' num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']] d = '0' denom = np.roll(avg, tuple(centers[k] - centers[d]))[ rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']] n = num/denom prod = num*denom n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) * (n > ratio_th[0]) * (n < ratio_th[1])) n_m[~np.isfinite(n)] = 0 n[~np.isfinite(n)] = 0 k = 'p' num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']] d = '0' denom = np.roll(avg, tuple(centers[k] - centers[d]))[ rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']] p = num/denom prod = num*denom p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) * (p > ratio_th[0]) * (p < ratio_th[1])) p_m[~np.isfinite(p)] = 0 p[~np.isfinite(p)] = 0 return n, n_m, p, p_m def plane_fitting(params): """Fit the plane flat field normalization. Inputs ------ params: parameters Returns ------- res: the minimization result. The fitted vector res.x = [a, b, c, d] defines the plane as a*x + b*y + c*z + d = 0 """ assert params.arr_dark is not None, "Data not loaded" dark = average_module(params.arr_dark).compute() assert params.arr is not None, "Data not loaded" data = average_module(params.arr, dark=dark, ret='mean', mask=params.mask, sat_roi=params.rois['sat'], sat_level=params.sat_level).compute() data_mean = data.mean(axis=0) # mean over pulseId n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois, params.flat_field_prod_th, params.flat_field_ratio_th) def _crit(x): """Fitting criteria for the plane field normalization. Inputs ------ x: 2 vector [a, b, c, d] concatenated defining the plane as a*x + b*y + c*z + d = 0 """ a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x num_n = a_n**2 + b_n**2 + c_n**2 roi = params.rois['n'] pixel_pos = _get_pixel_pos() # DSSC pixel position on hexagonal lattice pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] d0_2 = np.sum(n_m*(a_n*pos[:, :, 0] + b_n*pos[:, :, 1] + c_n*n + d_n)**2)/num_n num_p = a_p**2 + b_p**2 + c_p**2 roi = params.rois['p'] # DSSC pixel position on hexagonal lattice pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] d2_2 = np.sum(p_m*(a_p*pos[:, :, 0] + b_p*pos[:, :, 1] + c_p*p + d_p)**2)/num_p return 1e3*(d2_2 + d0_2) if params.plane_guess_fit is None: p_guess_fit = [-20, 0.0, 1.5, -0.5, -20, 0, 1.5, -0.5 ] else: p_guess_fit = params.plane_guess_fit res = minimize(_crit, p_guess_fit) return res def process_module(arr, tid, dark, rois, mask=None, sat_level=511, flat_field=None, F_INL=None): """Process one module and extract roi intensity. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) tid: array of train id number dark: pulse resolved dark image to remove rois: dictionnary of rois mask: default None, mask of ignored pixels sat_level: integer, default 511, at which level pixel begin to saturate flat_field: default None, flat field correction F_INL: default None, non linear correction function given as a lookup table with 9 bits integer input Returns ------- dataset of extracted pulse and train resolved roi intensities. """ # F_INL if F_INL is not None: narr = arr.map_blocks(lambda x: F_INL[x]) else: narr = arr # apply mask if mask is not None: narr = narr*mask # crop rois r = {} rd = {} for n in rois.keys(): r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] # find saturated shots r_sat = {} for n in rois.keys(): r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3)) # TODO: flat field should not be applied on intra darks # # change flat field dimension to match data # if flat_field is not None: # temp = np.ones_like(dark) # temp[::2, :, :] = flat_field[:, :] # flat_field = temp # compute dark corrected ROI values v = {} for n in rois.keys(): r[n] = r[n] - rd[n] if flat_field is not None: # TODO: flat field should not be applied on intra darks # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'], # rois[n]['xl']:rois[n]['xh']] ff = flat_field[rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] r[n] = r[n]/ff v[n] = r[n].sum(axis=(2, 3)) res = xr.Dataset() dims = ['trainId', 'pulseId'] r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} for n in rois.keys(): res[n + '_sat'] = xr.DataArray(r_sat[n][:, :], coords=r_coords, dims=dims) res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims) for n in rois.keys(): roi = rois[n] res[n + '_area'] = xr.DataArray(np.array([ (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])])) return res def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): """Process dark and run data with corrections. Inputs ------ Fmodel: correction lookup table arr_dark: dark data arr: data rois: ['n', '0', 'p', 'sat'] rois mask: mask of good pixels flat_field: zone plate flat field correction sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- roi extracted intensities """ # dark process res = average_module(arr_dark, F_INL=Fmodel) dark = res.compute() # data process proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level, flat_field=flat_field, F_INL=Fmodel) data = proc.compute() return data def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): """Criteria for the non linear correction. Inputs ------ p: vector of dy non linear correction domain: domain over which the non linear correction is defined alpha: float, coefficient scaling the cost of the correction function in the criterion arr_dark: dark data arr: data tid: train id of arr data rois: ['n', '0', 'p', 'sat'] rois mask: mask fo good pixels flat_field: zone plate flat field correction sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of error squared from a transmission of 1.0 and err2 is the sum of the square of the deviation from the ideal detector response. """ Fmodel = nl_lut(domain, p) data = process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(), methods=['weighted']) err_1 = 1e8*v_1['weighted']['s']**2 v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(), methods=['weighted']) err_2 = 1e8*v_2['weighted']['s']**2 err_a = np.sum((Fmodel-np.arange(2**9))**2) return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a def nl_fit(params, domain): """Fit non linearities correction function. Inputs ------ params: parameters domain: array of index Returns ------- res: scipy minimize result. res.x is the optimized parameters firres: iteration index arrays of criteria results for [alpha=0, alpha, alpha=1] """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] # p0 N = np.unique(domain).shape[0] - 1 p0 = np.array([0]*N) # flat flat_field ff = compute_flat_field_correction(params.rois, params.get_flat_field()) fixed_p = (domain, params.alpha, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level) def fit_callback(x): if not hasattr(fit_callback, "counter"): fit_callback.counter = 0 # it doesn't exist yet, so initialize it fit_callback.start = time.monotonic() fit_callback.res = [] now = time.monotonic() time_delta = datetime.timedelta(seconds=now-fit_callback.start) fit_callback.counter += 1 temp = list(fixed_p) Jalpha = nl_crit(x, *temp) temp[1] = 0 J0 = nl_crit(x, *temp) temp[1] = 1 J1 = nl_crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' f'({J0}, {Jalpha}, {J1}), {x}') return False fit_callback(p0) res = minimize(nl_crit, p0, fixed_p, options={'disp': True, 'maxiter': params.max_iter}, callback=fit_callback) return res, fit_callback.res def inspect_nl_fit(res_fit): """Plot the progress of the fit. Inputs ------ res_fit: Returns ------- matplotlib figure """ r = np.array(res_fit) f = plt.figure(figsize=(6, 4)) ax = f.gca() ax2 = plt.twinx() ax.plot(1.0/np.sqrt(1e-8*r[:, 0]), c='C0') ax2.plot(r[:, 2], c='C1', ls='-.') ax.set_xlabel('# iteration') ax.set_ylabel('SNR', color='C0') ax2.set_ylabel('correction cost', color='C1') ax.set_yscale('log') ax2.set_yscale('log') return f def snr(sig, ref, methods=None, verbose=False): """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref. Inputs ------ sig: 1D signal samples ref: 1D reference samples methods: None by default or list of strings to select which methods to use. Possible values are 'direct', 'weighted', 'diff'. In case of None, all methods will be calculated. verbose: booleand, if True prints calculated values Returns ------- dictionnary of [methods][value] where value is 'mu' for mean and 's' for standard deviation. """ if methods is None: methods = ['direct', 'weighted', 'diff'] w = ref x = sig/ref mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref) w = w[mask] sig = sig[mask] ref = ref[mask] x = x[mask] res = {} # direct mean and std if 'direct' in methods: mu = np.mean(x) s = np.std(x) if verbose: print(f'mu: {mu}, s: {s}, snr: {mu/s}') res['direct'] = {'mu': mu, 's':s} # weighted mean and std if 'weighted' in methods: wmu = np.sum(sig)/np.sum(ref) v1 = np.sum(w) v2 = np.sum(w**2) ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1)) if verbose: print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}') res['weighted'] = {'mu': wmu, 's':ws} # noise from diff if 'diff' in methods: dmu = np.mean(x) ds = np.std(np.diff(x))/np.sqrt(2) if verbose: print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}') res['diff'] = {'mu': dmu, 's':ds} return res def inspect_correction(params, gain=None): """Criteria for the non linear correction. Inputs ------ params: parameters gain: float, default None, DSSC gain in ph/bin Returns ------- matplotlib figure """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] # flat flat_field plane_ff = params.get_flat_field() if plane_ff is None: plane_ff = [0.0, 0.0, 1.0, -1.0] ff = compute_flat_field_correction(params.rois, plane_ff) # non linearities Fnl = params.get_Fnl() if Fnl is None: Fnl = np.arange(2**9) # compute all levels of correction data = process(np.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), np.ones_like(ff), params.sat_level) data_ff = process(np.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level) data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level) # for conversion to nb of photons if gain is None: g = 1 else: g = gain scale = 1e-6 f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True) # nbins = np.linspace(0.01, 1.0, 100) photon_scale = None for k, d in enumerate([data, data_ff, data_ff_nl]): for l, (n, r) in enumerate([('n', '0'), ('p', '0'), ('n', 'p')]): if photon_scale is None: lower = 0 upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9) photon_scale = np.linspace(lower, upper, 150) good_d = d.where(d['sat_sat'] == False, drop=True) sat_d = d.where(d['sat_sat'], drop=True) snr_v = snr(good_d[n].values.flatten(), good_d[r].values.flatten(), verbose=True) m = snr_v['direct']['mu'] h, xedges, yedges, img = axs[l, k].hist2d( g*scale*good_d[r].values.flatten(), good_d[n].values.flatten()/good_d[r].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Blues', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) h, xedges, yedges, img2 = axs[l, k].hist2d( g*scale*sat_d[r].values.flatten(), sat_d[n].values.flatten()/sat_d[r].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Reds', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) v = snr_v['direct']['mu']/snr_v['direct']['s'] axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}', transform = axs[l, k].transAxes) v = snr_v['weighted']['mu']/snr_v['weighted']['s'] axs[l, k].text(0.4, 0.05, f'wSNR: {v:.0f}', transform = axs[l, k].transAxes) # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') axs[l, k].set_ylim([0.95*m, 1.05*m]) for k in range(3): #for l in range(3): # axs[l, k].set_ylim([0.95, 1.05]) if gain: axs[2, k].set_xlabel('#ph (10$^6$)') else: axs[2, k].set_xlabel('ADU (10$^6$)') f.colorbar(img, ax=axs, label='counts') axs[0, 0].set_title('raw') axs[0, 1].set_title('flat field') axs[0, 2].set_title('non-linear') axs[0, 0].set_ylabel(r'-1$^\mathrm{st}$/0$^\mathrm{th}$ order') axs[1, 0].set_ylabel(r'1$^\mathrm{st}$/0$^\mathrm{th}$ order') axs[2, 0].set_ylabel(r'-1$^\mathrm{st}$/1$^\mathrm{th}$ order') return f def inspect_Fnl(Fnl): """Plot the correction function Fnl. Inputs ------ Fnl: non linear correction function lookup table Returns ------- matplotlib figure """ x = np.arange(2**9) f = plt.figure(figsize=(6, 4)) plt.plot(x, Fnl - x) # plt.axvline(40, c='k', ls='--') # plt.axvline(280, c='k', ls='--') plt.xlabel('input value') plt.ylabel('output correction F(x)-x') plt.xlim([0, 511]) return f def inspect_saturation(data, gain, Nbins=200): """Plot roi integrated histogram of the data with saturation Inputs ------ data: xarray of roi integrated DSSC data gain: nominal DSSC gain in ph/bin Nbins: number of bins for the histogram, by default 200 Returns ------- f: handle to the matplotlib figure h: xarray of the histogram data """ d = data.where(data['sat_sat'] == False, drop=True) s = data.where(data['sat_sat'] == True, drop=True) # percentage of saturated shots N_nonsat = d['n'].count() N_all = data.dims['trainId'] * data.dims['pulseId'] sat_percent = ((N_all - N_nonsat)/N_all).values*100.0 # find the bin ranges sum_v = {} low = 0 high = 0 scale = 1e-6 for k in ['n', '0', 'p']: v = data[k].values.ravel()*scale*gain sum_v[k] = np.nansum(v) v_low, v_high = np.nanmin(v), np.nanmax(v) if v_low < low: low = v_low if v_high > high: high = v_high # compute bins edges, center and width bins = np.linspace(low, high, Nbins+1) bins_c = 0.5*(bins[:-1] + bins[1:]) w = bins[1] - bins[0] fig, ax = plt.subplots(figsize=(6,4)) h = {} for kk, k in enumerate(['n', '0', 'p']): v_d = d[k].values.ravel()*scale*gain v_s = s[k].values.ravel()*scale*gain h[k+'_nosat'], bin_e = np.histogram(v_d, bins) h[k+'_sat'], bin_e = np.histogram(v_s, bins) # compute density normalization on all data norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat'])) ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm, facecolor=f"C{kk}", edgecolor='none', alpha=0.2) ax.plot(bins_c, h[k+'_nosat']/norm, label=k, c=f'C{kk}', alpha=0.4) ax.text(0.6, 0.9, f"saturation: {sat_percent:.2f}%", color='r', alpha=0.5, transform=plt.gca().transAxes) ax.legend() ax.set_xlabel(r'10$^6$ ph') ax.set_ylabel('density') # save data as xarray dataset dv = {} for k in h.keys(): dv[k] = {"dims": "N", "data": h[k]} ds = { "coords": {"N": {"dims": "N", "data": bins_c, "attrs": {"units": f"{scale:g} ph"}}}, "attrs": {"saturation (%)": sat_percent}, "dims": "N", "data_vars": dv} return fig, xr.Dataset.from_dict(ds)