diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 5703ce3c1c0d055ec7118ce71eb352de07f70c78..807f8d200635b30da90bdbe8dc8c3e47a4a3f697 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -16,10 +16,14 @@ from scipy.optimize import minimize
 import matplotlib.pyplot as plt
 from matplotlib.colors import LogNorm
 from matplotlib import cm
+from matplotlib.patches import Polygon
+from matplotlib.collections import PatchCollection
 
 from extra_data import open_run
 from extra_geom import DSSC_1MGeometry
 
+from toolbox_scs import xas
+
 __all__ = [
     'load_dssc_module',
     'inspect_dark',
@@ -53,7 +57,9 @@ class parameters():
         self.darkrun = darkrun
         self.run = run
         self.module = module
+        self.pixel_pos = _get_pixel_pos(self.module)
         self.gain = gain
+
         self.mask_idx = None
         self.mean_th = (None, None)
         self.std_th = (None, None)
@@ -63,6 +69,8 @@ class parameters():
         self.flat_field_prod_th = (5.0, np.PINF)
         self.flat_field_ratio_th = (np.NINF, 1.2)
         self.plane_guess_fit = None
+        self.use_hex = False
+        self.force_mirror = True
         self.Fnl = None
         self.alpha = None
         self.sat_level = None
@@ -169,6 +177,8 @@ class parameters():
         v['flat_field_prod_th'] = self.flat_field_prod_th
         v['flat_field_ratio_th'] = self.flat_field_ratio_th
         v['plane_guess_fit'] = self.plane_guess_fit
+        v['use_hex'] = self.use_hex
+        v['force_mirror'] = self.force_mirror
 
         v['Fnl'] = self.Fnl
         v['alpha'] = self.alpha
@@ -202,6 +212,8 @@ class parameters():
 
         c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th'])
         c.plane_guess_fit = v['plane_guess_fit']
+        c.use_hex = v['use_hex']
+        c.force_mirror = v['force_mirror']
 
         c.set_Fnl(v['Fnl'])
         c.alpha = v['alpha']
@@ -225,6 +237,8 @@ class parameters():
 
         f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
         f += f'plane guess fit: {self.plane_guess_fit}\n'
+        f += f'use hexagons: {self.use_hex}\n'
+        f += f'enforce mirror symmetry: {self.force_mirror}\n'
 
         if self.Fnl is not None:
             f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
@@ -235,124 +249,248 @@ class parameters():
 
         return f
 
-def _get_pixel_pos():
-    """Compute the pixel position on hexagonal lattice of DSSC module 15"""
+# Hexagonal pixels related function
+
+def _get_pixel_pos(module):
+    """Compute the pixel position on hexagonal lattice of DSSC module."""
     # module pixel position
     dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
     g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
 
     # keeping only module 15 pixel X,Y position
-    return g.get_pixel_positions()[15][:, :, :2]
+    return g.get_pixel_positions()[module][:, :, :2]
 
+def _get_pixel_corners(module):
+    """Compute the pixel corners of DSSC module."""
+    # module pixel position
+    dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
+    g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
 
-def _plane_flat_field(p, roi):
-    """Compute the p plane over the given roi.
+    # corners are in z,y,x oder so we rop z, flip x & y
+    corners = g.to_distortion_array(allow_negative_xy=True)
+    corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1]
+    
+    return corners
 
-    Given the plane parameters p, compute the plane over the roi
-    size.
+def _get_pixel_hexagons(module):
+    """Compute DSSC pixel hexagons for plotting.
+    
+    Parameters:
+    -----------
+    module: int, module number
 
-    Parameters
-    ----------
-    p: a vector of a, b, c, d plane parameter with the
-       plane given by ax+ by + cz + d = 0
-    roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
+    Returns:
+    matplotlib PatchCollection of hexagons
+    """
+    hexes = []
+    
+    corners = _get_pixel_corners(module)
+    for y in range(128):
+        for x in range(512):
+            c = 1e3*corners[y, x, :, :] # convert to mm
+            hexes.append(Polygon(c))
+
+    return PatchCollection(hexes)
+
+def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
+    """Add a colobar on a new axes so it match the plot size.
+
+    Inputs
+    ------
+    im: image plotted
+    ax: axes on which the image was plotted
+    loc: string, default 'right', location of the colorbar
+    size: string, default '5%', proportion of the colobar with respect to the
+          plotted image
+    pad: float, default 0.05, pad width between plot and colorbar
+    """
+    from mpl_toolkits.axes_grid1 import make_axes_locatable
+
+    fig = ax.figure
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes(loc, size=size, pad=pad)
+    cbar = fig.colorbar(im, cax=cax)
+
+    return cbar
+
+# dark related functions
+
+def bad_pixel_map(params):
+    """Compute the bad pixels map.
+
+    Inputs
+    ------
+    params: parameters
 
     Returns
     -------
-    the plane field given by p evaluated on the roi
-    extend.
+    bad pixel map
     """
-    a, b, c, d = p
+    assert params.arr_dark is not None, "Data not loaded"
+
+    # compute mean and std
+    dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
+    dark_std = params.arr_dark.std(axis=(0, 1)).compute()
 
-    # DSSC pixel position on hexagonal lattice
-    pixel_pos = _get_pixel_pos()
-    pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+    mask = np.ones_like(dark_mean)
+    if params.mean_th[0] is not None:
+        mask *= dark_mean >= params.mean_th[0]
+    if params.mean_th[1] is not None:
+        mask *= dark_mean <= params.mean_th[1]
+    if params.std_th[0] is not None:
+        mask *= dark_std >= params.std_th[0]
+    if params.std_th[1] is not None:
+        mask *= dark_std >= params.std_th[1]
 
-    Z = -(a*pos[:, :, 0] + b*pos[:, :, 1] + d)/c
+    print(f'# bad pixel: {int(128*512-mask.sum())}')
 
-    return Z
+    return mask.astype(bool)
 
 
-def compute_flat_field_correction(rois, plane, plot=False):
-    """Compute the plane field correction on beam rois.
+def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
+    """Inspect dark run data and plot diagnostic.
 
     Inputs
     ------
-    rois: dictionnary of beam rois['n', '0', 'p']
-    plane: 2 plane vector concatenated
-    plot: boolean, True by default, diagnostic plot
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    mean_th: tuple of threshold (low, high), default (None, None), to compute
+            a mask of good pixels for which the mean dark value lie inside this
+            range
+    std_th: tuple of threshold (low, high), default (None, None), to compute a
+            mask of bad pixels for which the dark std value lie inside this
+            range
 
     Returns
     -------
-    numpy 2D array of the flat field correction evaluated over one DSSC ladder
-    (2 sensors)
+    fig: matplotlib figure
     """
-    flat_field = np.ones((128, 512))
+    # compute mean and std
+    dark_mean = arr.mean(axis=(0, 1)).compute()
+    dark_std = arr.std(axis=(0, 1)).compute()
 
-    r = rois['n']
-    flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
-                       _plane_flat_field(plane[:4], r)
-    r = rois['p']
-    flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
-                       _plane_flat_field(plane[4:], r)
+    fig = plt.figure(figsize=(7, 2.7))
+    gs = fig.add_gridspec(2, 4)
+    ax1 = fig.add_subplot(gs[0, 1:])
+    ax1.set_xticklabels([])
+    ax1.set_yticklabels([])
+    ax11 = fig.add_subplot(gs[0, 0])
 
-    if plot:
-        f, ax = plt.subplots(1, 1, figsize=(6, 2))
-        img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
-        f.colorbar(img, ax=[ax], label='amplitude')
-        ax.set_xlabel('px')
-        ax.set_ylabel('px')
-        ax.set_aspect('equal')
+    ax2 = fig.add_subplot(gs[1, 1:])
+    ax2.set_xticklabels([])
+    ax2.set_yticklabels([])
+    ax22 = fig.add_subplot(gs[1, 0])
 
-    return flat_field
+    vmin = np.percentile(dark_mean.flatten(), 2)
+    vmax = np.percentile(dark_mean.flatten(), 98)
+    im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
+    ax1.invert_yaxis()
+    ax1.set_aspect('equal')
+    cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
+    cbar1.ax.set_ylabel('dark mean')
 
+    ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
+            range=(vmin/2, vmax*2))
+    if mean_th[0] is not None:
+        ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
+    if mean_th[1] is not None:
+        ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
+    ax11.set_yscale('log')
 
