diff --git a/src/toolbox_scs/detectors/jf_hrixs.py b/src/toolbox_scs/detectors/jf_hrixs.py
index d4f5099d16145c65120cd7614639f0f57a1d602e..f0cabd510c5bd86cd1fdccaf198aae7ac3609746 100644
--- a/src/toolbox_scs/detectors/jf_hrixs.py
+++ b/src/toolbox_scs/detectors/jf_hrixs.py
@@ -128,7 +128,7 @@ class JF_hRIXS:
             data = data.isel(trainId=slice(1, None))
         return data
 
-    def find_curvature(self, img, args, plot=False, **kwargs):
+    def find_curvature(self, img, args=None, plot=False, **kwargs):
         """Find the curvature correction coefficients.
 
         The hRIXS has some abberations which leads to the spectroscopic lines
@@ -155,23 +155,35 @@ class JF_hRIXS:
         def parabola(x, a, b, c, s=0, h=0, o=0):
             return (a*x + b)*x + c
 
-        def gauss(y, x, a, b, c, s, h, o=0):
+        def gauss(x, y, a, b, c, s, h, o=0):
             return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
 
-        x = np.arange(img.shape[1])[None, :]
-        y = np.arange(img.shape[0])[:, None]
+        xx, yy = np.meshgrid(img['x'], img['y'])
 
-        if plot:
-            plt.figure(figsize=(10, 10))
-            plt.imshow(img, cmap='gray', aspect='auto',
-                       interpolation='nearest', **kwargs)
-            plt.plot(x[0, :], parabola(x[0, :], *args))
+        if args is None:
+            spec = img.mean('x').values
+            x0 = spec.argmax()
+            b = 0.25
+            args = (0.0, b, (x0 - b*img.sizes['x']/2),
+                    3, spec.max(), 0)
+            print(args)
 
-        args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(),
+        if plot:
+            plt.figure()
+            plt.imshow(img, cmap='magma_r', aspect=1/9,
+                       interpolation='none', **kwargs)
+            plt.plot(xx[0, :], parabola(xx[0, :], *args),
+                     ls='--', c='C0', alpha=0.5)
+
+        args, _ = leastsq(lambda args: (gauss(xx, yy, *args)
+                                        - img).values.flatten(),
                           args)
 
         if plot:
-            plt.plot(x[0, :], parabola(x[0, :], *args))
+            plt.plot(xx[0, :], parabola(xx[0, :], *args), ls='-', c='C0')
+
+        self.CURVE_B, self.CURVE_A, *_ = args
+
         return args
 
     def parabola(self, x):