From c0259d3243830ccbe992f817b6c64107b0c34f45 Mon Sep 17 00:00:00 2001
From: Johannes Niskanen <niskanen@max-exfl099.desy.de>
Date: Tue, 31 May 2022 10:28:22 +0200
Subject: [PATCH] Modified and commented dark_load Centroid function modified
 to always output spectra

---
 src/toolbox_scs/detectors/hrixs.py | 55 ++++++++++++++++++++++++------
 1 file changed, 44 insertions(+), 11 deletions(-)

diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py
index 28d3d88..d8f94ad 100644
--- a/src/toolbox_scs/detectors/hrixs.py
+++ b/src/toolbox_scs/detectors/hrixs.py
@@ -328,30 +328,56 @@ class hRIXS:
 
         return data
 
+    
     def load_dark(self, runNB, proposal=None, use_dark=True, mask=True,
                   mask_threshold=None):
+        #*************************************************************
+        # Loads a dark image and assigns it to the hRIXS instance
+        # In addition sets attributes whether or not 
+        # - hot pixels are identified and masked out
+        # - the dark image is to be used in background subtraction
+        # In addition a threshold value for hot pixel mask generation
+        # can be given.
+        #*************************************************************
         if mask_threshold == None:
             mask_threshold = self.DARK_MASK_THRESHOLD
-        try:
-            # Checks to see if runNB can be iterated over (is list-like)
+        #*************************************************************
+        # If given a list of runs, iterate over them. 
+        # Otherwise read one. Give an exception if neither is the case.
+        #*************************************************************
+        if type(runNB) == type([]):
             data_list = []
             for run in runNB:
                 data_list.append(self.from_run(run, proposal))
-            data = xr.concat(data_list, dim='trainId') 
-        except TypeError:
-            # If runNB cannot be iterated over, we assume it's a single run
+            data = xr.concat(data_list, dim='trainId')
+        elif type(runNB) == type(1):
             data = self.from_run(runNB, proposal)
+        else:
+            raise Exception('load_dark() expects a list of indeces or an integer.')      
+        #*************************************************************
+        # Store the dark image (mean over aqs.) in two formats
+        #*************************************************************            
         self.dark_image = data['hRIXS_det'].mean(dim='trainId')
         self.dark_im_array = self.dark_image.to_numpy()
+        #*************************************************************
+        # Set a flag whether the dark image is to be used later
+        #*************************************************************  
         if use_dark:
             self.USE_DARK = True
+        #*************************************************************
+        # If hot/dead pixel masking is requested, find the mask and
+        # set a flag in the instance. Set the masked dark values to 
+        # mean intensity.
+        #*************************************************************  
         if mask:
             dark_avg = np.mean(self.dark_im_array[self.MASK_AVG_Y,
                                                               self.MASK_AVG_X], (0, 1))
-            self.dark_mask = np.abs(self.dark_im_array - dark_avg) > mask_threshold
+            self.dark_mask = self.dark_im_array > dark_avg + mask_threshold
             self.dark_im_array_m = np.array(self.dark_im_array)
             self.dark_im_array_m[self.dark_mask] = dark_avg
             self.USE_DARK_MASK = True
+        return
+    
 
     def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs):
         data = self.from_run(runNB, proposal)
@@ -367,6 +393,7 @@ class hRIXS:
         return self.CURVE_A, self.CURVE_B
 
     def centroid(self, data, bins=None, return_hits=False):
+        print('jee')
         if bins is None:
             bins = self.BINS
         hit_x = []
@@ -406,12 +433,18 @@ class hRIXS:
         data = data.assign_coords(
             energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins)
             * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
+        #**********************************************
+        # If hits were requested, assign them to data
+        #**********************************************
         if return_hits:
-            return data.assign(hits=(("trainId"), hits),
-                               xhits=(("trainId"), hit_x),
-                               yhits=(("trainId"), hit_y))
-        else:
-            return data.assign(spectrum=(("trainId", "energy"), ret))
+            data = data.assign(hits=(("trainId"), hits),
+                        xhits=(("trainId"), hit_x),
+                        yhits=(("trainId"), hit_y))
+        #**********************************************
+        # Always assign the spectrum to data
+        #**********************************************
+        data = data.assign(spectrum=(("trainId", "energy"), ret))
+        return data
 
     def integrate(self, data):
         bins = self.Y_RANGE.stop - self.Y_RANGE.start
-- 
GitLab