diff --git a/doc/changelog.rst b/doc/changelog.rst
index 13b8746efd18d822b71cc4c2862b5a7f2c3d450a..f6f685a29dbf62de735c8cd24e4af59c8be1d411 100644
--- a/doc/changelog.rst
+++ b/doc/changelog.rst
@@ -23,6 +23,7 @@ unreleased
 - **New Features**
 
     - fix :issue:`75` add frame counts when aggregating RIXS data :mr:`274`
+    - BOZ normalization using 2D flat field polylines for S K-edge :mr:`293`
 
 1.7.0
 -----
diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 81635de62c19209b220316e80b6fe7cf1599e3de..d280b309a12a37ab1d2bb089d674ec6f9e7acf4e 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -1,12 +1,13 @@
 """
 Beam splitting Off-axis Zone plate analysis routines.
 
-Copyright (2021) SCS Team.
+Copyright (2021, 2022, 2023, 2024) SCS Team.
 """
 
 import time
 import datetime
 import json
+import warnings
 
 import numpy as np
 import xarray as xr
@@ -54,11 +55,13 @@ __all__ = [
     'nl_domain',
     'nl_lut',
     'nl_crit',
+    'nl_crit_sk',
     'nl_fit',
     'inspect_nl_fit',
     'snr',
     'inspect_Fnl',
     'inspect_correction',
+    'inspect_correction_sk',
     'load_dssc_module',
     'average_module',
     'process_module',
@@ -99,6 +102,8 @@ class parameters():
         self.std_th = (None, None)
         self.rois = None
         self.rois_th = None
+
+        self.ff_type = 'plane'
         self.flat_field = None
         self.flat_field_prod_th = (5.0, np.PINF)
         self.flat_field_ratio_th = (np.NINF, 1.2)
@@ -121,12 +126,14 @@ class parameters():
         self.arr = None
         self.tid = None
 
-    def dask_load_persistently(self, dark_data_size_Gb=None, data_size_Gb=None):
+    def dask_load_persistently(self, dark_data_size_Gb=None,
+                               data_size_Gb=None):
         """Load dask data array in memory.
 
         Inputs
         ------
-        dark_data_size_Gb: float, optional size of dark to load in memory, in Gb
+        dark_data_size_Gb: float, optional size of dark to load in memory,
+                           in Gb
         data_size_Gb: float, optional size of data to load in memory, in Gb
         """
         self.arr_dark, self.tid_dark = load_dssc_module(self.proposal,
@@ -216,13 +223,22 @@ class parameters():
         return self.plane_guess_fit
 
 
-    def set_flat_field(self, plane,
+    def set_flat_field(self, ff_params, ff_type='plane',
             prod_th=None, ratio_th=None):
-        """Set the flat-field plane definition."""
-        if type(plane) is not list:
-            self.flat_field = plane.tolist()
+        """Set the flat-field plane definition.
+
+        Inputs
+        ------
+        ff_params: list of parameters
+        ff_type: string identifying the type of flat field normalization,
+            default is 'plane'.
+        """
+        self.ff_type = ff_type
+        if type(ff_params) is not list:
+            self.flat_field = ff_params.tolist()
         else:
-            self.flat_field = plane
+            self.flat_field = ff_params
+
         if prod_th is not None:
             self.flat_field_prod_th = prod_th
         if ratio_th is not None:
@@ -274,6 +290,7 @@ class parameters():
         v['rois'] = self.rois
         v['rois_th'] = self.rois_th
 
+        v['ff_type'] = self.ff_type
         v['flat_field'] = self.flat_field
         v['flat_field_prod_th'] = self.flat_field_prod_th
         v['flat_field_ratio_th'] = self.flat_field_ratio_th
@@ -314,7 +331,10 @@ class parameters():
         c.rois = v['rois']
         c.rois_th = v['rois_th']
 
-        c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th'])
+        if 'ff_type' not in v:
+            v['ff_type'] = 'plane'
+        c.set_flat_field(v['flat_field'], v['ff_type'],
+            v['flat_field_prod_th'], v['flat_field_ratio_th'])
         c.plane_guess_fit = v['plane_guess_fit']
         c.use_hex = v['use_hex']
         c.force_mirror = v['force_mirror']
@@ -342,7 +362,10 @@ class parameters():
         f += f'rois threshold: {self.rois_th}\n'
         f += f'rois: {self.rois}\n'
 
-        f += f'flat-field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
+        f += f'flat-field type: {self.ff_type}\n'
+        f += f'flat-field p: {self.flat_field} '
+        f += f'prod:{self.flat_field_prod_th} '
+        f += f'ratio:{self.flat_field_ratio_th}\n'
         f += f'plane guess fit: {self.plane_guess_fit}\n'
         f += f'use hexagons: {self.use_hex}\n'
         f += f'enforce mirror symmetry: {self.force_mirror}\n'
@@ -664,7 +687,8 @@ def find_rois(data_mean, threshold, extended=False):
 
     # along X
     lowX = int(np.argmax(pX > threshold) - 1)  # 1st occurrence returned
-    highX = int(pX.shape[0] - np.argmax(pX[::-1] > threshold))  # last occ. returned
+    highX = int(pX.shape[0] -
+                np.argmax(pX[::-1] > threshold))  # last occ. returned
 
     midX = int(0.5*(lowX+highX))
 
@@ -676,7 +700,8 @@ def find_rois(data_mean, threshold, extended=False):
 
     # along Y
     lowY = int(np.argmax(pY > threshold) - 1)  # 1st occurrence returned
-    highY = int(pY.shape[0] - np.argmax(pY[::-1] > threshold))  # last occ. returned
+    highY = int(pY.shape[0]
+                - np.argmax(pY[::-1] > threshold))  # last occ. returned
 
     # define rois
     rois = {}
@@ -854,6 +879,14 @@ def _plane_flat_field(p, roi, params):
 
 
 def compute_flat_field_correction(rois, params, plot=False):
+    if params.ff_type == 'plane':
+        return compute_plane_flat_field_correction(rois, params, plot)
+    elif params.ff_type == 'polyline':
+        return compute_polyline_flat_field_correction(rois, params, plot)
+    else:
+        raise ValueError(f'Uknown flat field type {params.ff_type}')
+
+def compute_plane_flat_field_correction(rois, params, plot=False):
     """Compute the plane-field correction on beam rois.
 
     Inputs
@@ -896,8 +929,123 @@ def compute_flat_field_correction(rois, params, plot=False):
 
     return flat_field
 
+def initialize_polyline_ff_correction(avg, rois, params, plot=False):
+    """Initialize the polyline flat field correction.
+
+    Inputs
+    ------
+    avg: 2D array, average module image
+    rois: dictionnary of ROIs.
+    plot: boolean, plot initialized polyline versus data projection
+
+    Returns
+    -------
+    fig: handle to figure or None
+    """
+    refn = avg[rois['n']['yl']:rois['n']['yh'],
+             rois['n']['xl']:rois['n']['xh']]
+    refp = avg[rois['p']['yl']:rois['p']['yh'],
+             rois['p']['xl']:rois['p']['xh']]
+    mid = avg[rois['0']['yl']:rois['0']['yh'],
+             rois['0']['xl']:rois['0']['xh']]
+
+    mref = 0.5*(refn + refp)
+
+    inv_signal = mref/mid # normalization
+    H_projection = inv_signal[:, :].mean(axis=0)
+    x = np.arange(0, len(H_projection))
+    H_z = np.polyfit(x, H_projection, 6)
+    H_p = np.poly1d(H_z)
+
+    V_projection = (inv_signal/H_p(x))[:, :].mean(axis=1)
+    y = np.arange(0, len(V_projection))
+    V_z = np.polyfit(y, V_projection, 6)
+
+    if plot:
+        fig, axs = plt.subplots(2, 1, figsize=(4,6))
+        axs[0].plot(x, H_projection, label='data (n+p)/2x0')
+        axs[0].plot(x, H_p(x), label='poly')
+        axs[0].legend()
+        axs[0].set_xlabel('x (px)')
+        axs[0].set_ylabel('H projection')
+
+        axs[1].plot(y, V_projection, label='data (n+p)/2x0')
+        V_p = np.poly1d(V_z)
+        axs[1].plot(y, V_p(y), label='poly')
+        axs[1].legend()
+        axs[1].set_xlabel('y (px)')
+        axs[1].set_ylabel('V projection')
+    else:
+        fig = None
+
+    # scaling on polynom coefficients for better fitting
+    ff = np.array([H_z/np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]),
+                   V_z/np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])])
+
+    params.set_flat_field(ff.flatten())
+    params.ff_type = 'polyline'
+
+    return fig
 
-def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
+def compute_polyline_flat_field_correction(rois, params, plot=False):
+    """Compute the 1D polyline field correction on beam rois.
+
+    Inputs
+    ------
+    rois: dictionnary of beam rois['n', '0', 'p']
+    params: parameters
+    plot: boolean, True by default, diagnostic plot
+
+    Returns
+    -------
+    numpy 2D array of the flat-field correction evaluated over one DSSC ladder
+    (2 sensors)
+    """
+    flat_field = np.ones((128, 512))
+
+    z = np.array(params.get_flat_field()).reshape((2, -1))
+    H_z = z[0, :]
+    V_z = z[1, :]
+    
+    coeffs = np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0])
+    H_p = np.poly1d(H_z*coeffs)
+    coeffs = np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])
+    V_p = np.poly1d(V_z*coeffs)
+
+    n = rois['n']
+    p = rois['p']
+    wn = n['xh']-n['xl']
+    wp = p['xh']-p['xl']
+    assert wn == wp, (\
+        f"For polyline flat field normalization, both 'n' and 'p' ROIs "
+        f"must have the same width {wn} and {wp}px"
+    )
+    x = np.arange(wn)
+    wn = n['yh']-n['yl']
+    y = np.arange(wn)
+    norm = V_p(y)[:, np.newaxis]*H_p(x)
+    
+    n_int = flat_field[n['yl']:n['yh'], n['xl']:n['xh']]
+    flat_field[n['yl']:n['yh'], n['xl']:n['xh']]  = \
+                        norm*n_int
+    
+    p_int = flat_field[p['yl']:p['yh'], p['xl']:p['xh']]
+    flat_field[p['yl']:p['yh'], p['xl']:p['xh']] = \
+                        norm*p_int  # not the mirror
+
+    if plot:
+        f, ax = plt.subplots(1, 1, figsize=(6, 2))
+        img = ax.pcolormesh(
+            np.flipud(flat_field[:, :256]), cmap='Greys_r')
+        f.colorbar(img, ax=[ax], label='amplitude')
+        ax.set_xlabel('px')
+        ax.set_ylabel('px')
+        ax.set_aspect('equal')
+
+    return flat_field
+
+def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None,
+                              vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
     Inputs
@@ -954,7 +1102,8 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None
         if ratio_vmin is None:
             ratio_vmin = np.percentile(v, 5)
             ratio_vmax = np.percentile(v, 99.8)
-        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
+        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax,
+                               cmap='RdBu_r')
         axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
 
     cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
@@ -972,8 +1121,11 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None
 
     return fig, domain
 
-
 def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None):
+    warnings.warn("This method is depreciated, use inspect_ff_fitting instead")
+    return inspect_ff_fitting(avg, rois, domain, vmin, vmax)
+
+def inspect_ff_fitting(avg, rois, domain=None, vmin=None, vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
     Inputs
@@ -1032,6 +1184,89 @@ def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None):
 
     return fig
 
+def inspect_ff_fitting_sk(avg, rois, ff, domain=None, vmin=None, vmax=None):
+    """Extract beams roi from average image and compute the ratio.
+
+    Inputs
+    ------
+    avg: module average image with no saturated shots for the flat-field
+         determination
+    rois: dictionnary of rois
+    ff: 2D array, flat field normalization
+    domain: list of domain mask for the -1st and +1st order
+    vmin: imshow vmin level, default None will use 5 percentile value
+    vmax: imshow vmax level, default None will use 99.8 percentile value
+
+    Returns
+    -------
+    fig: matplotlib figure plotted
+    """
+    if vmin is None:
+        vmin = np.percentile(avg, 5)
+    if vmax is None:
+        vmax = np.percentile(avg, 99.8)
+
+    refn = avg[rois['n']['yl']:rois['n']['yh'],
+               rois['n']['xl']:rois['n']['xh']]
+    refp = avg[rois['p']['yl']:rois['p']['yh'],
+               rois['p']['xl']:rois['p']['xh']]
+    mid = avg[rois['0']['yl']:rois['0']['yh'],
+              rois['0']['xl']:rois['0']['xh']]
+    mref = 0.5*(refn + refp)
+
+    ffn = ff[rois['n']['yl']:rois['n']['yh'],
+             rois['n']['xl']:rois['n']['xh']]
+    ffp = ff[rois['p']['yl']:rois['p']['yh'],
+             rois['p']['xl']:rois['p']['xh']]
+    ffmid = ff[rois['0']['yl']:rois['0']['yh'],
+               rois['0']['xl']:rois['0']['xh']]
+    np_norm = 0.5*(ffn+ffp)
+    mid_norm = ffmid
+
+    fig, axs = plt.subplots(3, 3, sharex=True, sharey=True,
+        figsize=(8, 4))
+    im = axs[0, 0].imshow(mref)
+    axs[0, 0].set_title('(n+p)/2')
+    fig.colorbar(im, ax=axs[0, 0])
+
+    im = axs[1, 0].imshow(mid)
+    axs[1, 0].set_title('0')
+    fig.colorbar(im, ax=axs[1, 0])
+
+    im = axs[2, 0].imshow(mid/mref-1, cmap='RdBu_r', vmin=-1, vmax=1)
+    axs[2, 0].set_title('2x0/(n+p) - 1')
+    fig.colorbar(im, ax=axs[2, 0])
+
+    im = axs[0, 1].imshow(np_norm)
+    axs[0, 1].set_title('norm: (n+p)/2')
+    fig.colorbar(im, ax=axs[0, 1])
+
+    im = axs[1, 1].imshow(mid_norm)
+    axs[1, 1].set_title('norm: 0')
+    fig.colorbar(im, ax=axs[1, 1])
+
+    im = axs[2, 1].imshow(mid_norm/np_norm-1, cmap='RdBu_r', vmin=-1, vmax=1)
+    axs[2, 1].set_title('norm: 2x0/(n+p) - 1')
+    fig.colorbar(im, ax=axs[2, 1])
+
+
+    im = axs[0, 2].imshow(mref/np_norm)
+    axs[0, 2].set_title('(n+p)/2 /norm')
+    fig.colorbar(im, ax=axs[0, 2])
+
+    im = axs[1, 2].imshow(mid/mid_norm)
+    axs[1, 2].set_title('0 /norm')
+    fig.colorbar(im, ax=axs[1, 2])
+
+    im = axs[2, 2].imshow((mid/mid_norm)/(mref/np_norm)-1,
+                          cmap='RdBu_r', vmin=-1, vmax=1)
+    axs[2, 2].set_title('2x0/(n+p) - 1 /norm')
+    fig.colorbar(im, ax=axs[2, 2])
+
+    # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
+
+    return fig
+
 
 def plane_fitting_domain(avg, rois, prod_th, ratio_th):
     """Extract beams roi, compute their ratio and the domain.
@@ -1192,8 +1427,47 @@ def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
 
     return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
 
+def ff_refine_crit_sk(p, alpha, params, arr_dark, arr, tid, rois,
+    mask, sat_level=511):
+    """Criteria for the ff_refine_fit, combining 'n' and 'p' as reference.
+
+    Inputs
+    ------
+    p: ff plane
+    params: parameters
+    arr_dark: dark data
+    arr: data
+    tid: train id of arr data
+    rois: ['n', '0', 'p', 'sat'] rois
+    mask: mask fo good pixels
+    sat_level: integer, default 511, at which level pixel begin to saturate
+
+    Returns
+    -------
+    sum of standard deviation on binned 0th order intensity
+    """
+    params.set_flat_field(p, params.ff_type)
+    ff = compute_flat_field_correction(rois, params)
+    if np.any(ff < 0.0):
+        bad = 1e6
+    else:
+        bad = 0.0
+
+    data = process(None, arr_dark, arr, tid, rois, mask, ff,
+        sat_level, params._using_gpu)
+
+    # drop saturated shots
+    d = data.where(data['sat_sat'] == False, drop=True)
+
+    r = xas(d, 40, Iokey='np_mean', Itkey='0', nrjkey='0', fluorescence=True)
 
-def ff_refine_fit(params):
+    err_sigma = np.nansum(r['sigmaA'])
+    err_mean = (1.0 - np.nanmean(r['muA']))**2
+
+    return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
+
+
+def ff_refine_fit(params, crit=ff_refine_crit):
     """Refine the flat-field fit by minimizing data spread.
 
     Inputs
@@ -1234,19 +1508,19 @@ def ff_refine_fit(params):
         fit_callback.counter += 1
 
         temp = list(fixed_p)
-        Jalpha = ff_refine_crit(x, *temp)
+        Jalpha = crit(x, *temp)
         temp[0] = 0
-        J0 = ff_refine_crit(x, *temp)
+        J0 = crit(x, *temp)
         temp[0] = 1
-        J1 = ff_refine_crit(x, *temp)
+        J1 = crit(x, *temp)
         fit_callback.res.append([J0, Jalpha, J1])
         print(f'{fit_callback.counter-1}: {time_delta} '
-                f'({J0}, {Jalpha}, {J1}), {x}')
+                f'(reg. term: {J0}, {Jalpha}, err. term: {J1}), {x}')
 
         return False
 
     fit_callback(p0)
-    res = minimize(ff_refine_crit, p0, fixed_p,
+    res = minimize(crit, p0, fixed_p,
         options={'disp': True, 'maxiter': params.ff_max_iter},
         callback=fit_callback)
 
@@ -1345,13 +1619,54 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
     return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
 
 
-def nl_fit(params, domain):
+def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
+    sat_level=511, use_gpu=False):
+    """Non linear correction criteria, combining 'n' and 'p' as reference.
+
+    Inputs
+    ------
+    p: vector of dy non linear correction
+    domain: domain over which the non linear correction is defined
+    alpha: float, coefficient scaling the cost of the correction function
+        in the criterion
+    arr_dark: dark data
+    arr: data
+    tid: train id of arr data
+    rois: ['n', '0', 'p', 'sat'] rois
+    mask: mask fo good pixels
+    flat_field: zone plate flat-field correction
+    sat_level: integer, default 511, at which level pixel begin to saturate
+
+    Returns
+    -------
+    (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
+    error squared from a transmission of 1.0 and err2 is the sum of the square
+    of the deviation from the ideal detector response.
+    """
+    Fmodel = nl_lut(domain, p)
+    data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
+        arr, tid, rois, mask, flat_field, sat_level, use_gpu)
+
+    # drop saturated shots
+    d = data.where(data['sat_sat'] == False, drop=True)
+
+    v = snr(d['np_mean'].values.flatten(), d['0'].values.flatten(),
+            methods=['weighted'])
+    err = 1e8*v['weighted']['s']**2
+
+    err_a = np.sum((Fmodel-np.arange(2**9))**2)
+
+    return (1.0 - alpha)*err + alpha*err_a
+
+def nl_fit(params, domain, ff=None, crit=None):
     """Fit non linearities correction function.
 
     Inputs
     ------
     params: parameters
     domain: array of index
+    ff: array, flat field correction
+    crit: function, criteria function
 
     Returns
     -------
@@ -1374,10 +1689,15 @@ def nl_fit(params, domain):
     p0 = np.array([0]*N)
 
     # flat flat_field
-    ff = compute_flat_field_correction(params.rois, params)
+    if ff is None:
+        ff = compute_flat_field_correction(params.rois, params)
 
-    fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+    if crit is None:
+        crit = nl_crit
+
+    fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr,
+               params.tid, fitrois, params.get_mask(), ff, params.sat_level,
+               params._using_gpu)
 
     def fit_callback(x):
         if not hasattr(fit_callback, "counter"):
@@ -1390,11 +1710,11 @@ def nl_fit(params, domain):
         fit_callback.counter += 1
 
         temp = list(fixed_p)
-        Jalpha = nl_crit(x, *temp)
+        Jalpha = crit(x, *temp)
         temp[1] = 0
-        J0 = nl_crit(x, *temp)
+        J0 = crit(x, *temp)
         temp[1] = 1
-        J1 = nl_crit(x, *temp)
+        J1 = crit(x, *temp)
         fit_callback.res.append([J0, Jalpha, J1])
         print(f'{fit_callback.counter-1}: {time_delta} '
                 f'({J0}, {Jalpha}, {J1}), {x}')
@@ -1402,7 +1722,7 @@ def nl_fit(params, domain):
         return False
 
     fit_callback(p0)
-    res = minimize(nl_crit, p0, fixed_p,
+    res = minimize(crit, p0, fixed_p,
         options={'disp': True, 'maxiter': params.nl_max_iter},
         callback=fit_callback)
 
@@ -1436,7 +1756,7 @@ def inspect_nl_fit(res_fit):
 
 
 def snr(sig, ref, methods=None, verbose=False):
-    """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref.
+    """ Compute mean, std and SNR from transmitted and I0 signals.
 
     Inputs
     ------
@@ -1525,7 +1845,7 @@ def inspect_Fnl(Fnl):
 
 
 def inspect_correction(params, gain=None):
-    """Criteria for the non linear correction.
+    """Comparison plot of the different corrections.
 
     Inputs
     ------
@@ -1601,7 +1921,7 @@ def inspect_correction(params, gain=None):
                 [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Blues',
                 norm=LogNorm(vmin=0.2, vmax=200),
-                # alpha=0.5 # make  the plot looks ugly with lots of white lines
+                # alpha=0.5 # make the plot looks ugly with lots of white lines
                 )
             h, xedges, yedges, img2 = axs[l, k].hist2d(
                 g*scale*sat_d[r].values.flatten(),
@@ -1609,7 +1929,7 @@ def inspect_correction(params, gain=None):
                 [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Reds',
                 norm=LogNorm(vmin=0.2, vmax=200),
-                # alpha=0.5 # make  the plot looks ugly with lots of white lines
+                # alpha=0.5 # make the plot looks ugly with lots of white lines
                 )
 
             v = snr_v['direct']['mu']/snr_v['direct']['s']
@@ -1619,8 +1939,8 @@ def inspect_correction(params, gain=None):
             axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
                             transform = axs[l, k].transAxes)
 
-            # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
-            # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+            #axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+            #axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
 
             axs[l, k].set_ylim([0.95*m, 1.05*m])
 
@@ -1644,6 +1964,124 @@ def inspect_correction(params, gain=None):
 
     return f
 
+def inspect_correction_sk(params, ff, gain=None):
+    """Comparison plot of the different corrections, combining 'n' and 'p'.
+
+    Inputs
+    ------
+    params: parameters
+    gain: float, default None, DSSC gain in ph/bin
+
+    Returns
+    -------
+    matplotlib figure
+    """
+    # load data
+    assert params.arr is not None, "Data not loaded"
+    assert params.arr_dark is not None, "Data not loaded"
+
+    # we only need few rois
+    fitrois = {}
+    for k in ['n', '0', 'p', 'sat']:
+        fitrois[k] = params.rois[k]
+
+    # flat flat_field
+    #plane_ff = params.get_flat_field()
+    #if plane_ff is None:
+    #    plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
+    #ff = compute_flat_field_correction(params.rois, params)
+
+    # non linearities
+    Fnl = params.get_Fnl()
+    if Fnl is None:
+        Fnl = np.arange(2**9)
+
+    xp = np if not params._using_gpu else cp
+    # compute all levels of correction
+    data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
+        params._using_gpu)
+    data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+    data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+
+    # for conversion to nb of photons
+    if gain is None:
+        g = 1
+    else:
+        g = gain
+
+    scale = 1e-6
+
+    f, axs = plt.subplots(1, 3, figsize=(8, 2), sharex=True)
+
+    # nbins = np.linspace(0.01, 1.0, 100)
+
+    photon_scale = None
+
+    for k, d in enumerate([data, data_ff, data_ff_nl]):
+        if photon_scale is None:
+            lower = 0
+            upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
+            photon_scale = np.linspace(lower, upper, 150)
+
+        good_d = d.where(d['sat_sat'] == False, drop=True)
+        sat_d = d.where(d['sat_sat'], drop=True)
+
+        good_d['np_mean'] = 0.5*(good_d['n']+good_d['p'])
+        sat_d['np_mean'] = 0.5*(sat_d['n']+sat_d['p'])
+
+        snr_v = snr(good_d['np_mean'].values.flatten(),
+                    good_d['0'].values.flatten(), verbose=True)
+
+        m = snr_v['direct']['mu']
+        h, xedges, yedges, img = axs[k].hist2d(
+            g*scale*good_d['0'].values.flatten(),
+            good_d['np_mean'].values.flatten()/good_d['0'].values.flatten(),
+            [photon_scale, np.linspace(0.95, 1.05, 150)*m],
+            cmap='Blues',
+            norm=LogNorm(vmin=0.2, vmax=200),
+            # alpha=0.5 # make  the plot looks ugly with lots of white lines
+            )
+        h, xedges, yedges, img2 = axs[k].hist2d(
+            g*scale*sat_d['0'].values.flatten(),
+            sat_d['np_mean'].values.flatten()/sat_d['0'].values.flatten(),
+            [photon_scale, np.linspace(0.95, 1.05, 150)*m],
+            cmap='Reds',
+            norm=LogNorm(vmin=0.2, vmax=200),
+            # alpha=0.5 # make  the plot looks ugly with lots of white lines
+            )
+
+        v = snr_v['direct']['mu']/snr_v['direct']['s']
+        axs[k].text(0.4, 0.15, f'SNR: {v:.0f}',
+                        transform = axs[k].transAxes)
+        v = snr_v['weighted']['mu']/snr_v['weighted']['s']
+        axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
+                        transform = axs[k].transAxes)
+
+        # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+        # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+
+        axs[k].set_ylim([0.95*m, 1.05*m])
+
+    for k in range(3):
+        #for l in range(3):
+        #    axs[l, k].set_ylim([0.95, 1.05])
+        if gain:
+            axs[k].set_xlabel('photons (10$^6$)')
+        else:
+            axs[k].set_xlabel('ADU (10$^6$)')
+
+    f.colorbar(img, ax=axs, label='events')
+
+    axs[0].set_title('raw')
+    axs[1].set_title('flat-field')
+    axs[2].set_title('non-linear')
+
+    axs[0].set_ylabel(r'np_mean/0')
+
+    return f
 
 # data processing related functions
 
@@ -1829,6 +2267,8 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
 
     # compute dark corrected ROI values
     v = {}
+    r_ff = {}
+    ff = {}
     for n in rois.keys():
 
         r[n] = r[n] - rd[n]
@@ -1837,25 +2277,36 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
             # TODO:  flat-field should not be applied on intra darks
             # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
             #                 rois[n]['xl']:rois[n]['xh']]
-            ff = flat_field[rois[n]['yl']:rois[n]['yh'],
-                            rois[n]['xl']:rois[n]['xh']]
-            r[n] = r[n]/ff
+            ff[n] = flat_field[rois[n]['yl']:rois[n]['yh'],
+                        rois[n]['xl']:rois[n]['xh']]
+            r_ff[n] = r[n]/ff[n]
+        else:
+            ff[n] = 1.0
+            r_ff[n] = r[n]
 
-        v[n] = r[n].sum(axis=(2, 3))
+        v[n] = r_ff[n].sum(axis=(2, 3))
+    
+    # np_mean roi where we normalize the sum of flat_field
+    np_mean  = (r['n'] + r['p'])/(ff['n'] + ff['p'])
+    v['np_mean'] = np_mean.sum(axis=(2,3))
 
     res = xr.Dataset()
 
     dims = ['trainId', 'pulseId']
     r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
     for n in rois.keys():
+        res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims)
         res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]),
                                        coords=r_coords, dims=dims)
-        res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims)
+    res['np_mean'] = xr.DataArray(ensure_on_host(v['np_mean']),
+                                  coords=r_coords, dims=dims)
+    res['np_mean_sat'] = res['n_sat'] + res['p_sat']
 
     for n in rois.keys():
         roi = rois[n]
         res[n + '_area'] = xr.DataArray(np.array([
             (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
+    res['np_mean_area'] = res['n_area'] + res['p_area']
 
     return res
 
@@ -1939,8 +2390,9 @@ def inspect_saturation(data, gain, Nbins=200):
         # compute density normalization on all data
         norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat']))
     
-        ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm,
-                        facecolor=f"C{kk}", edgecolor='none', alpha=0.2)
+        ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm,
+                        h[k+'_nosat']/norm, facecolor=f"C{kk}",
+                        edgecolor='none', alpha=0.2)
 
         ax.plot(bins_c, h[k+'_nosat']/norm, label=k,
                 c=f'C{kk}', alpha=0.4)