diff --git a/src/toolbox_scs/base/knife_edge.py b/src/toolbox_scs/base/knife_edge.py
index 7e9fbca8e5dc30619746879af11d6ada56cd5052..45d354a5537fa9f69bb2ce186a1de6e2bb976530 100644
--- a/src/toolbox_scs/base/knife_edge.py
+++ b/src/toolbox_scs/base/knife_edge.py
@@ -107,13 +107,7 @@ def prepare_arrays(arrX: np.ndarray, arrY: np.ndarray,
     3. Retrieve finite values.
     """
     # Convert both arrays to 1D of the same size
-    assert arrX.shape[0] == arrY.shape[0]
-    arrX = arrX.flatten()
-    arrY = arrY.flatten()
-    if len(arrX) > len(arrY):
-        arrY = np.repeat(arrY, len(arrX) // len(arrY))
-    else:
-        arrX = np.repeat(arrX, len(arrY) // len(arrX))
+    arrX, arrY = arrays_to1d(arrX, arrY)
 
     # Select ranges
     if xRange is not None:
@@ -133,6 +127,18 @@ def prepare_arrays(arrX: np.ndarray, arrY: np.ndarray,
     return arrX, arrY
 
 
+def arrays_to1d(arrX: np.ndarray, arrY: np.ndarray):
+    """Flatten two arrays and matches their sizes
+    """
+    assert arrX.shape[0] == arrY.shape[0]
+    arrX, arrY = arrX.flatten(), arrY.flatten()
+    if len(arrX) > len(arrY):
+        arrY = np.repeat(arrY, len(arrX) // len(arrY))
+    else:
+        arrX = np.repeat(arrX, len(arrY) // len(arrX))
+    return arrX, arrY
+
+
 def range_mask(array, minimum=None, maximum=None):
     """Retrieve the resulting array from the given minimum and maximum
     """
diff --git a/src/toolbox_scs/routines/knife_edge.py b/src/toolbox_scs/routines/knife_edge.py
index fccc153f0ede2287051c06afd051677d8bf5d781..1b22dc8187ba4e5f676a18ee880390d2ba7ba2ef 100644
--- a/src/toolbox_scs/routines/knife_edge.py
+++ b/src/toolbox_scs/routines/knife_edge.py
@@ -1,30 +1,27 @@
 """ Toolbox for SCS.
 
     Various utilities function to quickly process data measured
-    at the SCS instruments.
+    at the SCS instrument.
 
     Copyright (2019-) SCS Team.
 """
 import matplotlib.pyplot as plt
 import numpy as np
-from scipy.special import erfc
-from scipy.optimize import curve_fit
-import bisect
+from toolbox_scs.base.knife_edge import knife_edge_base, erfc, arrays_to1d
 
 __all__ = [
     'knife_edge'
 ]
 
 
-def knife_edge(ds, axisKey='scannerX',
-               signalKey='FastADC4peaks',
-               axisRange=[None, None], p0=None,
-               full=False, plot=False):
+def knife_edge(ds, axisKey='scannerX', signalKey='FastADC4peaks',
+               axisRange=[None, None], p0=None, full=False, plot=False,
+               display=False):
     """
     Calculates the beam radius at 1/e^2 from a knife-edge scan by
-    fitting with erfc function: f(a,b,u) = a*erfc(u) + b or
-    f(a,b,u) = a*erfc(-u) + b where u = sqrt(2)*(x-x0)/w0 with w0
-    the beam radius at 1/e^2 and x0 the beam center.
+    fitting with erfc function:
+    f(x, x0, w0, a, b) = a*erfc(np.sqrt(2)*(x-x0)/w0) + b
+    with w0 the beam radius at 1/e^2 and x0 the beam center.
 
     Parameters
     ----------
@@ -38,103 +35,62 @@ def knife_edge(ds, axisKey='scannerX',
         edges of the scanning axis between which to apply the fit.
     p0: list of floats, numpy 1D array
         initial parameters used for the fit: x0, w0, a, b. If None, a beam
-        radius of 100 um is assumed.
+        radius of 100 micrometers is assumed.
     full: bool
         If False, returns the beam radius and standard error.
         If True, returns the popt, pcov list of parameters and covariance
-        matrix from scipy.optimize.curve_fit as well as the fitting function.
+        matrix from scipy.optimize.curve_fit.
     plot: bool
-        If True, plots the data and the result of the fit.
+        If True, plots the data and the result of the fit. Default is False.
+    display: bool
+        If True, displays info on the fit. True when plot is True, default is
+        False.
 
     Returns
     -------
-    If full is False, ndarray with beam radius at 1/e^2 in mm and standard
+    If full is False, tuple with beam radius at 1/e^2 in mm and standard
         error from the fit in mm. If full is True, returns parameters and
         covariance matrix from scipy.optimize.curve_fit function.
     """
-    def stepUp(x, x0, w0, a, b):
-        return a*erfc(-np.sqrt(2)*(x-x0)/w0) + b
-
-    def stepDown(x, x0, w0, a, b):
-        return a*erfc(np.sqrt(2)*(x-x0)/w0) + b
-
-    # get the number of pulses per train from the signal source:
-    dim = [k for k in ds[signalKey].dims if k != 'trainId'][0]
-    # duplicate motor position values to match signal shape
-    # this is faster than using ds.stack()
-    positions = np.repeat(ds[axisKey].values,
-                          len(ds[dim])).astype(ds[signalKey].dtype)
-    # sort the data to decide which fitting function to use
-    sortIdx = np.argsort(positions)
-    positions = positions[sortIdx]
-    intensities = ds[signalKey].values.flatten()[sortIdx]
-
-    if axisRange[0] is None or axisRange[0] < positions[0]:
-        idxMin = 0
-    else:
-        if axisRange[0] >= positions[-1]:
-            raise ValueError('The minimum value of axisRange is too large')
-        idxMin = bisect.bisect(positions, axisRange[0])
-    if axisRange[1] is None or axisRange[1] > positions[-1]:
-        idxMax = None
-    else:
-        if axisRange[1] <= positions[0]:
-            raise ValueError('The maximum value of axisRange is too small')
-        idxMax = bisect.bisect(positions, axisRange[1]) + 1
-    pos_sel = positions[idxMin:idxMax]
-    int_sel = intensities[idxMin:idxMax]
-    no_nan = ~np.isnan(int_sel)
-    pos_sel = pos_sel[no_nan]
-    int_sel = int_sel[no_nan]
+    popt, pcov = knife_edge_base(ds[axisKey].values, ds[signalKey].values,
+                                 axisRange=axisRange, p0=p0)
+    if plot:
+        positions, intensities = arrays_to1d(ds[axisKey].values,
+                                             ds[signalKey].values)
+        plot_knife_edge(positions, intensities, popt, pcov[1, 1]**0.5,
+                        ds.attrs['runFolder'], axisKey, signalKey)
+        display = True
 
-    # estimate a linear slope fitting the data to determine which function
-    # to fit
-    slope = np.cov(pos_sel, int_sel)[0][1]/np.var(pos_sel)
-    if slope < 0:
-        func = stepDown
+    if display:
         funcStr = 'a*erfc(np.sqrt(2)*(x-x0)/w0) + b'
-    else:
-        func = stepUp
-        funcStr = 'a*erfc(-np.sqrt(2)*(x-x0)/w0) + b'
-    if p0 is None:
-        p0 = [np.mean(pos_sel), 0.1, np.max(int_sel)/2, 0]
-    try:
-        popt, pcov = curve_fit(func, pos_sel, int_sel, p0=p0)
         print('fitting function:', funcStr)
-        print('w0 = (%.1f +/- %.1f) um' % (popt[1]*1e3, pcov[1, 1]**0.5*1e3))
+        print('w0 = (%.1f +/- %.1f) um' % (np.abs(popt[1])*1e3,
+                                           pcov[1, 1]**0.5*1e3))
         print('x0 = (%.3f +/- %.3f) mm' % (popt[0], pcov[0, 0]**0.5))
         print('a = %e +/- %e ' % (popt[2], pcov[2, 2]**0.5))
         print('b = %e +/- %e ' % (popt[3], pcov[3, 3]**0.5))
-        fitSuccess = True
-    except Exception as e:
-        print(f'Could not fit the data with erfc function: {e}.' +
-              ' Try adjusting the axisRange and the initial parameters p0')
-        fitSuccess = False
 
-    if plot:
-        plt.figure(figsize=(7, 4))
-        plt.scatter(positions, intensities, color='C1',
-                    label='exp', s=2, alpha=0.1)
-        if fitSuccess:
-            xfit = np.linspace(positions.min(), positions.max(), 1000)
-            yfit = func(xfit, *popt)
-            plt.plot(xfit, yfit, color='C4',
-                 label=r'fit $\rightarrow$ $w_0=$(%.1f $\pm$ %.1f) $\mu$m' % (
-                                            popt[1]*1e3, pcov[1, 1]**0.5*1e3))
-        leg = plt.legend()
-        for lh in leg.legendHandles:
-            lh.set_alpha(1)
-        plt.ylabel(signalKey)
-        plt.xlabel(axisKey + ' position [mm]')
-        plt.title(ds.attrs['runFolder'])
-        plt.tight_layout()
     if full:
-        if fitSuccess:
-            return popt, pcov, func
-        else:
-            return np.zeros(4), np.zeros(2), None
+        return popt, pcov
     else:
-        if fitSuccess:
-            return np.array([popt[1], pcov[1, 1]**0.5])
-        else:
-            return np.zeros(2)
+        return np.abs(popt[1]), pcov[1, 1]**0.5
+
+
+def plot_knife_edge(positions, intensities, fit_params, rel_err, title,
+                    axisKey, signalKey):
+
+    plt.figure(figsize=(7, 4))
+    plt.scatter(positions, intensities, color='C1',
+                label='measured', s=2, alpha=0.1)
+    xfit = np.linspace(positions.min(), positions.max(), 1000)
+    yfit = erfc(xfit, *fit_params)
+    plt.plot(xfit, yfit, color='C4',
+             label=r'fit $\rightarrow$ $w_0=$(%.1f $\pm$ %.1f) $\mu$m' % (
+                                    np.abs(fit_params[1])*1e3, rel_err*1e3))
+    leg = plt.legend()
+    for lh in leg.legendHandles:
+        lh.set_alpha(1)
+    plt.ylabel(signalKey)
+    plt.xlabel(axisKey + ' position [mm]')
+    plt.title(title)
+    plt.tight_layout()