-def nl_domain(N, low, high):
-    """Create the input domain where the non-linear correction defined.
+    vmin = np.percentile(dark_std.flatten(), 2)
+    vmax = np.percentile(dark_std.flatten(), 98)
+    im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
+    ax2.invert_yaxis()
+    ax2.set_aspect('equal')
+    cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
+    cbar2.ax.set_ylabel('dark std')
+
+    ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
+    if std_th[0] is not None:
+        ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
+    if std_th[1] is not None:
+        ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
+    ax22.set_yscale('log')
+
+    return fig
+
+
+# histogram related functions
+
+def histogram_module(arr, mask=None):
+    """Compute a histogram of the 9 bits raw pixel values over a module.
 
     Inputs
     ------
-    N: integer, number of control points or intervals
-    low: input values below or equal to low will not be corrected
-    high: input values higher or equal to high will not be corrected
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    mask: optional bad pixel mask
 
     Returns
     -------
-    array of 2**9 integer values with N segments
+    histogram
     """
-    x = np.arange(2**9)
-    vx = x.copy()
-    eps = 1e-5
-    vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
-    vx[x <= low] = 0
-    vx[x >= high] = 0
-
-    return vx
+    if mask is not None:
+        w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
+                arr.shape[1], axis=1), arr.shape[0], axis=0)
+        w = w.rechunk((100, -1, -1, -1))
+        return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
+    else:
+        return da.bincount(arr.ravel(), minlength=512).compute()
 
 
-def nl_lut(domain, dy):
-    """Compute the non-linear correction.
+def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
+    """Compute and plot a histogram of the 9 bits raw pixel values.
 
     Inputs
     ------
-    domain: input domain where dy is defined. For zero no correction is
-        defined. For non-zero value x, dy[x] is applied.
-    dy: a vector of deviation from linearity on control point homogeneously
-        dispersed over 9 bits.
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
+    mask: optional bad pixel mask
+    extra_lines: boolean, default False, plot extra lines at period values
 
     Returns
     -------
-    F_INL: default None, non linear correction function given as a
-           lookup table with 9 bits integer input
+    (h, hd): histogram of arr, arr_dark
+    figure
     """
