From 834e53f96f0e5286740511032d36df3a2b710419 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Thu, 19 May 2022 17:21:53 +0200
Subject: [PATCH] Add regularization term for flat field refinement

---
 src/toolbox_scs/routines/boz.py | 56 +++++++++++++++++++++------------
 1 file changed, 36 insertions(+), 20 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 603c387..ca0dd87 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -70,10 +70,13 @@ class parameters():
         self.plane_guess_fit = None
         self.use_hex = False
         self.force_mirror = True
+        self.ff_alpha = None
+        self.ff_max_iter = None
+
         self.Fnl = None
-        self.alpha = None
+        self.nl_alpha = None
         self.sat_level = None
-        self.max_iter = None
+        self.nl_max_iter = None
 
         # temporary data
         self.arr_dark = None
@@ -178,11 +181,13 @@ class parameters():
         v['plane_guess_fit'] = self.plane_guess_fit
         v['use_hex'] = self.use_hex
         v['force_mirror'] = self.force_mirror
+        v['ff_alpha'] = self.ff_alpha
+        v['ff_max_iter'] = self.ff_max_iter
 
         v['Fnl'] = self.Fnl
-        v['alpha'] = self.alpha
+        v['nl_alpha'] = self.nl_alpha
         v['sat_level'] = self.sat_level
-        v['max_iter'] = self.max_iter
+        v['nl_max_iter'] = self.nl_max_iter
 
         fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json'
 
@@ -213,11 +218,13 @@ class parameters():
         c.plane_guess_fit = v['plane_guess_fit']
         c.use_hex = v['use_hex']
         c.force_mirror = v['force_mirror']
+        c.ff_alpha = v['ff_alpha']
+        c.ff_max_iter = v['ff_max_iter']
 
         c.set_Fnl(v['Fnl'])
-        c.alpha = v['alpha']
+        c.nl_alpha = v['nl_alpha']
         c.sat_level = v['sat_level']
-        c.max_iter = v['max_iter']
+        c.nl_max_iter = v['nl_max_iter']
 
         return c
 
@@ -238,11 +245,12 @@ class parameters():
         f += f'plane guess fit: {self.plane_guess_fit}\n'
         f += f'use hexagons: {self.use_hex}\n'
         f += f'enforce mirror symmetry: {self.force_mirror}\n'
+        f += f'ff alpha: {self.ff_alpha}, max. iter.: {self.ff_max_iter}\n'
 
         if self.Fnl is not None:
             f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
-            f += f'alpha:{self.alpha}, sat. level:{self.sat_level}, '
-            f += f' max. iter.:{self.max_iter}'
+            f += f'nl alpha:{self.nl_alpha}, sat. level:{self.sat_level}, '
+            f += f' nl max. iter.:{self.nl_max_iter}'
         else:
             f += 'Fnl: None'
 
@@ -1022,7 +1030,8 @@ def plane_fitting(params):
     return res
 
 
-def ff_refine_crit(p, params, arr_dark, arr, tid, rois, mask, sat_level=511):
+def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
+    mask, sat_level=511):
     """Criteria for the ff_refine_fit.
 
     Inputs
@@ -1054,8 +1063,11 @@ def ff_refine_crit(p, params, arr_dark, arr, tid, rois, mask, sat_level=511):
     rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0')
 
     err = np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA']) + np.nansum(rd['sigmaA'])
+    mean = ((1.0 - np.nanmean(rn['muA']))**2 +
+            (1.0 - np.nanmean(rp['muA']))**2 +
+            (1.0 - np.nanmean(rd['muA']))**2)
 
-    return 1e3*err
+    return 1e3*(err*alpha + (1-alpha)*mean)
 
 
 def ff_refine_fit(params):
@@ -1069,8 +1081,8 @@ def ff_refine_fit(params):
     -------
     res: scipy minimize result. res.x is the optimized parameters
 
-    firres: iteration index arrays of criteria results for
-        [criteria]
+    fitrres: iteration index arrays of criteria results for
+        [alpha=0, alpha, alpha=1]
     """
     # load data
     assert params.arr is not None, "Data not loaded"
@@ -1083,8 +1095,8 @@ def ff_refine_fit(params):
 
     p0 = params.get_flat_field()
 
-    fixed_p = (params, params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), params.sat_level)
+    fixed_p = (params.ff_alpha, params, params.arr_dark, params.arr,
+        params.tid, fitrois, params.get_mask(), params.sat_level)
 
     def fit_callback(x):
         if not hasattr(fit_callback, "counter"):
@@ -1098,15 +1110,19 @@ def ff_refine_fit(params):
 
         temp = list(fixed_p)
         Jalpha = ff_refine_crit(x, *temp)
-        fit_callback.res.append([Jalpha])
+        temp[0] = 0
+        J0 = ff_refine_crit(x, *temp)
+        temp[0] = 1
+        J1 = ff_refine_crit(x, *temp)
+        fit_callback.res.append([J0, Jalpha, J1])
         print(f'{fit_callback.counter-1}: {time_delta} '
-                f'({Jalpha}), {x}')
+                f'({J0}, {Jalpha}, {J1}), {x}')
 
         return False
 
     fit_callback(p0)
     res = minimize(ff_refine_crit, p0, fixed_p,
-        options={'disp': True, 'maxiter': params.max_iter},
+        options={'disp': True, 'maxiter': params.ff_max_iter},
         callback=fit_callback)
 
     return res, fit_callback.res
@@ -1216,7 +1232,7 @@ def nl_fit(params, domain):
     -------
     res: scipy minimize result. res.x is the optimized parameters
 
-    firres: iteration index arrays of criteria results for
+    fitrres: iteration index arrays of criteria results for
         [alpha=0, alpha, alpha=1]
     """
     # load data
@@ -1235,7 +1251,7 @@ def nl_fit(params, domain):
     # flat flat_field
     ff = compute_flat_field_correction(params.rois, params.get_flat_field())
 
-    fixed_p = (domain, params.alpha, params.arr_dark, params.arr, params.tid,
+    fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
         fitrois, params.get_mask(), ff, params.sat_level)
 
     def fit_callback(x):
@@ -1262,7 +1278,7 @@ def nl_fit(params, domain):
 
     fit_callback(p0)
     res = minimize(nl_crit, p0, fixed_p,
-        options={'disp': True, 'maxiter': params.max_iter},
+        options={'disp': True, 'maxiter': params.nl_max_iter},
         callback=fit_callback)
 
     return res, fit_callback.res
-- 
GitLab