From 60c5e70c209735f683fcb8efec5897637fc0f6b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Le=20Guyader?= <loic.le.guyader@xfel.eu> Date: Fri, 17 Nov 2023 15:43:10 +0100 Subject: [PATCH] cleanup and minimize code --- src/toolbox_scs/detectors/jf_hrixs.py | 320 +++----------------------- 1 file changed, 29 insertions(+), 291 deletions(-) diff --git a/src/toolbox_scs/detectors/jf_hrixs.py b/src/toolbox_scs/detectors/jf_hrixs.py index 62d701d..d4f5099 100644 --- a/src/toolbox_scs/detectors/jf_hrixs.py +++ b/src/toolbox_scs/detectors/jf_hrixs.py @@ -1,11 +1,8 @@ -from functools import lru_cache import xarray as xr import numpy as np import matplotlib.pyplot as plt from scipy.optimize import leastsq -from scipy.optimize import curve_fit -from scipy.signal import fftconvolve import toolbox_scs as tb @@ -14,111 +11,13 @@ __all__ = [ 'JF_hRIXS', ] - -# ----------------------------------------------------------------------------- -# Curvature - -def correct_curvature(image, factor=None, axis=1): - if factor is None: - return - - if axis == 1: - image = image.T - - ydim, xdim = image.shape - x = np.arange(xdim + 1) - y = np.arange(ydim + 1) - xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5) - xxn = xx - factor[0] * yy - factor[1] * yy ** 2 - ret = np.histogramdd((xxn.flatten(), yy.flatten()), - bins=[x, y], - weights=image.flatten())[0] - - return ret if axis == 1 else ret.T - - -def get_spectrum(image, factor=None, axis=0, - pixel_range=None, energy_range=None, ): - start, stop = (0, image.shape[axis - 1]) - if pixel_range is not None: - start = max(pixel_range[0] or start, start) - stop = min(pixel_range[1] or stop, stop) - - edge = image.sum(axis=axis)[start:stop] - bins = np.arange(start, stop + 1) - centers = (bins[1:] + bins[:-1]) * 0.5 - if factor is not None: - centers, edge = calibrate(centers, edge, - factor=factor, - range_=energy_range) - - return centers, edge - - -# ----------------------------------------------------------------------------- -# Energy calibration - - -def energy_calibration(channels, energies): - return np.polyfit(channels, energies, deg=1) - - -def calibrate(x, y=None, factor=None, range_=None): - if factor is not None: - x = np.polyval(factor, x) - - if y is not None and range_ is not None: - start = np.argmin(np.abs((x - range_[0]))) - stop = np.argmin(np.abs((x - range_[1]))) - # Calibrated energies have a different direction - x, y = x[stop:start], y[stop:start] - - return x, y - - -# ----------------------------------------------------------------------------- -# Gaussian-related functions - - FWHM_COEFF = 2 * np.sqrt(2 * np.log(2)) -def gaussian_fit(x_data, y_data, offset=0): - """ - Centre-of-mass and width. Lifted from image_processing.imageCentreofMass() - """ - - x0 = np.average(x_data, weights=y_data) - sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data)) - - # Gaussian fit - baseline = y_data.min() - p_0 = (y_data.max(), x0 + offset, sx, baseline) - try: - p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000) - return p_f - except (RuntimeError, TypeError) as e: - print(e) - return None - - -def gauss1d(x, height, x0, sigma, offset): - return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset - - def to_fwhm(sigma): return abs(sigma * FWHM_COEFF) -def decentroid(res): - res = np.array(res) - ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int)) - for cy, cx in res: - if cx > 0 and cy > 0: - ret[int(cy), int(cx)] += 1 - return ret - - class JF_hRIXS: """The JUNGFRAU hRIXS analysis, especially curvature correction @@ -144,7 +43,7 @@ class JF_hRIXS: STD_THRESHOLD: same as THRESHOLD, in standard deviations. DBL_THRESHOLD: - threshold controling whether a detected hit is considered to be a + threshold controling whether a detected hit is considered to be a double hit. BINS: int the number of bins used in centroiding @@ -161,7 +60,7 @@ class JF_hRIXS: Example ------- - + proposal = 3145 h = hRIXS(proposal) h.Y_RANGE = slice(700, 900) @@ -172,7 +71,7 @@ class JF_hRIXS: """ def __init__(self, proposalNB): - self.PROPOSAL=proposalNB + self.PROPOSAL = proposalNB # image range self.X_RANGE = np.s_[:] @@ -187,7 +86,8 @@ class JF_hRIXS: self.ENERGY_INTERCEPT = 0 self.ENERGY_SLOPE = 1 - self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm', 'nrj'] + self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', + 'hRIXS_norm', 'nrj'] def set_params(self, **params): for key, value in params.items(): @@ -200,14 +100,16 @@ class JF_hRIXS: 'bins', 'fields') return {param: getattr(self, param.upper()) for param in params} - def from_run(self, runNB, proposal=None, extra_fields=(), drop_first=False): - """load a run + def from_run(self, runNB, proposal=None, extra_fields=(), + drop_first=False): + """Load a run. Load the run `runNB`. A thin wrapper around `toolbox.load`. Parameters ---------- drop_first: bool - if True, the first image in the run is removed from the dataset. + if True, the first image in the run is removed from the + dataset. Example ------- @@ -220,55 +122,14 @@ class JF_hRIXS: """ if proposal is None: proposal = self.PROPOSAL - run, data = tb.load(proposal, runNB=runNB, - fields=self.FIELDS + list(extra_fields)) + _, data = tb.load(proposal, runNB=runNB, + fields=self.FIELDS + list(extra_fields)) if drop_first is True: data = data.isel(trainId=slice(1, None)) return data - def find_curvature(self, runNB, proposal=None, plot=True, args=None, - **kwargs): - """find the curvature correction coefficients - - The hRIXS has some abberations which leads to the spectroscopic lines - being curved on the detector. We approximate these abberations with - a parabola for later correction. - - Load a run and determine the curvature. The curvature is set in `self`, - and returned as a pair of floats. - - Parameters - ---------- - - runNB: int - the run number to use - proposal: int - the proposal to use, default to the current proposal - plot: bool - whether to plot the found curvature onto the data - args: pair of float, optional - a starting value to prime the fitting routine - - Example - ------- - - h.find_curvature(155) # use run 155 to fit the curvature - """ - data = self.from_run(runNB, proposal) - - image = data['hRIXS_det'].sum(dim='trainId') \ - .values[self.X_RANGE, self.Y_RANGE].T - if args is None: - spec = (image - image[:10, :].mean()).mean(axis=1) - mean = np.average(np.arange(len(spec)), weights=spec) - args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3, - spec.max(), image.mean()) - args = _find_curvature(image, args, plot=plot, **kwargs) - self.CURVE_B, self.CURVE_A, *_ = args - return self.CURVE_A, self.CURVE_B - - def find_curvature(img, args, plot=False, **kwargs): - """find the curvature correction coefficients + def find_curvature(self, img, args, plot=False, **kwargs): + """Find the curvature correction coefficients. The hRIXS has some abberations which leads to the spectroscopic lines being curved on the detector. We approximate these abberations with @@ -279,7 +140,6 @@ class JF_hRIXS: Parameters ---------- - img: array 2D average image args: (a, b, c, s, h, o) initial coefficients @@ -287,51 +147,57 @@ class JF_hRIXS: h the height and o an offset plot: bool whether to plot the found curvature onto the data + Example ------- - h.find_curvature(155) # use run 155 to fit the curvature """ - def parabola(x, a, b, c, s=0, h=0, o=0): return (a*x + b)*x + c + def gauss(y, x, a, b, c, s, h, o=0): return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o + x = np.arange(img.shape[1])[None, :] y = np.arange(img.shape[0])[:, None] if plot: - plt.figure(figsize=(10,10)) - plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs) + plt.figure(figsize=(10, 10)) + plt.imshow(img, cmap='gray', aspect='auto', + interpolation='nearest', **kwargs) plt.plot(x[0, :], parabola(x[0, :], *args)) - args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args) + args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), + args) if plot: plt.plot(x[0, :], parabola(x[0, :], *args)) return args + def parabola(self, x): + return (self.CURVE_B * x + self.CURVE_A) * x + def spectrum(self, fname): """Bin photon hit data into spectrum. Parameters ---------- - fname: string file name of the data to load. """ - data_interp = xr.load_dataset(fname) + def hist_curv(x, y): H, _ = np.histogram( x - self.parabola(y), bins=self.BINS, range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) - + return H energy = (np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, - self.BINS) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) + self.BINS) * self.ENERGY_SLOPE + + self.ENERGY_INTERCEPT) spectrum = xr.apply_ufunc(hist_curv, data_interp['y'], @@ -347,131 +213,3 @@ class JF_hRIXS: spectrum['energy'] = energy return spectrum - - def parabola(self, x): - return (self.CURVE_B * x + self.CURVE_A) * x - - def integrate(self, data): - """calculate a spectrum by integration - - This takes the `xarray` `data` and returns a copy of it, with a new - dataarray named `spectrum` added, which contains the energy spectrum - calculated for each hRIXS image. - - First the energy that corresponds to each pixel is calculated. - Then all pixels within an energy range are summed, where the intensity - of one pixel is distributed among the two energy ranges the pixel - spans, proportionally to the overlap between the pixel and bin energy - ranges. - - The resulting data is normalized to one pixel, so the average - intensity that arrived on one pixel. - - Example - ------- - - h.integrate(data) # create spectrum by summing pixels - data.spectrum[0, :].plot() # plot the spectrum of the first image - """ - bins = self.Y_RANGE.stop - self.Y_RANGE.start - margin = 10 - ret = np.zeros((len(data["hRIXS_det"]), bins - 2 * margin)) - if self.USE_DARK: - dark_image = self.dark_image.values[self.X_RANGE, self.Y_RANGE] - images = data["hRIXS_det"].values[:, self.X_RANGE, self.Y_RANGE] - - x, y = np.ogrid[:images.shape[1], :images.shape[2]] - quo, rem = divmod(y - self.parabola(x), 1) - quo = np.array([quo, quo + 1]) - rem = np.array([rem, 1 - rem]) - wrong = (quo < margin) | (quo >= bins - margin) - quo[wrong] = margin - rem[wrong] = 0 - quo = (quo - margin).astype(int).ravel() - - for image, r in zip(images, ret): - if self.USE_DARK: - image = image - dark_image - r[:] = np.bincount(quo, weights=(rem * image).ravel()) - ret /= np.bincount(quo, weights=rem.ravel()) - data.coords["energy"] = ( - np.arange(self.Y_RANGE.start + margin, self.Y_RANGE.stop - margin) - * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) - data['spectrum'] = (("trainId", "energy"), ret) - return data - - aggregators = dict( - hRIXS_det=lambda x, dim: x.sum(dim=dim), - Delay=lambda x, dim: x.mean(dim=dim), - hRIXS_delay=lambda x, dim: x.mean(dim=dim), - hRIXS_norm=lambda x, dim: x.sum(dim=dim), - spectrum=lambda x, dim: x.sum(dim=dim), - dbl_spectrum=lambda x, dim: x.sum(dim=dim), - total_hits=lambda x, dim: x.sum(dim=dim), - dbl_hits=lambda x, dim: x.sum(dim=dim), - counts=lambda x, dim: x.sum(dim=dim) - ) - - def aggregator(self, da, dim): - agg = self.aggregators.get(da.name) - if agg is None: - return None - return agg(da, dim=dim) - - def aggregate(self, ds, var=None, dim="trainId"): - """aggregate (i.e. mostly sum) all data within one dataset - - take all images in a dataset and aggregate them and their metadata. - For images, spectra and normalizations that means adding them, for - others (e.g. delays) adding would not make sense, so we treat them - properly. The aggregation functions of each variable are defined - in the aggregators attribute of the class. - If var is specified, group the dataset by var prior to aggregation. - A new variable "counts" gives the number of frames aggregated in - each group. - - Parameters - ---------- - ds: xarray Dataset - the dataset containing RIXS data - var: string - One of the variables in the dataset. If var is specified, the - dataset is grouped by var prior to aggregation. This is useful - for sorting e.g. a dataset that contains multiple delays. - dim: string - the dimension over which to aggregate the data - - Example - ------- - - h.centroid(data) # create spectra from finding photons - agg = h.aggregate(data) # sum all spectra - agg.spectrum.plot() # plot the resulting spectrum - - agg2 = h.aggregate(data, 'hRIXS_delay') # group data by delay - agg2.spectrum[0, :].plot() # plot the spectrum for first value - """ - ds["counts"] = xr.ones_like(ds[dim]) - if var is not None: - groups = ds.groupby(var) - return groups.map(self.aggregate_ds, dim=dim) - return self.aggregate_ds(ds, dim) - - def aggregate_ds(self, ds, dim='trainId'): - ret = ds.map(self.aggregator, dim=dim) - ret = ret.drop_vars([n for n in ret if n not in self.aggregators]) - return ret - - def normalize(self, data, which="hRIXS_norm"): - """ Adds a 'normalized' variable to the dataset defined as the - ration between 'spectrum' and 'which' - - Parameters - ---------- - data: xarray Dataset - the dataset containing hRIXS data - which: string, default="hRIXS_norm" - one of the variables of the dataset, usually "hRIXS_norm" - or "counts" - """ - return data.assign(normalized=data["spectrum"] / data[which]) -- GitLab