import argparse
from datetime import datetime
import nbconvert
import nbformat
from nbparameterise import (
    extract_parameters, replace_definitions, parameter_values
)
import os
import re
from subprocess import Popen, PIPE, check_output
import sys
from uuid import uuid4
import warnings

# settings
notebooks = {
             "AGIPD": {
                       "DARK": {
                               "notebook": "AGIPD/Characterize_AGIPD_Gain_Darks_NBC.ipynb",
                               "concurrency": (None, None),                                                                
                               },
                       "PC":   {
                               "notebook": "AGIPD/Chracterize_AGIPD_Gain_PC_NBC.ipynb",
                               "concurrency": ("modules", 16),                                                                
                               },
                       "FF":   {
                               "notebook": "AGIPD/Characterize_AGIPD_Gain_FlatFields_NBC.ipynb",
                               "concurrency": ("modules", 16),                                                                
                               },
                       "CORRECT":   {
                               "notebook": "AGIPD/AGIPD_Correct_and_Verify.ipynb",
                               "concurrency": (None, None),                                                                
                               },
                       "COMBINE":   {
                               "notebook": "AGIPD/AGIPD_Characterize_Gain_Combine_NBC.ipynb",
                               "concurrency": (None, None),                                                                
                               },
                      },
             "LPD": {
                       "DARK": {
                               "notebook": "LPD/LPDChar_Darks_NBC.ipynb",
                               "concurrency": (None, None),                                                                
                               },
                       "PC":   {
                               "notebook": "LPD/Characterize_LPD_GAIN_CI_per_pixel_NBC.ipynb",
                               "concurrency": ("modules", 16),                                                                
                               },
                       "CORRECT":   {
                               "notebook": "LPD/LPD_Correct_and_Verify.ipynb",
                               "concurrency": (None, None),                                                                
                               },
                      },
             
            }

temp_path = os.getcwd()
slurm_partion = "exfel"
python_path = "python"
ipython_path = "ipython"
jupyter_path = "jupyter"
karabo_activate_path = "/gpfs/exfel/data/user/haufs/karabo/activate"
ipcluster_path = "/gpfs/exfel/data/user/haufs/karabo/extern/bin/ipcluster"
report_path = "{}/calibration_reports/".format(os.getcwd())
try_report_to_output = True

# logic
parser = argparse.ArgumentParser(description="Main entry point "
                                             "for offline calibration",
                                formatter_class=argparse.MetavarTypeHelpFormatter)

parser.add_argument('detector', metavar='DETECTOR', type=str,
                    help='The detector to calibrate')

parser.add_argument('type', metavar='TYPE', type=str,
                    help='Type of calibration: '+",".join(notebooks.keys()))

def make_intelli_list(ltype):
    class IntelliListAction(argparse.Action):
        ltype = ltype
        def __init__(self, *args, **kwargs):
            super(IntelliListAction, self).__init__(*args, **kwargs)            

        def __call__(self, parser, namespace, values, option_string=None):
            
            parsed_values = []
            values = ",".join(values)
            if isinstance(values, str):
                for rcomp in values.split(","):
                    if "-" in rcomp:
                        start, end = rcomp.split("-")
                        parsed_values += list(range(int(start), int(end)))
                    else:
                        parsed_values += [int(rcomp)]
            elif isinstance(values, (list, tuple)):
                parsed_values = values
            else:
                parsed_values = [values,]
            parsed_values = [self.ltype(p) for p in parsed_values]
            print("Parsed input {} to {}".format(values, parsed_values))
            setattr(namespace, self.dest, parsed_values)
    return IntelliListAction

        
def consolize_name(name):
    return name.replace("_", "-")

def deconsolize_args(args):
    new_args = {}
    for k, v in args.items():
        new_args[k.replace("-", "_")] = v
    return new_args

def extract_title_author_version(nb):
    def find_first_markdown(nb):
        for cell in nb.cells:
            if cell.cell_type == 'markdown':
                return cell
    first_md = find_first_markdown(nb)
    source = first_md["source"]
    title = re.findall(r'\#+\s*(.*)\s*\#+', source)
    author = re.findall(r'author[\s]*[:][\s]*(.*?)\s*(?:[,?]|version)', source, flags=re.IGNORECASE)
    version = re.findall(r'version[\s]*:\s*(.*)', source, flags=re.IGNORECASE)
    
    title = title[0] if len(title) else None
    author = author[0] if len(author) else None
    version = version[0] if len(version) else None
    return title, author, version
    