-    x = np.arange(2**9)
-    ndy = np.insert(dy, 0, 0)  # add zero to dy
+    from matplotlib.ticker import MultipleLocator
 
-    f = x + ndy[domain]
+    f = plt.figure(figsize=(6, 3))
+    ax = plt.gca()
+    h = histogram_module(arr, mask=mask)
+    Sum_h = np.sum(h)
+    ax.plot(np.arange(2**9), h/Sum_h, marker='o',
+            ms=3, markerfacecolor='none', lw=1)
 
-    return f
+    if arr_dark is not None:
+        hd = histogram_module(arr_dark, mask=mask)
+        Sum_hd = np.sum(hd)
+        ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
+                ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
+    else:
+        hd = None
+
+    if extra_lines:
+        for k in range(50, 271):
+            if not (k - 2) % 8:
+                ax.axvline(k, c='k', alpha=0.5, ls='--')
+            if not (k - 3) % 16:
+                ax.axvline(k, c='g', alpha=0.3, ls='--')
+            if not (k - 7) % 32:
+                ax.axvline(k, c='r', alpha=0.3, ls='--')
+
+    ax.axvline(271, c='C1', alpha=0.5, ls='--')
+
+    ax.set_xlim([0, 2**9-1])
+    ax.set_yscale('log')
+    ax.xaxis.set_minor_locator(MultipleLocator(10))
+    ax.set_xlabel('DSSC pixel value')
+    ax.set_ylabel('count frequency')
+
+    return (h, hd), f
 
 
+# rois related function
+
 def find_rois(data_mean, threshold):
     """Find rois from 3 beams configuration.
 
@@ -517,300 +655,92 @@ def inspect_rois(data_mean, rois, threshold=None, allrois=False):
     return fig
 
 
-def histogram_module(arr, mask=None):
-    """Compute a histogram of the 9 bits raw pixel values over a module.
+# Flat field related functions
 
-    Inputs
-    ------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    mask: optional bad pixel mask
+def _plane_flat_field(p, roi, pixel_pos, use_hex=False):
+    """Compute the p plane over the given roi.
+
+    Given the plane parameters p, compute the plane over the roi
+    size.
+
+    Parameters
+    ----------
+    p: a vector of a, b, c, d plane parameter with the
+       plane given by ax+ by + cz + d = 0
+    roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
+    pixel_pos: array of DSSC pixel position on hexagonal lattice
+    use_hex: boolean, use actual DSSC pixel position from pixel_pos
+        or fake cartesian pixel position
 
     Returns
     -------
-    histogram
+    the plane field given by p evaluated on the roi
+    extend.
     """
-    if mask is not None:
-        w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
-                arr.shape[1], axis=1), arr.shape[0], axis=0)
-        w = w.rechunk((100, -1, -1, -1))
-        return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
+    a, b, c, d = p
+
+    if use_hex:
+        # DSSC pixel position on hexagonal lattice
+        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+        X = pos[:, :, 0]
+        Y = pos[:, :, 1]
     else:
-        return da.bincount(arr.ravel(), minlength=512).compute()
+        nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
+        X = np.arange(nX)/100
+        Y = np.arange(nY)[:, np.newaxis]/100
 
+    # center of ROI is put to 0,0
+    X -= np.mean(X)
+    Y -= np.mean(Y)
 
