diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py index 40f24f821affcbc764b256730b953be8ef4a7da4..686db07e464070cfaac46ec697c5d941e4abfc1c 100644 --- a/src/toolbox_scs/detectors/hrixs.py +++ b/src/toolbox_scs/detectors/hrixs.py @@ -1,7 +1,11 @@ +from functools import lru_cache + import numpy as np from scipy.optimize import curve_fit from scipy.signal import fftconvolve +import toolbox_scs as tb + __all__ = [ 'find_curvature', @@ -142,3 +146,197 @@ def gauss1d(x, height, x0, sigma, offset): def to_fwhm(sigma): return abs(sigma * FWHM_COEFF) + + +# ----------------------------------------------------------------------------- +# Centroid + +THRESHOLD = 510 # pixel counts above which a hit candidate is assumed +CURVE_A = 2.19042931e-02 # curvature parameters as determined elsewhere +CURVE_B = -3.02191568e-07 + + +def centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)): + gs = 2 + base = image.mean() + cp = np.argwhere(image[gs // 2: -gs // 2, gs // 2: -gs // 2] > threshold) + np.array([gs // 2, gs // 2]) + if len(cp) > 100000: + raise RuntimeError('Threshold too low or acquisition time too long') + + res = [] + for cy, cx in cp: + spot = image[cy - gs // 2: cy + gs // 2 + 1, cx - gs // 2: cx + gs // 2 + 1] - base + spot[spot < 0] = 0 + if (spot > image[cy, cx]).sum() == 0: + mx = np.average(np.arange(cx - gs // 2, cx + gs // 2 + 1), weights=spot.sum(axis=0)) + my = np.average(np.arange(cy - gs // 2, cy + gs // 2 + 1), weights=spot.sum(axis=1)) + my -= (curvature[0] + curvature[1] * mx) * mx + res.append((my, mx)) + return res + + +def decentroid(res): + res = np.array(res) + ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int)) + for cy, cx in res: + if cx > 0 and cy > 0: + ret[int(cy), int(cx)] += 1 + return ret + + +# ----------------------------------------------------------------------------- +# Integral + +FACTOR = 3 +RANGE = [300, 400] +BINS = abs(np.subtract(*RANGE)) * FACTOR + + +def parabola(x, a, b, c=0): + return (a * x + b) * x + c + + +def integrate(image, factor=FACTOR, range=RANGE, curvature=(CURVE_A, CURVE_B), ): + image = image - image.mean() + x = np.arange(image.shape[1])[None, :] + y = np.arange(image.shape[0])[:, None] + ys = factor * (y - parabola(x, curvature[1], curvature[0])) + ysf = np.floor(ys) + rang = (factor * range[0], factor * range[1]) + bins = rang[1] - rang[0] + lhy, lhx = np.histogram(ysf.ravel(), bins=bins, weights=((ys - ysf) * image).ravel(), range=rang) + rhy, rhx = np.histogram((ysf + 1).ravel(), bins=bins, weights=(((ysf + 1) - ys) * image).ravel(), range=rang) + lvy, lvx = np.histogram(ysf.ravel(), bins=bins, weights=(ys - ysf).ravel(), range=rang) + rvy, rvx = np.histogram((ysf + 1).ravel(), bins=bins, weights=((ysf + 1) - ys).ravel(), range=rang) + return (lhy + rhy) / (lvy + rvy) + + +# ----------------------------------------------------------------------------- +# hRIXS class + + +class hRIXS: + # run + PROPOSAL = 2769 + + # image range + X_RANGE = np.s_[1300:-100] + Y_RANGE = np.s_[:] + + # centroid + THRESHOLD = THRESHOLD # pixel counts above which a hit candidate is assumed + CURVE_A = CURVE_A # curvature parameters as determined elsewhere + CURVE_B = CURVE_B + + # integral + FACTOR = FACTOR + RANGE = RANGE + BINS = abs(np.subtract(*RANGE)) * FACTOR + + METHOD = 'centroid' # ['centroid', 'integral'] + + @classmethod + def set_params(cls, **params): + for key, value in params.items(): + setattr(cls, key.upper(), value) + + def __set_params(self, **params): + self.__class__.set_params(**params) + self.refresh() + + @classmethod + def get_params(cls, *params): + if not params: + params = ('proposal', 'x_range', 'y_range', + 'threshold', 'curve_a', 'curve_b', + 'factor', 'range', 'bins', + 'method') + return {param: getattr(cls, param.upper()) for param in params} + + def refresh(self): + cls = self.__class__ + for cached in ['_centroid', '_integral']: + getattr(cls, cached).fget.cache_clear() + + def __init__(self, images, norm=None): + self.images = images + self.norm = norm + + # class/instance method compatibility + self.set_params = self.__set_params + + @classmethod + def from_run(cls, runNB, proposal=None, first_wrong=False): + if proposal is None: + proposal = cls.PROPOSAL + + run, data = tb.load(proposal, + runNB=runNB, + fields=['hRIXS_det', 'SCS_slowTrain']) + + # Get slow train data + mnemo = tb.mnemonics['SCS_slowTrain'][0] + slow_train = run[mnemo['source'], mnemo['key']].ndarray().sum() + + return cls(images=data['hRIXS_det'][1 if first_wrong else 0:].data, + norm=slow_train) + + @property + @lru_cache() + def _centroid(self): + return sum((centroid(image[self.Y_RANGE, self.X_RANGE].T, + threshold=self.THRESHOLD, + curvature=(self.CURVE_A, self.CURVE_B), ) + for image in self.images), []) + + def _centroid_spectrum(self, bins=None, range=None, normalize=True): + if bins is None: + bins = self.BINS + if range is None: + range = self.RANGE + + r = np.array(self._centroid) + hy, hx = np.histogram(r[:, 0], bins=bins, range=range) + if normalize and self.norm is not None: + hy = hy / self.norm + + return (hx[:-1] + hx[1:]) / 2, hy + + @property + @lru_cache() + def _integral(self): + return sum((integrate(image[self.Y_RANGE, self.X_RANGE].T, + factor=self.FACTOR, + range=self.RANGE, + curvature=(self.CURVE_A, self.CURVE_B)) + for image in self.images)) + + def _integral_spectrum(self, normalize=True): + values = self._integral + if normalize and self.norm is not None: + values = values / self.norm + return np.arange(values.size), values + + @property + def corrected(self): + return decentroid(self._centroid) + + def spectrum(self, normalize=True): + spec_func = (self._centroid_spectrum if self.METHOD.lower() == 'centroid' + else self._integral_spectrum) + return spec_func(normalize=normalize) + + def __sub__(self, other): + px, py = self.spectrum() + mx, my = other.spectrum() + return (px + mx) / 2, py - my + + def __add__(self, other): + ix, iy = self.spectrum(normalize=False) + jx, jy = other.spectrum(normalize=False) + + i_n = self.norm or 0 + j_n = other.norm or 0 + norm = ((i_n + j_n) or 1) + + return ix, (iy + jy) / norm