From 59abb97a70d3d5ecfc64000bffa2ec40b0b9d639 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Thu, 22 Sep 2022 09:52:04 +0200
Subject: [PATCH] Cumulative updates from beamtime 2776 (van Kuiken)

---
 src/toolbox_scs/detectors/hrixs.py | 212 ++++++++++++++++-------------
 1 file changed, 121 insertions(+), 91 deletions(-)

diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py
index 59d51e2..6282d85 100644
--- a/src/toolbox_scs/detectors/hrixs.py
+++ b/src/toolbox_scs/detectors/hrixs.py
@@ -1,6 +1,8 @@
 from functools import lru_cache
 
 import numpy as np
+import matplotlib.pyplot as plt
+from scipy.optimize import leastsq
 from scipy.optimize import curve_fit
 from scipy.signal import fftconvolve
 
@@ -56,6 +58,26 @@ def find_curvature(image, frangex=None, frangey=None,
     return curv[:-1][::-1]
 
 
+def find_curvature(img, args, plot=False, **kwargs):
+    def parabola(x, a, b, c, s=0, h=0, o=0):
+        return (a*x + b)*x + c
+    def gauss(y, x, a, b, c, s, h, o=0):
+        return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
+    x = np.arange(img.shape[1])[None, :]
+    y = np.arange(img.shape[0])[:, None]
+
+    if plot:
+        plt.figure(figsize=(10,10))
+        plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs)
+        plt.plot(x[0, :], parabola(x[0, :], *args))
+
+    args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args)
+
+    if plot:
+        plt.plot(x[0, :], parabola(x[0, :], *args))
+    return args
+
+
 def correct_curvature(image, factor=None, axis=1):
     if factor is None:
         return
@@ -175,7 +197,7 @@ def _esrf_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
     return res
 
 
-def _new_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
+def _new_centroid(image, threshold=THRESHOLD, std_threshold=3.5, curvature=(CURVE_A, CURVE_B)):
     """find the position of photons with sub-pixel precision
 
     A photon is supposed to have hit the detector if the intensity within a
@@ -186,7 +208,8 @@ def _new_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
     """
     base = image.mean()
     corners = image[1:, 1:] + image[:-1, 1:] + image[1:, :-1] + image[:-1, :-1]
-    threshold = corners.mean() + 3.5 * corners.std()
+    if threshold is None:
+        threshold = corners.mean() + std_threshold * corners.std()
     middle = corners[1:-1, 1:-1]
     candidates = (
             (middle > threshold)
@@ -257,11 +280,12 @@ class hRIXS:
     PROPOSAL = 2769
 
     # image range
-    X_RANGE = np.s_[1300:-100]
+    X_RANGE = np.s_[:]
     Y_RANGE = np.s_[:]
 
     # centroid
-    THRESHOLD = THRESHOLD  # pixel counts above which a hit candidate is assumed
+    THRESHOLD = None  # pixel counts above which a hit candidate is assumed
+    STD_THRESHOLD = 3.5  # same as THRESHOLD, in standard deviations
     CURVE_A = CURVE_A  # curvature parameters as determined elsewhere
     CURVE_B = CURVE_B
 
@@ -271,101 +295,107 @@ class hRIXS:
     BINS = abs(np.subtract(*RANGE)) * FACTOR
 
     METHOD = 'centroid'  # ['centroid', 'integral']
+    USE_DARK = False
 
-    @classmethod
-    def set_params(cls, **params):
-        for key, value in params.items():
-            setattr(cls, key.upper(), value)
+    ENERGY_INTERCEPT = 0
+    ENERGY_SLOPE = 1
 
-    def __set_params(self, **params):
-        self.__class__.set_params(**params)
-        self.refresh()
+    FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm']
+
+    def set_params(self, **params):
+        for key, value in params.items():
+            setattr(self, key.upper(), value)
 
-    @classmethod
-    def get_params(cls, *params):
+    def get_params(self, *params):
         if not params:
             params = ('proposal', 'x_range', 'y_range',
                       'threshold', 'curve_a', 'curve_b',
                       'factor', 'range', 'bins',
-                      'method')
-        return {param: getattr(cls, param.upper()) for param in params}
+                      'method', 'fields')
+        return {param: getattr(self, param.upper()) for param in params}
 
-    def refresh(self):
-        cls = self.__class__
-        for cached in ['_centroid', '_integral']:
-            getattr(cls, cached).fget.cache_clear()
-
-    def __init__(self, images, norm=None):
-        self.images = images
-        self.norm = norm
-
-        # class/instance method compatibility
-        self.set_params = self.__set_params
-
-    @classmethod
-    def from_run(cls, runNB, proposal=None, first_wrong=False):
+    def from_run(self, runNB, proposal=None, extra_fields=()):
         if proposal is None:
-            proposal = cls.PROPOSAL
-
-        run, data = tb.load(proposal, runNB=runNB, fields=['hRIXS_det'])
-
-        # Get slow train data
-        mnemo = tb.mnemonics_for_run(run)['SCS_slowTrain']
-        slow_train = run[mnemo['source'], mnemo['key']].ndarray().sum()
-
-        return cls(images=data['hRIXS_det'][1 if first_wrong else 0:].data,
-                   norm=slow_train)
-
-    @property
-    @lru_cache()
-    def _centroid(self):
-        return sum((centroid(image[self.Y_RANGE, self.X_RANGE].T,
-                             threshold=self.THRESHOLD,
-                             curvature=(self.CURVE_A, self.CURVE_B), )
-                    for image in self.images), [])
-
-    def _centroid_spectrum(self, bins=None, range=None, normalize=True):
+            proposal = self.PROPOSAL
+        run, data = tb.load(proposal, runNB=runNB,
+                            fields=self.FIELDS + list(extra_fields))
+
+        return data
+
+    def load_dark(self, runNB, proposal=None):
+        data = self.from_run(runNB, proposal)
+        self.dark_image = data['hRIXS_det'].mean(dim='trainId')
+        self.USE_DARK = True
+
+    def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs):
+        data = self.from_run(runNB, proposal)
+
+        image = data['hRIXS_det'].sum(dim='trainId') \
+                .to_numpy()[self.X_RANGE, self.Y_RANGE].T
+        if args is None:
+            spec = (image - image[:10, :].mean()).mean(axis=1)
+            mean = np.average(np.arange(len(spec)), weights=spec)
+            args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3,
+                    spec.max(), image.mean())
+        args = find_curvature(image, args, plot=plot, **kwargs)
+        self.CURVE_B, self.CURVE_A, *_ = args
+        return self.CURVE_A, self.CURVE_B
+
+    def centroid(self, data, bins=None):
         if bins is None:
             bins = self.BINS
-        if range is None:
-            range = self.RANGE
-
-        r = np.array(self._centroid)
-        hy, hx = np.histogram(r[:, 0], bins=bins, range=range)
-        if normalize and self.norm is not None:
-            hy = hy / self.norm
-
-        return (hx[:-1] + hx[1:]) / 2, hy
-
-    @property
-    @lru_cache()
-    def _integral(self):
-        return sum((integrate(image[self.Y_RANGE, self.X_RANGE].T,
-                              factor=self.FACTOR,
-                              range=self.RANGE,
-                              curvature=(self.CURVE_A, self.CURVE_B))
-                    for image in self.images))
-
-    def _integral_spectrum(self, normalize=True):
-        values = self._integral
-        if normalize and self.norm is not None:
-            values = values / self.norm
-        return np.arange(values.size), values
-
-    @property
-    def corrected(self):
-        return decentroid(self._centroid)
-
-    def spectrum(self, normalize=True):
-        spec_func = (self._centroid_spectrum if self.METHOD.lower() == 'centroid'
-                     else self._integral_spectrum)
-        return spec_func(normalize=normalize)
-
-    def __sub__(self, other):
-        px, py = self.spectrum()
-        mx, my = other.spectrum()
-        return (px + mx) / 2, py - my
-
-    def __add__(self, other):
-        return self.__class__(images=list(self.images) + list(other.images),
-                              norm=self.norm + other.norm)
+        ret = np.zeros((len(data["hRIXS_det"]), bins))
+        for image, r in zip(data["hRIXS_det"], ret):
+            c = centroid(
+                image.to_numpy()[self.X_RANGE, self.Y_RANGE].T,
+                threshold=self.THRESHOLD,
+                std_threshold=self.STD_THRESHOLD,
+                curvature=(self.CURVE_A, self.CURVE_B))
+            if not len(c):
+                continue
+            rc = np.array(c)
+            hy, hx = np.histogram(
+                rc[:, 0], bins=bins,
+                range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
+            r[:] = hy
+
+        data = data.assign_coords(
+            energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins)
+            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
+        return data.assign(spectrum=(("trainId", "energy"), ret))
+
+    def integrate(self, data):
+        bins = self.Y_RANGE.stop - self.Y_RANGE.start
+        ret = np.zeros((len(data["hRIXS_det"]), bins - 20))
+        for image, r in zip(data["hRIXS_det"], ret):
+            if self.USE_DARK:
+                image = image - self.dark_image
+            r[:] = integrate(image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, factor=1,
+                             range=(10, bins - 10),
+                             curvature=(self.CURVE_A, self.CURVE_B))
+        data = data.assign_coords(
+            energy=np.arange(self.Y_RANGE.start + 10, self.Y_RANGE.stop - 10)
+            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
+        return data.assign(spectrum=(("trainId", "energy"), ret))
+
+    aggregators = dict(
+        hRIXS_det=lambda x, dim: x.sum(dim=dim),
+        Delay=lambda x, dim: x.mean(dim=dim),
+        hRIXS_delay=lambda x, dim: x.mean(dim=dim),
+        hRIXS_norm=lambda x, dim: x.sum(dim=dim),
+        spectrum=lambda x, dim: x.sum(dim=dim),
+    )
+
+    def aggregator(self, da, dim):
+        agg = self.aggregators.get(da.name)
+        if agg is None:
+            return None
+        return agg(da, dim=dim)
+
+    def aggregate(self, ds, dim="trainId"):
+        ret = ds.map(self.aggregator, dim=dim)
+        ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
+        return ret
+
+    def normalize(self, data):
+        return data.assign(normalized=data["spectrum"] / data["hRIXS_norm"])
-- 
GitLab