-def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
-    """Compute and plot a histogram of the 9 bits raw pixel values.
+    Z = -(a*X + b*Y + d)/c
 
-    Inputs
-    ------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
-    mask: optional bad pixel mask
-    extra_lines: boolean, default False, plot extra lines at period values
-
-    Returns
-    -------
-    (h, hd): histogram of arr, arr_dark
-    figure
-    """
-    from matplotlib.ticker import MultipleLocator
-
-    f = plt.figure(figsize=(6, 3))
-    ax = plt.gca()
-    h = histogram_module(arr, mask=mask)
-    Sum_h = np.sum(h)
-    ax.plot(np.arange(2**9), h/Sum_h, marker='o',
-            ms=3, markerfacecolor='none', lw=1)
-
-    if arr_dark is not None:
-        hd = histogram_module(arr_dark, mask=mask)
-        Sum_hd = np.sum(hd)
-        ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
-                ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
-    else:
-        hd = None
-
-    if extra_lines:
-        for k in range(50, 271):
-            if not (k - 2) % 8:
-                ax.axvline(k, c='k', alpha=0.5, ls='--')
-            if not (k - 3) % 16:
-                ax.axvline(k, c='g', alpha=0.3, ls='--')
-            if not (k - 7) % 32:
-                ax.axvline(k, c='r', alpha=0.3, ls='--')
-
-    ax.axvline(271, c='C1', alpha=0.5, ls='--')
-
-    ax.set_xlim([0, 2**9-1])
-    ax.set_yscale('log')
-    ax.xaxis.set_minor_locator(MultipleLocator(10))
-    ax.set_xlabel('DSSC pixel value')
-    ax.set_ylabel('count frequency')
-
-    return (h, hd), f
-
-
-def load_dssc_module(proposalNB, runNB, moduleNB=15,
-                     subset=slice(None), drop_intra_darks=True, persist=False):
-    """Load single module dssc data as dask array.
-
-    Inputs
-    ------
-    proposalNB: proposal number
-    runNB: run number
-    moduleNB: default 15, module number
-    subset: default slice(None), subset of trains to load
-    drop_intra_darks: boolean, default True, remove intra darks from the data
-    persist: default False, load all data persistently in memory
-
-    Returns
-    -------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    tid: array of train id number
-    """
-    run = open_run(proposal=proposalNB, run=runNB)
-
-    # DSSC
-    source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
-    key = 'image.data'
-
-    arr = run[source, key][subset].dask_array()
-    # fix 256 value becoming spuriously 0 instead
-    arr[arr == 0] = 256
-
-    ppt = run[source, key][subset].data_counts()
-    # ignore train with no pulses, can happen in burst mode acquisition
-    ppt = ppt[ppt > 0]
-    tid = ppt.index.to_numpy()
-
-    ppt = np.unique(ppt)
-    assert ppt.shape[0] == 1, "number of pulses changed during the run"
-    ppt = ppt[0]
-
-    # reshape in trainId, pulseId, 2d-image
-    arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
-
-    # drop intra darks
-    if drop_intra_darks:
-        arr = arr[:, ::2, :, :]
-
-    # load data in memory
-    if persist:
-        arr = arr.persist()
-
-    return arr, tid
-
-
-def average_module(arr, dark=None, ret='mean',
-                   mask=None, sat_roi=None, sat_level=300, F_INL=None):
-    """Compute the average or std over a module.
-
-    Inputs
-    ------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    dark: default None, dark to be substracted
-    ret: string, either 'mean' to compute the mean or 'std' to compute the
-         standard deviation
-    mask: default None, mask of bad pixels to ignore
-    sat_roi: roi over which to check for pixel with values larger than
-             sat_level to drop the image from the average or std
-    sat_level: int, minimum pixel value for a pixel to be considered saturated
-    F_INL: default None, non linear correction function given as a
-           lookup table with 9 bits integer input
-
-    Returns
-    -------
-    average or standard deviation image
-    """
-    # F_INL
-    if F_INL is not None:
-        narr = arr.map_blocks(lambda x: F_INL[x])
-    else:
-        narr = arr
-
-    if mask is not None:
-        narr = narr*mask
-
-    if sat_roi is not None:
-        temp = (da.logical_not(da.any(
-                narr[:, :, sat_roi['yl']:sat_roi['yh'],
-                           sat_roi['xl']:sat_roi['xh']] >= sat_level,
-                               axis=[2, 3], keepdims=True)))
-        not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)
-
-    if dark is not None:
-        narr = narr - dark
-
-    if ret == 'mean':
-        if sat_roi is not None:
-            return da.average(narr, axis=0, weights=not_sat)
-        else:
-            return narr.mean(axis=0)
-    elif ret == 'std':
-        return narr.std(axis=0)
-    else:
-        raise ValueError(f'ret={ret} not supported')
-
-
-def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
-    """Add a colobar on a new axes so it match the plot size.
-
-    Inputs
-    ------
-    im: image plotted
-    ax: axes on which the image was plotted
-    loc: string, default 'right', location of the colorbar
-    size: string, default '5%', proportion of the colobar with respect to the
-          plotted image
-    pad: float, default 0.05, pad width between plot and colorbar
-    """
-    from mpl_toolkits.axes_grid1 import make_axes_locatable
-
-    fig = ax.figure
-    divider = make_axes_locatable(ax)
-    cax = divider.append_axes(loc, size=size, pad=pad)
-    cbar = fig.colorbar(im, cax=cax)
-
-    return cbar
+    return Z
 
 
-def bad_pixel_map(params):
-    """Compute the bad pixels map.
+def compute_flat_field_correction(rois, params, plot=False):
+    """Compute the plane field correction on beam rois.
 
     Inputs
     ------
+    rois: dictionnary of beam rois['n', '0', 'p']
     params: parameters
+    plot: boolean, True by default, diagnostic plot
 
     Returns
     -------
