#!/usr/bin/env python

import argparse
import copy
from datetime import datetime
import nbconvert
import nbformat
from nbparameterise import (
    extract_parameters, replace_definitions, parameter_values
)
import os
import pprint
import re
from subprocess import Popen, PIPE, check_output
import sys
from uuid import uuid4
import warnings
from .settings import *
from .notebooks import notebooks
from jinja2 import Template
import textwrap

# Add a class combining raw description formatting with
# Metavariable default outputs
class RawTypeFormatter(argparse.RawDescriptionHelpFormatter,
                       argparse.MetavarTypeHelpFormatter,
                       argparse.ArgumentDefaultsHelpFormatter):
    pass


# The argument parser for calibrate.py, will be extended depending
# on the options given.

def make_initial_parser():
    parser = argparse.ArgumentParser(description="Main entry point "
                                                 "for offline calibration",
                                     formatter_class=RawTypeFormatter)

    parser.add_argument('detector', metavar='DETECTOR', type=str,
                        help='The detector to calibrate')

    parser.add_argument('type', metavar='TYPE', type=str,
                        help='Type of calibration: '+",".join(notebooks.keys()))

    parser.add_argument('--no-cluster-job',
                        action="store_true",
                        default=False,
                        help="Do not run as a cluster job")

    parser.add_argument('--report-to', type=str,
                        help='Filename (and optionally path) for output report')

    parser.add_argument('--priority', type=int, default=2,
                        help="Priority of batch jobs. If priority<=1, reserved nodes become available.")

    parser.add_argument('--vector-figs', action="store_true", default=False,
                        help="Use vector graphics for figures in the report.")

    parser.add_argument_group('required arguments')

    return parser


parser = make_initial_parser()

# Helper functions for parser extensions

def make_intelli_list(ltype):
    """ Parses a list from range and comma expressions.

    An expression of the form "1-5,6" will be parsed into the following
    list: [1,2,3,4,6]

    """
    class IntelliListAction(argparse.Action):

        def __init__(self, *args, **kwargs):
            super(IntelliListAction, self).__init__(*args, **kwargs)

        def __call__(self, parser, namespace, values, option_string=None):

            parsed_values = []
            values = ",".join(values)
            if isinstance(values, str):
                for rcomp in values.split(","):
                    if "-" in rcomp:
                        start, end = rcomp.split("-")
                        parsed_values += list(range(int(start), int(end)))
                    else:
                        parsed_values += [int(rcomp)]
            elif isinstance(values, (list, tuple)):
                parsed_values = values
            else:
                parsed_values = [values, ]
            parsed_values = [self.ltype(p) for p in parsed_values]
            print("Parsed input {} to {}".format(values, parsed_values))
            setattr(namespace, self.dest, parsed_values)
    IntelliListAction.ltype = ltype
    return IntelliListAction


def consolize_name(name):
    """ Names of console parameters don't have underscores """
    return name.replace("_", "-")


def deconsolize_args(args):
    """ Variable names have underscores """
    new_args = {}
    for k, v in args.items():
        new_args[k.replace("-", "_")] = v
    return new_args


def extract_title_author_version(nb):
    """ Tries to extract title, author and versions from markdown """

    first_md = first_markdown_cell(nb)
    source = first_md["source"]
    title = re.findall(r'\#+\s*(.*)\s*\#+', source)
    author = re.findall(
        r'author[\s]*[:][\s]*(.*?)\s*(?:[,?]|version)', source, flags=re.IGNORECASE)
    version = re.findall(r'version[\s]*:\s*(.*)', source, flags=re.IGNORECASE)

    title = title[0] if len(title) else None
    author = author[0] if len(author) else None
    version = version[0] if len(version) else None
    return title, author, version


def get_cell_n(nb, cell_type, cell_n):
    """
    Return notebook cell with given number and given type

    :param nb: jupyter notebook
    :param cell_type: cell type, 'code' or 'markdown'
    :param cell_n: cell number (count from 0)
    :return: notebook cell
    """
    counter = 0
    for cell in nb.cells:
        if cell.cell_type == cell_type:
            if counter == cell_n:
                return cell
            counter=+1


