From 60c5e70c209735f683fcb8efec5897637fc0f6b0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu>
Date: Fri, 17 Nov 2023 15:43:10 +0100
Subject: [PATCH] cleanup and minimize code

---
 src/toolbox_scs/detectors/jf_hrixs.py | 320 +++-----------------------
 1 file changed, 29 insertions(+), 291 deletions(-)

diff --git a/src/toolbox_scs/detectors/jf_hrixs.py b/src/toolbox_scs/detectors/jf_hrixs.py
index 62d701d..d4f5099 100644
--- a/src/toolbox_scs/detectors/jf_hrixs.py
+++ b/src/toolbox_scs/detectors/jf_hrixs.py
@@ -1,11 +1,8 @@
-from functools import lru_cache
 import xarray as xr
 
 import numpy as np
 import matplotlib.pyplot as plt
 from scipy.optimize import leastsq
-from scipy.optimize import curve_fit
-from scipy.signal import fftconvolve
 
 import toolbox_scs as tb
 
@@ -14,111 +11,13 @@ __all__ = [
     'JF_hRIXS',
 ]
 
-
-# -----------------------------------------------------------------------------
-# Curvature
-
-def correct_curvature(image, factor=None, axis=1):
-    if factor is None:
-        return
-
-    if axis == 1:
-        image = image.T
-
-    ydim, xdim = image.shape
-    x = np.arange(xdim + 1)
-    y = np.arange(ydim + 1)
-    xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5)
-    xxn = xx - factor[0] * yy - factor[1] * yy ** 2
-    ret = np.histogramdd((xxn.flatten(), yy.flatten()),
-                         bins=[x, y],
-                         weights=image.flatten())[0]
-
-    return ret if axis == 1 else ret.T
-
-
-def get_spectrum(image, factor=None, axis=0,
-                 pixel_range=None, energy_range=None, ):
-    start, stop = (0, image.shape[axis - 1])
-    if pixel_range is not None:
-        start = max(pixel_range[0] or start, start)
-        stop = min(pixel_range[1] or stop, stop)
-
-    edge = image.sum(axis=axis)[start:stop]
-    bins = np.arange(start, stop + 1)
-    centers = (bins[1:] + bins[:-1]) * 0.5
-    if factor is not None:
-        centers, edge = calibrate(centers, edge,
-                                  factor=factor,
-                                  range_=energy_range)
-
-    return centers, edge
-
-
-# -----------------------------------------------------------------------------
-# Energy calibration
-
-
-def energy_calibration(channels, energies):
-    return np.polyfit(channels, energies, deg=1)
-
-
-def calibrate(x, y=None, factor=None, range_=None):
-    if factor is not None:
-        x = np.polyval(factor, x)
-
-    if y is not None and range_ is not None:
-        start = np.argmin(np.abs((x - range_[0])))
-        stop = np.argmin(np.abs((x - range_[1])))
-        # Calibrated energies have a different direction
-        x, y = x[stop:start], y[stop:start]
-
-    return x, y
-
-
-# -----------------------------------------------------------------------------
-# Gaussian-related functions
-
-
 FWHM_COEFF = 2 * np.sqrt(2 * np.log(2))
 
 
-def gaussian_fit(x_data, y_data, offset=0):
-    """
-    Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
-    """
-
-    x0 = np.average(x_data, weights=y_data)
-    sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data))
-
-    # Gaussian fit
-    baseline = y_data.min()
-    p_0 = (y_data.max(), x0 + offset, sx, baseline)
-    try:
-        p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000)
-        return p_f
-    except (RuntimeError, TypeError) as e:
-        print(e)
-        return None
-
-
-def gauss1d(x, height, x0, sigma, offset):
-    return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset
-
-
 def to_fwhm(sigma):
     return abs(sigma * FWHM_COEFF)
 
 
