diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py index 8c7c12d643c1f418d3664ae005ff37a5b30c60fa..88aaeeca58c5cfaac81d07cdb85d37047f618774 100644 --- a/src/toolbox_scs/detectors/hrixs.py +++ b/src/toolbox_scs/detectors/hrixs.py @@ -333,11 +333,11 @@ class hRIXS: # If runNB cannot be iterated over, we assume it's a single run data = self.from_run(runNB, proposal) self.dark_image = data['hRIXS_det'].mean(dim='trainId') + self.dark_im_array = self.dark_image.to_numpy() self.USE_DARK = True def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs): data = self.from_run(runNB, proposal) - image = data['hRIXS_det'].sum(dim='trainId') \ .to_numpy()[self.X_RANGE, self.Y_RANGE].T if args is None: @@ -356,8 +356,12 @@ class hRIXS: hit_y = [] ret = np.zeros((len(data["hRIXS_det"]), bins)) for image, r in zip(data["hRIXS_det"], ret): + if self.USE_DARK: + use_image = image.to_numpy() - self.dark_im_array + else: + use_image = image.to_numpy() c = centroid( - image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, + use_image[self.X_RANGE, self.Y_RANGE].T, threshold=self.THRESHOLD, std_threshold=self.STD_THRESHOLD, curvature=(self.CURVE_A, self.CURVE_B))