# -*- coding: utf-8 -*-
"""
    Toolbox for SCS.

    Various utilities function to quickly process data measured at the SCS
    instruments.

    Copyright (2019) SCS Team.
"""

import logging
import os

import numpy as np
import xarray as xr
import extra_data as ed
from extra_data.read_machinery import find_proposal

from .constants import mnemonics as _mnemonics
from .mnemonics_machinery import mnemonics_for_run
from .util.exceptions import ToolBoxValueError
import toolbox_scs.detectors as tbdet

__all__ = [
    'concatenateRuns',
    'find_run_path',
    'get_array',
    'load',
    'open_run',
    'run_by_path',
]

log = logging.getLogger(__name__)


def load(proposalNB=None, runNB=None,
         fields=None,
         subFolder='raw',
         display=False,
         validate=False,
         subset=ed.by_index[:],
         rois={},
         extract_tim=True,
         extract_laser=True,
         extract_xgm=True,
         extract_bam=True,
         tim_bp='sase3',
         laser_bp='scs_ppl',
         ):
    """
    Load a run and extract the data. Output is an xarray with aligned
    trainIds

    Parameters
    ----------

    proposalNB: str, int
        proposal number e.g. 'p002252' or 2252
    runNB: str, int
        run number as integer
    fields: str, list of str, list of dict
        list of mnemonics to load specific data such as "fastccd",
        "SCS_XGM", or dictionnaries defining a custom mnemonic such as
        {"extra": {'source: 'SCS_CDIFFT_MAG/SUPPLY/CURRENT',
                   'key': 'actual_current.value',
                   'dim': None}}
    subFolder: str
        'raw', 'proc' (processed) or 'all' (both 'raw' and 'proc') to access
        data from either or both of those folders. If 'all' is used, sources
        present in 'proc' overwrite those in 'raw'. The default is 'raw'.
    display: bool
        whether to show the run.info or not
    validate: bool
        whether to run extra-data-validate or not
    subset:
        a subset of train that can be load with by_index[:5] for the first 5
        trains
    rois: dict
        a dictionnary of mnemonics with a list of rois definition and
        the desired names, for example:
        {'fastccd': {'ref': {'roi': by_index[730:890, 535:720],
                             'dim': ['ref_x', 'ref_y']},
                     'sam': {'roi':by_index[1050:1210, 535:720],
                             'dim': ['sam_x', 'sam_y']}}}
    extract_tim: bool
        If True, extracts the peaks from TIM variables (e.g. 'MCP2raw',
        'MCP3apd') and aligns the pulse Id with the sase3 bunch pattern.
    extract_laser: bool
        If True, extracts the peaks from FastADC variables (e.g. 'FastADC5raw',
        'FastADC3peaks') and aligns the pulse Id with the PP laser bunch
        pattern.
    extract_xgm: bool
        If True, extracts the values from XGM variables (e.g. 'SCS_SA3',
        'XTD10_XGM') and aligns the pulse Id with the sase1 / sase3 bunch
        pattern.
    extract_bam: bool
        If True, extracts the values from BAM variables (e.g. 'BAM1932M')
        and aligns the pulse Id with the sase3 bunch pattern.
    tim_bp: str
        bunch pattern used to extract the TIM pulses.
        Ignored if extract_tim=False.
    laser_bp: str
        bunch pattern used to extract the TIM pulses.
        Ignored if extract_tim=False.

    Returns
    -------
    run, data: DataCollection, xarray.DataArray
        extra_data DataCollection of the proposal and run number and an
        xarray Dataset with aligned trainIds and pulseIds

    Example
    -------
    >>> import toolbox_scs as tb
    >>> run, data = tb.load(2212, 208, ['SCS_SA3', 'MCP2apd', 'nrj'])

    """
    runFolder = find_run_path(proposalNB, runNB, subFolder)
    run = ed.RunDirectory(runFolder).select_trains(subset)
    if fields is None:
        return run, xr.Dataset()
    if isinstance(fields, str):
        fields = [fields]
    if validate:
        # get_ipython().system('extra-data-validate ' + runFolder)
        pass
    if display:
        print('Loading data from {}'.format(runFolder))
        run.info()

    data_arrays = []
    run_mnemonics = mnemonics_for_run(run)
    # load pulse pattern info
    bpt = load_bpt(run, run_mnemonics=run_mnemonics)
    if bpt is None:
        log.warning('Bunch pattern table not found in run. Skipping!')
    else:
        data_arrays.append(bpt)

    for f in fields:
        if type(f) == dict:
            # extracting mnemomic defined on the spot
            if len(f.keys()) > 1:
                print('Loading only one "on-the-spot" mnemonic at a time, '
                      'skipping all others !')
            k = list(f.keys())[0]
            v = f[k]
        else:
            # extracting mnemomic from the table
            if f in run_mnemonics:
                v = run_mnemonics[f]
                k = f
            else:
                if f in _mnemonics:
                    log.warning(f'Mnemonic "{f}" not found in run. Skipping!')
                    print(f'Mnemonic "{f}" not found in run. Skipping!')
                else:
                    log.warning(f'Unknow mnemonic "{f}". Skipping!')
                    print(f'Unknow mnemonic "{f}". Skipping!')
                continue
        if k in [d.name for d in data_arrays]:
            continue  # already loaded, skip
        if display:
            print(f'Loading {k}')
        if v['source'] not in run.all_sources:
            log.warning(f'Source {v["source"]} not found in run. Skipping!')
            print(f'Source {v["source"]} not found in run. Skipping!')
            continue
        if k not in rois:
            # no ROIs selection, we read everything
            arr = run.get_array(*v.values(), name=k)
            if len(arr) == 0:
                log.warning(f'Empty array for {f}: {v["source"]}, {v["key"]}. '
                            'Skipping!')
                print(f'Empty array for {f}: {v["source"]}, {v["key"]}. '
                      'Skipping!')
                continue
            data_arrays.append(arr)
        else:
            # ROIs selection, for each ROI we select a region of the data and
            # save it with new name and dimensions
            for nk, nv in rois[k].items():
                arr = run.get_array(v['source'], v['key'],
                                    extra_dims=nv['dim'],
                                    roi=nv['roi'],
                                    name=nk)
                if len(arr) == 0:
                    log.warning(f'Empty array for {f}: {v["source"]}, '
                                f'{v["key"]}. Skipping!')
                    print(f'Empty array for {f}: {v["source"]}, {v["key"]}. '
                          'Skipping!')
                    continue
                data_arrays.append(arr)

    data = xr.merge(data_arrays, join='inner')
    data.attrs['runFolder'] = runFolder

    tim = [k for k in run_mnemonics if 'MCP' in k and k in data]
    if extract_tim and len(tim) > 0:
        data = tbdet.get_tim_peaks(run, mnemonics=tim, merge_with=data,
                                   bunchPattern=tim_bp)

    laser = [k for k in run_mnemonics if 'FastADC' in k and k in data]
    if extract_laser and len(laser) > 0:
        data = tbdet.get_laser_peaks(run, mnemonics=laser, merge_with=data,
                                     bunchPattern=laser_bp)

    xgm = ['XTD10_XGM', 'XTD10_XGM_sigma', 'XTD10_SA3', 'XTD10_SA3_sigma',
           'XTD10_SA1', 'XTD10_SA1_sigma', 'SCS_XGM', 'SCS_XGM_sigma',
           'SCS_SA1', 'SCS_SA1_sigma', 'SCS_SA3', 'SCS_SA3_sigma']
    xgm = [k for k in xgm if k in data]
    if extract_xgm and len(xgm) > 0:
        data = tbdet.get_xgm(run, mnemonics=xgm, merge_with=data)

    bam = [k for k in run_mnemonics if 'BAM' in k and k in data]
    if extract_bam and len(bam) > 0:
        data = tbdet.get_bam(run, mnemonics=bam, merge_with=data)

    return run, data


