diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 3cd9472f7b6ce7b5c09b28c047367f4c14d9d4a8..f118fb80e1b5cd49545d3f4b8429becc6aceea4f 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -342,11 +342,11 @@ def find_rois(data_mean, threshold):
     pY = data_mean.sum(axis=1)
 
     # along X
-    lowX = int(np.argmax(pX[:128] > threshold))  # 1st occurrence returned
-    highX = int(np.argmax(pX[128:] < threshold) + 128)  # 1st occ. returned
-    midX = (lowX + highX)//2
-    leftX = int(np.argmin(pX[(lowX+20):midX]) + lowX + 20)
-    rightX = int(np.argmin(pX[midX:highX-20]) + midX)
+    lowX = int(np.argmax(pX[:64] > threshold))  # 1st occurrence returned
+    highX = int(np.argmax(pX[192:] <= threshold) + 192)  # 1st occ. returned
+
+    leftX = int(np.argmin(pX[64:128]) + 64)
+    rightX = int(np.argmin(pX[128:192]) + 128)
 
     # along Y
     lowY = int(np.argmax(pY[:64] > threshold))  # 1st occurrence returned
@@ -795,26 +795,24 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
     fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6))
 
     img_rois = {}
+    centers = {}
+
+    for k, r in enumerate(['n', '0', 'p']):
+        roi = rois[r]
+        centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
+                      (roi['xl'] + roi['xh'])//2])
 
+    d = '0'
+    roi = rois[d]
     for k, r in enumerate(['n', '0', 'p']):
-        img_rois[r] = avg[rois[r]['yl']:rois[r]['yh'],
-                          rois[r]['xl']:rois[r]['xh']]
+        img_rois[r] = np.roll(avg, tuple(centers[r] - centers[d]))[
+        roi['yl']:roi['yh'], roi['xl']:roi['xh']]
         im = axs[0, k].imshow(img_rois[r],
                               vmin=vmin,
                               vmax=vmax)
 
     for k, r in enumerate(['n', '0', 'p']):
-        if img_rois[r].shape[1] != img_rois['0'].shape[1]:
-            if k == 0:
-                n1 = img_rois[r].shape[1]
-                n = img_rois['0'].shape[1]
-                v = img_rois[r][:, (n1-n):]/img_rois['0']
-            else:
-                n1 = img_rois[r].shape[1]
-                n = img_rois['0'].shape[1]
-                v = img_rois[r][:, :-(n1-n)]/img_rois['0']
-        else:
-            v = img_rois[r]/img_rois['0']
+        v = img_rois[r]/img_rois['0']
         im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')
 
     cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
@@ -1073,16 +1071,17 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
     # drop saturated shots
     d = data.where(data['sat_sat'] == False, drop=True)
 
-    # calculated error from transmission of 1.0
-    v = d['n'].values.flatten()/d['0'].values.flatten()
-    err1 = 1e8*np.nanmean((v - 1.0)**2)
+    v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(),
+            methods=['weighted'])
+    err_1 = 1e8*v_1['weighted']['s']**2
 
-    err2 = np.sum((Fmodel-np.arange(2**9))**2)
+    v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(),
+            methods=['weighted'])
+    err_2 = 1e8*v_2['weighted']['s']**2
 
-    # print(f'{err}: {p}')
-    # logging.info(f'{err}: {p}')
+    err_a = np.sum((Fmodel-np.arange(2**9))**2)
 
-    return (1.0 - alpha)*err1 + alpha*err2
+    return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
 
 
 def nl_fit(params, domain):
@@ -1175,44 +1174,68 @@ def inspect_nl_fit(res_fit):
     return f
 
 
