diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py index 5703ce3c1c0d055ec7118ce71eb352de07f70c78..807f8d200635b30da90bdbe8dc8c3e47a4a3f697 100644 --- a/src/toolbox_scs/routines/boz.py +++ b/src/toolbox_scs/routines/boz.py @@ -16,10 +16,14 @@ from scipy.optimize import minimize import matplotlib.pyplot as plt from matplotlib.colors import LogNorm from matplotlib import cm +from matplotlib.patches import Polygon +from matplotlib.collections import PatchCollection from extra_data import open_run from extra_geom import DSSC_1MGeometry +from toolbox_scs import xas + __all__ = [ 'load_dssc_module', 'inspect_dark', @@ -53,7 +57,9 @@ class parameters(): self.darkrun = darkrun self.run = run self.module = module + self.pixel_pos = _get_pixel_pos(self.module) self.gain = gain + self.mask_idx = None self.mean_th = (None, None) self.std_th = (None, None) @@ -63,6 +69,8 @@ class parameters(): self.flat_field_prod_th = (5.0, np.PINF) self.flat_field_ratio_th = (np.NINF, 1.2) self.plane_guess_fit = None + self.use_hex = False + self.force_mirror = True self.Fnl = None self.alpha = None self.sat_level = None @@ -169,6 +177,8 @@ class parameters(): v['flat_field_prod_th'] = self.flat_field_prod_th v['flat_field_ratio_th'] = self.flat_field_ratio_th v['plane_guess_fit'] = self.plane_guess_fit + v['use_hex'] = self.use_hex + v['force_mirror'] = self.force_mirror v['Fnl'] = self.Fnl v['alpha'] = self.alpha @@ -202,6 +212,8 @@ class parameters(): c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th']) c.plane_guess_fit = v['plane_guess_fit'] + c.use_hex = v['use_hex'] + c.force_mirror = v['force_mirror'] c.set_Fnl(v['Fnl']) c.alpha = v['alpha'] @@ -225,6 +237,8 @@ class parameters(): f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n' f += f'plane guess fit: {self.plane_guess_fit}\n' + f += f'use hexagons: {self.use_hex}\n' + f += f'enforce mirror symmetry: {self.force_mirror}\n' if self.Fnl is not None: f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n' @@ -235,124 +249,248 @@ class parameters(): return f -def _get_pixel_pos(): - """Compute the pixel position on hexagonal lattice of DSSC module 15""" +# Hexagonal pixels related function + +def _get_pixel_pos(module): + """Compute the pixel position on hexagonal lattice of DSSC module.""" # module pixel position dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)] g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos) # keeping only module 15 pixel X,Y position - return g.get_pixel_positions()[15][:, :, :2] + return g.get_pixel_positions()[module][:, :, :2] +def _get_pixel_corners(module): + """Compute the pixel corners of DSSC module.""" + # module pixel position + dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)] + g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos) -def _plane_flat_field(p, roi): - """Compute the p plane over the given roi. + # corners are in z,y,x oder so we rop z, flip x & y + corners = g.to_distortion_array(allow_negative_xy=True) + corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1] + + return corners - Given the plane parameters p, compute the plane over the roi - size. +def _get_pixel_hexagons(module): + """Compute DSSC pixel hexagons for plotting. + + Parameters: + ----------- + module: int, module number - Parameters - ---------- - p: a vector of a, b, c, d plane parameter with the - plane given by ax+ by + cz + d = 0 - roi: a dictionnary roi['yh', 'yl', 'xh', 'xl'] + Returns: + matplotlib PatchCollection of hexagons + """ + hexes = [] + + corners = _get_pixel_corners(module) + for y in range(128): + for x in range(512): + c = 1e3*corners[y, x, :, :] # convert to mm + hexes.append(Polygon(c)) + + return PatchCollection(hexes) + +def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05): + """Add a colobar on a new axes so it match the plot size. + + Inputs + ------ + im: image plotted + ax: axes on which the image was plotted + loc: string, default 'right', location of the colorbar + size: string, default '5%', proportion of the colobar with respect to the + plotted image + pad: float, default 0.05, pad width between plot and colorbar + """ + from mpl_toolkits.axes_grid1 import make_axes_locatable + + fig = ax.figure + divider = make_axes_locatable(ax) + cax = divider.append_axes(loc, size=size, pad=pad) + cbar = fig.colorbar(im, cax=cax) + + return cbar + +# dark related functions + +def bad_pixel_map(params): + """Compute the bad pixels map. + + Inputs + ------ + params: parameters Returns ------- - the plane field given by p evaluated on the roi - extend. + bad pixel map """ - a, b, c, d = p + assert params.arr_dark is not None, "Data not loaded" + + # compute mean and std + dark_mean = params.arr_dark.mean(axis=(0, 1)).compute() + dark_std = params.arr_dark.std(axis=(0, 1)).compute() - # DSSC pixel position on hexagonal lattice - pixel_pos = _get_pixel_pos() - pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] + mask = np.ones_like(dark_mean) + if params.mean_th[0] is not None: + mask *= dark_mean >= params.mean_th[0] + if params.mean_th[1] is not None: + mask *= dark_mean <= params.mean_th[1] + if params.std_th[0] is not None: + mask *= dark_std >= params.std_th[0] + if params.std_th[1] is not None: + mask *= dark_std >= params.std_th[1] - Z = -(a*pos[:, :, 0] + b*pos[:, :, 1] + d)/c + print(f'# bad pixel: {int(128*512-mask.sum())}') - return Z + return mask.astype(bool) -def compute_flat_field_correction(rois, plane, plot=False): - """Compute the plane field correction on beam rois. +def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)): + """Inspect dark run data and plot diagnostic. Inputs ------ - rois: dictionnary of beam rois['n', '0', 'p'] - plane: 2 plane vector concatenated - plot: boolean, True by default, diagnostic plot + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + mean_th: tuple of threshold (low, high), default (None, None), to compute + a mask of good pixels for which the mean dark value lie inside this + range + std_th: tuple of threshold (low, high), default (None, None), to compute a + mask of bad pixels for which the dark std value lie inside this + range Returns ------- - numpy 2D array of the flat field correction evaluated over one DSSC ladder - (2 sensors) + fig: matplotlib figure """ - flat_field = np.ones((128, 512)) + # compute mean and std + dark_mean = arr.mean(axis=(0, 1)).compute() + dark_std = arr.std(axis=(0, 1)).compute() - r = rois['n'] - flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ - _plane_flat_field(plane[:4], r) - r = rois['p'] - flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ - _plane_flat_field(plane[4:], r) + fig = plt.figure(figsize=(7, 2.7)) + gs = fig.add_gridspec(2, 4) + ax1 = fig.add_subplot(gs[0, 1:]) + ax1.set_xticklabels([]) + ax1.set_yticklabels([]) + ax11 = fig.add_subplot(gs[0, 0]) - if plot: - f, ax = plt.subplots(1, 1, figsize=(6, 2)) - img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r') - f.colorbar(img, ax=[ax], label='amplitude') - ax.set_xlabel('px') - ax.set_ylabel('px') - ax.set_aspect('equal') + ax2 = fig.add_subplot(gs[1, 1:]) + ax2.set_xticklabels([]) + ax2.set_yticklabels([]) + ax22 = fig.add_subplot(gs[1, 0]) - return flat_field + vmin = np.percentile(dark_mean.flatten(), 2) + vmax = np.percentile(dark_mean.flatten(), 98) + im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax) + ax1.invert_yaxis() + ax1.set_aspect('equal') + cbar1 = _add_colorbar(im1, ax=ax1, size='2%') + cbar1.ax.set_ylabel('dark mean') + ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1), + range=(vmin/2, vmax*2)) + if mean_th[0] is not None: + ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--') + if mean_th[1] is not None: + ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--') + ax11.set_yscale('log') -def nl_domain(N, low, high): - """Create the input domain where the non-linear correction defined. + vmin = np.percentile(dark_std.flatten(), 2) + vmax = np.percentile(dark_std.flatten(), 98) + im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax) + ax2.invert_yaxis() + ax2.set_aspect('equal') + cbar2 = _add_colorbar(im2, ax=ax2, size='2%') + cbar2.ax.set_ylabel('dark std') + + ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2)) + if std_th[0] is not None: + ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--') + if std_th[1] is not None: + ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--') + ax22.set_yscale('log') + + return fig + + +# histogram related functions + +def histogram_module(arr, mask=None): + """Compute a histogram of the 9 bits raw pixel values over a module. Inputs ------ - N: integer, number of control points or intervals - low: input values below or equal to low will not be corrected - high: input values higher or equal to high will not be corrected + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + mask: optional bad pixel mask Returns ------- - array of 2**9 integer values with N segments + histogram """ - x = np.arange(2**9) - vx = x.copy() - eps = 1e-5 - vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1) - vx[x <= low] = 0 - vx[x >= high] = 0 - - return vx + if mask is not None: + w = da.repeat(da.repeat(da.array(mask[None, None, :, :]), + arr.shape[1], axis=1), arr.shape[0], axis=0) + w = w.rechunk((100, -1, -1, -1)) + return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute() + else: + return da.bincount(arr.ravel(), minlength=512).compute() -def nl_lut(domain, dy): - """Compute the non-linear correction. +def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False): + """Compute and plot a histogram of the 9 bits raw pixel values. Inputs ------ - domain: input domain where dy is defined. For zero no correction is - defined. For non-zero value x, dy[x] is applied. - dy: a vector of deviation from linearity on control point homogeneously - dispersed over 9 bits. + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y) + mask: optional bad pixel mask + extra_lines: boolean, default False, plot extra lines at period values Returns ------- - F_INL: default None, non linear correction function given as a - lookup table with 9 bits integer input + (h, hd): histogram of arr, arr_dark + figure """ - x = np.arange(2**9) - ndy = np.insert(dy, 0, 0) # add zero to dy + from matplotlib.ticker import MultipleLocator - f = x + ndy[domain] + f = plt.figure(figsize=(6, 3)) + ax = plt.gca() + h = histogram_module(arr, mask=mask) + Sum_h = np.sum(h) + ax.plot(np.arange(2**9), h/Sum_h, marker='o', + ms=3, markerfacecolor='none', lw=1) - return f + if arr_dark is not None: + hd = histogram_module(arr_dark, mask=mask) + Sum_hd = np.sum(hd) + ax.plot(np.arange(2**9), hd/Sum_hd, marker='o', + ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5) + else: + hd = None + + if extra_lines: + for k in range(50, 271): + if not (k - 2) % 8: + ax.axvline(k, c='k', alpha=0.5, ls='--') + if not (k - 3) % 16: + ax.axvline(k, c='g', alpha=0.3, ls='--') + if not (k - 7) % 32: + ax.axvline(k, c='r', alpha=0.3, ls='--') + + ax.axvline(271, c='C1', alpha=0.5, ls='--') + + ax.set_xlim([0, 2**9-1]) + ax.set_yscale('log') + ax.xaxis.set_minor_locator(MultipleLocator(10)) + ax.set_xlabel('DSSC pixel value') + ax.set_ylabel('count frequency') + + return (h, hd), f +# rois related function + def find_rois(data_mean, threshold): """Find rois from 3 beams configuration. @@ -517,300 +655,92 @@ def inspect_rois(data_mean, rois, threshold=None, allrois=False): return fig -def histogram_module(arr, mask=None): - """Compute a histogram of the 9 bits raw pixel values over a module. +# Flat field related functions - Inputs - ------ - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - mask: optional bad pixel mask +def _plane_flat_field(p, roi, pixel_pos, use_hex=False): + """Compute the p plane over the given roi. + + Given the plane parameters p, compute the plane over the roi + size. + + Parameters + ---------- + p: a vector of a, b, c, d plane parameter with the + plane given by ax+ by + cz + d = 0 + roi: a dictionnary roi['yh', 'yl', 'xh', 'xl'] + pixel_pos: array of DSSC pixel position on hexagonal lattice + use_hex: boolean, use actual DSSC pixel position from pixel_pos + or fake cartesian pixel position Returns ------- - histogram + the plane field given by p evaluated on the roi + extend. """ - if mask is not None: - w = da.repeat(da.repeat(da.array(mask[None, None, :, :]), - arr.shape[1], axis=1), arr.shape[0], axis=0) - w = w.rechunk((100, -1, -1, -1)) - return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute() + a, b, c, d = p + + if use_hex: + # DSSC pixel position on hexagonal lattice + pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] + X = pos[:, :, 0] + Y = pos[:, :, 1] else: - return da.bincount(arr.ravel(), minlength=512).compute() + nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl'] + X = np.arange(nX)/100 + Y = np.arange(nY)[:, np.newaxis]/100 + # center of ROI is put to 0,0 + X -= np.mean(X) + Y -= np.mean(Y) -def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False): - """Compute and plot a histogram of the 9 bits raw pixel values. + Z = -(a*X + b*Y + d)/c - Inputs - ------ - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y) - mask: optional bad pixel mask - extra_lines: boolean, default False, plot extra lines at period values - - Returns - ------- - (h, hd): histogram of arr, arr_dark - figure - """ - from matplotlib.ticker import MultipleLocator - - f = plt.figure(figsize=(6, 3)) - ax = plt.gca() - h = histogram_module(arr, mask=mask) - Sum_h = np.sum(h) - ax.plot(np.arange(2**9), h/Sum_h, marker='o', - ms=3, markerfacecolor='none', lw=1) - - if arr_dark is not None: - hd = histogram_module(arr_dark, mask=mask) - Sum_hd = np.sum(hd) - ax.plot(np.arange(2**9), hd/Sum_hd, marker='o', - ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5) - else: - hd = None - - if extra_lines: - for k in range(50, 271): - if not (k - 2) % 8: - ax.axvline(k, c='k', alpha=0.5, ls='--') - if not (k - 3) % 16: - ax.axvline(k, c='g', alpha=0.3, ls='--') - if not (k - 7) % 32: - ax.axvline(k, c='r', alpha=0.3, ls='--') - - ax.axvline(271, c='C1', alpha=0.5, ls='--') - - ax.set_xlim([0, 2**9-1]) - ax.set_yscale('log') - ax.xaxis.set_minor_locator(MultipleLocator(10)) - ax.set_xlabel('DSSC pixel value') - ax.set_ylabel('count frequency') - - return (h, hd), f - - -def load_dssc_module(proposalNB, runNB, moduleNB=15, - subset=slice(None), drop_intra_darks=True, persist=False): - """Load single module dssc data as dask array. - - Inputs - ------ - proposalNB: proposal number - runNB: run number - moduleNB: default 15, module number - subset: default slice(None), subset of trains to load - drop_intra_darks: boolean, default True, remove intra darks from the data - persist: default False, load all data persistently in memory - - Returns - ------- - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - tid: array of train id number - """ - run = open_run(proposal=proposalNB, run=runNB) - - # DSSC - source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf' - key = 'image.data' - - arr = run[source, key][subset].dask_array() - # fix 256 value becoming spuriously 0 instead - arr[arr == 0] = 256 - - ppt = run[source, key][subset].data_counts() - # ignore train with no pulses, can happen in burst mode acquisition - ppt = ppt[ppt > 0] - tid = ppt.index.to_numpy() - - ppt = np.unique(ppt) - assert ppt.shape[0] == 1, "number of pulses changed during the run" - ppt = ppt[0] - - # reshape in trainId, pulseId, 2d-image - arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3]) - - # drop intra darks - if drop_intra_darks: - arr = arr[:, ::2, :, :] - - # load data in memory - if persist: - arr = arr.persist() - - return arr, tid - - -def average_module(arr, dark=None, ret='mean', - mask=None, sat_roi=None, sat_level=300, F_INL=None): - """Compute the average or std over a module. - - Inputs - ------ - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - dark: default None, dark to be substracted - ret: string, either 'mean' to compute the mean or 'std' to compute the - standard deviation - mask: default None, mask of bad pixels to ignore - sat_roi: roi over which to check for pixel with values larger than - sat_level to drop the image from the average or std - sat_level: int, minimum pixel value for a pixel to be considered saturated - F_INL: default None, non linear correction function given as a - lookup table with 9 bits integer input - - Returns - ------- - average or standard deviation image - """ - # F_INL - if F_INL is not None: - narr = arr.map_blocks(lambda x: F_INL[x]) - else: - narr = arr - - if mask is not None: - narr = narr*mask - - if sat_roi is not None: - temp = (da.logical_not(da.any( - narr[:, :, sat_roi['yl']:sat_roi['yh'], - sat_roi['xl']:sat_roi['xh']] >= sat_level, - axis=[2, 3], keepdims=True))) - not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3) - - if dark is not None: - narr = narr - dark - - if ret == 'mean': - if sat_roi is not None: - return da.average(narr, axis=0, weights=not_sat) - else: - return narr.mean(axis=0) - elif ret == 'std': - return narr.std(axis=0) - else: - raise ValueError(f'ret={ret} not supported') - - -def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05): - """Add a colobar on a new axes so it match the plot size. - - Inputs - ------ - im: image plotted - ax: axes on which the image was plotted - loc: string, default 'right', location of the colorbar - size: string, default '5%', proportion of the colobar with respect to the - plotted image - pad: float, default 0.05, pad width between plot and colorbar - """ - from mpl_toolkits.axes_grid1 import make_axes_locatable - - fig = ax.figure - divider = make_axes_locatable(ax) - cax = divider.append_axes(loc, size=size, pad=pad) - cbar = fig.colorbar(im, cax=cax) - - return cbar + return Z -def bad_pixel_map(params): - """Compute the bad pixels map. +def compute_flat_field_correction(rois, params, plot=False): + """Compute the plane field correction on beam rois. Inputs ------ + rois: dictionnary of beam rois['n', '0', 'p'] params: parameters + plot: boolean, True by default, diagnostic plot Returns ------- - bad pixel map - """ - assert params.arr_dark is not None, "Data not loaded" - - # compute mean and std - dark_mean = params.arr_dark.mean(axis=(0, 1)).compute() - dark_std = params.arr_dark.std(axis=(0, 1)).compute() - - mask = np.ones_like(dark_mean) - if params.mean_th[0] is not None: - mask *= dark_mean >= params.mean_th[0] - if params.mean_th[1] is not None: - mask *= dark_mean <= params.mean_th[1] - if params.std_th[0] is not None: - mask *= dark_std >= params.std_th[0] - if params.std_th[1] is not None: - mask *= dark_std >= params.std_th[1] - - print(f'# bad pixel: {int(128*512-mask.sum())}') - - return mask.astype(bool) - - -def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)): - """Inspect dark run data and plot diagnostic. - - Inputs - ------ - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - mean_th: tuple of threshold (low, high), default (None, None), to compute - a mask of good pixels for which the mean dark value lie inside this - range - std_th: tuple of threshold (low, high), default (None, None), to compute a - mask of bad pixels for which the dark std value lie inside this - range - - Returns - ------- - fig: matplotlib figure + numpy 2D array of the flat field correction evaluated over one DSSC ladder + (2 sensors) """ - # compute mean and std - dark_mean = arr.mean(axis=(0, 1)).compute() - dark_std = arr.std(axis=(0, 1)).compute() - - fig = plt.figure(figsize=(7, 2.7)) - gs = fig.add_gridspec(2, 4) - ax1 = fig.add_subplot(gs[0, 1:]) - ax1.set_xticklabels([]) - ax1.set_yticklabels([]) - ax11 = fig.add_subplot(gs[0, 0]) - - ax2 = fig.add_subplot(gs[1, 1:]) - ax2.set_xticklabels([]) - ax2.set_yticklabels([]) - ax22 = fig.add_subplot(gs[1, 0]) + flat_field = np.ones((128, 512)) - vmin = np.percentile(dark_mean.flatten(), 2) - vmax = np.percentile(dark_mean.flatten(), 98) - im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax) - ax1.invert_yaxis() - ax1.set_aspect('equal') - cbar1 = _add_colorbar(im1, ax=ax1, size='2%') - cbar1.ax.set_ylabel('dark mean') + plane = params.get_flat_field() + use_hex = params.use_hex + pixel_pos = params.pixel_pos + force_mirror = params.force_mirror - ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1), - range=(vmin/2, vmax*2)) - if mean_th[0] is not None: - ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--') - if mean_th[1] is not None: - ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--') - ax11.set_yscale('log') - - vmin = np.percentile(dark_std.flatten(), 2) - vmax = np.percentile(dark_std.flatten(), 98) - im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax) - ax2.invert_yaxis() - ax2.set_aspect('equal') - cbar2 = _add_colorbar(im2, ax=ax2, size='2%') - cbar2.ax.set_ylabel('dark std') + r = rois['n'] + flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ + _plane_flat_field(plane[:4], r, pixel_pos, use_hex) + + r = rois['p'] + if force_mirror: + a, b, c, d = plane[:4] + flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ + _plane_flat_field([-a, b, c, d], r, pixel_pos, use_hex) + else: + flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ + _plane_flat_field(plane[4:], r, pixel_pos, use_hex) - ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2)) - if std_th[0] is not None: - ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--') - if std_th[1] is not None: - ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--') - ax22.set_yscale('log') + if plot: + f, ax = plt.subplots(1, 1, figsize=(6, 2)) + img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r') + f.colorbar(img, ax=[ax], label='amplitude') + ax.set_xlabel('px') + ax.set_ylabel('px') + ax.set_aspect('equal') - return fig + return flat_field def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None): @@ -888,6 +818,7 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None return fig, domain + def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. @@ -1039,24 +970,51 @@ def plane_fitting(params): num_n = a_n**2 + b_n**2 + c_n**2 roi = params.rois['n'] - pixel_pos = _get_pixel_pos() - # DSSC pixel position on hexagonal lattice - pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] - d0_2 = np.sum(n_m*(a_n*pos[:, :, 0] + b_n*pos[:, :, 1] - + c_n*n + d_n)**2)/num_n + if params.use_hex: + # DSSC pixel position on hexagonal lattice + pos = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] + X = pos[:, :, 0] + Y = pos[:, :, 1] + else: + nY, nX = n.shape + X = np.arange(nX)/100 + Y = np.arange(nY)[:, np.newaxis]/100 + + # center of ROI is put to 0,0 + X -= np.mean(X) + Y -= np.mean(Y) + + d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n num_p = a_p**2 + b_p**2 + c_p**2 roi = params.rois['p'] - # DSSC pixel position on hexagonal lattice - pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] - d2_2 = np.sum(p_m*(a_p*pos[:, :, 0] + b_p*pos[:, :, 1] - + c_p*p + d_p)**2)/num_p + if params.use_hex: + # DSSC pixel position on hexagonal lattice + pos = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :] + X = pos[:, :, 0] + Y = pos[:, :, 1] + else: + nY, nX = p.shape + X = np.arange(nX)/100 + Y = np.arange(nY)[:, np.newaxis]/100 + + # center of ROI is put to 0,0 + X -= np.mean(X) + Y -= np.mean(Y) + + if params.force_mirror: + d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n + else: + d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p return 1e3*(d2_2 + d0_2) if params.plane_guess_fit is None: - p_guess_fit = [-20, 0.0, 1.5, -0.5, -20, 0, 1.5, -0.5 ] + if params.use_hex: + p_guess_fit = [-20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ] + else: + p_guess_fit = [-0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54] else: p_guess_fit = params.plane_guess_fit @@ -1065,117 +1023,142 @@ def plane_fitting(params): return res -def process_module(arr, tid, dark, rois, mask=None, sat_level=511, - flat_field=None, F_INL=None): - """Process one module and extract roi intensity. +def ff_refine_crit(p, params, arr_dark, arr, tid, rois, mask, sat_level=511): + """Criteria for the ff_refine_fit. Inputs ------ - arr: dask array of reshaped dssc data (trainId, pulseId, x, y) - tid: array of train id number - dark: pulse resolved dark image to remove - rois: dictionnary of rois - mask: default None, mask of ignored pixels + p: ff plane + params: parameters + arr_dark: dark data + arr: data + tid: train id of arr data + rois: ['n', '0', 'p', 'sat'] rois + mask: mask fo good pixels sat_level: integer, default 511, at which level pixel begin to saturate - flat_field: default None, flat field correction - F_INL: default None, non linear correction function given as a - lookup table with 9 bits integer input Returns ------- - dataset of extracted pulse and train resolved roi intensities. + sum of standard deviation on binned 0th order intensity """ - # F_INL - if F_INL is not None: - narr = arr.map_blocks(lambda x: F_INL[x]) - else: - narr = arr + params.set_flat_field(p) + ff = compute_flat_field_correction(rois, params) + + data = process(np.arange(2**9), arr_dark, arr, tid, rois, mask, ff, + sat_level) - # apply mask - if mask is not None: - narr = narr*mask + # drop saturated shots + d = data.where(data['sat_sat'] == False, drop=True) + + rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0') + rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0') + rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0') - # crop rois - r = {} - rd = {} - for n in rois.keys(): - r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'], - rois[n]['xl']:rois[n]['xh']] - rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'], - rois[n]['xl']:rois[n]['xh']] + err = np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA']) + np.nansum(rd['sigmaA']) - # find saturated shots - r_sat = {} - for n in rois.keys(): - r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3)) + return 1e3*err - # TODO: flat field should not be applied on intra darks - # # change flat field dimension to match data - # if flat_field is not None: - # temp = np.ones_like(dark) - # temp[::2, :, :] = flat_field[:, :] - # flat_field = temp - # compute dark corrected ROI values - v = {} - for n in rois.keys(): +def ff_refine_fit(params): + """Refine the flat field fit by minimizing data spread. - r[n] = r[n] - rd[n] + Inputs + ------ + params: parameters - if flat_field is not None: - # TODO: flat field should not be applied on intra darks - # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'], - # rois[n]['xl']:rois[n]['xh']] - ff = flat_field[rois[n]['yl']:rois[n]['yh'], - rois[n]['xl']:rois[n]['xh']] - r[n] = r[n]/ff + Returns + ------- + res: scipy minimize result. res.x is the optimized parameters - v[n] = r[n].sum(axis=(2, 3)) + firres: iteration index arrays of criteria results for + [criteria] + """ + # load data + assert params.arr is not None, "Data not loaded" + assert params.arr_dark is not None, "Data not loaded" - res = xr.Dataset() + # we only need few rois + fitrois = {} + for k in ['n', '0', 'p', 'sat']: + fitrois[k] = params.rois[k] - dims = ['trainId', 'pulseId'] - r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} - for n in rois.keys(): - res[n + '_sat'] = xr.DataArray(r_sat[n][:, :], - coords=r_coords, dims=dims) - res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims) + p0 = params.get_flat_field() - for n in rois.keys(): - roi = rois[n] - res[n + '_area'] = xr.DataArray(np.array([ - (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])])) + fixed_p = (params, params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), params.sat_level) - return res + def fit_callback(x): + if not hasattr(fit_callback, "counter"): + fit_callback.counter = 0 # it doesn't exist yet, so initialize it + fit_callback.start = time.monotonic() + fit_callback.res = [] + now = time.monotonic() + time_delta = datetime.timedelta(seconds=now-fit_callback.start) + fit_callback.counter += 1 -def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): - """Process dark and run data with corrections. + temp = list(fixed_p) + Jalpha = ff_refine_crit(x, *temp) + fit_callback.res.append([Jalpha]) + print(f'{fit_callback.counter-1}: {time_delta} ' + f'({Jalpha}), {x}') + + return False + + fit_callback(p0) + res = minimize(ff_refine_crit, p0, fixed_p, + options={'disp': True, 'maxiter': params.max_iter}, + callback=fit_callback) + + return res, fit_callback.res + + +# non-linearity related functions + +def nl_domain(N, low, high): + """Create the input domain where the non-linear correction defined. Inputs ------ - Fmodel: correction lookup table - arr_dark: dark data - arr: data - rois: ['n', '0', 'p', 'sat'] rois - mask: mask of good pixels - flat_field: zone plate flat field correction - sat_level: integer, default 511, at which level pixel begin to saturate + N: integer, number of control points or intervals + low: input values below or equal to low will not be corrected + high: input values higher or equal to high will not be corrected + + Returns + ------- + array of 2**9 integer values with N segments + """ + x = np.arange(2**9) + vx = x.copy() + eps = 1e-5 + vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1) + vx[x <= low] = 0 + vx[x >= high] = 0 + + return vx + + +def nl_lut(domain, dy): + """Compute the non-linear correction. + + Inputs + ------ + domain: input domain where dy is defined. For zero no correction is + defined. For non-zero value x, dy[x] is applied. + dy: a vector of deviation from linearity on control point homogeneously + dispersed over 9 bits. Returns ------- - roi extracted intensities + F_INL: default None, non linear correction function given as a + lookup table with 9 bits integer input """ - # dark process - res = average_module(arr_dark, F_INL=Fmodel) - dark = res.compute() + x = np.arange(2**9) + ndy = np.insert(dy, 0, 0) # add zero to dy - # data process - proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level, - flat_field=flat_field, F_INL=Fmodel) - data = proc.compute() + f = x + ndy[domain] - return data + return f def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, @@ -1377,6 +1360,29 @@ def snr(sig, ref, methods=None, verbose=False): return res +def inspect_Fnl(Fnl): + """Plot the correction function Fnl. + + Inputs + ------ + Fnl: non linear correction function lookup table + + Returns + ------- + matplotlib figure + """ + x = np.arange(2**9) + f = plt.figure(figsize=(6, 4)) + + plt.plot(x, Fnl - x) + # plt.axvline(40, c='k', ls='--') + # plt.axvline(280, c='k', ls='--') + plt.xlabel('input value') + plt.ylabel('output correction F(x)-x') + plt.xlim([0, 511]) + + return f + def inspect_correction(params, gain=None): """Criteria for the non linear correction. @@ -1403,7 +1409,7 @@ def inspect_correction(params, gain=None): plane_ff = params.get_flat_field() if plane_ff is None: plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0] - ff = compute_flat_field_correction(params.rois, plane_ff) + ff = compute_flat_field_correction(params.rois, params) # non linearities Fnl = params.get_Fnl() @@ -1497,28 +1503,221 @@ def inspect_correction(params, gain=None): return f -def inspect_Fnl(Fnl): - """Plot the correction function Fnl. +# data processing related functions + +def load_dssc_module(proposalNB, runNB, moduleNB=15, + subset=slice(None), drop_intra_darks=True, persist=False): + """Load single module dssc data as dask array. Inputs ------ - Fnl: non linear correction function lookup table + proposalNB: proposal number + runNB: run number + moduleNB: default 15, module number + subset: default slice(None), subset of trains to load + drop_intra_darks: boolean, default True, remove intra darks from the data + persist: default False, load all data persistently in memory Returns ------- - matplotlib figure + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + tid: array of train id number """ - x = np.arange(2**9) - f = plt.figure(figsize=(6, 4)) + run = open_run(proposal=proposalNB, run=runNB) - plt.plot(x, Fnl - x) - # plt.axvline(40, c='k', ls='--') - # plt.axvline(280, c='k', ls='--') - plt.xlabel('input value') - plt.ylabel('output correction F(x)-x') - plt.xlim([0, 511]) + # DSSC + source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf' + key = 'image.data' - return f + arr = run[source, key][subset].dask_array() + # fix 256 value becoming spuriously 0 instead + arr[arr == 0] = 256 + + ppt = run[source, key][subset].data_counts() + # ignore train with no pulses, can happen in burst mode acquisition + ppt = ppt[ppt > 0] + tid = ppt.index.to_numpy() + + ppt = np.unique(ppt) + assert ppt.shape[0] == 1, "number of pulses changed during the run" + ppt = ppt[0] + + # reshape in trainId, pulseId, 2d-image + arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3]) + + # drop intra darks + if drop_intra_darks: + arr = arr[:, ::2, :, :] + + # load data in memory + if persist: + arr = arr.persist() + + return arr, tid + + +def average_module(arr, dark=None, ret='mean', + mask=None, sat_roi=None, sat_level=300, F_INL=None): + """Compute the average or std over a module. + + Inputs + ------ + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + dark: default None, dark to be substracted + ret: string, either 'mean' to compute the mean or 'std' to compute the + standard deviation + mask: default None, mask of bad pixels to ignore + sat_roi: roi over which to check for pixel with values larger than + sat_level to drop the image from the average or std + sat_level: int, minimum pixel value for a pixel to be considered saturated + F_INL: default None, non linear correction function given as a + lookup table with 9 bits integer input + + Returns + ------- + average or standard deviation image + """ + # F_INL + if F_INL is not None: + narr = arr.map_blocks(lambda x: F_INL[x]) + else: + narr = arr + + if mask is not None: + narr = narr*mask + + if sat_roi is not None: + temp = (da.logical_not(da.any( + narr[:, :, sat_roi['yl']:sat_roi['yh'], + sat_roi['xl']:sat_roi['xh']] >= sat_level, + axis=[2, 3], keepdims=True))) + not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3) + + if dark is not None: + narr = narr - dark + + if ret == 'mean': + if sat_roi is not None: + return da.average(narr, axis=0, weights=not_sat) + else: + return narr.mean(axis=0) + elif ret == 'std': + return narr.std(axis=0) + else: + raise ValueError(f'ret={ret} not supported') + + +def process_module(arr, tid, dark, rois, mask=None, sat_level=511, + flat_field=None, F_INL=None): + """Process one module and extract roi intensity. + + Inputs + ------ + arr: dask array of reshaped dssc data (trainId, pulseId, x, y) + tid: array of train id number + dark: pulse resolved dark image to remove + rois: dictionnary of rois + mask: default None, mask of ignored pixels + sat_level: integer, default 511, at which level pixel begin to saturate + flat_field: default None, flat field correction + F_INL: default None, non linear correction function given as a + lookup table with 9 bits integer input + + Returns + ------- + dataset of extracted pulse and train resolved roi intensities. + """ + # F_INL + if F_INL is not None: + narr = arr.map_blocks(lambda x: F_INL[x]) + else: + narr = arr + + # apply mask + if mask is not None: + narr = narr*mask + + # crop rois + r = {} + rd = {} + for n in rois.keys(): + r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'], + rois[n]['xl']:rois[n]['xh']] + rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'], + rois[n]['xl']:rois[n]['xh']] + + # find saturated shots + r_sat = {} + for n in rois.keys(): + r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3)) + + # TODO: flat field should not be applied on intra darks + # # change flat field dimension to match data + # if flat_field is not None: + # temp = np.ones_like(dark) + # temp[::2, :, :] = flat_field[:, :] + # flat_field = temp + + # compute dark corrected ROI values + v = {} + for n in rois.keys(): + + r[n] = r[n] - rd[n] + + if flat_field is not None: + # TODO: flat field should not be applied on intra darks + # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'], + # rois[n]['xl']:rois[n]['xh']] + ff = flat_field[rois[n]['yl']:rois[n]['yh'], + rois[n]['xl']:rois[n]['xh']] + r[n] = r[n]/ff + + v[n] = r[n].sum(axis=(2, 3)) + + res = xr.Dataset() + + dims = ['trainId', 'pulseId'] + r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} + for n in rois.keys(): + res[n + '_sat'] = xr.DataArray(r_sat[n][:, :], + coords=r_coords, dims=dims) + res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims) + + for n in rois.keys(): + roi = rois[n] + res[n + '_area'] = xr.DataArray(np.array([ + (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])])) + + return res + + +def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): + """Process dark and run data with corrections. + + Inputs + ------ + Fmodel: correction lookup table + arr_dark: dark data + arr: data + rois: ['n', '0', 'p', 'sat'] rois + mask: mask of good pixels + flat_field: zone plate flat field correction + sat_level: integer, default 511, at which level pixel begin to saturate + + Returns + ------- + roi extracted intensities + """ + # dark process + res = average_module(arr_dark, F_INL=Fmodel) + dark = res.compute() + + # data process + proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level, + flat_field=flat_field, F_INL=Fmodel) + data = proc.compute() + + return data def inspect_saturation(data, gain, Nbins=200): """Plot roi integrated histogram of the data with saturation