def run_by_path(path):
    """
    Return specified run

    Wraps the extra_data RunDirectory routine, to ease its use for the
    scs-toolbox user.

    Parameters
    ----------
    path: str
        path to the run directory

    Returns
    -------
    run : extra_data.DataCollection
        DataCollection object containing information about the specified
        run. Data can be loaded using built-in class methods.
    """
    return ed.RunDirectory(path)


def find_run_path(proposalNB, runNB, data='raw'):
    """
    Return the run path given the specified proposal and run numbers.

    Parameters
    ----------
    proposalNB: (str, int)
        proposal number e.g. 'p002252' or 2252
    runNB: (str, int)
        run number as integer
    data: str
        'raw', 'proc' (processed) or 'all' (both 'raw' and 'proc') to access
        data from either or both of those folders. If 'all' is used, sources
        present in 'proc' overwrite those in 'raw'. The default is 'raw'.
    Returns
    -------
    path: str
        The run path.
    """
    if isinstance(runNB, int):
        runNB = 'r{:04d}'.format(runNB)
    if isinstance(proposalNB, int):
        proposalNB = 'p{:06d}'.format(proposalNB)
    return os.path.join(find_proposal(proposalNB), data, runNB)


def open_run(proposalNB, runNB, subset=ed.by_index[:], **kwargs):
    """
    Get extra_data.DataCollection in a given proposal.
    Wraps the extra_data open_run routine and adds subset selection, out of
    convenience for the toolbox user. More information can be found in the
    extra_data documentation.

    Parameters
    ----------
    proposalNB: (str, int)
        proposal number e.g. 'p002252' or 2252
    runNB: (str, int)
        run number e.g. 17 or 'r0017'
    subset:
        a subset of train that can be load with by_index[:5] for the first 5
        trains

    **kwargs
    --------
    data: str
        default -> 'raw'
    include: str
        default -> '*'

    Returns
    -------
    run : extra_data.DataCollection
        DataCollection object containing information about the specified
        run. Data can be loaded using built-in class methods.
    """
    return ed.open_run(proposalNB, runNB, **kwargs).select_trains(subset)