-def decentroid(res):
-    res = np.array(res)
-    ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
-    for cy, cx in res:
-        if cx > 0 and cy > 0:
-            ret[int(cy), int(cx)] += 1
-    return ret
-
-
 class JF_hRIXS:
     """The JUNGFRAU hRIXS analysis, especially curvature correction
 
@@ -144,7 +43,7 @@ class JF_hRIXS:
     STD_THRESHOLD:
         same as THRESHOLD, in standard deviations.
     DBL_THRESHOLD:
-        threshold controling whether a detected hit is considered to be a 
+        threshold controling whether a detected hit is considered to be a
         double hit.
     BINS: int
         the number of bins used in centroiding
@@ -161,7 +60,7 @@ class JF_hRIXS:
 
     Example
     -------
-        
+
         proposal = 3145
         h = hRIXS(proposal)
         h.Y_RANGE = slice(700, 900)
@@ -172,7 +71,7 @@ class JF_hRIXS:
     """
     def __init__(self, proposalNB):
 
-        self.PROPOSAL=proposalNB
+        self.PROPOSAL = proposalNB
 
         # image range
         self.X_RANGE = np.s_[:]
@@ -187,7 +86,8 @@ class JF_hRIXS:
         self.ENERGY_INTERCEPT = 0
         self.ENERGY_SLOPE = 1
 
-        self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm', 'nrj']
+        self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay',
+                       'hRIXS_norm', 'nrj']
 
     def set_params(self, **params):
         for key, value in params.items():
@@ -200,14 +100,16 @@ class JF_hRIXS:
                       'bins', 'fields')
         return {param: getattr(self, param.upper()) for param in params}
 
-    def from_run(self, runNB, proposal=None, extra_fields=(), drop_first=False):
-        """load a run
+    def from_run(self, runNB, proposal=None, extra_fields=(),
+                 drop_first=False):
+        """Load a run.
 
         Load the run `runNB`. A thin wrapper around `toolbox.load`.
         Parameters
         ----------
             drop_first: bool
-                if True, the first image in the run is removed from the dataset.
+                if True, the first image in the run is removed from the
+                dataset.
 
         Example
         -------
@@ -220,55 +122,14 @@ class JF_hRIXS:
         """
         if proposal is None:
             proposal = self.PROPOSAL
-        run, data = tb.load(proposal, runNB=runNB,
-                            fields=self.FIELDS + list(extra_fields))
+        _, data = tb.load(proposal, runNB=runNB,
+                          fields=self.FIELDS + list(extra_fields))
         if drop_first is True:
             data = data.isel(trainId=slice(1, None))
         return data
 
-    def find_curvature(self, runNB, proposal=None, plot=True, args=None,
-                       **kwargs):
-        """find the curvature correction coefficients
-
-        The hRIXS has some abberations which leads to the spectroscopic lines
-        being curved on the detector. We approximate these abberations with
-        a parabola for later correction.
-
-        Load a run and determine the curvature. The curvature is set in `self`,
-        and returned as a pair of floats.
-
-        Parameters
-        ----------
-
-        runNB: int
-            the run number to use
-        proposal: int
-            the proposal to use, default to the current proposal
-        plot: bool
-            whether to plot the found curvature onto the data
-        args: pair of float, optional
-            a starting value to prime the fitting routine
-
-        Example
-        -------
-
-            h.find_curvature(155)  # use run 155 to fit the curvature
-        """
-        data = self.from_run(runNB, proposal)
-
-        image = data['hRIXS_det'].sum(dim='trainId') \
-                .values[self.X_RANGE, self.Y_RANGE].T
-        if args is None:
-            spec = (image - image[:10, :].mean()).mean(axis=1)
-            mean = np.average(np.arange(len(spec)), weights=spec)
-            args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3,
-                    spec.max(), image.mean())
-        args = _find_curvature(image, args, plot=plot, **kwargs)
-        self.CURVE_B, self.CURVE_A, *_ = args
-        return self.CURVE_A, self.CURVE_B
-
-    def find_curvature(img, args, plot=False, **kwargs):
-        """find the curvature correction coefficients
+    def find_curvature(self, img, args, plot=False, **kwargs):
+        """Find the curvature correction coefficients.
 
         The hRIXS has some abberations which leads to the spectroscopic lines
         being curved on the detector. We approximate these abberations with
@@ -279,7 +140,6 @@ class JF_hRIXS:
 
         Parameters
         ----------
-
         img: array
             2D average image
         args: (a, b, c, s, h, o) initial coefficients
@@ -287,51 +147,57 @@ class JF_hRIXS:
             h the height and o an offset
         plot: bool
             whether to plot the found curvature onto the data
+
         Example
         -------
-
             h.find_curvature(155)  # use run 155 to fit the curvature
         """