def first_code_cell(nb):
    """ Return the first code cell of a notebook """
    return get_cell_n(nb, 'code', 0)


def first_markdown_cell(nb):
    """ Return the first markdown cell of a notebook """
    return get_cell_n(nb, 'markdown', 0)


def make_epilog(nb, caltype=None):
    """ Make an epilog from the notebook to add to parser help
    """
    msg = ""
    header_cell = first_markdown_cell(nb)
    lines = header_cell.source.split("\n")
    if caltype:
        msg += "{:<15}  {}".format(caltype, lines[0]) + "\n"
    else:
        msg += "{}".format(lines[0]) + "\n"
    pp = pprint.PrettyPrinter(indent=(17 if caltype else 0))
    if len(lines[1:]):
        plines = pp.pformat(lines[1:])[1:-1].split("\n")
        for line in plines:
            sline = line.replace("'", "", 1)
            sline = sline.replace("', '", " "*(17 if caltype else 0), 1)
            sline = sline[::-1].replace("'", "", 1)[::-1]
            sline = sline.replace(" ,", " ")
            if len(sline) > 1 and sline[0] == ",":
                sline = sline[1:]
            msg += sline + "\n"
    msg += "\n"
    return msg


def get_notebook_function(nb, fname):
    import re
    flines = []
    def_found = False
    indent = None
    for cell in nb.cells:
        if cell.cell_type == 'code':
            lines = cell.source.split("\n")
            for line in lines:

                if def_found:
                    lin = len(line) - len(line.lstrip())
                    if indent is None:
                        if lin != 0:
                            indent = lin
                            flines.append(line)
                    elif lin >= indent:
                        flines.append(line)
                    else:
                        return "\n".join(flines)

                if re.search(r"def\s+{}\(.*\):\s*".format(fname), line) and not def_found:
                    # print("Found {} in line {}".format(fname, line))
                    # set this to indent level
                    def_found = True
                    flines.append(line)
    return None


# extend the parser according to user input
# the first case is if a detector was given, but no calibration type
if len(sys.argv) == 3 and "-h" in sys.argv[2]:
    detector = sys.argv[1].upper()
    try:
        det_notebooks = notebooks[detector]
    except KeyError:
        print("Not one of the known detectors: {}".format(notebooks.keys()))
        exit()

    msg = "Options for detector {}\n".format(detector)
    msg += "*"*len(msg)+"\n\n"

    # basically, this creates help in the form of
    #
    # TYPE        some description that is
    #             indented for this type.
    #
    # The information is extracted from the first markdown cell of
    # the notebook.
    for caltype, notebook in det_notebooks.items():
        nbpath = os.path.abspath(
            "{}/{}".format(os.path.dirname(__file__), notebook["notebook"]))
        with open(nbpath, "r") as f:
            nb = nbformat.read(f, as_version=4)
            msg += make_epilog(nb, caltype=caltype)

    parser.epilog = msg
# second case is if no detector was given either
elif len(sys.argv) == 2 and "-h" in sys.argv[1]:
    epilog = "Available detectors are: {}".format(
        ", ".join([k for k in notebooks.keys()]))
    parser.epilog = epilog
