diff --git a/doc/changelog.rst b/doc/changelog.rst index 13b8746efd18d822b71cc4c2862b5a7f2c3d450a..f6f685a29dbf62de735c8cd24e4af59c8be1d411 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -23,6 +23,7 @@ unreleased - **New Features** - fix :issue:`75` add frame counts when aggregating RIXS data :mr:`274` + - BOZ normalization using 2D flat field polylines for S K-edge :mr:`293` 1.7.0 ----- diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py index 81635de62c19209b220316e80b6fe7cf1599e3de..d280b309a12a37ab1d2bb089d674ec6f9e7acf4e 100644 --- a/src/toolbox_scs/routines/boz.py +++ b/src/toolbox_scs/routines/boz.py @@ -1,12 +1,13 @@ """ Beam splitting Off-axis Zone plate analysis routines. -Copyright (2021) SCS Team. +Copyright (2021, 2022, 2023, 2024) SCS Team. """ import time import datetime import json +import warnings import numpy as np import xarray as xr @@ -54,11 +55,13 @@ __all__ = [ 'nl_domain', 'nl_lut', 'nl_crit', + 'nl_crit_sk', 'nl_fit', 'inspect_nl_fit', 'snr', 'inspect_Fnl', 'inspect_correction', + 'inspect_correction_sk', 'load_dssc_module', 'average_module', 'process_module', @@ -99,6 +102,8 @@ class parameters(): self.std_th = (None, None) self.rois = None self.rois_th = None + + self.ff_type = 'plane' self.flat_field = None self.flat_field_prod_th = (5.0, np.PINF) self.flat_field_ratio_th = (np.NINF, 1.2) @@ -121,12 +126,14 @@ class parameters(): self.arr = None self.tid = None - def dask_load_persistently(self, dark_data_size_Gb=None, data_size_Gb=None): + def dask_load_persistently(self, dark_data_size_Gb=None, + data_size_Gb=None): """Load dask data array in memory. Inputs ------ - dark_data_size_Gb: float, optional size of dark to load in memory, in Gb + dark_data_size_Gb: float, optional size of dark to load in memory, + in Gb data_size_Gb: float, optional size of data to load in memory, in Gb """ self.arr_dark, self.tid_dark = load_dssc_module(self.proposal, @@ -216,13 +223,22 @@ class parameters(): return self.plane_guess_fit - def set_flat_field(self, plane, + def set_flat_field(self, ff_params, ff_type='plane', prod_th=None, ratio_th=None): - """Set the flat-field plane definition.""" - if type(plane) is not list: - self.flat_field = plane.tolist() + """Set the flat-field plane definition. + + Inputs + ------ + ff_params: list of parameters + ff_type: string identifying the type of flat field normalization, + default is 'plane'. + """ + self.ff_type = ff_type + if type(ff_params) is not list: + self.flat_field = ff_params.tolist() else: - self.flat_field = plane + self.flat_field = ff_params + if prod_th is not None: self.flat_field_prod_th = prod_th if ratio_th is not None: @@ -274,6 +290,7 @@ class parameters(): v['rois'] = self.rois v['rois_th'] = self.rois_th + v['ff_type'] = self.ff_type v['flat_field'] = self.flat_field v['flat_field_prod_th'] = self.flat_field_prod_th v['flat_field_ratio_th'] = self.flat_field_ratio_th @@ -314,7 +331,10 @@ class parameters(): c.rois = v['rois'] c.rois_th = v['rois_th'] - c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th']) + if 'ff_type' not in v: + v['ff_type'] = 'plane' + c.set_flat_field(v['flat_field'], v['ff_type'], + v['flat_field_prod_th'], v['flat_field_ratio_th']) c.plane_guess_fit = v['plane_guess_fit'] c.use_hex = v['use_hex'] c.force_mirror = v['force_mirror'] @@ -342,7 +362,10 @@ class parameters(): f += f'rois threshold: {self.rois_th}\n' f += f'rois: {self.rois}\n' - f += f'flat-field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n' + f += f'flat-field type: {self.ff_type}\n' + f += f'flat-field p: {self.flat_field} ' + f += f'prod:{self.flat_field_prod_th} ' + f += f'ratio:{self.flat_field_ratio_th}\n' f += f'plane guess fit: {self.plane_guess_fit}\n' f += f'use hexagons: {self.use_hex}\n' f += f'enforce mirror symmetry: {self.force_mirror}\n' @@ -664,7 +687,8 @@ def find_rois(data_mean, threshold, extended=False): # along X lowX = int(np.argmax(pX > threshold) - 1) # 1st occurrence returned - highX = int(pX.shape[0] - np.argmax(pX[::-1] > threshold)) # last occ. returned + highX = int(pX.shape[0] - + np.argmax(pX[::-1] > threshold)) # last occ. returned midX = int(0.5*(lowX+highX)) @@ -676,7 +700,8 @@ def find_rois(data_mean, threshold, extended=False): # along Y lowY = int(np.argmax(pY > threshold) - 1) # 1st occurrence returned - highY = int(pY.shape[0] - np.argmax(pY[::-1] > threshold)) # last occ. returned + highY = int(pY.shape[0] + - np.argmax(pY[::-1] > threshold)) # last occ. returned # define rois rois = {} @@ -854,6 +879,14 @@ def _plane_flat_field(p, roi, params): def compute_flat_field_correction(rois, params, plot=False): + if params.ff_type == 'plane': + return compute_plane_flat_field_correction(rois, params, plot) + elif params.ff_type == 'polyline': + return compute_polyline_flat_field_correction(rois, params, plot) + else: + raise ValueError(f'Uknown flat field type {params.ff_type}') + +def compute_plane_flat_field_correction(rois, params, plot=False): """Compute the plane-field correction on beam rois. Inputs @@ -896,8 +929,123 @@ def compute_flat_field_correction(rois, params, plot=False): return flat_field +def initialize_polyline_ff_correction(avg, rois, params, plot=False): + """Initialize the polyline flat field correction. + + Inputs + ------ + avg: 2D array, average module image + rois: dictionnary of ROIs. + plot: boolean, plot initialized polyline versus data projection + + Returns + ------- + fig: handle to figure or None + """ + refn = avg[rois['n']['yl']:rois['n']['yh'], + rois['n']['xl']:rois['n']['xh']] + refp = avg[rois['p']['yl']:rois['p']['yh'], + rois['p']['xl']:rois['p']['xh']] + mid = avg[rois['0']['yl']:rois['0']['yh'], + rois['0']['xl']:rois['0']['xh']] + + mref = 0.5*(refn + refp) + + inv_signal = mref/mid # normalization + H_projection = inv_signal[:, :].mean(axis=0) + x = np.arange(0, len(H_projection)) + H_z = np.polyfit(x, H_projection, 6) + H_p = np.poly1d(H_z) + + V_projection = (inv_signal/H_p(x))[:, :].mean(axis=1) + y = np.arange(0, len(V_projection)) + V_z = np.polyfit(y, V_projection, 6) + + if plot: + fig, axs = plt.subplots(2, 1, figsize=(4,6)) + axs[0].plot(x, H_projection, label='data (n+p)/2x0') + axs[0].plot(x, H_p(x), label='poly') + axs[0].legend() + axs[0].set_xlabel('x (px)') + axs[0].set_ylabel('H projection') + + axs[1].plot(y, V_projection, label='data (n+p)/2x0') + V_p = np.poly1d(V_z) + axs[1].plot(y, V_p(y), label='poly') + axs[1].legend() + axs[1].set_xlabel('y (px)') + axs[1].set_ylabel('V projection') + else: + fig = None + + # scaling on polynom coefficients for better fitting + ff = np.array([H_z/np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]), + V_z/np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])]) + + params.set_flat_field(ff.flatten()) + params.ff_type = 'polyline' + + return fig -def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None): +def compute_polyline_flat_field_correction(rois, params, plot=False): + """Compute the 1D polyline field correction on beam rois. + + Inputs + ------ + rois: dictionnary of beam rois['n', '0', 'p'] + params: parameters + plot: boolean, True by default, diagnostic plot + + Returns + ------- + numpy 2D array of the flat-field correction evaluated over one DSSC ladder + (2 sensors) + """ + flat_field = np.ones((128, 512)) + + z = np.array(params.get_flat_field()).reshape((2, -1)) + H_z = z[0, :] + V_z = z[1, :] + + coeffs = np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]) + H_p = np.poly1d(H_z*coeffs) + coeffs = np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0]) + V_p = np.poly1d(V_z*coeffs) + + n = rois['n'] + p = rois['p'] + wn = n['xh']-n['xl'] + wp = p['xh']-p['xl'] + assert wn == wp, (\ + f"For polyline flat field normalization, both 'n' and 'p' ROIs " + f"must have the same width {wn} and {wp}px" + ) + x = np.arange(wn) + wn = n['yh']-n['yl'] + y = np.arange(wn) + norm = V_p(y)[:, np.newaxis]*H_p(x) + + n_int = flat_field[n['yl']:n['yh'], n['xl']:n['xh']] + flat_field[n['yl']:n['yh'], n['xl']:n['xh']] = \ + norm*n_int + + p_int = flat_field[p['yl']:p['yh'], p['xl']:p['xh']] + flat_field[p['yl']:p['yh'], p['xl']:p['xh']] = \ + norm*p_int # not the mirror + + if plot: + f, ax = plt.subplots(1, 1, figsize=(6, 2)) + img = ax.pcolormesh( + np.flipud(flat_field[:, :256]), cmap='Greys_r') + f.colorbar(img, ax=[ax], label='amplitude') + ax.set_xlabel('px') + ax.set_ylabel('px') + ax.set_aspect('equal') + + return flat_field + +def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, + vmax=None): """Extract beams roi from average image and compute the ratio. Inputs @@ -954,7 +1102,8 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None if ratio_vmin is None: ratio_vmin = np.percentile(v, 5) ratio_vmax = np.percentile(v, 99.8) - im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r') + im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, + cmap='RdBu_r') axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2)) cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") @@ -972,8 +1121,11 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None return fig, domain - def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None): + warnings.warn("This method is depreciated, use inspect_ff_fitting instead") + return inspect_ff_fitting(avg, rois, domain, vmin, vmax) + +def inspect_ff_fitting(avg, rois, domain=None, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs @@ -1032,6 +1184,89 @@ def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None): return fig +def inspect_ff_fitting_sk(avg, rois, ff, domain=None, vmin=None, vmax=None): + """Extract beams roi from average image and compute the ratio. + + Inputs + ------ + avg: module average image with no saturated shots for the flat-field + determination + rois: dictionnary of rois + ff: 2D array, flat field normalization + domain: list of domain mask for the -1st and +1st order + vmin: imshow vmin level, default None will use 5 percentile value + vmax: imshow vmax level, default None will use 99.8 percentile value + + Returns + ------- + fig: matplotlib figure plotted + """ + if vmin is None: + vmin = np.percentile(avg, 5) + if vmax is None: + vmax = np.percentile(avg, 99.8) + + refn = avg[rois['n']['yl']:rois['n']['yh'], + rois['n']['xl']:rois['n']['xh']] + refp = avg[rois['p']['yl']:rois['p']['yh'], + rois['p']['xl']:rois['p']['xh']] + mid = avg[rois['0']['yl']:rois['0']['yh'], + rois['0']['xl']:rois['0']['xh']] + mref = 0.5*(refn + refp) + + ffn = ff[rois['n']['yl']:rois['n']['yh'], + rois['n']['xl']:rois['n']['xh']] + ffp = ff[rois['p']['yl']:rois['p']['yh'], + rois['p']['xl']:rois['p']['xh']] + ffmid = ff[rois['0']['yl']:rois['0']['yh'], + rois['0']['xl']:rois['0']['xh']] + np_norm = 0.5*(ffn+ffp) + mid_norm = ffmid + + fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, + figsize=(8, 4)) + im = axs[0, 0].imshow(mref) + axs[0, 0].set_title('(n+p)/2') + fig.colorbar(im, ax=axs[0, 0]) + + im = axs[1, 0].imshow(mid) + axs[1, 0].set_title('0') + fig.colorbar(im, ax=axs[1, 0]) + + im = axs[2, 0].imshow(mid/mref-1, cmap='RdBu_r', vmin=-1, vmax=1) + axs[2, 0].set_title('2x0/(n+p) - 1') + fig.colorbar(im, ax=axs[2, 0]) + + im = axs[0, 1].imshow(np_norm) + axs[0, 1].set_title('norm: (n+p)/2') + fig.colorbar(im, ax=axs[0, 1]) + + im = axs[1, 1].imshow(mid_norm) + axs[1, 1].set_title('norm: 0') + fig.colorbar(im, ax=axs[1, 1]) + + im = axs[2, 1].imshow(mid_norm/np_norm-1, cmap='RdBu_r', vmin=-1, vmax=1) + axs[2, 1].set_title('norm: 2x0/(n+p) - 1') + fig.colorbar(im, ax=axs[2, 1]) + + + im = axs[0, 2].imshow(mref/np_norm) + axs[0, 2].set_title('(n+p)/2 /norm') + fig.colorbar(im, ax=axs[0, 2]) + + im = axs[1, 2].imshow(mid/mid_norm) + axs[1, 2].set_title('0 /norm') + fig.colorbar(im, ax=axs[1, 2]) + + im = axs[2, 2].imshow((mid/mid_norm)/(mref/np_norm)-1, + cmap='RdBu_r', vmin=-1, vmax=1) + axs[2, 2].set_title('2x0/(n+p) - 1 /norm') + fig.colorbar(im, ax=axs[2, 2]) + + # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') + + return fig + def plane_fitting_domain(avg, rois, prod_th, ratio_th): """Extract beams roi, compute their ratio and the domain. @@ -1192,8 +1427,47 @@ def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois, return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean) +def ff_refine_crit_sk(p, alpha, params, arr_dark, arr, tid, rois, + mask, sat_level=511): + """Criteria for the ff_refine_fit, combining 'n' and 'p' as reference. + + Inputs + ------ + p: ff plane + params: parameters + arr_dark: dark data + arr: data + tid: train id of arr data + rois: ['n', '0', 'p', 'sat'] rois + mask: mask fo good pixels + sat_level: integer, default 511, at which level pixel begin to saturate + + Returns + ------- + sum of standard deviation on binned 0th order intensity + """ + params.set_flat_field(p, params.ff_type) + ff = compute_flat_field_correction(rois, params) + if np.any(ff < 0.0): + bad = 1e6 + else: + bad = 0.0 + + data = process(None, arr_dark, arr, tid, rois, mask, ff, + sat_level, params._using_gpu) + + # drop saturated shots + d = data.where(data['sat_sat'] == False, drop=True) + + r = xas(d, 40, Iokey='np_mean', Itkey='0', nrjkey='0', fluorescence=True) -def ff_refine_fit(params): + err_sigma = np.nansum(r['sigmaA']) + err_mean = (1.0 - np.nanmean(r['muA']))**2 + + return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean) + + +def ff_refine_fit(params, crit=ff_refine_crit): """Refine the flat-field fit by minimizing data spread. Inputs @@ -1234,19 +1508,19 @@ def ff_refine_fit(params): fit_callback.counter += 1 temp = list(fixed_p) - Jalpha = ff_refine_crit(x, *temp) + Jalpha = crit(x, *temp) temp[0] = 0 - J0 = ff_refine_crit(x, *temp) + J0 = crit(x, *temp) temp[0] = 1 - J1 = ff_refine_crit(x, *temp) + J1 = crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' - f'({J0}, {Jalpha}, {J1}), {x}') + f'(reg. term: {J0}, {Jalpha}, err. term: {J1}), {x}') return False fit_callback(p0) - res = minimize(ff_refine_crit, p0, fixed_p, + res = minimize(crit, p0, fixed_p, options={'disp': True, 'maxiter': params.ff_max_iter}, callback=fit_callback) @@ -1345,13 +1619,54 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a -def nl_fit(params, domain): +def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, + sat_level=511, use_gpu=False): + """Non linear correction criteria, combining 'n' and 'p' as reference. + + Inputs + ------ + p: vector of dy non linear correction + domain: domain over which the non linear correction is defined + alpha: float, coefficient scaling the cost of the correction function + in the criterion + arr_dark: dark data + arr: data + tid: train id of arr data + rois: ['n', '0', 'p', 'sat'] rois + mask: mask fo good pixels + flat_field: zone plate flat-field correction + sat_level: integer, default 511, at which level pixel begin to saturate + + Returns + ------- + (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of + error squared from a transmission of 1.0 and err2 is the sum of the square + of the deviation from the ideal detector response. + """ + Fmodel = nl_lut(domain, p) + data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark, + arr, tid, rois, mask, flat_field, sat_level, use_gpu) + + # drop saturated shots + d = data.where(data['sat_sat'] == False, drop=True) + + v = snr(d['np_mean'].values.flatten(), d['0'].values.flatten(), + methods=['weighted']) + err = 1e8*v['weighted']['s']**2 + + err_a = np.sum((Fmodel-np.arange(2**9))**2) + + return (1.0 - alpha)*err + alpha*err_a + +def nl_fit(params, domain, ff=None, crit=None): """Fit non linearities correction function. Inputs ------ params: parameters domain: array of index + ff: array, flat field correction + crit: function, criteria function Returns ------- @@ -1374,10 +1689,15 @@ def nl_fit(params, domain): p0 = np.array([0]*N) # flat flat_field - ff = compute_flat_field_correction(params.rois, params) + if ff is None: + ff = compute_flat_field_correction(params.rois, params) - fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid, - fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + if crit is None: + crit = nl_crit + + fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, + params.tid, fitrois, params.get_mask(), ff, params.sat_level, + params._using_gpu) def fit_callback(x): if not hasattr(fit_callback, "counter"): @@ -1390,11 +1710,11 @@ def nl_fit(params, domain): fit_callback.counter += 1 temp = list(fixed_p) - Jalpha = nl_crit(x, *temp) + Jalpha = crit(x, *temp) temp[1] = 0 - J0 = nl_crit(x, *temp) + J0 = crit(x, *temp) temp[1] = 1 - J1 = nl_crit(x, *temp) + J1 = crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' f'({J0}, {Jalpha}, {J1}), {x}') @@ -1402,7 +1722,7 @@ def nl_fit(params, domain): return False fit_callback(p0) - res = minimize(nl_crit, p0, fixed_p, + res = minimize(crit, p0, fixed_p, options={'disp': True, 'maxiter': params.nl_max_iter}, callback=fit_callback) @@ -1436,7 +1756,7 @@ def inspect_nl_fit(res_fit): def snr(sig, ref, methods=None, verbose=False): - """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref. + """ Compute mean, std and SNR from transmitted and I0 signals. Inputs ------ @@ -1525,7 +1845,7 @@ def inspect_Fnl(Fnl): def inspect_correction(params, gain=None): - """Criteria for the non linear correction. + """Comparison plot of the different corrections. Inputs ------ @@ -1601,7 +1921,7 @@ def inspect_correction(params, gain=None): [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Blues', norm=LogNorm(vmin=0.2, vmax=200), - # alpha=0.5 # make the plot looks ugly with lots of white lines + # alpha=0.5 # make the plot looks ugly with lots of white lines ) h, xedges, yedges, img2 = axs[l, k].hist2d( g*scale*sat_d[r].values.flatten(), @@ -1609,7 +1929,7 @@ def inspect_correction(params, gain=None): [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Reds', norm=LogNorm(vmin=0.2, vmax=200), - # alpha=0.5 # make the plot looks ugly with lots of white lines + # alpha=0.5 # make the plot looks ugly with lots of white lines ) v = snr_v['direct']['mu']/snr_v['direct']['s'] @@ -1619,8 +1939,8 @@ def inspect_correction(params, gain=None): axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}', transform = axs[l, k].transAxes) - # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') - # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + #axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + #axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') axs[l, k].set_ylim([0.95*m, 1.05*m]) @@ -1644,6 +1964,124 @@ def inspect_correction(params, gain=None): return f +def inspect_correction_sk(params, ff, gain=None): + """Comparison plot of the different corrections, combining 'n' and 'p'. + + Inputs + ------ + params: parameters + gain: float, default None, DSSC gain in ph/bin + + Returns + ------- + matplotlib figure + """ + # load data + assert params.arr is not None, "Data not loaded" + assert params.arr_dark is not None, "Data not loaded" + + # we only need few rois + fitrois = {} + for k in ['n', '0', 'p', 'sat']: + fitrois[k] = params.rois[k] + + # flat flat_field + #plane_ff = params.get_flat_field() + #if plane_ff is None: + # plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0] + #ff = compute_flat_field_correction(params.rois, params) + + # non linearities + Fnl = params.get_Fnl() + if Fnl is None: + Fnl = np.arange(2**9) + + xp = np if not params._using_gpu else cp + # compute all levels of correction + data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level, + params._using_gpu) + data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + + # for conversion to nb of photons + if gain is None: + g = 1 + else: + g = gain + + scale = 1e-6 + + f, axs = plt.subplots(1, 3, figsize=(8, 2), sharex=True) + + # nbins = np.linspace(0.01, 1.0, 100) + + photon_scale = None + + for k, d in enumerate([data, data_ff, data_ff_nl]): + if photon_scale is None: + lower = 0 + upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9) + photon_scale = np.linspace(lower, upper, 150) + + good_d = d.where(d['sat_sat'] == False, drop=True) + sat_d = d.where(d['sat_sat'], drop=True) + + good_d['np_mean'] = 0.5*(good_d['n']+good_d['p']) + sat_d['np_mean'] = 0.5*(sat_d['n']+sat_d['p']) + + snr_v = snr(good_d['np_mean'].values.flatten(), + good_d['0'].values.flatten(), verbose=True) + + m = snr_v['direct']['mu'] + h, xedges, yedges, img = axs[k].hist2d( + g*scale*good_d['0'].values.flatten(), + good_d['np_mean'].values.flatten()/good_d['0'].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], + cmap='Blues', + norm=LogNorm(vmin=0.2, vmax=200), + # alpha=0.5 # make the plot looks ugly with lots of white lines + ) + h, xedges, yedges, img2 = axs[k].hist2d( + g*scale*sat_d['0'].values.flatten(), + sat_d['np_mean'].values.flatten()/sat_d['0'].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], + cmap='Reds', + norm=LogNorm(vmin=0.2, vmax=200), + # alpha=0.5 # make the plot looks ugly with lots of white lines + ) + + v = snr_v['direct']['mu']/snr_v['direct']['s'] + axs[k].text(0.4, 0.15, f'SNR: {v:.0f}', + transform = axs[k].transAxes) + v = snr_v['weighted']['mu']/snr_v['weighted']['s'] + axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}', + transform = axs[k].transAxes) + + # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + + axs[k].set_ylim([0.95*m, 1.05*m]) + + for k in range(3): + #for l in range(3): + # axs[l, k].set_ylim([0.95, 1.05]) + if gain: + axs[k].set_xlabel('photons (10$^6$)') + else: + axs[k].set_xlabel('ADU (10$^6$)') + + f.colorbar(img, ax=axs, label='events') + + axs[0].set_title('raw') + axs[1].set_title('flat-field') + axs[2].set_title('non-linear') + + axs[0].set_ylabel(r'np_mean/0') + + return f # data processing related functions @@ -1829,6 +2267,8 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, # compute dark corrected ROI values v = {} + r_ff = {} + ff = {} for n in rois.keys(): r[n] = r[n] - rd[n] @@ -1837,25 +2277,36 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, # TODO: flat-field should not be applied on intra darks # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'], # rois[n]['xl']:rois[n]['xh']] - ff = flat_field[rois[n]['yl']:rois[n]['yh'], - rois[n]['xl']:rois[n]['xh']] - r[n] = r[n]/ff + ff[n] = flat_field[rois[n]['yl']:rois[n]['yh'], + rois[n]['xl']:rois[n]['xh']] + r_ff[n] = r[n]/ff[n] + else: + ff[n] = 1.0 + r_ff[n] = r[n] - v[n] = r[n].sum(axis=(2, 3)) + v[n] = r_ff[n].sum(axis=(2, 3)) + + # np_mean roi where we normalize the sum of flat_field + np_mean = (r['n'] + r['p'])/(ff['n'] + ff['p']) + v['np_mean'] = np_mean.sum(axis=(2,3)) res = xr.Dataset() dims = ['trainId', 'pulseId'] r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} for n in rois.keys(): + res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims) res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]), coords=r_coords, dims=dims) - res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims) + res['np_mean'] = xr.DataArray(ensure_on_host(v['np_mean']), + coords=r_coords, dims=dims) + res['np_mean_sat'] = res['n_sat'] + res['p_sat'] for n in rois.keys(): roi = rois[n] res[n + '_area'] = xr.DataArray(np.array([ (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])])) + res['np_mean_area'] = res['n_area'] + res['p_area'] return res @@ -1939,8 +2390,9 @@ def inspect_saturation(data, gain, Nbins=200): # compute density normalization on all data norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat'])) - ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm, - facecolor=f"C{kk}", edgecolor='none', alpha=0.2) + ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, + h[k+'_nosat']/norm, facecolor=f"C{kk}", + edgecolor='none', alpha=0.2) ax.plot(bins_c, h[k+'_nosat']/norm, label=k, c=f'C{kk}', alpha=0.4)