""" Beam Arrival Monitor related sub-routines

    Copyright (2021) SCS Team.

    (contributions preferrably comply with pep8 code structure
    guidelines.)
"""

import logging
import numpy as np
import xarray as xr
import extra_data as ed

from ..misc.bunch_pattern_external import is_sase_3
from ..mnemonics_machinery import (mnemonics_to_process,
                                   mnemonics_for_run)
from ..constants import mnemonics as _mnemonics

__all__ = [
    'get_pes_params',
    'get_pes_tof',
]


log = logging.getLogger(__name__)


def get_pes_tof(run, mnemonics=None, merge_with=None,
                start=31390, width=300, origin=None, width_ns=None,
                subtract_baseline=True,
                baseStart=None, baseWidth=80,
                sample_rate=2e9):
    """
    Extracts time-of-flight spectra from raw digitizer traces. The
    tracesvare either loaded via ToolBox mnemonics or those in the
    optionally provided merge_with dataset. The spectra are aligned
    by pulse Id using the SASE 3 bunch pattern, and have time coordinates
    in nanoseconds.

    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the digitizer data
    mnemonics: str or list of str
        mnemonics for PES, e.g. "PES_W_raw" or ["PES_W_raw", "PES_ENE_raw"].
        If None and no merge_with dataset is provided, defaults to "PES_W_raw".
    merge_with: xarray Dataset
        If provided, the resulting Dataset will be merged with this
        one. The PES variables of merge_with (if any) will also be
        computed and merged.
    start: int
        starting sample of the first spectrum in the raw trace.
    width: int
        number of samples per spectra.
    origin: int
        sample of the raw trace that corresponds to time-of-flight origin.
        If None, origin is equal to start.
    width_ns: float
        time window for one spectrum. If None, the time window is defined by
        width / sample rate.
    subtract_baseline: bool
        If True, subtract baseline defined by baseStart and baseWidth to each
        spectrum.
    baseStart: int
        starting sample of the baseline.
    baseWidth: int
        number of samples to average (starting from baseStart) for baseline
        calculation.
    sample_rate: float
        sample rate of the digitizer.

    Returns
    -------
    pes: xarray Dataset
        Dataset containing the PES time-of-flight spectra (e.g. "PES_W_tof"),
        merged with optionally provided merg_with dataset.

    Example
    -------
    >>> import toolbox_scs as tb
    >>> import toolbox_scs.detectors as tbdet
    >>> run, ds = tb.load(2927, 100, "PES_W_raw")
    >>> pes = tbdet.get_pes_tof(run, merge_with=ds)
    """
    def to_processed_name(name):
        return name.replace('raw', 'tof')
    to_process = mnemonics_to_process(mnemonics, merge_with,
                                      'PES', to_processed_name)
    run_mnemonics = mnemonics_for_run(run)
    # check if bunch pattern table exists
    if bool(merge_with) and 'bunchPatternTable' in merge_with:
        bpt = merge_with['bunchPatternTable']
    elif 'bunchPatternTable' in run_mnemonics:
        bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
    elif 'bunchPatternTable_SA3' in run_mnemonics:
        bpt = run.get_array(*run_mnemonics['bunchPatternTable_SA3'].values())
    else:
        bpt = None

    mask = is_sase_3(bpt).assign_coords({'pulse_slot': np.arange(2700)})
    mask_on = mask.where(mask, drop=True)
    npulses = mask.sum(dim='pulse_slot')[0].values
    if npulses > 1:
        period = mask_on['pulse_slot'].diff(dim='pulse_slot')[0].values
    else:
        period = 0
    if origin is None:
        origin = start
    if baseStart is None:
        baseStart = start
    if width_ns is not None:
        width = int(sample_rate * width_ns * 1e-9)
    time_ns = 1e9 * (np.arange(start, start + width) - origin) / sample_rate

    ds = xr.Dataset()
    for m in to_process:
        if bool(merge_with) and m in merge_with:
            arr = merge_with[m]
        else:
            arr = run.get_array(*run_mnemonics[m].values(), name=m)
        if arr.sizes['PESsampleId'] < npulses*period*440 + start + width:
            log.warning('Not all pulses were recorded. The number of samples on '
                        f'the digitizer {arr.sizes["PESsampleId"]} is not enough '
                        f'to cover the {npulses} spectra. Missing spectra will be '
                       'filled with NaNs.')
        spectra = []
        for p in range(npulses):
            begin = p*period*440 + start
            end = begin + width
            if end > arr.sizes['PESsampleId']:
                break
            pes = arr.isel(PESsampleId=slice(begin, end))
            if subtract_baseline:
                baseBegin = p*period*440 + baseStart
                baseEnd = baseBegin + baseWidth
                bl = arr.isel(
                    PESsampleId=slice(baseBegin, baseEnd)).mean(dim='PESsampleId')
                pes = pes - bl
            spectra.append(pes)
        spectra = xr.concat(spectra,
                            dim='sa3_pId').rename(m.replace('raw', 'tof'))
        ds = ds.merge(spectra)
    if len(ds.variables) > 0:
        ds = ds.assign_coords({'sa3_pId': mask_on['pulse_slot'][:ds.sizes['sa3_pId']].values})
        ds = ds.rename({'PESsampleId': 'time_ns'})
        ds = ds.assign_coords({'time_ns': time_ns})
    if bool(merge_with):
        ds = merge_with.drop(to_process,
                             errors='ignore').merge(ds, join='left')

    return ds


def get_pes_params(run):
    """
    Extract PES parameters for a given extra_data DataCollection.
    Parameters are gas, retardation voltage.

    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the digitizer data

    Returns
    -------
    params: dict
        dictionnary of PES parameters
    """
    params = {}
    sel = run.select_trains(ed.by_index[:20])
    mnemonics = mnemonics_for_run(run)
    gas_dict = {'N2': 409.9, 'Ne': 870.2, 'Kr': 1921, 'Xe': 1148.7}
    for gas in gas_dict.keys():
        mnemo = _mnemonics[f'PES_{gas}'][0]
        arr = sel.get_run_value(mnemo['source'], mnemo['key'])
        if arr == 'ON':
            params['gas'] = gas
            params['binding_energy'] = gas_dict[gas]
            break
    if 'gas' not in params:
        params['gas'] = 'unknown'
        log.warning('Could not find which PES gas was used.')
    arr = sel.get_array(*mnemonics['PES_RV'].values())
    params['ret_voltage'] = float(arr[0].values)
    return params