-    bad pixel map
-    """
-    assert params.arr_dark is not None, "Data not loaded"
-
-    # compute mean and std
-    dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
-    dark_std = params.arr_dark.std(axis=(0, 1)).compute()
-
-    mask = np.ones_like(dark_mean)
-    if params.mean_th[0] is not None:
-        mask *= dark_mean >= params.mean_th[0]
-    if params.mean_th[1] is not None:
-        mask *= dark_mean <= params.mean_th[1]
-    if params.std_th[0] is not None:
-        mask *= dark_std >= params.std_th[0]
-    if params.std_th[1] is not None:
-        mask *= dark_std >= params.std_th[1]
-
-    print(f'# bad pixel: {int(128*512-mask.sum())}')
-
-    return mask.astype(bool)
-
-
-def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
-    """Inspect dark run data and plot diagnostic.
-
-    Inputs
-    ------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    mean_th: tuple of threshold (low, high), default (None, None), to compute
-            a mask of good pixels for which the mean dark value lie inside this
-            range
-    std_th: tuple of threshold (low, high), default (None, None), to compute a
-            mask of bad pixels for which the dark std value lie inside this
-            range
-
-    Returns
-    -------
-    fig: matplotlib figure
+    numpy 2D array of the flat field correction evaluated over one DSSC ladder
+    (2 sensors)
     """
-    # compute mean and std
-    dark_mean = arr.mean(axis=(0, 1)).compute()
-    dark_std = arr.std(axis=(0, 1)).compute()
-
-    fig = plt.figure(figsize=(7, 2.7))
-    gs = fig.add_gridspec(2, 4)
-    ax1 = fig.add_subplot(gs[0, 1:])
-    ax1.set_xticklabels([])
-    ax1.set_yticklabels([])
-    ax11 = fig.add_subplot(gs[0, 0])
-
-    ax2 = fig.add_subplot(gs[1, 1:])
-    ax2.set_xticklabels([])
-    ax2.set_yticklabels([])
-    ax22 = fig.add_subplot(gs[1, 0])
+    flat_field = np.ones((128, 512))
 
-    vmin = np.percentile(dark_mean.flatten(), 2)
-    vmax = np.percentile(dark_mean.flatten(), 98)
-    im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
-    ax1.invert_yaxis()
-    ax1.set_aspect('equal')
-    cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
-    cbar1.ax.set_ylabel('dark mean')
+    plane = params.get_flat_field()
+    use_hex = params.use_hex
+    pixel_pos = params.pixel_pos
+    force_mirror = params.force_mirror
 
-    ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
-            range=(vmin/2, vmax*2))
-    if mean_th[0] is not None:
-        ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
-    if mean_th[1] is not None:
-        ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
-    ax11.set_yscale('log')
-
-    vmin = np.percentile(dark_std.flatten(), 2)
-    vmax = np.percentile(dark_std.flatten(), 98)
-    im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
-    ax2.invert_yaxis()
-    ax2.set_aspect('equal')
-    cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
-    cbar2.ax.set_ylabel('dark std')
+    r = rois['n']
+    flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
+                       _plane_flat_field(plane[:4], r, pixel_pos, use_hex)
+    
+    r = rois['p']
+    if force_mirror:
+        a, b, c, d = plane[:4]
+        flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
+                       _plane_flat_field([-a, b, c, d], r, pixel_pos, use_hex)
+    else:
+        flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
+                       _plane_flat_field(plane[4:], r, pixel_pos, use_hex)
 
-    ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
-    if std_th[0] is not None:
-        ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
-    if std_th[1] is not None:
-        ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
-    ax22.set_yscale('log')
+    if plot:
+        f, ax = plt.subplots(1, 1, figsize=(6, 2))
+        img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
+        f.colorbar(img, ax=[ax], label='amplitude')
+        ax.set_xlabel('px')
+        ax.set_ylabel('px')
+        ax.set_aspect('equal')
 
-    return fig
+    return flat_field
 
 
 def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
@@ -888,6 +818,7 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None
 
     return fig, domain
 
+
 def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
@@ -1039,24 +970,51 @@ def plane_fitting(params):
         num_n = a_n**2 + b_n**2 + c_n**2
 
         roi = params.rois['n']
-        pixel_pos = _get_pixel_pos()
-        # DSSC pixel position on hexagonal lattice
-        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
-        d0_2 = np.sum(n_m*(a_n*pos[:, :, 0] + b_n*pos[:, :, 1]
-            + c_n*n + d_n)**2)/num_n
+        if params.use_hex:
+            # DSSC pixel position on hexagonal lattice
+            pos = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+            X = pos[:, :, 0]
+            Y = pos[:, :, 1]
+        else:
+            nY, nX = n.shape
+            X = np.arange(nX)/100
+            Y = np.arange(nY)[:, np.newaxis]/100
+
+        # center of ROI is put to 0,0
+        X -= np.mean(X)
+        Y -= np.mean(Y)
+
+        d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n
 
         num_p = a_p**2 + b_p**2 + c_p**2
 
         roi = params.rois['p']
-        # DSSC pixel position on hexagonal lattice
-        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
-        d2_2 = np.sum(p_m*(a_p*pos[:, :, 0] + b_p*pos[:, :, 1]
-            + c_p*p + d_p)**2)/num_p
+        if params.use_hex:
+            # DSSC pixel position on hexagonal lattice
+            pos = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+            X = pos[:, :, 0]
+            Y = pos[:, :, 1]
+        else:
+            nY, nX = p.shape
+            X = np.arange(nX)/100
+            Y = np.arange(nY)[:, np.newaxis]/100
+
+        # center of ROI is put to 0,0
+        X -= np.mean(X)
+        Y -= np.mean(Y)
+
+        if params.force_mirror:
+            d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n
+        else:
+            d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p
 
         return 1e3*(d2_2 + d0_2)
 
     if params.plane_guess_fit is None:
-        p_guess_fit = [-20, 0.0, 1.5, -0.5, -20, 0, 1.5, -0.5 ]
+        if params.use_hex:
+            p_guess_fit = [-20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ]
+        else:
+            p_guess_fit = [-0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54]
     else:
         p_guess_fit = params.plane_guess_fit
 
@@ -1065,117 +1023,142 @@ def plane_fitting(params):
     return res
 
 
-def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
-                   flat_field=None, F_INL=None):
-    """Process one module and extract roi intensity.
+def ff_refine_crit(p, params, arr_dark, arr, tid, rois, mask, sat_level=511):
+    """Criteria for the ff_refine_fit.
 
     Inputs
     ------
-    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
-    tid: array of train id number
-    dark: pulse resolved dark image to remove
-    rois: dictionnary of rois
-    mask: default None, mask of ignored pixels
+    p: ff plane
+    params: parameters
+    arr_dark: dark data
+    arr: data
+    tid: train id of arr data
+    rois: ['n', '0', 'p', 'sat'] rois
+    mask: mask fo good pixels
     sat_level: integer, default 511, at which level pixel begin to saturate
-    flat_field: default None, flat field correction
-    F_INL: default None, non linear correction function given as a
-           lookup table with 9 bits integer input
 
     Returns
     -------
-    dataset of extracted pulse and train resolved roi intensities.
+    sum of standard deviation on binned 0th order intensity
     """