# final case: a detector and type was given. We derive the arguments
# from the corresponding notebook
elif len(sys.argv) >= 3:
    detector = sys.argv[1].upper()
    caltype = sys.argv[2].upper()
    try:
        notebook = os.path.abspath(
            "{}/{}".format(os.path.dirname(__file__), notebooks[detector][caltype]["notebook"]))
        cvar = notebooks[detector][caltype].get("concurrency",
                                                {"parameter": None,
                                                 "default concurrency": None,
                                                 "cluster cores": 8})["parameter"]
    except KeyError:
        print("Not one of the known calibrations or detectors")
        exit()
    with open(notebook, "r") as f:
        nb = nbformat.read(f, as_version=4)

        ext_func = notebooks[detector][caltype].get("extend parms", None)

        def do_parse(nb, parser, overwrite_reqs=False):
            parser.description = make_epilog(nb)
            parms = extract_parameters(nb)

            for p in parms:

                helpstr = ("Default: %(default)s" if not p.comment
                           else "{}. Default: %(default)s".format(p.comment.replace("#", " ").strip()))
                required = (p.comment is not None
                            and "required" in p.comment
                            and not overwrite_reqs
                            and p.name != cvar)

                # This may be not a public API
                # May require reprogramming in case of argparse updates
                pars_group = parser._action_groups[2 if required else 1]

                default = p.value if (not required) else None

                if p.type == list or p.name == cvar:
                    if p.type is list:
                        try:
                            ltype = type(p.value[0])
                        except:
                            print(
                                "List '{}' is empty. Parameter type can not be defined.".format(p.name))
                            print("See first code cell in jupyter-notebook: '{}'".format(
                                notebooks[detector][caltype]["notebook"]))
                            exit()
                    else:
                        ltype = p.type

                    range_allowed = "RANGE ALLOWED" in p.comment.upper() if p.comment else False
                    pars_group.add_argument("--{}".format(consolize_name(p.name)),
                                            nargs='+',
                                            type=ltype if not range_allowed else str,
                                            default=default,
                                            help=helpstr,
                                            required=required,
                                            action=make_intelli_list(ltype) if range_allowed else None)
                elif p.type == bool:
                    pars_group.add_argument("--{}".format(consolize_name(p.name)),
                                            action="store_true",
                                            default=default,
                                            help=helpstr,
                                            required=required)

                else:
                    pars_group.add_argument("--{}".format(consolize_name(p.name)),
                                            type=p.type,
                                            default=default,
                                            help=helpstr,
                                            required=required)

        do_parse(nb, parser, True)

        # extend parameters if needed
        ext_func = notebooks[detector][caltype].get("extend parms", None)
        if ext_func is not None:
            func = get_notebook_function(nb, ext_func)

            if func is None:
                warnings.warn("Didn't find concurrency function {} in notebook".format(ext_func),
                              RuntimeWarning)

            else:
                # remove help calls as they will cause the argument parser to exit
                add_help = False
                if "-h" in sys.argv:
                    sys.argv.remove("-h")
                    add_help = True
                if "--help" in sys.argv:
                    sys.argv.remove("--help")
                    add_help = True
                known, remaining = parser.parse_known_args()
                if add_help:
                    sys.argv.append("--help")
                args = deconsolize_args(vars(known))

                df = {}

                exec(func, df)
                f = df[ext_func]
                import inspect
                sig = inspect.signature(f)
                callargs = []
                for arg in sig.parameters:
                    callargs.append(args[arg])

                extention = f(*callargs)
                fcc = first_code_cell(nb)
                fcc["source"] += "\n"+extention
                parser = make_initial_parser()
                do_parse(nb, parser, False)


def has_parm(parms, name):
    """ Check if a parameter of `name` exists in parms """
    for p in parms:
        if p.name == name:
            return True
    return False


def flatten_list(l):
    return "_".join([str(flatten_list(v)) for v in l]) if isinstance(l, list) else l


def set_figure_format(nb, enable_vector_format):
    """
    Set svg format in inline backend for figures

    If parameter 'vector_figs' is set to True svg format will
    be used for figures in the notebook rendering. Subsequently  vector
    graphics figures will be used for report.

    :param nb: jupyter notebook
    :param param: value of corresponding parameter
    """

    if enable_vector_format:
        cell = get_cell_n(nb, 'code', 1)
        cell.source += "\n%config InlineBackend.figure_formats = ['svg']\n"


