from euxfel_bunch_pattern import indices_at_sase  # Installable from PyPI
import numpy as np
import string, os, re
from scipy.interpolate import interp1d
from pathlib import Path

def detectors(run):
    """
    Provides p2828 detector information when given a run number.
    
    Parameters
    ----------
    run : unsigned int
        Run number within proposal 2828
    
    Returns
    -------
    detinfo : structured ndarray
        Keys: name (detector name), 
              digitizer,
              channel,
              angle (looking along the beam, 0 is right, increasing counter-clockwise),
              sample_rate (GS/s)
    """
    confs = os.listdir(Path(__file__).parent / 'configurations')
    groups = [re.search('(\d*)-(\d*).txt', f) for f in confs]
    for idx, gr in enumerate(groups):
        if gr is not None:
            print(gr)
            a, b = gr.group(1, 2)
            if (run >= int(a)) & (run <= int(b)):
                return  np.genfromtxt(Path(__file__).parent / 'configurations' / confs[idx],
                                      names=True, dtype=('|U5', '|U4', '|U3', '<f8', '<i4'))
    raise Exception('Did not find run file')


def correct_adq_common_mode(trace, region=np.s_[1000:], sym=8):
    """Baseline subtraction based on common mode.

    Since ADQ digitizers always interleave multiple ADCs per channel to sample
    a trace, regular baseline subtraction will cause an ADC-dependent common
    mode to appear. This correction directly uses a per-ADC baseline instead
    to perform include this in the substraction.

    Empirical testing has shown that a symmetry of 8 (should be 4 for
    non-interleaved) regularly shows better results in suppressing high
    frequency signals from the common mode. Going 16 or higher is not recommend
    in most cases, as additional artifacts appear.

    Parameters
    ----------
    trace : array_like
        Vector to correct, will be modified in-place.
    region : slice, optional
        Region to use for computing baseline.
    sym : int, optional
        Periodic symmetry of ADC common mode.

    Returns
    -------
    trace : array_like
        Corrected vector, same shape as trace.
    """
    trace = trace.astype(np.float32)[:600000]

    for x in range(sym):
        trace[..., x::sym] -= trace[..., region][..., x::sym].mean()

    return trace


def separate_pulses(traces, ppt, adq_train_region=np.s_[2200:], adq_pulse_region=np.s_[:2000],
                    adq_sample_rate=4, sase=3):
    """Separate train into seperate pulses using the pulse pattern table.

    Parameters
    ----------
    traces : array_like
        Array of traces to be split up
    ppt : array_like
        Pulse pattern table from time server device.
    adq_train_region : slice, optional
        Region containing actual signal for the entire train.
    adq_pulse_region : slice, optional
        Region after pulse separation containing actual signal for the pulse.
    adq_pulse_region : slice, optional
        Region after pulse separation containing actual signal for the pulse.
    adq_sample_rate : array_like or int, optional
        Sample rate for all digitizer channels in GS/s.
    sase : int, optional
        Which SASE the experiments run at, 3 by default.

    Returns
    -------
    traces : array_like
        Separated vector of trains.
    """
    traces = traces[..., adq_train_region]
    pulse_ids = indices_at_sase(ppt, sase=sase)
    num_pulses = len(pulse_ids)
    if num_pulses < 2:
        return traces[..., adq_pulse_region].reshape(traces.shape[0], 1, -1)
    else:
        # ADQ412s run at a clock factor of 440 relative to the PPT un-interleaved.
        pulse_spacing = (pulse_ids[1] - pulse_ids[0]) * 220 * adq_sample_rate
        traces_aligned = traces[..., :(num_pulses * pulse_spacing)]
        assert num_pulses * pulse_spacing == traces_aligned.shape[-1] # digitizer traces are too short for pulse pattern?
        return traces_aligned.reshape(*(traces_aligned.shape[:-1]), num_pulses, 
                                      pulse_spacing)[..., adq_pulse_region]


# Polynomial computation using Horner's scheme
def f_horner(t, *coeffs):
    x = 1/t
    result = 0
    for coeff in coeffs:
        result = x * result + coeff
    return result

def df_horner(t, *coeffs):
    x = 1/t
    c = 1 - len(coeffs)
    result = 0
    for idx, coeff in enumerate(coeffs):
        fac = c + idx
        result = x * result + coeff * fac
    return result * x

# Linear functions
def f_linear(t, a, b):
    return a * t + b

def df_linear(t, a, b):
    return 0 * t + a

# Quadratic functions
def f_quad(t, a, b, c):
    return a / t**2 + b / t + c

def df_quad(t, a, b, c):
    return - 2*a / t**3 - b / t**2

# Cubic functions
def f_cubic(t, a, b, c, d):
    return a / t**3 + b / t**2 + c / t + d

def df_cubic(t, a, b, c, d):
    return - 3*a / t**4 - 2*b / t**3 - c / t**2

# ^4 functions
def f_order4(t, a, b, c, d, e):
    return a / t**4 + b / t**3 + c / t**2 + d / t + e

def df_order4(t, a, b, c, d, e):
    return - 4*a / t**5 - 3*b / t**4 - 2*c / t**3 - d / t**2 + e

# Defaults
f = f_linear
df = df_linear

def energy_calibration(traces, calib_params, sample_rate, valid_energies,
                       energy_nodes=None, model=f, model_derivative=df):
    """Calibrate time of flight traces to energy spectra using given coefficients.

    Parameters
    ----------
    traces : array_like
        Traces to be calibrated, shape (detectors, ..., samples).
    calib_params : array_like
        Calibration parameters passed to the model, first dimension needs
        to be number of function arguments.
    sample_rate : array_like or int, optional
        Sample rate for all digitizer channels in GS/s.
    valid_energies : slice
        Slice applied to input allowing to discard divergent values.
    energy_nodes : array_like, optional
        Energy values to be evaluated in interpolation, if you want.
    model : function or int, optional
        Mapping from time to energy with arguments (t, *args),
        or an integer defining the polynomial order.
        Defaults to a polynomial with the parameter inversed.
    model_derivative : function, optional
        Derivative of `model' function.

    Returns
    -------
    energy : ndarray
        Energy values corresponding to time of flight.
    spectra : ndarray
        Reweighted spectra.
    resampled_spectra : ndarray
        Spectra resampled according to `energy_nodes`.
    time : ndarray
        Times of flight corresponding to samples.
    """
    # Convert digitizer channels into ToF [ns]
    time = np.arange(traces.shape[-1])[:, np.newaxis] / sample_rate

    # Calibrate ToF to energy
    energy = model(time, *calib_params['energy'].T)[valid_energies].T
    dE = model_derivative(time, *calib_params['energy'].T)[valid_energies].T
    spectra = traces[..., valid_energies] / dE[:, None]  # only dE here?
    spectra = np.clip(spectra, 0, None)

    if energy_nodes is not None:
        resampled_spectra = []
        for e, t, en in zip(energy, spectra, calib_params['enabled']):
            if en:
                interp = interp1d(e, t, fill_value='extrapolate')(energy_nodes)
            else:
                newsh = list(spectra.shape)[1:]
                newsh[-1] = energy_nodes.shape[0]
                interp = np.full(newsh, 0)
            resampled_spectra.append(interp)
    else:
        resampled_spectra = None

    return energy, spectra, np.array(resampled_spectra), time[valid_energies]