-    # F_INL
-    if F_INL is not None:
-        narr = arr.map_blocks(lambda x: F_INL[x])
-    else:
-        narr = arr
+    params.set_flat_field(p)
+    ff = compute_flat_field_correction(rois, params)
+    
+    data = process(np.arange(2**9), arr_dark, arr, tid, rois, mask, ff,
+        sat_level)
 
-    # apply mask
-    if mask is not None:
-        narr = narr*mask
+    # drop saturated shots
+    d = data.where(data['sat_sat'] == False, drop=True)
+    
+    rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0')
+    rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0')
+    rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0')
 
-    # crop rois
-    r = {}
-    rd = {}
-    for n in rois.keys():
-        r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
-                    rois[n]['xl']:rois[n]['xh']]
-        rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
-                     rois[n]['xl']:rois[n]['xh']]
+    err = np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA']) + np.nansum(rd['sigmaA'])
 
-    # find saturated shots
-    r_sat = {}
-    for n in rois.keys():
-        r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
+    return 1e3*err
 
-    # TODO: flat field should not be applied on intra darks
-    # # change flat field dimension to match data
-    # if flat_field is not None:
-    #     temp = np.ones_like(dark)
-    #     temp[::2, :, :] = flat_field[:, :]
-    #    flat_field = temp
 
-    # compute dark corrected ROI values
-    v = {}
-    for n in rois.keys():
+def ff_refine_fit(params):
+    """Refine the flat field fit by minimizing data spread.
 
-        r[n] = r[n] - rd[n]
+    Inputs
+    ------
+    params: parameters
 
-        if flat_field is not None:
-            # TODO:  flat field should not be applied on intra darks
-            # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
-            #                 rois[n]['xl']:rois[n]['xh']]
-            ff = flat_field[rois[n]['yl']:rois[n]['yh'],
-                            rois[n]['xl']:rois[n]['xh']]
-            r[n] = r[n]/ff
+    Returns
+    -------
+    res: scipy minimize result. res.x is the optimized parameters
 
-        v[n] = r[n].sum(axis=(2, 3))
+    firres: iteration index arrays of criteria results for
+        [criteria]
+    """
+    # load data
+    assert params.arr is not None, "Data not loaded"
+    assert params.arr_dark is not None, "Data not loaded"
 
-    res = xr.Dataset()
+    # we only need few rois
+    fitrois = {}
+    for k in ['n', '0', 'p', 'sat']:
+        fitrois[k] = params.rois[k]
 
-    dims = ['trainId', 'pulseId']
-    r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
-    for n in rois.keys():
-        res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
-                                       coords=r_coords, dims=dims)
-        res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)
+    p0 = params.get_flat_field()
 