def concurrent_run(temp_path, nb, nbname, args, cparm=None, cval=None,
                   final_job=False, job_list=[], fmtcmd="", cluster_cores=8,
                   sequential=False, priority=2):
    """ Launch a concurrent job on the cluster via SLURM
    """

    if cparm is not None:
        args[cparm] = cval

    suffix = flatten_list(cval)
    if "cluster_profile" in args:
        args["cluster_profile"] = "{}_{}".format(
            args["cluster_profile"], suffix)

    # first convert the notebook
    parms = extract_parameters(nb)
    params = parameter_values(parms, **args)
    new_nb = replace_definitions(nb, params, execute=False)
    set_figure_format(new_nb, args["vector_figs"])
    base_name = nbname.replace(".ipynb", "")
    new_name = "{}__{}__{}.ipynb".format(
        os.path.basename(base_name), cparm, suffix)

    nbpath = "{}/{}".format(temp_path, new_name)
    with open(nbpath, "w") as f:
        f.write(nbconvert.exporters.export(
            nbconvert.NotebookExporter, new_nb)[0])

    # add finalization to the last job
    if final_job:
        import stat
        with open("{}/finalize.sh".format(temp_path), "w") as finfile:
            finfile.write("#!/bin/tcsh\n")
            finfile.write("module load texlive/2017\n")
            finfile.write("if [[ `which python` != *'karabo'* ]];\n")
            finfile.write("then module load anaconda/3; fi\n")
            finfile.write("echo 'Running finalize script'\n")
            finfile.write(
                "python3 -c {}\n".format(fmtcmd.format(joblist=job_list)))
        all_stats = stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH
        os.chmod("{}/finalize.sh".format(temp_path), all_stats)
    # then run an sbatch job
    if not sequential:
        # calculate number of general nodes available
        free = int(check_output(free_nodes_cmd, shell=True).decode('utf8'))
        preempt = int(check_output(
            preempt_nodes_cmd, shell=True).decode('utf8'))
        if free + preempt >= max_reserved or priority > 1 or reservation == "":
            srun_base = launcher_command.format(
                temp_path=temp_path) + " -p {}".format(sprof)
            srun_base = srun_base.split()
        else:
            srun_base = launcher_command.format(
                temp_path=temp_path) + " --reservation={}".format(reservation)
            srun_base = srun_base.split()
            print(" ".join(srun_base))
    else:
        srun_base = []

    srun_base += [os.path.abspath("{}/bin/slurm_calibrate.sh".format(os.path.dirname(__file__))),  # path to helper sh
                  os.path.abspath(nbpath),  # path to notebook
                  python_path,  # path to python
                  ipython_path,  # path to ipython
                  jupyter_path,  # path to jupyter
                  ipcluster_path,  # path to ipcluster
                  # karabo activate path
                  karabo_activate_path if karabo_activate_path else "NO_KRB_ACTIVATE",
                  args.get("cluster_profile", "NO_CLUSTER"),
                  '"{}"'.format(base_name.upper()),
                  '"{}"'.format(args["detector"].upper()),
                  '"{}"'.format(args["type"].upper()),
                  "FINAL" if final_job else "NONFINAL",
                  ". {}/finalize.sh".format(os.path.abspath(temp_path)),
                  str(cluster_cores)]

    output = check_output(srun_base).decode('utf8')
    jobid = None
    if not sequential:
        for line in output.split("\n"):
            if "Submitted batch job " in line:
                jobid = line.split(" ")[3]
        print("Submitted job: {}".format(jobid))
    return jobid


