""" Toolbox for SCS.

    Various utilities function to quickly process data measured
    at the SCS instrument.

    Copyright (2019-) SCS Team.
"""
import matplotlib.pyplot as plt
import numpy as np
import re
from toolbox_scs.base.knife_edge import knife_edge_base, erfc, arrays_to1d

__all__ = [
    'knife_edge'
]


def knife_edge(ds, axisKey='scannerX', signalKey='FastADC4peaks',
               axisRange=None, p0=None, full=False, plot=False,
               display=False):
    """
    Calculates the beam radius at 1/e^2 from a knife-edge scan by
    fitting with erfc function:
    f(x, x0, w0, a, b) = a*erfc(np.sqrt(2)*(x-x0)/w0) + b
    with w0 the beam radius at 1/e^2 and x0 the beam center.

    Parameters
    ----------
    ds: xarray Dataset
        dataset containing the detector signal and the motor position.
    axisKey: str
        key of the axis against which the knife-edge is  performed.
    signalKey: str
        key of the detector signal.
    axisRange: list of floats
        edges of the scanning axis between which to apply the fit.
    p0: list of floats, numpy 1D array
        initial parameters used for the fit: x0, w0, a, b. If None, a beam
        radius of 100 micrometers is assumed.
    full: bool
        If False, returns the beam radius and standard error.
        If True, returns the popt, pcov list of parameters and covariance
        matrix from scipy.optimize.curve_fit.
    plot: bool
        If True, plots the data and the result of the fit. Default is False.
    display: bool
        If True, displays info on the fit. True when plot is True, default is
        False.

    Returns
    -------
    If full is False, tuple with beam radius at 1/e^2 in mm and standard
        error from the fit in mm. If full is True, returns parameters and
        covariance matrix from scipy.optimize.curve_fit function.
    """
    popt, pcov = knife_edge_base(ds[axisKey].values, ds[signalKey].values,
                                 axisRange=axisRange, p0=p0)
    if plot:
        positions, intensities = arrays_to1d(ds[axisKey].values,
                                             ds[signalKey].values)
        title = ''
        if ds.attrs.get('proposal') and ds.attrs.get('runNB'):
            proposalNB = int(re.findall(r'p(\d{6})',
                                        ds.attrs['proposal'])[0])
            runNB = ds.attrs['runNB']
            title = f'run {runNB} p{proposalNB}'
        plot_knife_edge(positions, intensities, popt, pcov[1, 1]**0.5,
                        title, axisKey, signalKey)
        display = True

    if display:
        funcStr = 'a*erfc(np.sqrt(2)*(x-x0)/w0) + b'
        print('fitting function:', funcStr)
        print('w0 = (%.1f +/- %.1f) um' % (np.abs(popt[1])*1e3,
                                           pcov[1, 1]**0.5*1e3))
        print('x0 = (%.3f +/- %.3f) mm' % (popt[0], pcov[0, 0]**0.5))
        print('a = %e +/- %e ' % (popt[2], pcov[2, 2]**0.5))
        print('b = %e +/- %e ' % (popt[3], pcov[3, 3]**0.5))

    if full:
        return popt, pcov
    else:
        return np.abs(popt[1]), pcov[1, 1]**0.5


def plot_knife_edge(positions, intensities, fit_params, rel_err, title,
                    axisKey, signalKey):

    plt.figure(figsize=(7, 4))
    plt.scatter(positions, intensities, color='C1',
                label='measured', s=2, alpha=0.1)
    xfit = np.linspace(positions.min(), positions.max(), 1000)
    yfit = erfc(xfit, *fit_params)
    plt.plot(xfit, yfit, color='C4',
             label=r'fit $\rightarrow$ $w_0=$(%.1f $\pm$ %.1f) $\mu$m' % (
                                    np.abs(fit_params[1])*1e3, rel_err*1e3))
    leg = plt.legend()
    for lh in leg.legendHandles:
        lh.set_alpha(1)
    plt.ylabel(signalKey)
    plt.xlabel(axisKey + ' position [mm]')
    plt.title(title)