From 59abb97a70d3d5ecfc64000bffa2ec40b0b9d639 Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@xfel.eu> Date: Thu, 22 Sep 2022 09:52:04 +0200 Subject: [PATCH] Cumulative updates from beamtime 2776 (van Kuiken) --- src/toolbox_scs/detectors/hrixs.py | 212 ++++++++++++++++------------- 1 file changed, 121 insertions(+), 91 deletions(-) diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py index 59d51e2..6282d85 100644 --- a/src/toolbox_scs/detectors/hrixs.py +++ b/src/toolbox_scs/detectors/hrixs.py @@ -1,6 +1,8 @@ from functools import lru_cache import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import leastsq from scipy.optimize import curve_fit from scipy.signal import fftconvolve @@ -56,6 +58,26 @@ def find_curvature(image, frangex=None, frangey=None, return curv[:-1][::-1] +def find_curvature(img, args, plot=False, **kwargs): + def parabola(x, a, b, c, s=0, h=0, o=0): + return (a*x + b)*x + c + def gauss(y, x, a, b, c, s, h, o=0): + return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o + x = np.arange(img.shape[1])[None, :] + y = np.arange(img.shape[0])[:, None] + + if plot: + plt.figure(figsize=(10,10)) + plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs) + plt.plot(x[0, :], parabola(x[0, :], *args)) + + args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args) + + if plot: + plt.plot(x[0, :], parabola(x[0, :], *args)) + return args + + def correct_curvature(image, factor=None, axis=1): if factor is None: return @@ -175,7 +197,7 @@ def _esrf_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)): return res -def _new_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)): +def _new_centroid(image, threshold=THRESHOLD, std_threshold=3.5, curvature=(CURVE_A, CURVE_B)): """find the position of photons with sub-pixel precision A photon is supposed to have hit the detector if the intensity within a @@ -186,7 +208,8 @@ def _new_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)): """ base = image.mean() corners = image[1:, 1:] + image[:-1, 1:] + image[1:, :-1] + image[:-1, :-1] - threshold = corners.mean() + 3.5 * corners.std() + if threshold is None: + threshold = corners.mean() + std_threshold * corners.std() middle = corners[1:-1, 1:-1] candidates = ( (middle > threshold) @@ -257,11 +280,12 @@ class hRIXS: PROPOSAL = 2769 # image range - X_RANGE = np.s_[1300:-100] + X_RANGE = np.s_[:] Y_RANGE = np.s_[:] # centroid - THRESHOLD = THRESHOLD # pixel counts above which a hit candidate is assumed + THRESHOLD = None # pixel counts above which a hit candidate is assumed + STD_THRESHOLD = 3.5 # same as THRESHOLD, in standard deviations CURVE_A = CURVE_A # curvature parameters as determined elsewhere CURVE_B = CURVE_B @@ -271,101 +295,107 @@ class hRIXS: BINS = abs(np.subtract(*RANGE)) * FACTOR METHOD = 'centroid' # ['centroid', 'integral'] + USE_DARK = False - @classmethod - def set_params(cls, **params): - for key, value in params.items(): - setattr(cls, key.upper(), value) + ENERGY_INTERCEPT = 0 + ENERGY_SLOPE = 1 - def __set_params(self, **params): - self.__class__.set_params(**params) - self.refresh() + FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm'] + + def set_params(self, **params): + for key, value in params.items(): + setattr(self, key.upper(), value) - @classmethod - def get_params(cls, *params): + def get_params(self, *params): if not params: params = ('proposal', 'x_range', 'y_range', 'threshold', 'curve_a', 'curve_b', 'factor', 'range', 'bins', - 'method') - return {param: getattr(cls, param.upper()) for param in params} + 'method', 'fields') + return {param: getattr(self, param.upper()) for param in params} - def refresh(self): - cls = self.__class__ - for cached in ['_centroid', '_integral']: - getattr(cls, cached).fget.cache_clear() - - def __init__(self, images, norm=None): - self.images = images - self.norm = norm - - # class/instance method compatibility - self.set_params = self.__set_params - - @classmethod - def from_run(cls, runNB, proposal=None, first_wrong=False): + def from_run(self, runNB, proposal=None, extra_fields=()): if proposal is None: - proposal = cls.PROPOSAL - - run, data = tb.load(proposal, runNB=runNB, fields=['hRIXS_det']) - - # Get slow train data - mnemo = tb.mnemonics_for_run(run)['SCS_slowTrain'] - slow_train = run[mnemo['source'], mnemo['key']].ndarray().sum() - - return cls(images=data['hRIXS_det'][1 if first_wrong else 0:].data, - norm=slow_train) - - @property - @lru_cache() - def _centroid(self): - return sum((centroid(image[self.Y_RANGE, self.X_RANGE].T, - threshold=self.THRESHOLD, - curvature=(self.CURVE_A, self.CURVE_B), ) - for image in self.images), []) - - def _centroid_spectrum(self, bins=None, range=None, normalize=True): + proposal = self.PROPOSAL + run, data = tb.load(proposal, runNB=runNB, + fields=self.FIELDS + list(extra_fields)) + + return data + + def load_dark(self, runNB, proposal=None): + data = self.from_run(runNB, proposal) + self.dark_image = data['hRIXS_det'].mean(dim='trainId') + self.USE_DARK = True + + def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs): + data = self.from_run(runNB, proposal) + + image = data['hRIXS_det'].sum(dim='trainId') \ + .to_numpy()[self.X_RANGE, self.Y_RANGE].T + if args is None: + spec = (image - image[:10, :].mean()).mean(axis=1) + mean = np.average(np.arange(len(spec)), weights=spec) + args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3, + spec.max(), image.mean()) + args = find_curvature(image, args, plot=plot, **kwargs) + self.CURVE_B, self.CURVE_A, *_ = args + return self.CURVE_A, self.CURVE_B + + def centroid(self, data, bins=None): if bins is None: bins = self.BINS - if range is None: - range = self.RANGE - - r = np.array(self._centroid) - hy, hx = np.histogram(r[:, 0], bins=bins, range=range) - if normalize and self.norm is not None: - hy = hy / self.norm - - return (hx[:-1] + hx[1:]) / 2, hy - - @property - @lru_cache() - def _integral(self): - return sum((integrate(image[self.Y_RANGE, self.X_RANGE].T, - factor=self.FACTOR, - range=self.RANGE, - curvature=(self.CURVE_A, self.CURVE_B)) - for image in self.images)) - - def _integral_spectrum(self, normalize=True): - values = self._integral - if normalize and self.norm is not None: - values = values / self.norm - return np.arange(values.size), values - - @property - def corrected(self): - return decentroid(self._centroid) - - def spectrum(self, normalize=True): - spec_func = (self._centroid_spectrum if self.METHOD.lower() == 'centroid' - else self._integral_spectrum) - return spec_func(normalize=normalize) - - def __sub__(self, other): - px, py = self.spectrum() - mx, my = other.spectrum() - return (px + mx) / 2, py - my - - def __add__(self, other): - return self.__class__(images=list(self.images) + list(other.images), - norm=self.norm + other.norm) + ret = np.zeros((len(data["hRIXS_det"]), bins)) + for image, r in zip(data["hRIXS_det"], ret): + c = centroid( + image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, + threshold=self.THRESHOLD, + std_threshold=self.STD_THRESHOLD, + curvature=(self.CURVE_A, self.CURVE_B)) + if not len(c): + continue + rc = np.array(c) + hy, hx = np.histogram( + rc[:, 0], bins=bins, + range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) + r[:] = hy + + data = data.assign_coords( + energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins) + * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) + return data.assign(spectrum=(("trainId", "energy"), ret)) + + def integrate(self, data): + bins = self.Y_RANGE.stop - self.Y_RANGE.start + ret = np.zeros((len(data["hRIXS_det"]), bins - 20)) + for image, r in zip(data["hRIXS_det"], ret): + if self.USE_DARK: + image = image - self.dark_image + r[:] = integrate(image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, factor=1, + range=(10, bins - 10), + curvature=(self.CURVE_A, self.CURVE_B)) + data = data.assign_coords( + energy=np.arange(self.Y_RANGE.start + 10, self.Y_RANGE.stop - 10) + * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) + return data.assign(spectrum=(("trainId", "energy"), ret)) + + aggregators = dict( + hRIXS_det=lambda x, dim: x.sum(dim=dim), + Delay=lambda x, dim: x.mean(dim=dim), + hRIXS_delay=lambda x, dim: x.mean(dim=dim), + hRIXS_norm=lambda x, dim: x.sum(dim=dim), + spectrum=lambda x, dim: x.sum(dim=dim), + ) + + def aggregator(self, da, dim): + agg = self.aggregators.get(da.name) + if agg is None: + return None + return agg(da, dim=dim) + + def aggregate(self, ds, dim="trainId"): + ret = ds.map(self.aggregator, dim=dim) + ret = ret.drop_vars([n for n in ret if n not in self.aggregators]) + return ret + + def normalize(self, data): + return data.assign(normalized=data["spectrum"] / data["hRIXS_norm"]) -- GitLab