def make_par_table(parms, temp_path, run_uuid):
    """
    Create a table with input parameters if the notebook

    :param parms: parameters of the notebook
    :param temp_path: path to temporary directory for job outputs
    :param run_uuid: inset of folder name containing job output
    """

    # Add space in long strings to wrap them in latex
    def split_len(seq, length):
        l = [seq[i:i + length] for i in range(0, len(seq), length)]
        return ' '.join(l)

    # Prepare strings and estimate their length
    l_parms = []
    len_parms = [0, 0]
    max_len = [30, 30]
    for p in parms:
        name = p.name.replace('_', '-')
        if len(name) > max_len[0]:
            len_parms[0] = max_len[0]
            name = split_len(name, max_len[0])

        value = str(p.value)
        if len(value) > max_len[1]:
            len_parms[1] = max_len[1]
            value = split_len(value, max_len[1])
        if p.type is str:
            value = "``{}''".format(value)
        value = value.replace('_', '\\_')
        comment = str(p.comment)[1:].replace('_', '\\_')
        l_parms.append([name, value, comment])

    # Fix column width is needed
    col_type = ['l', 'c', 'p{.3\\textwidth}']
    if len_parms[0] == max_len[0]:
        col_type[0] = col_type[2]
    if len_parms[1] == max_len[1]:
        col_type[1] = col_type[2]

    tmpl = Template('''
                    Input of the calibration pipeline 
                    =================================
                    
                    .. math::
                    
                        \\begin{tabular}{ {% for k in p %}{{ k }}{%- endfor %}  } 
                        \hline
                        {% for k in lines %}
                        {{ k[0] }} & {{ k[1] }} & {{ k[2] }} \\\\
                        {%- endfor %}
                        \hline
                        \end{tabular}
                    ''')

    f_name = "{}/slurm_tmp_{}/InputParameters.rst".format(temp_path, run_uuid)
    with open(f_name, "w") as finfile:
        finfile.write(textwrap.dedent(tmpl.render(p=col_type, lines=l_parms)))


