From c91837dae4694b61a78cab053718951e06ad86be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Tue, 7 May 2024 14:34:29 +0200
Subject: [PATCH] Initial S K-edge normalization

---
 src/toolbox_scs/routines/boz.py | 179 ++++++++++++++++++++++++++++++--
 1 file changed, 173 insertions(+), 6 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 81635de..f8b6018 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -54,11 +54,13 @@ __all__ = [
     'nl_domain',
     'nl_lut',
     'nl_crit',
+    'nl_crit_sk',
     'nl_fit',
     'inspect_nl_fit',
     'snr',
     'inspect_Fnl',
     'inspect_correction',
+    'inspect_correction_sk',
     'load_dssc_module',
     'average_module',
     'process_module',
@@ -1345,13 +1347,56 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
     return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
 
 
-def nl_fit(params, domain):
+def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
+    sat_level=511, use_gpu=False):
+    """Criteria for the non linear correction, combining 'n' and 'p' as reference
+
+    Inputs
+    ------
+    p: vector of dy non linear correction
+    domain: domain over which the non linear correction is defined
+    alpha: float, coefficient scaling the cost of the correction function
+        in the criterion
+    arr_dark: dark data
+    arr: data
+    tid: train id of arr data
+    rois: ['n', '0', 'p', 'sat'] rois
+    mask: mask fo good pixels
+    flat_field: zone plate flat-field correction
+    sat_level: integer, default 511, at which level pixel begin to saturate
+
+    Returns
+    -------
+    (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
+    error squared from a transmission of 1.0 and err2 is the sum of the square
+    of the deviation from the ideal detector response.
+    """
+    Fmodel = nl_lut(domain, p)
+    data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
+        arr, tid, rois, mask, flat_field, sat_level, use_gpu)
+
+    # drop saturated shots
+    d = data.where(data['sat_sat'] == False, drop=True)
+
+    d['np_mean'] = 0.5*(d['n'] + d['p'])
+
+    v = snr(d['np_mean'].values.flatten(), d['0'].values.flatten(),
+            methods=['weighted'])
+    err = 1e8*v['weighted']['s']**2
+
+    err_a = np.sum((Fmodel-np.arange(2**9))**2)
+
+    return (1.0 - alpha)*err + alpha*err_a
+
+def nl_fit(params, domain, ff=None, crit=None):
     """Fit non linearities correction function.
 
     Inputs
     ------
     params: parameters
     domain: array of index
+    ff: array, flat field correction
+    crit: function, criteria function
 
     Returns
     -------
@@ -1374,7 +1419,11 @@ def nl_fit(params, domain):
     p0 = np.array([0]*N)
 
     # flat flat_field
-    ff = compute_flat_field_correction(params.rois, params)
+    if ff is None:
+        ff = compute_flat_field_correction(params.rois, params)
+
+    if crit is None:
+        crit = nl_crit
 
     fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
         fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
@@ -1390,11 +1439,11 @@ def nl_fit(params, domain):
         fit_callback.counter += 1
 
         temp = list(fixed_p)
-        Jalpha = nl_crit(x, *temp)
+        Jalpha = crit(x, *temp)
         temp[1] = 0
-        J0 = nl_crit(x, *temp)
+        J0 = crit(x, *temp)
         temp[1] = 1
-        J1 = nl_crit(x, *temp)
+        J1 = crit(x, *temp)
         fit_callback.res.append([J0, Jalpha, J1])
         print(f'{fit_callback.counter-1}: {time_delta} '
                 f'({J0}, {Jalpha}, {J1}), {x}')
@@ -1402,7 +1451,7 @@ def nl_fit(params, domain):
         return False
 
     fit_callback(p0)
-    res = minimize(nl_crit, p0, fixed_p,
+    res = minimize(crit, p0, fixed_p,
         options={'disp': True, 'maxiter': params.nl_max_iter},
         callback=fit_callback)
 
@@ -1644,6 +1693,124 @@ def inspect_correction(params, gain=None):
 
     return f
 
+def inspect_correction_sk(params, ff, gain=None):
+    """Criteria for the non linear correction.
+
+    Inputs
+    ------
+    params: parameters
+    gain: float, default None, DSSC gain in ph/bin
+
+    Returns
+    -------
+    matplotlib figure
+    """
+    # load data
+    assert params.arr is not None, "Data not loaded"
+    assert params.arr_dark is not None, "Data not loaded"
+
+    # we only need few rois
+    fitrois = {}
+    for k in ['n', '0', 'p', 'sat']:
+        fitrois[k] = params.rois[k]
+
+    # flat flat_field
+    #plane_ff = params.get_flat_field()
+    #if plane_ff is None:
+    #    plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
+    #ff = compute_flat_field_correction(params.rois, params)
+
+    # non linearities
+    Fnl = params.get_Fnl()
+    if Fnl is None:
+        Fnl = np.arange(2**9)
+
+    xp = np if not params._using_gpu else cp
+    # compute all levels of correction
+    data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
+        params._using_gpu)
+    data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+    data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+
+    # for conversion to nb of photons
+    if gain is None:
+        g = 1
+    else:
+        g = gain
+
+    scale = 1e-6
+
+    f, axs = plt.subplots(1, 3, figsize=(8, 2), sharex=True)
+
+    # nbins = np.linspace(0.01, 1.0, 100)
+
+    photon_scale = None
+
+    for k, d in enumerate([data, data_ff, data_ff_nl]):
+        if photon_scale is None:
+            lower = 0
+            upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
+            photon_scale = np.linspace(lower, upper, 150)
+
+        good_d = d.where(d['sat_sat'] == False, drop=True)
+        sat_d = d.where(d['sat_sat'], drop=True)
+
+        good_d['np_mean'] = 0.5*(good_d['n']+good_d['p'])
+        sat_d['np_mean'] = 0.5*(sat_d['n']+sat_d['p'])
+
+        snr_v = snr(good_d['np_mean'].values.flatten(),
+                    good_d['0'].values.flatten(), verbose=True)
+
+        m = snr_v['direct']['mu']
+        h, xedges, yedges, img = axs[k].hist2d(
+            g*scale*good_d['0'].values.flatten(),
+            good_d['np_mean'].values.flatten()/good_d['0'].values.flatten(),
+            [photon_scale, np.linspace(0.95, 1.05, 150)*m],
+            cmap='Blues',
+            norm=LogNorm(vmin=0.2, vmax=200),
+            # alpha=0.5 # make  the plot looks ugly with lots of white lines
+            )
+        h, xedges, yedges, img2 = axs[k].hist2d(
+            g*scale*sat_d['0'].values.flatten(),
+            sat_d['np_mean'].values.flatten()/sat_d['0'].values.flatten(),
+            [photon_scale, np.linspace(0.95, 1.05, 150)*m],
+            cmap='Reds',
+            norm=LogNorm(vmin=0.2, vmax=200),
+            # alpha=0.5 # make  the plot looks ugly with lots of white lines
+            )
+
+        v = snr_v['direct']['mu']/snr_v['direct']['s']
+        axs[k].text(0.4, 0.15, f'SNR: {v:.0f}',
+                        transform = axs[k].transAxes)
+        v = snr_v['weighted']['mu']/snr_v['weighted']['s']
+        axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
+                        transform = axs[k].transAxes)
+
+        # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+        # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+
+        axs[k].set_ylim([0.95*m, 1.05*m])
+
+    for k in range(3):
+        #for l in range(3):
+        #    axs[l, k].set_ylim([0.95, 1.05])
+        if gain:
+            axs[k].set_xlabel('photons (10$^6$)')
+        else:
+            axs[k].set_xlabel('ADU (10$^6$)')
+
+    f.colorbar(img, ax=axs, label='events')
+
+    axs[0].set_title('raw')
+    axs[1].set_title('flat-field')
+    axs[2].set_title('non-linear')
+
+    axs[0].set_ylabel(r'np_mean/0')
+
+    return f
 
 # data processing related functions
 
-- 
GitLab