-def snr(sig, ref, verbose=False):
-    """ Compute mean, std and SNR with and without weight from transmitted signal sig
-        and I0 signal ref
+def snr(sig, ref, methods=None, verbose=False):
+    """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref.
+
+    Inputs
+    ------
+    sig: 1D signal samples
+    ref: 1D reference samples
+    methods: None by default or list of strings to select which methods to use.
+        Possible values are 'direct', 'weighted', 'diff'. In case of None, all
+        methods will be calculated.
+    verbose: booleand, if True prints calculated values
+
+    Returns
+    -------
+    dictionnary of [methods][value] where value is 'mu' for mean and 's' for
+    standard deviation.
+
     """
+    if methods is None:
+        methods = ['direct', 'weighted', 'diff']
+
     w = ref
     x = sig/ref
 
     mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref)
-    
+
     w = w[mask]
     sig = sig[mask]
     ref = ref[mask]
     x = x[mask]
 
-    # direct mean and std
-    mu = np.mean(x)
-    s = np.std(x)
-    if verbose:
-        print(f'mu: {mu}, s: {s}, snr: {mu/s}')
     res = {}
-    res['direct'] = {'mu': mu, 's':s}
+
+    # direct mean and std
+    if 'direct' in methods:
+        mu = np.mean(x)
+        s = np.std(x)
+        if verbose:
+            print(f'mu: {mu}, s: {s}, snr: {mu/s}')
+
+        res['direct'] = {'mu': mu, 's':s}
 
     # weighted mean and std
-    wmu = np.sum(sig)/np.sum(ref)
-    v1 = np.sum(w)
-    v2 = np.sum(w**2)
-    ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1))
+    if 'weighted' in methods:
+        wmu = np.sum(sig)/np.sum(ref)
+        v1 = np.sum(w)
+        v2 = np.sum(w**2)
+        ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1))
+
+        if verbose:
+            print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}')
 
-    if verbose:
-        print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}')
-    res['weighted'] = {'mu': wmu, 's':ws}
+        res['weighted'] = {'mu': wmu, 's':ws}
 
     # noise from diff
-    dmu = np.mean(x)
-    ds = np.std(np.diff(x))/np.sqrt(2)
-    if verbose:
-        print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}')
-    res['diff'] = {'mu': dmu, 's':ds}
+    if 'diff' in methods:
+        dmu = np.mean(x)
+        ds = np.std(np.diff(x))/np.sqrt(2)
+        if verbose:
+            print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}')
+
+        res['diff'] = {'mu': dmu, 's':ds}
 
     return res
 
@@ -1256,7 +1279,7 @@ def inspect_correction(params, gain=None):
 
     scale = 1e-6
 
-    f, axs = plt.subplots(3, 3, figsize=(6, 6), sharex=True, sharey=True)
+    f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True)
 
     # nbins = np.linspace(0.01, 1.0, 100)
 
@@ -1278,16 +1301,11 @@ def inspect_correction(params, gain=None):
             snr_v = snr(good_d[n].values.flatten(),
                         good_d[r].values.flatten(), verbose=True)
 
-            if k == 0:
-                m = np.nanmean(good_d[n].values.flatten()
-                    /good_d[r].values.flatten())
-            else:
-                m = 1
-
+            m = snr_v['direct']['mu']
             h, xedges, yedges, img = axs[l, k].hist2d(
                 g*scale*good_d[r].values.flatten(),
-                good_d[n].values.flatten()/good_d[r].values.flatten()/m,
-                [photon_scale, np.linspace(0.95, 1.05, 150)],
+                good_d[n].values.flatten()/good_d[r].values.flatten(),
+                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Blues',
                 vmax=200,
                 norm=LogNorm(),
@@ -1295,13 +1313,14 @@ def inspect_correction(params, gain=None):
                 )
             h, xedges, yedges, img2 = axs[l, k].hist2d(
                 g*scale*sat_d[r].values.flatten(),
-                sat_d[n].values.flatten()/sat_d[r].values.flatten()/m,
-                [photon_scale, np.linspace(0.95, 1.05, 150)],
+                sat_d[n].values.flatten()/sat_d[r].values.flatten(),
+                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Reds',
                 vmax=200,
                 norm=LogNorm(),
                 # alpha=0.5 # make  the plot looks ugly with lots of white lines
                 )
+
             v = snr_v['direct']['mu']/snr_v['direct']['s']
             axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}',
                             transform = axs[l, k].transAxes)
@@ -1312,9 +1331,11 @@ def inspect_correction(params, gain=None):
             # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
             # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
 
+            axs[l, k].set_ylim([0.95*m, 1.05*m])
+
     for k in range(3):
-        for l in range(3):
-            axs[l, k].set_ylim([0.95, 1.05])
+        #for l in range(3):
+        #    axs[l, k].set_ylim([0.95, 1.05])
         if gain:
             axs[2, k].set_xlabel('#ph (10$^6$)')
         else: