From ca5154e5f693adcee9d2d5df7f4203dee2052122 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Thu, 12 May 2022 13:21:21 +0200
Subject: [PATCH] Add low and high threshold on product and ratio of ROIs to
 calculate flat field

---
 src/toolbox_scs/routines/boz.py | 120 ++++++++++++++++++++++++++++----
 1 file changed, 108 insertions(+), 12 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 0b11c5b..0f6b7d5 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -14,6 +14,8 @@ import dask.array as da
 from scipy.optimize import minimize
 
 import matplotlib.pyplot as plt
+from matplotlib.colors import LogNorm
+from matplotlib import cm
 
 from extra_data import open_run
 
@@ -57,6 +59,8 @@ class parameters():
         self.rois = None
         self.rois_th = None
         self.flat_field = None
+        self.flat_field_prod_th = (5.0, np.PINF)
+        self.flat_field_ratio_th = (np.NINF, 1.2)
         self.Fnl = None
         self.alpha = None
         self.sat_level = None
@@ -105,12 +109,17 @@ class parameters():
         """Get the list of bad pixel indices."""
         return self.mask_idx
 
-    def set_flat_field(self, plane):
+    def set_flat_field(self, plane,
+            prod_th=None, ratio_th=None):
         """Set the flat field plane definition."""
         if type(plane) is not list:
             self.flat_field = plane.tolist()
         else:
             self.flat_field = plane
+        if prod_th is not None:
+            self.flat_field_prod_th = prod_th
+        if ratio_th is not None:
+            self.flat_field_ratio_th = ratio_th
 
     def get_flat_field(self):
         """Get the flat field plane definition."""
@@ -155,6 +164,8 @@ class parameters():
         v['rois_th'] = self.rois_th
 
         v['flat_field'] = self.flat_field
+        v['flat_field_prod_th'] = self.flat_field_prod_th
+        v['flat_field_ratio_th'] = self.flat_field_ratio_th
 
         v['Fnl'] = self.Fnl
         v['alpha'] = self.alpha
@@ -186,7 +197,7 @@ class parameters():
         c.rois = v['rois']
         c.rois_th = v['rois_th']
 
-        c.set_flat_field(v['flat_field'])
+        c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th'])
 
         c.set_Fnl(v['Fnl'])
         c.alpha = v['alpha']
@@ -208,7 +219,7 @@ class parameters():
         f += f'rois threshold: {self.rois_th}\n'
         f += f'rois: {self.rois}\n'
 
-        f += f'flat field: {self.flat_field}\n'
+        f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
 
         if self.Fnl is not None:
             f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
@@ -793,15 +804,89 @@ def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
     return fig
 
 
