From 0157f0afb2b6174031ba17761f9c5a735b7b1b6c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Thu, 16 May 2024 13:03:51 +0200
Subject: [PATCH] Improved formatting

---
 src/toolbox_scs/routines/boz.py | 46 ++++++++++++++++++++-------------
 1 file changed, 28 insertions(+), 18 deletions(-)

diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 613b2dc..348d993 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -1,7 +1,7 @@
 """
 Beam splitting Off-axis Zone plate analysis routines.
 
-Copyright (2021) SCS Team.
+Copyright (2021, 2022, 2023, 2024) SCS Team.
 """
 
 import time
@@ -126,12 +126,14 @@ class parameters():
         self.arr = None
         self.tid = None
 
-    def dask_load_persistently(self, dark_data_size_Gb=None, data_size_Gb=None):
+    def dask_load_persistently(self, dark_data_size_Gb=None,
+                               data_size_Gb=None):
         """Load dask data array in memory.
 
         Inputs
         ------
-        dark_data_size_Gb: float, optional size of dark to load in memory, in Gb
+        dark_data_size_Gb: float, optional size of dark to load in memory,
+                           in Gb
         data_size_Gb: float, optional size of data to load in memory, in Gb
         """
         self.arr_dark, self.tid_dark = load_dssc_module(self.proposal,
@@ -361,7 +363,9 @@ class parameters():
         f += f'rois: {self.rois}\n'
 
         f += f'flat-field type: {self.ff_type}\n'
-        f += f'flat-field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
+        f += f'flat-field p: {self.flat_field} '
+        f += f'prod:{self.flat_field_prod_th} '
+        f += f'ratio:{self.flat_field_ratio_th}\n'
         f += f'plane guess fit: {self.plane_guess_fit}\n'
         f += f'use hexagons: {self.use_hex}\n'
         f += f'enforce mirror symmetry: {self.force_mirror}\n'
@@ -683,7 +687,8 @@ def find_rois(data_mean, threshold, extended=False):
 
     # along X
     lowX = int(np.argmax(pX > threshold) - 1)  # 1st occurrence returned
-    highX = int(pX.shape[0] - np.argmax(pX[::-1] > threshold))  # last occ. returned
+    highX = int(pX.shape[0] -
+                np.argmax(pX[::-1] > threshold))  # last occ. returned
 
     midX = int(0.5*(lowX+highX))
 
@@ -695,7 +700,8 @@ def find_rois(data_mean, threshold, extended=False):
 
     # along Y
     lowY = int(np.argmax(pY > threshold) - 1)  # 1st occurrence returned
-    highY = int(pY.shape[0] - np.argmax(pY[::-1] > threshold))  # last occ. returned
+    highY = int(pY.shape[0]
+                - np.argmax(pY[::-1] > threshold))  # last occ. returned
 
     # define rois
     rois = {}
@@ -972,7 +978,8 @@ def compute_polyline_flat_field_correction(rois, params, plot=False):
 
     return flat_field
 
-def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
+def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None,
+                              vmax=None):
     """Extract beams roi from average image and compute the ratio.
 
     Inputs
@@ -1029,7 +1036,8 @@ def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None
         if ratio_vmin is None:
             ratio_vmin = np.percentile(v, 5)
             ratio_vmax = np.percentile(v, 99.8)
-        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
+        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax,
+                               cmap='RdBu_r')
         axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
 
     cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
@@ -1545,7 +1553,7 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
 
 def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
     sat_level=511, use_gpu=False):
-    """Criteria for the non linear correction, combining 'n' and 'p' as reference
+    """Non linear correction criteria, combining 'n' and 'p' as reference.
 
     Inputs
     ------
@@ -1619,8 +1627,9 @@ def nl_fit(params, domain, ff=None, crit=None):
     if crit is None:
         crit = nl_crit
 
-    fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+    fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr,
+               params.tid, fitrois, params.get_mask(), ff, params.sat_level,
+               params._using_gpu)
 
     def fit_callback(x):
         if not hasattr(fit_callback, "counter"):
@@ -1679,7 +1688,7 @@ def inspect_nl_fit(res_fit):
 
 
 def snr(sig, ref, methods=None, verbose=False):
-    """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref.
+    """ Compute mean, std and SNR from transmitted and I0 signals.
 
     Inputs
     ------
@@ -1844,7 +1853,7 @@ def inspect_correction(params, gain=None):
                 [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Blues',
                 norm=LogNorm(vmin=0.2, vmax=200),
-                # alpha=0.5 # make  the plot looks ugly with lots of white lines
+                # alpha=0.5 # make the plot looks ugly with lots of white lines
                 )
             h, xedges, yedges, img2 = axs[l, k].hist2d(
                 g*scale*sat_d[r].values.flatten(),
@@ -1852,7 +1861,7 @@ def inspect_correction(params, gain=None):
                 [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                 cmap='Reds',
                 norm=LogNorm(vmin=0.2, vmax=200),
-                # alpha=0.5 # make  the plot looks ugly with lots of white lines
+                # alpha=0.5 # make the plot looks ugly with lots of white lines
                 )
 
             v = snr_v['direct']['mu']/snr_v['direct']['s']
@@ -1862,8 +1871,8 @@ def inspect_correction(params, gain=None):
             axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
                             transform = axs[l, k].transAxes)
 
-            # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
-            # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+            #axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
+            #axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
 
             axs[l, k].set_ylim([0.95*m, 1.05*m])
 
@@ -2313,8 +2322,9 @@ def inspect_saturation(data, gain, Nbins=200):
         # compute density normalization on all data
         norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat']))
     
-        ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm,
-                        facecolor=f"C{kk}", edgecolor='none', alpha=0.2)
+        ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm,
+                        h[k+'_nosat']/norm, facecolor=f"C{kk}",
+                        edgecolor='none', alpha=0.2)
 
         ax.plot(bins_c, h[k+'_nosat']/norm, label=k,
                 c=f'C{kk}', alpha=0.4)
-- 
GitLab