From 65c39c2dd6fe3d8328dd25a7e73199f744e7fc28 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Thu, 12 May 2022 17:24:44 +0200
Subject: [PATCH] Adds hexgonal DSSC pixel lattice geometry for flat field
 fitting

---
 src/toolbox_scs/routines/boz.py | 69 +++++++++++++++++++--------------
 1 file changed, 39 insertions(+), 30 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index bee93c7..b07ac7f 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -18,6 +18,7 @@ from matplotlib.colors import LogNorm
 from matplotlib import cm
 
 from extra_data import open_run
+from extra_geom import DSSC_1MGeometry
 
 __all__ = [
     'load_dssc_module',
@@ -235,6 +236,15 @@ class parameters():
 
         return f
 
+def _get_pixel_pos():
+    """Compute the pixel position on hexagonal lattice of DSSC module 15"""
+    # module pixel position
+    dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
+    g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
+
+    # keeping only module 15 pixel X,Y position
+    return g.get_pixel_positions()[15][:, :, :2]
+
 
 def _plane_flat_field(p, roi):
     """Compute the p plane over the given roi.
@@ -252,19 +262,14 @@ def _plane_flat_field(p, roi):
     -------
     the plane field given by p evaluated on the roi
     extend.
-
-    TODO
-    ----
-    the hexagonal lattice is currently ignored.
     """
     a, b, c, d = p
 
-    nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
+    # DSSC pixel position on hexagonal lattice
+    pixel_pos = _get_pixel_pos()
+    pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
 
-    X = np.arange(nX)/100
-    Y = np.arange(nY)[:, np.newaxis]/100
-
-    Z = -(a*X + b*Y + d)/c
+    Z = -(a*pos[:, :, 0] + b*pos[:, :, 1] + d)/c
 
     return Z
 
@@ -287,10 +292,10 @@ def compute_flat_field_correction(rois, p, plot=False):
 
     r = 'n'
     flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
-                       _plane_flat_field(p, rois[r])
+                       _plane_flat_field(p[:4], rois[r])
     r = 'p'
     flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
-                       np.fliplr(_plane_flat_field(p, rois[r]))
+                       _plane_flat_field(p[4:], rois[r])
 
     if plot:
         f, ax = plt.subplots(1, 1, figsize=(6, 2))
@@ -809,14 +814,16 @@ def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
     return fig
 
 
-def inspect_flat_field_domain(avg, params, vmin=None, vmax=None):
+def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
     Inputs
     ------
     avg: module average image with no saturated shots for the flat field
          determination
-    params: dictionnary, boz parameters
+    rois: dictionnary or ROIs
+    prod_th, ratio_th: tuple of floats for low and high threshold on
+        product and ratio
     vmin: imshow vmin level, default None will use 5 percentile value
     vmax: imshow vmax level, default None will use 99.8 percentile value
 
@@ -834,7 +841,6 @@ def inspect_flat_field_domain(avg, params, vmin=None, vmax=None):
 
     img_rois = {}
     centers = {}
-    rois = params.rois
 
     for k, r in enumerate(['n', '0', 'p']):
         roi = rois[r]
@@ -850,8 +856,7 @@ def inspect_flat_field_domain(avg, params, vmin=None, vmax=None):
                               vmin=vmin,
                               vmax=vmax)
 
-    n, n_m, p, p_m = plane_fitting_domain(avg, rois,
-        params.flat_field_prod_th, params.flat_field_ratio_th)
+    n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th)
 
     prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
     for k, r in enumerate(['n', '0', 'p']):
@@ -860,14 +865,14 @@ def inspect_flat_field_domain(avg, params, vmin=None, vmax=None):
             prod_vmin = np.percentile(v, .5)
             prod_vmax = np.percentile(v, 20) # we look for low intensity region
         im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
-        axs[1,k].contour(v, params.flat_field_prod_th, cmap=cm.get_cmap(cm.cool, 2))
+        axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2))
 
         v = img_rois[r]/img_rois['0']
         if ratio_vmin is None:
             ratio_vmin = np.percentile(v, 5)
             ratio_vmax = np.percentile(v, 99.8)
         im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
-        axs[2,k].contour(v, params.flat_field_ratio_th, cmap=cm.get_cmap(cm.cool, 2))
+        axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
 
     cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
     cbar.ax.set_xlabel('data mean')
@@ -1026,24 +1031,28 @@ def plane_fitting(params):
 
         Inputs
         ------
-        x: vector [a, b, c, d] defining the plane as
+        x: 2 vector [a, b, c, d] concatenated defining the plane as
                 a*x + b*y + c*z + d = 0
         """
-        a, b, c, d = x
 
-        num = a**2 + b**2 + c**2
+        a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x
+
+        num_n = a_n**2 + b_n**2 + c_n**2
 
-        nY, nX = n.shape
-        X = np.arange(nX)/100
-        Y = np.arange(nY)[:, np.newaxis]/100
-        d0_2 = np.sum(n_m*(a*X + b*Y + c*n + d)**2)/num
+        roi = params.rois['n']
+        pixel_pos = _get_pixel_pos()
+        # DSSC pixel position on hexagonal lattice
+        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+        d0_2 = np.sum(n_m*(a_n*pos[:, :, 0] + b_n*pos[:, :, 1]
+            + c_n*n + d_n)**2)/num_n
 
-        nY, nX = p.shape
-        X = np.arange(nX)/100
-        Y = np.arange(nY)[:, np.newaxis]/100
-        d2_2 = np.sum(np.fliplr(p_m)*(a*X + b*Y + c*np.fliplr(p) + d)**2)/num
+        num_p = a_p**2 + b_p**2 + c_p**2
 
-        return d2_2 + d0_2
+        roi = params.rois['p']
+        # DSSC pixel position on hexagonal lattice
+        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
+        d2_2 = np.sum(p_m*(a_p*pos[:, :, 0] + b_p*pos[:, :, 1]
+            + c_p*p + d_p)**2)/num_p
 
         return 1e3*(d2_2 + d0_2)
 
-- 
GitLab