if len(sys.argv) >= 3:
    detector = sys.argv[1].upper()
    caltype = sys.argv[2].upper()
    try:
        notebook = notebooks[detector][caltype]["notebook"]
        cvar = notebooks[detector][caltype].get("concurrency", (None, None))[0]
    except KeyError:
        print("Not one of the known calibrations or detectors")
        exit()
    with open(notebook, "r") as f:
        nb = nbformat.read(f, as_version=4)
        parms = extract_parameters(nb)
        for p in parms:
            helpstr = ("Default: %(default)s" if not p.comment 
                       else "{}. Default: %(default)s".format(p.comment.replace("#", " ").strip()))
            required = p.comment is not None and "required" in p.comment
            
            if p.type == list:
                if len(p.value):
                    ltype = type(p.value[0])
                else:
                    ltype = str
                range_allowed = "RANGE ALLOWED" in p.comment.upper() if p.comment else False
                parser.add_argument("--{}".format(consolize_name(p.name)),
                                    nargs='+',
                                    type=ltype if not range_allowed else str,
                                    default=p.value if (not required) and p.name != cvar else None,
                                    help=helpstr,
                                    required=required and p.name != cvar,
                                    action=make_intelli_list(ltype) if range_allowed else None)
            elif p.type == bool:
                parser.add_argument("--{}".format(consolize_name(p.name)),
                                    action="store_true",
                                    default=p.value if not required else None,
                                    help=helpstr,
                                    required=required)
            else:
                parser.add_argument("--{}".format(consolize_name(p.name)),
                                    type=p.type,
                                    default=p.value if not required else None,
                                    help=helpstr,
                                    required=required)

                
def has_parm(parms, name):
    for p in parms:
        if p.name == name:
            return True
    return False

