from euxfel_bunch_pattern import indices_at_sase  # Installable from PyPI
import numpy as np
import string
from scipy.interpolate import interp1d
from pathlib import Path

detectors = np.genfromtxt(Path(__file__).parent / 'detectors.txt', 
                          names=True, dtype=('|U5', '|U4', '|U3', '<f8', '<f8'))

def correct_adq_common_mode(trace, region, sym):
    """Baseline substraction based on common mode.
    
    Since ADQ digitizers always interleave multiple ADCs per channel to sample
    a trace, regular baseline substraction will cause an ADC-dependant common
    mode to appear. This correction directly uses a per-ADC baseline instead
    to perform include this in the substraction. 
    
    Empirical testing has shown that a symmetry of 8 (should be 4 for
    non-interleaved) regularly shows better results in suppressing high
    frequency signals from the common mode. Going 16 or higher is not recommend
    in most cases, as additional artifacts appear. 
    
    Parameters
    ----------
        trace : array_like
            Vector to correct, will be modified in-place.
        region : slice
            Region to use for computing baseline.
        sym : int
            Periodic symmetry of ADC common mode.
        
    Returns
    -------
        trace : array_like
            Corrected vector, same shape as trace.
    """
    trace = trace.astype(np.float32)
    
    for x in range(sym):
        trace[x::sym] -= trace[region][x::sym].mean()
        
    return trace

def separate_pulses(traces, ppt, adq_train_region, adq_pulse_region, adq_sample_rate=4, sase=3):
    """Separate train into seperate pulses using the pulse pattern table
    
    Parameters
    ----------
        traces : array_like
            Array of traces to be split up
        ppt : array_like
            Pulse pattern table from time server device.
        adq_train_region : slice, optional
            Region containing actual signal for the entire train.
        adq_pulse_region : slice, optional
            Region after pulse separation containing actual signal for the pulse.
        adq_pulse_region : slice, optional
            Region after pulse separation containing actual signal for the pulse.
        adq_sample_rate : array_like or int, optional
            Sample rate for all digitizer channels in GS/s.
        sase : int, optional
            Which SASE the experiments run at, 3 by default.
        
    Returns
    -------
        (array_like) Separated vector.
    """
    trace = traces[:, adq_train_region]
    pulse_ids = indices_at_sase(ppt, sase=sase)
    num_pulses = len(pulse_ids)
    adq_sample_rate = np.asarray(adq_sample_rate)
    if num_pulses < 2:
        return trace[:, adq_pulse_region].reshape(trace.shape[0], 1, -1)
    else:
        # ADQ412s run at a clock factor of 440 relative to the PPT un-interleaved.
        pulse_spacing = (pulse_ids[1] - pulse_ids[0]) * 220 * adq_sample_rate
        return traces[:, :(len(trace) // pulse_spacing) * pulse_spacing].reshape(
            -1, pulse_spacing)[..., adq_pulse_region]


def _f(t, a, b, c, d):
        return a / t**3 + b / t**2 + c / t + d

def _df(t, a, b, c, d):
        return -3 * a / t**4 - 2 * b / t**3 - c / t**2
    
def energy_calibration(traces, calib_params, sample_rate, valid_energies, 
                       energy_nodes=None, model=_f, model_derivative=_df):
    """Calibrate time of flight traces to energy spectra
    
    Parameters
    ----------
        traces : array_like
            Traces to be calibrated, shape (detectors, ..., samples)
        calib_params : array_like
            Calibration parameters passed to the model, first dimension needs 
            to be number of function arguments
        sample_rate : array_like or int, optional
            Sample rate for all digitizer channels in GS/s.
        valid_energies : slice
            Slice applied to input allowing to discard divergent values
        energy_nodes : array_like, optional
            Energy values to be evaluated in interpolation, if you want
        model : function, optional
            Mapping from time to energy, arguments (t, *args). Defaults to 
            a / t**3 + b / t**2 + c / t + d.
        model_derivative : function, optional
            Derivative of `model`
        
    Returns
    -------
        energy : ndarray
            Energy values corresponding to time of flight
        spectra : ndarray
            Reweighted spectra
        resampled_spectra : ndarray
            Spectra resampled according to `energy_nodes`
        time : ndarray
            Times of flight corresponding to samples
    """
    time = np.arange(traces.shape[-1])[:, np.newaxis] / sample_rate
    energy = model(1e-9 * time, *calib_params)[valid_energies].T
    dE = model_derivative(1e-9 * time[:, None], *calib_params)[valid_energies].T
    spectra = traces[..., valid_energies] / dE
    
    if energy_nodes is not None:
        try:
            resampled_spectra = np.asarray([interp1d(e, t)(energy_nodes) for e, t in 
                                            zip(energy, spectra)])
        except ValueError as e:
            print(e)
            resampled_spectra = np.full((*traces.shape[:-1], len(energy_nodes)), np.nan)
    else:
        resampled_spectra = None
    return energy, spectra, resampled_spectra, time[valid_energies]