def get_array(run=None, mnemonic=None, stepsize=None,
              subset=ed.by_index[:], subFolder='raw',
              proposalNB=None, runNB=None):
    """
    Loads one data array for the specified mnemonic and rounds its values to
    integer multiples of stepsize for consistent grouping (except for
    stepsize=None).
    Returns a 1D array of ones if mnemonic is set to None.

    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the data.
        Used if proposalNB and runNB are None.
    mnemonic: str
        Identifier of a single item in the mnemonic collection. None creates a
        dummy 1D array of ones with length equal to the number of trains.
    stepsize : float
        nominal stepsize of the array data - values will be rounded to integer
        multiples of this value.
    subset:
        a subset of train that can be load with by_index[:5] for the first 5
        trains
    subFolder: (str)
        sub-folder from which to load the data. Use 'raw' for raw data
        or 'proc' for processed data.
    proposalNB: (str, int)
        proposal number e.g. 'p002252' or 2252.
    runNB: (str, int)
        run number e.g. 17 or 'r0017'.

    Returns
    -------
    data : xarray.DataArray
        xarray DataArray containing rounded array values using the trainId as
        coordinate.

    Raises
    ------
    ToolBoxValueError: Exception
        Toolbox specific exception, indicating a non-valid mnemonic entry

    Example
    -------
    >>> import toolbox_scs as tb
    >>> run = tb.open_run(2212, 235)
    >>> mnemonic = 'PP800_PhaseShifter'
    >>> data_PhaseShifter = tb.get_array(run, mnemonic, 0.5)
    """
    if run is None:
        run = open_run(proposalNB, runNB, subset, data=subFolder)
    if not isinstance(run, ed.DataCollection):
        raise TypeError(f'run argument has type {type(run)} but '
                         'expected type is extra_data.DataCollection')
    run = run.select_trains(subset)
    run_mnemonics = mnemonics_for_run(run)

    try:
        if mnemonic is None:
            data = xr.DataArray(
                        np.ones(len(run.train_ids), dtype=np.int16),
                        dims=['trainId'], coords={'trainId': run.train_ids})
        elif mnemonic in run_mnemonics:
            mnem = run_mnemonics[mnemonic]
            data = run.get_array(*mnem.values(), name=mnemonic)
        else:
            raise ToolBoxValueError("Invalid mnemonic", mnemonic)

        if stepsize is not None:
            data = stepsize * np.round(data / stepsize)
        log.debug(f"Got data for {mnemonic}")
    except ToolBoxValueError as err:
        log.error(f"{err.message}")
        raise

    return data


def concatenateRuns(runs):
    """ Sorts and concatenate a list of runs with identical data variables
        along the trainId dimension.

        Input:
            runs: (list) the xarray Datasets to concatenate
        Output:
            a concatenated xarray Dataset
    """
    firstTid = {i: int(run.trainId[0].values) for i, run in enumerate(runs)}
    orderedDict = dict(sorted(firstTid.items(), key=lambda t: t[1]))
    orderedRuns = [runs[i] for i in orderedDict]
    keys = orderedRuns[0].keys()
    for run in orderedRuns[1:]:
        if run.keys() != keys:
            print('data fields between different runs are not identical. '
                  'Cannot combine runs.')
            return

    result = xr.concat(orderedRuns, dim='trainId')
    for k in orderedRuns[0].attrs.keys():
        result.attrs[k] = [run.attrs[k] for run in orderedRuns]
    return result


def load_bpt(run, merge_with=None, run_mnemonics=None):
    if run_mnemonics is None:
        run_mnemonics = mnemonics_for_run(run)

    for key in ['bunchPatternTable', 'bunchPatternTable_SA3']:
        if bool(merge_with) and key in merge_with:
            log.debug(f'Using {key} from merge_with dataset.')
            return merge_with[key]
        if key in run_mnemonics:
            bpt = run.get_array(*run_mnemonics[key].values(),
                                name='bunchPatternTable')
            log.debug(f'Loaded {key} from DataCollection.')
            return bpt
    log.debug('Could not find bunch pattern table.')
    return None