From 4c7161422c931bfc6cfdf0697049c5a944f94c4b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Wed, 28 Jul 2021 13:48:16 +0200
Subject: [PATCH] Inspect nl correction with unscaled data

---
 src/toolbox_scs/routines/boz.py | 24 +++++++++++-------------
 1 file changed, 11 insertions(+), 13 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index f3ebbe2..916d29a 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -1254,7 +1254,7 @@ def inspect_correction(params, gain=None):
 
     scale = 1e-6
 
-    f, axs = plt.subplots(3, 3, figsize=(6, 6), sharex=True, sharey=True)
+    f, axs = plt.subplots(3, 3, figsize=(6, 6), sharex=True)
 
     # nbins = np.linspace(0.01, 1.0, 100)
 
@@ -1276,16 +1276,11 @@ def inspect_correction(params, gain=None):
             snr_v = snr(good_d[n].values.flatten(),
                         good_d[r].values.flatten(), verbose=True)
 
-            if k == 0:
-                m = np.nanmean(good_d[n].values.flatten()
-                    /good_d[r].values.flatten())
-            else:
-                m = 1
-
+            m = snr_v['direct']['mu']
             h, xedges, yedges, img = axs[l, k].hist2d(
                 g*scale*good_d[r].values.flatten(),
-                good_d[n].values.flatten()/good_d[r].values.flatten()/m,
-                [photon_scale, np.linspace(0.95, 1.05, 150)],
+                good_d[n].values.flatten()/good_d[r].values.flatten(),
+                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Blues',
                 vmax=200,
                 norm=LogNorm(),
@@ -1293,13 +1288,14 @@ def inspect_correction(params, gain=None):
                 )
             h, xedges, yedges, img2 = axs[l, k].hist2d(
                 g*scale*sat_d[r].values.flatten(),
-                sat_d[n].values.flatten()/sat_d[r].values.flatten()/m,
-                [photon_scale, np.linspace(0.95, 1.05, 150)],
+                sat_d[n].values.flatten()/sat_d[r].values.flatten(),
+                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Reds',
                 vmax=200,
                 norm=LogNorm(),
                 # alpha=0.5 # make  the plot looks ugly with lots of white lines
                 )
+
             v = snr_v['direct']['mu']/snr_v['direct']['s']
             axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}',
                             transform = axs[l, k].transAxes)
@@ -1310,9 +1306,11 @@ def inspect_correction(params, gain=None):
             # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
             # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
 
+            axs[l, k].set_ylim([0.95*m, 1.05*m])
+
     for k in range(3):
-        for l in range(3):
-            axs[l, k].set_ylim([0.95, 1.05])
+        #for l in range(3):
+        #    axs[l, k].set_ylim([0.95, 1.05])
         if gain:
             axs[2, k].set_xlabel('#ph (10$^6$)')
         else:
-- 
GitLab