diff --git a/src/toolbox_scs/detectors/jf_hrixs.py b/src/toolbox_scs/detectors/jf_hrixs.py index d4f5099d16145c65120cd7614639f0f57a1d602e..f0cabd510c5bd86cd1fdccaf198aae7ac3609746 100644 --- a/src/toolbox_scs/detectors/jf_hrixs.py +++ b/src/toolbox_scs/detectors/jf_hrixs.py @@ -128,7 +128,7 @@ class JF_hRIXS: data = data.isel(trainId=slice(1, None)) return data - def find_curvature(self, img, args, plot=False, **kwargs): + def find_curvature(self, img, args=None, plot=False, **kwargs): """Find the curvature correction coefficients. The hRIXS has some abberations which leads to the spectroscopic lines @@ -155,23 +155,35 @@ class JF_hRIXS: def parabola(x, a, b, c, s=0, h=0, o=0): return (a*x + b)*x + c - def gauss(y, x, a, b, c, s, h, o=0): + def gauss(x, y, a, b, c, s, h, o=0): return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o - x = np.arange(img.shape[1])[None, :] - y = np.arange(img.shape[0])[:, None] + xx, yy = np.meshgrid(img['x'], img['y']) - if plot: - plt.figure(figsize=(10, 10)) - plt.imshow(img, cmap='gray', aspect='auto', - interpolation='nearest', **kwargs) - plt.plot(x[0, :], parabola(x[0, :], *args)) + if args is None: + spec = img.mean('x').values + x0 = spec.argmax() + b = 0.25 + args = (0.0, b, (x0 - b*img.sizes['x']/2), + 3, spec.max(), 0) + print(args) - args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), + if plot: + plt.figure() + plt.imshow(img, cmap='magma_r', aspect=1/9, + interpolation='none', **kwargs) + plt.plot(xx[0, :], parabola(xx[0, :], *args), + ls='--', c='C0', alpha=0.5) + + args, _ = leastsq(lambda args: (gauss(xx, yy, *args) + - img).values.flatten(), args) if plot: - plt.plot(x[0, :], parabola(x[0, :], *args)) + plt.plot(xx[0, :], parabola(xx[0, :], *args), ls='-', c='C0') + + self.CURVE_B, self.CURVE_A, *_ = args + return args def parabola(self, x):