From c91837dae4694b61a78cab053718951e06ad86be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu> Date: Tue, 7 May 2024 14:34:29 +0200 Subject: [PATCH] Initial S K-edge normalization --- src/toolbox_scs/routines/boz.py | 179 ++++++++++++++++++++++++++++++-- 1 file changed, 173 insertions(+), 6 deletions(-) diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py index 81635de..f8b6018 100644 --- a/src/toolbox_scs/routines/boz.py +++ b/src/toolbox_scs/routines/boz.py @@ -54,11 +54,13 @@ __all__ = [ 'nl_domain', 'nl_lut', 'nl_crit', + 'nl_crit_sk', 'nl_fit', 'inspect_nl_fit', 'snr', 'inspect_Fnl', 'inspect_correction', + 'inspect_correction_sk', 'load_dssc_module', 'average_module', 'process_module', @@ -1345,13 +1347,56 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a -def nl_fit(params, domain): +def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, + sat_level=511, use_gpu=False): + """Criteria for the non linear correction, combining 'n' and 'p' as reference + + Inputs + ------ + p: vector of dy non linear correction + domain: domain over which the non linear correction is defined + alpha: float, coefficient scaling the cost of the correction function + in the criterion + arr_dark: dark data + arr: data + tid: train id of arr data + rois: ['n', '0', 'p', 'sat'] rois + mask: mask fo good pixels + flat_field: zone plate flat-field correction + sat_level: integer, default 511, at which level pixel begin to saturate + + Returns + ------- + (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of + error squared from a transmission of 1.0 and err2 is the sum of the square + of the deviation from the ideal detector response. + """ + Fmodel = nl_lut(domain, p) + data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark, + arr, tid, rois, mask, flat_field, sat_level, use_gpu) + + # drop saturated shots + d = data.where(data['sat_sat'] == False, drop=True) + + d['np_mean'] = 0.5*(d['n'] + d['p']) + + v = snr(d['np_mean'].values.flatten(), d['0'].values.flatten(), + methods=['weighted']) + err = 1e8*v['weighted']['s']**2 + + err_a = np.sum((Fmodel-np.arange(2**9))**2) + + return (1.0 - alpha)*err + alpha*err_a + +def nl_fit(params, domain, ff=None, crit=None): """Fit non linearities correction function. Inputs ------ params: parameters domain: array of index + ff: array, flat field correction + crit: function, criteria function Returns ------- @@ -1374,7 +1419,11 @@ def nl_fit(params, domain): p0 = np.array([0]*N) # flat flat_field - ff = compute_flat_field_correction(params.rois, params) + if ff is None: + ff = compute_flat_field_correction(params.rois, params) + + if crit is None: + crit = nl_crit fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) @@ -1390,11 +1439,11 @@ def nl_fit(params, domain): fit_callback.counter += 1 temp = list(fixed_p) - Jalpha = nl_crit(x, *temp) + Jalpha = crit(x, *temp) temp[1] = 0 - J0 = nl_crit(x, *temp) + J0 = crit(x, *temp) temp[1] = 1 - J1 = nl_crit(x, *temp) + J1 = crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' f'({J0}, {Jalpha}, {J1}), {x}') @@ -1402,7 +1451,7 @@ def nl_fit(params, domain): return False fit_callback(p0) - res = minimize(nl_crit, p0, fixed_p, + res = minimize(crit, p0, fixed_p, options={'disp': True, 'maxiter': params.nl_max_iter}, callback=fit_callback) @@ -1644,6 +1693,124 @@ def inspect_correction(params, gain=None): return f +def inspect_correction_sk(params, ff, gain=None): + """Criteria for the non linear correction. + + Inputs + ------ + params: parameters + gain: float, default None, DSSC gain in ph/bin + + Returns + ------- + matplotlib figure + """ + # load data + assert params.arr is not None, "Data not loaded" + assert params.arr_dark is not None, "Data not loaded" + + # we only need few rois + fitrois = {} + for k in ['n', '0', 'p', 'sat']: + fitrois[k] = params.rois[k] + + # flat flat_field + #plane_ff = params.get_flat_field() + #if plane_ff is None: + # plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0] + #ff = compute_flat_field_correction(params.rois, params) + + # non linearities + Fnl = params.get_Fnl() + if Fnl is None: + Fnl = np.arange(2**9) + + xp = np if not params._using_gpu else cp + # compute all levels of correction + data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level, + params._using_gpu) + data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + + # for conversion to nb of photons + if gain is None: + g = 1 + else: + g = gain + + scale = 1e-6 + + f, axs = plt.subplots(1, 3, figsize=(8, 2), sharex=True) + + # nbins = np.linspace(0.01, 1.0, 100) + + photon_scale = None + + for k, d in enumerate([data, data_ff, data_ff_nl]): + if photon_scale is None: + lower = 0 + upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9) + photon_scale = np.linspace(lower, upper, 150) + + good_d = d.where(d['sat_sat'] == False, drop=True) + sat_d = d.where(d['sat_sat'], drop=True) + + good_d['np_mean'] = 0.5*(good_d['n']+good_d['p']) + sat_d['np_mean'] = 0.5*(sat_d['n']+sat_d['p']) + + snr_v = snr(good_d['np_mean'].values.flatten(), + good_d['0'].values.flatten(), verbose=True) + + m = snr_v['direct']['mu'] + h, xedges, yedges, img = axs[k].hist2d( + g*scale*good_d['0'].values.flatten(), + good_d['np_mean'].values.flatten()/good_d['0'].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], + cmap='Blues', + norm=LogNorm(vmin=0.2, vmax=200), + # alpha=0.5 # make the plot looks ugly with lots of white lines + ) + h, xedges, yedges, img2 = axs[k].hist2d( + g*scale*sat_d['0'].values.flatten(), + sat_d['np_mean'].values.flatten()/sat_d['0'].values.flatten(), + [photon_scale, np.linspace(0.95, 1.05, 150)*m], + cmap='Reds', + norm=LogNorm(vmin=0.2, vmax=200), + # alpha=0.5 # make the plot looks ugly with lots of white lines + ) + + v = snr_v['direct']['mu']/snr_v['direct']['s'] + axs[k].text(0.4, 0.15, f'SNR: {v:.0f}', + transform = axs[k].transAxes) + v = snr_v['weighted']['mu']/snr_v['weighted']['s'] + axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}', + transform = axs[k].transAxes) + + # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') + + axs[k].set_ylim([0.95*m, 1.05*m]) + + for k in range(3): + #for l in range(3): + # axs[l, k].set_ylim([0.95, 1.05]) + if gain: + axs[k].set_xlabel('photons (10$^6$)') + else: + axs[k].set_xlabel('ADU (10$^6$)') + + f.colorbar(img, ax=axs, label='events') + + axs[0].set_title('raw') + axs[1].set_title('flat-field') + axs[2].set_title('non-linear') + + axs[0].set_ylabel(r'np_mean/0') + + return f # data processing related functions -- GitLab