def run():
    """ Run a calibration task with parser arguments """

    args = deconsolize_args(vars(parser.parse_args()))
    detector = args["detector"].upper()
    caltype = args["type"].upper()
    sequential = args["no_cluster_job"]
    priority = int(args['priority'])

    if sequential:
        print("Not running on cluster")

    try:
        notebook = notebooks[detector][caltype]["notebook"]
        notebook = os.path.abspath(
            "{}/{}".format(os.path.dirname(__file__), notebook))
        concurrency = notebooks[detector][caltype].get("concurrency", None)
        version = notebooks[detector][caltype].get("version", "NA")
        author = notebooks[detector][caltype].get("author", "anonymous")

    except KeyError:
        print("Not one of the known calibrations or detectors")
        return
    with open(notebook, "r") as f:
        nb = nbformat.read(f, as_version=4)

        # extend parameters if needed
        ext_func = notebooks[detector][caltype].get("extend parms", None)
        if ext_func is not None:
            func = get_notebook_function(nb, ext_func)

            if func is None:
                warnings.warn("Didn't find concurrency function {} in notebook".format(ext_func),
                              RuntimeWarning)

            else:
                # remove help calls as they will cause the argument parser to exit
                known, remaining = parser.parse_known_args()
                args = deconsolize_args(vars(known))
                df = {}
                exec(func, df)
                f = df[ext_func]
                import inspect
                sig = inspect.signature(f)
                callargs = []
                for arg in sig.parameters:
                    callargs.append(args[arg])

                extention = f(*callargs)
                fcc = first_code_cell(nb)
                fcc["source"] += "\n"+extention

        parms = extract_parameters(nb)

        title, author, version = extract_title_author_version(nb)

        if not title:
            title = "{}: {} Calibration".format(detector, caltype)
        if not author:
            author = "anonymous"
        if not version:
            version = ""

        title = title.rstrip()

        run_uuid = uuid4()

        # check that a modules field is present if we run concurrently
        if not has_parm(parms,  concurrency["parameter"]) and concurrency["parameter"] is not None:
            msg = "Notebook cannot be run concurrently: no {} parameter".format(
                concurrency["parameter"])
            warnings.warn(msg, RuntimeWarning)

        if not has_parm(parms, "cluster_profile"):
            warnings.warn("Notebook has no cluster_profile parameter, " +
                          "running on cluster will likeyl fail!", RuntimeWarning)
        elif "cluster_profile" not in args or args["cluster_profile"] == parser.get_default('cluster_profile'):
            args["cluster_profile"] = "slurm_prof_{}".format(run_uuid)

        # create a temporary output directory to work in
        run_tmp_path = "{}/slurm_tmp_{}".format(temp_path, run_uuid)
        os.makedirs(run_tmp_path)

        # Write all input parameters to rst file to be included to final report
        parms = parameter_values(parms, **args)
        make_par_table(parms, temp_path, run_uuid)

        # wait on all jobs to run and then finalize the run by creating a report from the notebooks
        out_path = "{}/{}/{}/{}".format(report_path, detector.upper(),
                                        caltype.upper(), datetime.now().isoformat())
        if try_report_to_output:
            if "out_folder" in args:
                out_path = os.path.abspath(args["out_folder"])
                if "run" in args:
                    rr = args["run"]
                    if isinstance(rr, int):
                        out_path = "{}/r{:04d}/".format(out_path, rr)
                    else:
                        out_path = "{}/{}/".format(out_path, rr)
            else:
                print("No 'out_folder' defined as argument, outputting to '{}' instead.".format(
                    out_path))
        else:
            os.makedirs(out_path)
        cmd = ('"from cal_tools.tools import finalize; ' +
               'finalize({{joblist}}, \'{run_path}\', \'{out_path}\', ' +
               '\'{project}\', \'{calibration}\', \'{author}\', '
               '\'{version}\', \'{report_to}\')"')

        report_to = title.replace(" ", "")
        if args["report_to"] is not None:
            report_to = args["report_to"]

        fmtcmd = cmd.format(run_path=run_tmp_path, out_path=out_path,
                            project=title, calibration=title,
                            author=author, version=version, report_to=report_to)

        joblist = []
        if concurrency.get("parameter", None) is None:
            cluster_cores = concurrency.get("cluster cores", 8)
            jobid = concurrent_run(run_tmp_path, nb, os.path.basename(notebook), args,
                                   final_job=True, job_list=joblist, fmtcmd=fmtcmd,
                                   cluster_cores=cluster_cores, sequential=sequential, priority=priority)
            joblist.append(jobid)
        else:
            cvar = concurrency["parameter"]
            cvals = args.get(cvar, None)
            cluster_cores = concurrency.get("cluster cores", 8)

            con_func = concurrency.get("use function", None)

            if cvals is None:
                defcval = concurrency.get("default concurrency", None)
                if defcval is not None:
                    print(
                        "Concurrency parameter '{}' is taken from notebooks.py".format(cvar))
                    if not isinstance(defcval, (list, tuple)):
                        cvals = range(defcval)
                    else:
                        cvals = defcval

            if cvals is None:
                print(parms)
                for p in parms:
                    if p.name == cvar:
                        defcval = p.value
                if defcval is not None:
                    print("Concurrency parameter '{}' is taken from '{}'".format(
                        cvar, notebook))
                    if not isinstance(defcval, (list, tuple)):
                        cvals = [defcval]
                    else:
                        cvals = defcval

            if con_func:
                func = get_notebook_function(nb, con_func)
                if func is None:
                    warnings.warn("Didn't find concurrency function {} in notebook".format(con_func),
                                  RuntimeWarning)

                else:
                    df = {}
                    exec(func, df)
                    f = df[con_func]
                    import inspect
                    sig = inspect.signature(f)
                    callargs = []
                    if cvals:

                        # in case default needs to be used for function call
                        args[cvar] = cvals
                    for arg in sig.parameters:
                        callargs.append(args[arg])
                    cvals = f(*callargs)
                    print("Split concurrency into {}".format(cvals))

            # get expected type
            cvtype = list
            for p in parms:
                if p.name == cvar:
                    cvtype = p.type
                    break

            for cnum, cval in enumerate(cvals):
                jobid = concurrent_run(run_tmp_path, nb, notebook, args,
                                       cvar, [cval, ] if not isinstance(
                                           cval, list) and cvtype is list else cval,
                                       cnum == len(list(cvals)) -
                                       1, joblist, fmtcmd,
                                       cluster_cores=cluster_cores, sequential=sequential, priority=priority)
                joblist.append(jobid)

        if not all([j is None for j in joblist]):
            print("Submitted the following SLURM jobs: {}".format(",".join(joblist)))


if __name__ == "__main__":
    run()