def concurrent_run(temp_path, nb, nbname, args, cparm=None, cval=None, final_job=False, job_list=[], fmtcmd=""):
    if cparm is not None:
        args[cparm] = cval
        
    suffix = "_".join([str(v) for v in cval]) if isinstance(cval, list) else cval
    if "cluster_profile" in args:
        args["cluster_profile"] = "{}_{}".format(args["cluster_profile"], suffix)
        
    # first convert the notebook
    params = parameter_values(parms, **args)
    new_nb = replace_definitions(nb, params, execute=False)
    base_name = nbname.replace(".ipynb", "")
    
    new_name = "{}_{}.ipynb".format(os.path.basename(base_name), suffix)
    
    
    
    nbpath = "{}/{}".format(temp_path, new_name)
    with open(nbpath, "w") as f:
        f.write(nbconvert.exporters.export(nbconvert.NotebookExporter, new_nb)[0])
        
    
    # add finalization to the last job
    if final_job:
        import stat
        with open("{}/finalize.sh".format(temp_path), "w") as finfile:
                finfile.write("#!/bin/tcsh\n")
                finfile.write("module load texlive/2017\n")
                finfile.write("echo 'Running finalize script'\n")           
                finfile.write("python3 -c {}\n".format(fmtcmd.format(joblist=job_list)))
        os.chmod("{}/finalize.sh".format(temp_path), stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
    # then run an sbatch job
    srun_base = ["sbatch", "-p", slurm_partion, "-t", "24:00:00",
                 "--mem", "500G", "--mail-type", " END,FAIL"]
    
    srun_base += [os.path.abspath("{}/slurm_calibrate.sh".format(os.getcwd())), # path to helper sh
                  os.path.abspath(nbpath), # path to notebook
                  python_path, # path to python
                  ipython_path, # path to ipython
                  jupyter_path, # path to jupyter                  
                  ipcluster_path, # path to ipcluster
                  karabo_activate_path if karabo_activate_path else "NO_KRB_ACTIVATE", # karabo activate path
                  args.get("cluster_profile", "NO_CLUSTER"),
                  '"{}"'.format(base_name.upper()),
                  '"{}"'.format(args["detector"].upper()),
                  '"{}"'.format(args["type"].upper()),
                  "FINAL" if final_job else "NONFINAL",
                  "{}/finalize.sh".format(os.path.abspath(temp_path))]
                              
    output = check_output(srun_base).decode('utf8')
    jobid = None
    for line in output.split("\n"):
        if "Submitted batch job " in line:
            jobid = line.split(" ")[3]
    print("Submitted job: {}".format(jobid))
    return jobid
                
    
def run():
    args = deconsolize_args(vars(parser.parse_args()))
    detector = args["detector"].upper()
    caltype = args["type"].upper()
    try:
        notebook = notebooks[detector][caltype]["notebook"]
        concurrency = notebooks[detector][caltype].get("concurrency", None)
        version = notebooks[detector][caltype].get("version", "NA")
        author = notebooks[detector][caltype].get("author", "anonymous")
        
    except KeyError:
        print("Not one of the known calibrations or detectors")
        return
    with open(notebook, "r") as f:
        nb = nbformat.read(f, as_version=4)
        parms = extract_parameters(nb)
        title, author, version = extract_title_author_version(nb)
        
        if not title:
            title = "{}: {} Calibration".format(detector, caltype)
        if not author:
            author = "anonymous"
        if not version:
            version = ""
        
        run_uuid = uuid4()
        
        # check that a modules field is present if we run concurrently
        if not has_parm(parms,  concurrency[0]) and concurrency[0] is not None:
            msg = "Notebook cannot be run concurrently: no {} parameter".format(concurrency[0])
            warnings.warn(msg, RuntimeWarning)
            
        if not has_parm(parms, "cluster_profile"):
            warnings.warn("Notebook has no cluster_profile parameter, "+
                          "running on cluster will likeyl fail!", RuntimeWarning)
        elif "cluster_profile" not in args or args["cluster_profile"] == parser.get_default('cluster_profile'):
            args["cluster_profile"] = "slurm_prof_{}".format(run_uuid)

        # create a temporary output directory to work in        
        run_tmp_path = "{}/slurm_tmp_{}".format(temp_path, run_uuid)
        os.makedirs(run_tmp_path)
        
        # wait on all jobs to run and then finalize the run by creating a report from the notebooks
        out_path = "{}/{}/{}/{}".format(report_path, detector.upper(), caltype.upper(), datetime.now().isoformat())
        if try_report_to_output:
            if "out_folder" in args:
                out_path = args["out_folder"]
                if "run" in args:
                    rr = args["run"]
                    if isinstance(rr, int):
                        out_path = "{}/r{:04d}/".format(out_path, rr)
                    else:
                        out_path = "{}/{}/".format(out_path, rr)
            else:
                print("No 'out_folder' defined as argument, outputting to '{}' instead.".format(out_path))
        else:
            os.makedirs(out_path)
        cmd = ('"from cal_tools.cal_tools import finalize; ' +
               'finalize({{joblist}}, \'{run_path}\', \'{out_path}\', '+
               '\'{project}\', \'{calibration}\', \'{author}\', \'{version}\')"')
        
        fmtcmd = cmd.format(run_path=run_tmp_path, out_path=out_path,
                            project=title, calibration=title,
                            author=author, version=version)
        
        joblist = []
        if concurrency[0] is None:
            jobid = concurrent_run(run_tmp_path, nb, os.path.basename(notebook), args,
                                   final_job=True, job_list=joblist, fmtcmd=fmtcmd)
            joblist.append(jobid)
        else:
            cvar = concurrency[0]
            cvals = args.get(cvar, None)

            if cvals is None:
                cvals = range(concurrency[1])
            
            for cnum, cval in enumerate(cvals):
                jobid = concurrent_run(run_tmp_path, nb, notebook, args,
                                       cvar, [cval,], cnum==len(list(cvals))-1, joblist, fmtcmd)
                joblist.append(jobid)
                
        print("Submitted the following SLURM jobs: {}".format(",".join(joblist)))
              
        
        
        #srun_base = ["sbatch", "-p", "exfel", "-t", "24:00:00",
        #             "--chdir", os.getcwd(), "{}/finalize.sh".format(run_tmp_path)]

        #Popen(srun_base).wait()
if __name__ == "__main__":
    run()