-        
         def parabola(x, a, b, c, s=0, h=0, o=0):
             return (a*x + b)*x + c
+
         def gauss(y, x, a, b, c, s, h, o=0):
             return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
+
         x = np.arange(img.shape[1])[None, :]
         y = np.arange(img.shape[0])[:, None]
 
         if plot:
-            plt.figure(figsize=(10,10))
-            plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs)
+            plt.figure(figsize=(10, 10))
+            plt.imshow(img, cmap='gray', aspect='auto',
+                       interpolation='nearest', **kwargs)
             plt.plot(x[0, :], parabola(x[0, :], *args))
 
-        args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args)
+        args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(),
+                          args)
 
         if plot:
             plt.plot(x[0, :], parabola(x[0, :], *args))
         return args
 
+    def parabola(self, x):
+        return (self.CURVE_B * x + self.CURVE_A) * x
+
     def spectrum(self, fname):
         """Bin photon hit data into spectrum.
 
         Parameters
         ----------
-
         fname: string
             file name of the data to load.
         """
-
         data_interp = xr.load_dataset(fname)
+
         def hist_curv(x, y):
             H, _ = np.histogram(
                 x - self.parabola(y), bins=self.BINS,
                 range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
-    
+
             return H
 
         energy = (np.linspace(self.Y_RANGE.start,
                               self.Y_RANGE.stop,
-                              self.BINS) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
+                              self.BINS) * self.ENERGY_SLOPE
+                  + self.ENERGY_INTERCEPT)
 
         spectrum = xr.apply_ufunc(hist_curv,
                                   data_interp['y'],
@@ -347,131 +213,3 @@ class JF_hRIXS:
         spectrum['energy'] = energy
 
         return spectrum
-
-    def parabola(self, x):
-        return (self.CURVE_B * x + self.CURVE_A) * x
-
-    def integrate(self, data):
-        """calculate a spectrum by integration
-
-        This takes the `xarray` `data` and returns a copy of it, with a new
-        dataarray named `spectrum` added, which contains the energy spectrum
-        calculated for each hRIXS image.
-
-        First the energy that corresponds to each pixel is calculated.
-        Then all pixels within an energy range are summed, where the intensity
-        of one pixel is distributed among the two energy ranges the pixel
-        spans, proportionally to the overlap between the pixel and bin energy
-        ranges.
-
-        The resulting data is normalized to one pixel, so the average
-        intensity that arrived on one pixel.
-
-        Example
-        -------
-
-            h.integrate(data)  # create spectrum by summing pixels
-            data.spectrum[0, :].plot()  # plot the spectrum of the first image
-        """
-        bins = self.Y_RANGE.stop - self.Y_RANGE.start
-        margin = 10
-        ret = np.zeros((len(data["hRIXS_det"]), bins - 2 * margin))
-        if self.USE_DARK:
-            dark_image = self.dark_image.values[self.X_RANGE, self.Y_RANGE]
-        images = data["hRIXS_det"].values[:, self.X_RANGE, self.Y_RANGE]
-
-        x, y = np.ogrid[:images.shape[1], :images.shape[2]]
-        quo, rem = divmod(y - self.parabola(x), 1)
-        quo = np.array([quo, quo + 1])
-        rem = np.array([rem, 1 - rem])
-        wrong = (quo < margin) | (quo >= bins - margin)
-        quo[wrong] = margin
-        rem[wrong] = 0
-        quo = (quo - margin).astype(int).ravel()
-
-        for image, r in zip(images, ret):
-            if self.USE_DARK:
-                image = image - dark_image
-            r[:] = np.bincount(quo, weights=(rem * image).ravel())
-        ret /= np.bincount(quo, weights=rem.ravel())
-        data.coords["energy"] = (
-            np.arange(self.Y_RANGE.start + margin, self.Y_RANGE.stop - margin)
-            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
-        data['spectrum'] = (("trainId", "energy"), ret)
-        return data
-
-    aggregators = dict(
-        hRIXS_det=lambda x, dim: x.sum(dim=dim),
-        Delay=lambda x, dim: x.mean(dim=dim),
-        hRIXS_delay=lambda x, dim: x.mean(dim=dim),
-        hRIXS_norm=lambda x, dim: x.sum(dim=dim),
-        spectrum=lambda x, dim: x.sum(dim=dim),
-        dbl_spectrum=lambda x, dim: x.sum(dim=dim),
-        total_hits=lambda x, dim: x.sum(dim=dim),
-        dbl_hits=lambda x, dim: x.sum(dim=dim),
-        counts=lambda x, dim: x.sum(dim=dim)
-    )
-
-    def aggregator(self, da, dim):
-        agg = self.aggregators.get(da.name)
-        if agg is None:
-            return None
-        return agg(da, dim=dim)
-
-    def aggregate(self, ds, var=None, dim="trainId"):
-        """aggregate (i.e. mostly sum) all data within one dataset
-
-        take all images in a dataset and aggregate them and their metadata.
-        For images, spectra and normalizations that means adding them, for
-        others (e.g. delays) adding would not make sense, so we treat them
-        properly. The aggregation functions of each variable are defined
-        in the aggregators attribute of the class.
-        If var is specified, group the dataset by var prior to aggregation.
-        A new variable "counts" gives the number of frames aggregated in
-        each group.
-
-        Parameters
-        ----------
-        ds: xarray Dataset
-            the dataset containing RIXS data
-        var: string
-            One of the variables in the dataset. If var is specified, the
-            dataset is grouped by var prior to aggregation. This is useful
-            for sorting e.g. a dataset that contains multiple delays.
-        dim: string
-            the dimension over which to aggregate the data
-
-        Example
-        -------
-
-            h.centroid(data)  # create spectra from finding photons
-            agg = h.aggregate(data)  # sum all spectra
-            agg.spectrum.plot()  # plot the resulting spectrum
-
-            agg2 = h.aggregate(data, 'hRIXS_delay') # group data by delay
-            agg2.spectrum[0, :].plot()  # plot the spectrum for first value
-        """
-        ds["counts"] = xr.ones_like(ds[dim])
-        if var is not None:
-            groups = ds.groupby(var)
-            return groups.map(self.aggregate_ds, dim=dim)
-        return self.aggregate_ds(ds, dim)
-
-    def aggregate_ds(self, ds, dim='trainId'):
-        ret = ds.map(self.aggregator, dim=dim)
-        ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
-        return ret
-
-    def normalize(self, data, which="hRIXS_norm"):
-        """ Adds a 'normalized' variable to the dataset defined as the
-        ration between 'spectrum' and 'which'
-
-        Parameters
-        ----------
-            data: xarray Dataset
-                the dataset containing hRIXS data
-            which: string, default="hRIXS_norm"
-                one of the variables of the dataset, usually "hRIXS_norm"
-                or "counts"
-        """
-        return data.assign(normalized=data["spectrum"] / data[which])
-- 
GitLab