-def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
+def inspect_flat_field_domain(avg, params, vmin=None, vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
     Inputs
     ------
     avg: module average image with no saturated shots for the flat field
          determination
-    rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
-          as the reference beam in the middle
+    params: dictionnary, boz parameters
+    vmin: imshow vmin level, default None will use 5 percentile value
+    vmax: imshow vmax level, default None will use 99.8 percentile value
+
+    Returns
+    -------
+    fig: matplotlib figure plotted
+    domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order
+    """
+    if vmin is None:
+        vmin = np.percentile(avg, 5)
+    if vmax is None:
+        vmax = np.percentile(avg, 99.8)
+
+    fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9))
+
+    img_rois = {}
+    centers = {}
+    rois = params.rois
+
+    for k, r in enumerate(['n', '0', 'p']):
+        roi = rois[r]
+        centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
+                      (roi['xl'] + roi['xh'])//2])
+
+    d = '0'
+    roi = rois[d]
+    for k, r in enumerate(['n', '0', 'p']):
+        img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
+        roi['yl']:roi['yh'], roi['xl']:roi['xh']]
+        im = axs[0, k].imshow(img_rois[r],
+                              vmin=vmin,
+                              vmax=vmax)
+
+    n, n_m, p, p_m = plane_fitting_domain(avg, rois,
+        params.flat_field_prod_th, params.flat_field_ratio_th)
+
+    prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
+    for k, r in enumerate(['n', '0', 'p']):
+        v = img_rois[r]*img_rois['0']
+        if prod_vmin is None:
+            prod_vmin = np.percentile(v, .5)
+            prod_vmax = np.percentile(v, 20) # we look for low intensity region
+        im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
+        axs[1,k].contour(v, params.flat_field_prod_th, cmap=cm.get_cmap(cm.cool, 2))
+
+        v = img_rois[r]/img_rois['0']
+        if ratio_vmin is None:
+            ratio_vmin = np.percentile(v, 5)
+            ratio_vmax = np.percentile(v, 99.8)
+        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
+        axs[2,k].contour(v, params.flat_field_ratio_th, cmap=cm.get_cmap(cm.cool, 2))
+
+    cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
+    cbar.ax.set_xlabel('data mean')
+
+    cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
+    cbar.ax.set_xlabel('product')
+
+    cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal")
+    cbar.ax.set_xlabel('ratio')
+
+    # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
+
+    domain = (n_m, p_m)
+
+    return fig, domain
+
+def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None):
+    """Extract beams roi from average image and compute the ratio.
+
+    Inputs
+    ------
+    avg: module average image with no saturated shots for the flat field
+         determination
+    rois: dictionnary of rois
     vmin: imshow vmin level, default None will use 5 percentile value
     vmax: imshow vmax level, default None will use 99.8 percentile value
 
@@ -837,6 +922,10 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
         v = img_rois[r]/img_rois['0']
         im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')
 
+    n_m, p_m = domain
+    axs[1, 0].contour(n_m)
+    axs[1, 2].contour(p_m)
+
     cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
     cbar.ax.set_xlabel('data mean')
 
@@ -848,7 +937,7 @@ def inspect_plane_fitting(avg, rois, vmin=None, vmax=None):
     return fig
 
 
-def plane_fitting_domain(avg, rois):
+def plane_fitting_domain(avg, rois, prod_th, ratio_th):
     """Extract beams roi, compute their ratio and the domain.
 
     Inputs
@@ -857,6 +946,10 @@ def plane_fitting_domain(avg, rois):
          determination
     rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
           as the reference beam in the middle
+    prod_th: float tuple, low and hight threshold level to determine the plane
+        fitting domain on the product image of the orders
+    ratio_th: float tuple, low and high threshold level to determine the plane
+        fitting domain on the ratio image of the orders
 
     Returns
     -------
@@ -879,7 +972,9 @@ def plane_fitting_domain(avg, rois):
     denom = np.roll(avg, tuple(centers[k] - centers[d]))[
         rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
     n = num/denom
-    n_m = ((num*denom) > 5) * (num/denom < 1.2)
+    prod = num*denom
+    n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
+            (n > ratio_th[0]) * (n < ratio_th[1]))
     n_m[~np.isfinite(n)] = 0
     n[~np.isfinite(n)] = 0
 
@@ -889,7 +984,9 @@ def plane_fitting_domain(avg, rois):
     denom = np.roll(avg, tuple(centers[k] - centers[d]))[
         rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
     p = num/denom
-    p_m = ((num*denom) > 5) * (num/denom < 1.2)
+    prod = num*denom
+    p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
+            (p > ratio_th[0]) * (p < ratio_th[1]))
     p_m[~np.isfinite(p)] = 0
     p[~np.isfinite(p)] = 0
 
@@ -916,7 +1013,8 @@ def plane_fitting(params):
         sat_level=params.sat_level).compute()
     data_mean = data.mean(axis=0)  # mean over pulseId
 
-    n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois)
+    n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois,
+        params.flat_field_prod_th, params.flat_field_ratio_th)
 
     def _crit(x):
         """Fitting criteria for the plane field normalization.
@@ -1314,8 +1412,6 @@ def inspect_correction(params, gain=None):
 
     # nbins = np.linspace(0.01, 1.0, 100)
 
-    from matplotlib.colors import LogNorm
-
     photon_scale = None
 
     for k, d in enumerate([data, data_ff, data_ff_nl]):
-- 
GitLab