-    for n in rois.keys():
-        roi = rois[n]
-        res[n + '_area'] = xr.DataArray(np.array([
-            (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
+    fixed_p = (params, params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), params.sat_level)
 
-    return res
+    def fit_callback(x):
+        if not hasattr(fit_callback, "counter"):
+            fit_callback.counter = 0  # it doesn't exist yet, so initialize it
+            fit_callback.start = time.monotonic()
+            fit_callback.res = []
 
+        now = time.monotonic()
+        time_delta = datetime.timedelta(seconds=now-fit_callback.start)
+        fit_callback.counter += 1
 
-def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
-    """Process dark and run data with corrections.
+        temp = list(fixed_p)
+        Jalpha = ff_refine_crit(x, *temp)
+        fit_callback.res.append([Jalpha])
+        print(f'{fit_callback.counter-1}: {time_delta} '
+                f'({Jalpha}), {x}')
+
+        return False
+
+    fit_callback(p0)
+    res = minimize(ff_refine_crit, p0, fixed_p,
+        options={'disp': True, 'maxiter': params.max_iter},
+        callback=fit_callback)
+
+    return res, fit_callback.res
+
+
+# non-linearity related functions
+
+def nl_domain(N, low, high):
+    """Create the input domain where the non-linear correction defined.
 
     Inputs
     ------
-    Fmodel: correction lookup table
-    arr_dark: dark data
-    arr: data
-    rois: ['n', '0', 'p', 'sat'] rois
-    mask: mask of good pixels
-    flat_field: zone plate flat field correction
-    sat_level: integer, default 511, at which level pixel begin to saturate
+    N: integer, number of control points or intervals
+    low: input values below or equal to low will not be corrected
+    high: input values higher or equal to high will not be corrected
+
+    Returns
+    -------
+    array of 2**9 integer values with N segments
+    """
+    x = np.arange(2**9)
+    vx = x.copy()
+    eps = 1e-5
+    vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
+    vx[x <= low] = 0
+    vx[x >= high] = 0
+
+    return vx
+
+
+def nl_lut(domain, dy):
+    """Compute the non-linear correction.
+
+    Inputs
+    ------
+    domain: input domain where dy is defined. For zero no correction is
+        defined. For non-zero value x, dy[x] is applied.
+    dy: a vector of deviation from linearity on control point homogeneously
+        dispersed over 9 bits.
 
     Returns
     -------
-    roi extracted intensities
+    F_INL: default None, non linear correction function given as a
+           lookup table with 9 bits integer input
     """
-    # dark process
-    res = average_module(arr_dark, F_INL=Fmodel)
-    dark = res.compute()
+    x = np.arange(2**9)
+    ndy = np.insert(dy, 0, 0)  # add zero to dy
 
-    # data process
-    proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
-                          flat_field=flat_field, F_INL=Fmodel)
-    data = proc.compute()
+    f = x + ndy[domain]
 
-    return data
+    return f
 
 
 def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
@@ -1377,6 +1360,29 @@ def snr(sig, ref, methods=None, verbose=False):
 
     return res
 
+def inspect_Fnl(Fnl):
+    """Plot the correction function Fnl.
+
+    Inputs
+    ------
+    Fnl: non linear correction function lookup table
+
+    Returns
+    -------
+    matplotlib figure
+    """
+    x = np.arange(2**9)
+    f = plt.figure(figsize=(6, 4))
+
+    plt.plot(x, Fnl - x)
+    # plt.axvline(40, c='k', ls='--')
+    # plt.axvline(280, c='k', ls='--')
+    plt.xlabel('input value')
+    plt.ylabel('output correction F(x)-x')
+    plt.xlim([0, 511])
+
+    return f
+
 
 def inspect_correction(params, gain=None):
     """Criteria for the non linear correction.
@@ -1403,7 +1409,7 @@ def inspect_correction(params, gain=None):
     plane_ff = params.get_flat_field()
     if plane_ff is None:
         plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
-    ff = compute_flat_field_correction(params.rois, plane_ff)
+    ff = compute_flat_field_correction(params.rois, params)
 
     # non linearities
     Fnl = params.get_Fnl()
@@ -1497,28 +1503,221 @@ def inspect_correction(params, gain=None):
     return f
 
 
-def inspect_Fnl(Fnl):
-    """Plot the correction function Fnl.
+# data processing related functions
+
+def load_dssc_module(proposalNB, runNB, moduleNB=15,
+                     subset=slice(None), drop_intra_darks=True, persist=False):
+    """Load single module dssc data as dask array.
 
     Inputs
     ------
-    Fnl: non linear correction function lookup table
+    proposalNB: proposal number
+    runNB: run number
+    moduleNB: default 15, module number
+    subset: default slice(None), subset of trains to load
+    drop_intra_darks: boolean, default True, remove intra darks from the data
+    persist: default False, load all data persistently in memory
 
     Returns
     -------
-    matplotlib figure
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    tid: array of train id number
     """
-    x = np.arange(2**9)
-    f = plt.figure(figsize=(6, 4))
+    run = open_run(proposal=proposalNB, run=runNB)
 
-    plt.plot(x, Fnl - x)
-    # plt.axvline(40, c='k', ls='--')
-    # plt.axvline(280, c='k', ls='--')
-    plt.xlabel('input value')
-    plt.ylabel('output correction F(x)-x')
-    plt.xlim([0, 511])
+    # DSSC
+    source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
+    key = 'image.data'
 
-    return f
+    arr = run[source, key][subset].dask_array()
+    # fix 256 value becoming spuriously 0 instead
+    arr[arr == 0] = 256
+
+    ppt = run[source, key][subset].data_counts()
+    # ignore train with no pulses, can happen in burst mode acquisition
+    ppt = ppt[ppt > 0]
+    tid = ppt.index.to_numpy()
+
+    ppt = np.unique(ppt)
+    assert ppt.shape[0] == 1, "number of pulses changed during the run"
+    ppt = ppt[0]
+
+    # reshape in trainId, pulseId, 2d-image
+    arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
+
+    # drop intra darks
+    if drop_intra_darks:
+        arr = arr[:, ::2, :, :]
+
+    # load data in memory
+    if persist:
+        arr = arr.persist()
+
+    return arr, tid
+
+
+def average_module(arr, dark=None, ret='mean',
+                   mask=None, sat_roi=None, sat_level=300, F_INL=None):
+    """Compute the average or std over a module.
+
+    Inputs
+    ------
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    dark: default None, dark to be substracted
+    ret: string, either 'mean' to compute the mean or 'std' to compute the
+         standard deviation
+    mask: default None, mask of bad pixels to ignore
+    sat_roi: roi over which to check for pixel with values larger than
+             sat_level to drop the image from the average or std
+    sat_level: int, minimum pixel value for a pixel to be considered saturated
+    F_INL: default None, non linear correction function given as a
+           lookup table with 9 bits integer input
+
+    Returns
+    -------
+    average or standard deviation image
+    """
+    # F_INL
+    if F_INL is not None:
+        narr = arr.map_blocks(lambda x: F_INL[x])
+    else:
+        narr = arr
+
+    if mask is not None:
+        narr = narr*mask
+
+    if sat_roi is not None:
+        temp = (da.logical_not(da.any(
+                narr[:, :, sat_roi['yl']:sat_roi['yh'],
+                           sat_roi['xl']:sat_roi['xh']] >= sat_level,
+                               axis=[2, 3], keepdims=True)))
+        not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)
+
+    if dark is not None:
+        narr = narr - dark
+
+    if ret == 'mean':
+        if sat_roi is not None:
+            return da.average(narr, axis=0, weights=not_sat)
+        else:
+            return narr.mean(axis=0)
+    elif ret == 'std':
+        return narr.std(axis=0)
+    else:
+        raise ValueError(f'ret={ret} not supported')
+
+
+def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
+                   flat_field=None, F_INL=None):
+    """Process one module and extract roi intensity.
+
+    Inputs
+    ------
+    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
+    tid: array of train id number
+    dark: pulse resolved dark image to remove
+    rois: dictionnary of rois
+    mask: default None, mask of ignored pixels
+    sat_level: integer, default 511, at which level pixel begin to saturate
+    flat_field: default None, flat field correction
+    F_INL: default None, non linear correction function given as a
+           lookup table with 9 bits integer input
+
+    Returns
+    -------
+    dataset of extracted pulse and train resolved roi intensities.
+    """
+    # F_INL
+    if F_INL is not None:
+        narr = arr.map_blocks(lambda x: F_INL[x])
+    else:
+        narr = arr
+
+    # apply mask
+    if mask is not None:
+        narr = narr*mask
+
+    # crop rois
+    r = {}
+    rd = {}
+    for n in rois.keys():
+        r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
+                    rois[n]['xl']:rois[n]['xh']]
+        rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
+                     rois[n]['xl']:rois[n]['xh']]
+
+    # find saturated shots
+    r_sat = {}
+    for n in rois.keys():
+        r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
+
+    # TODO: flat field should not be applied on intra darks
+    # # change flat field dimension to match data
+    # if flat_field is not None:
+    #     temp = np.ones_like(dark)
+    #     temp[::2, :, :] = flat_field[:, :]
+    #    flat_field = temp
+
+    # compute dark corrected ROI values
+    v = {}
+    for n in rois.keys():
+
+        r[n] = r[n] - rd[n]
+
+        if flat_field is not None:
+            # TODO:  flat field should not be applied on intra darks
+            # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
+            #                 rois[n]['xl']:rois[n]['xh']]
+            ff = flat_field[rois[n]['yl']:rois[n]['yh'],
+                            rois[n]['xl']:rois[n]['xh']]
+            r[n] = r[n]/ff
+
+        v[n] = r[n].sum(axis=(2, 3))
+
+    res = xr.Dataset()
+
+    dims = ['trainId', 'pulseId']
+    r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
+    for n in rois.keys():
+        res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
+                                       coords=r_coords, dims=dims)
+        res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)
+
+    for n in rois.keys():
+        roi = rois[n]
+        res[n + '_area'] = xr.DataArray(np.array([
+            (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
+
+    return res
+
+
+def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
+    """Process dark and run data with corrections.
+
+    Inputs
+    ------
+    Fmodel: correction lookup table
+    arr_dark: dark data
+    arr: data
+    rois: ['n', '0', 'p', 'sat'] rois
+    mask: mask of good pixels
+    flat_field: zone plate flat field correction
+    sat_level: integer, default 511, at which level pixel begin to saturate
+
+    Returns
+    -------
+    roi extracted intensities
+    """
+    # dark process
+    res = average_module(arr_dark, F_INL=Fmodel)
+    dark = res.compute()
+
+    # data process
+    proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
+                          flat_field=flat_field, F_INL=Fmodel)
+    data = proc.compute()
+
+    return data
 
 def inspect_saturation(data, gain, Nbins=200):
     """Plot roi integrated histogram of the data with saturation