import argparse
import asyncio
import copy
import glob
import logging
import os
import subprocess
from asyncio.subprocess import PIPE

import yaml
import zmq
import zmq.asyncio
from git import InvalidGitRepositoryError, Repo
from messages import Errors

loop = asyncio.get_event_loop()


def init_config_repo(config):
    """ Make sure the configuration repo. is present and up-to-data

    :param config: the configuration defined in the `config-repo` section
    """
    os.makedirs(config['repo-local'], exist_ok=True)

    try:
        # Check if it is a git-repo.
        repo = Repo(config['repo-local'])
    except InvalidGitRepositoryError:
        # clone the repo.
        repo = Repo.clone_from(config['figures-remote'],
                               config['repo-local'])
        logging.info("Cloning the repository")
    try:
        # make sure it is updated
        repo.remote().pull()
    except Exception as e:
        logging.error(e)
        # update the head of local repository as the remote's
        repo.remote().fetch()
        repo.git.reset('--hard', 'origin/master')
        # then make sure
        repo.remote().pull()
    logging.info("Config repo is initialized")


async def parse_output(output):

    joblist = []
    for line in output.split('\n'):
        if 'Submitted' in line:
            joblist.append(line.split()[2])
    logging.info('joblist: {}'.format(joblist))
    return joblist


async def wait_jobs(joblist):

    counter = 0
    while True:
        found_jobs = set()
        output = subprocess.check_output(['squeue']).decode('utf8')
        for line in output.split("\n"):
            for job in joblist:
                if str(job) in line:
                    found_jobs.add(job)
        if len(found_jobs) == 0:
            logging.info('Jobs are finished')
            break
        await asyncio.sleep(10)
        counter += 10


async def get_run_base(instr_name, det_name, detector):

    run_base = ['xfel-calibrate'] + detector['det-type']

    for key, item in detector.items():

        if key in ['det-type']:
            continue

        run_base += ['--{}'.format(str(key))]
        if not isinstance(item, list):
            if key == 'out-folder':
                item = detector['out-folder'].format(instrument=instr_name,
                                                     detector=det_name)
            run_base += [str(item).replace(' ', '\ ')]
        else:
            for val in item:
                run_base += [str(val).replace(' ', '\ ')]

    return run_base


async def del_folder(fpath):
    """ Delete temporary folders e.g. the out-folder of new generated
        plots.

    :param fpath: the folder path that needs to be deleted.
    """
    cmd = ["rm", "-rf", fpath]
    await asyncio.subprocess.create_subprocess_shell(" ".join(cmd))
    logging.info('temp file {} has been deleted'.format(fpath))


async def copy_files(f, path, sem):
    """ Copying with concurrency limitation

    :param f: the main file with its current path.
    :param path: the path, where f is copied to.
    :param sem: Semaphore is a variable that controls
                access to common resources.
                sem can be give a None value, to indicate
                a copy of low numbers of file e.g. 1 file.
    """
    if sem:
        async with sem:
            cmd = ["rsync", "-a", f, path]
            await asyncio.subprocess.create_subprocess_shell(" ".join(cmd))
    else:
        cmd = ["rsync", "-a", f, path]
        await asyncio.subprocess.create_subprocess_shell(" ".join(cmd))


async def build_dc_report(dc_folder, report_fmt):
    """
    Generating a DC report (latex or html) using maxwell nodes.
    With the supported inputs a slurm job is submitted to sphinx-build
    pdf or html depending on mode of the report_service.
    html for prod mode and pdf or html for local mode depending
    on the chosen report_fmt

    :param dc_folder: the local DC folder path with figures and rst files
    :param report_fmt: the expected report format(html or pdf)
    """
    temp_path = "{}/temp/build_dc_report/".format(os.getcwd())
    os.makedirs(temp_path, exist_ok=True)

    # launching a slurm job and assigning the bash script to it.
    sprof = os.environ.get("XFELCALSLURM", "exfel")
    launcher_command = "sbatch -t 24:00:00 --mem 500G --requeue " \
                       "--output {temp_path}/slurm-%j.out"
    srun_base = launcher_command.format(
                temp_path=temp_path) + " -p {}".format(sprof)
    srun_base = srun_base.split()

    srun_base += [os.path.abspath("./build_dc_report.sh"),
                  os.path.abspath("{}/doc".format(dc_folder)),
                  report_fmt]
    logging.info("Building DC report submission: {}".format(srun_base))
    output = subprocess.check_output(srun_base).decode('utf8')

    jobid = None
    for line in output.split("\n"):
        if "Submitted batch job " in line:
            jobid = line.split(" ")[3]
    logging.info("Submitted job for building a report: {}".format(jobid))

    await asyncio.wait_for(wait_jobs([jobid]), timeout=7200)  # timeout=2hours
    # delete folder only after the pending slurm jobs finishes
    await del_folder("{}/slurm-{}.out".format(temp_path, jobid))


async def push_figures(repo_master, addf):
    """ Upload new figures

    :param repo_master: the local git-repository.
    :param addf: the generated figures to be added.
    """

    repo = Repo(repo_master)

    add_tries = 0
    while add_tries < 10:
        try:
            repo.index.add(addf)
            add_tries = 10
        except Exception as e:
            logging.error(str(e))
            await asyncio.sleep(2)
            add_tries += 1
    repo.index.commit("Add {} new figures".format(len(addf)))
    repo.remote().push()
    logging.info('Pushed to git')


async def server_runner(conf_file, mode):
    """
    The main server loop. After pulling the latest changes
    of the DC project, it awaits receiving configurations
    on how to generate the requested reports.

    Depending on receiving a conf yaml file or a list of
    instruments, the server proceeds by generating figures.
    Figures are copied to the corresponding folder in the
    DC project and an add-commit-push is performed to
    update the remote project and build reports that can
    be accessed through ReadTheDocs.
    """
    with open(conf_file, "r") as f:
        config = yaml.load(f.read(), Loader=yaml.FullLoader)

    # perform git-dir checks and pull the project
    # for updates only in production mode.
    if mode != 'sim':
        init_config_repo(config['GLOBAL']['git'])

    logging.info("Report service started in mode: {}".format(mode))

    logging.info("report service port: {}:{}"
                        .format(config['GLOBAL']['report-service']['bind-to'],
                                config['GLOBAL']['report-service']['port']))

    context = zmq.asyncio.Context()
    socket = context.socket(zmq.REP)

    socket.bind("{}:{}".format(config['GLOBAL']['report-service']['bind-to'],
                               config['GLOBAL']['report-service']['port']))

    while True:
        response = await socket.recv_pyobj()
        await socket.send_pyobj('Build DC reports through -->')
        logging.info("response: {} with uploading: {} and report format: {}"
                     .format(response['req'],
                             response['upload'],
                             response['report-fmt']))

        # Check if response is a list or a dict.
        # if list, it should either have instrument names or ['all'].
        # if dict, it should acquire the details of the requested reports
        # for generation. As it will be used instead of report_conf.yaml

        # reports config file
        req_cfg = {}

        # Validate the type of 'requested' response.
        if isinstance(response['req'], dict):
            req_cfg = response['req']
        elif isinstance(response['req'], list):
            if len(response['req']) == 1 and response['req'][0] == 'all':
                req_cfg = config
            else:
                req_cfg['GLOBAL'] = config['GLOBAL']
                for instr in response['req']:
                    try:
                        req_cfg[instr] = config[instr]
                    except:
                        logging.error(
                                Errors.INSTRUMENT_NOT_FOUND.format(instr))
                        continue
        else:
            logging.error(Errors.REQUEST_MALFORMED.format(response['req']))
            continue

        # No interaction with DC repository (local or remote)
        # is allowed if sim mode.
        if mode == 'sim':
            req_cfg['GLOBAL']['upload'] = False
            req_cfg['GLOBAL']['report-fmt'] = False
        else:
            # boolean for pushing to DC git repo.
            req_cfg['GLOBAL']['upload'] = response['upload']
            if mode == 'prod':
                req_cfg['GLOBAL']['report-fmt'] = 'html'
            else:
                req_cfg['GLOBAL']['report-fmt'] = response['report-fmt']
        logging.info('Requested Configuration: {}'.format(req_cfg))

        async def do_action(cfg, service_mode):

            logging.info('Run plot production')
            local_repo = cfg['GLOBAL']['git']['repo-local']
            fig_local = '{}/figures'.format(local_repo)
            jobs_timeout = cfg['GLOBAL']['report-service'].get('job-timeout',
                                                               3600)
            all_new_files = []

            for instr_name, instrument in cfg.items():

                if instr_name == "GLOBAL":
                    continue

                launched_jobs = []
                for det_name, det_conf in instrument.items():

                    logging.info('Process detector: {}'.format(det_name))
                    logging.debug('Config information: {}'.format(det_conf))

                    run_base = await get_run_base(instr_name,
                                                  det_name,
                                                  det_conf)
                    try:
                        output = await asyncio.create_subprocess_shell(
                                 " ".join(run_base), stdout=PIPE, stderr=PIPE)

                        launched_jobs.append(output.communicate())

                        logging.info('Submission information: {}:'
                                                    .format(run_base))
                    except Exception as e:
                        logging.error('Submission failed: {}'.format(e))
                        exit(1)

                outputs = await asyncio.gather(*launched_jobs)

                job_list = []
                for output in outputs:
                    if output[0]:
                        logging.info('Submission Output: {}'
                                     .format(output[0].decode('utf8')))
                    if output[1]:
                        logging.error('Submission Error: {}'
                                      .format(output[1].decode('utf8')))
                    job_list += await parse_output(output[0].decode('utf8'))

                try:
                    await asyncio.wait_for(wait_jobs(job_list),
                                           timeout=jobs_timeout)
                    logging.info('All jobs are finished')
                except asyncio.TimeoutError:
                    logging.error('Jobs have timed-out!')
                    logging.error('{}/temp has not been deleted.'.format(
                                  os.path.dirname(os.path.abspath(__file__))))

                # Avoid copying files if upload bool is False
                # to avoid causing local git repository errors.
                if cfg['GLOBAL']['upload']:
                    # Copy all plots
                    for det_name, det_conf in instrument.items():

                        out_folder = det_conf['out-folder'].format(
                                              instrument=instr_name,
                                              detector=det_name)

                        figures = glob.glob("{}/*png".format(out_folder))

                        det_new_files = {}

                        for f in figures:
                            const = f.split('/')[-1].split('_')[0]
                            fpath = '{}/{}/{}/{}'.format(fig_local, instr_name,
                                                         det_name, const)

                            os.makedirs(fpath, exist_ok=True)
                            det_new_files[f] = fpath

                            # Set concurrency limitation.
                            # 50 have been chosen by trial
                            # Note: This is not the max limitation.
                            sem = asyncio.Semaphore(50)
                            all_new_files.append(
                                '{}/{}'.format(fpath, f.split('/')[-1]))

                        await asyncio.gather(*[copy_files(k, v, sem)
                                               for k, v in det_new_files.items()])  # noqa

                        logging.info('{} figures of {} are copied into {}'
                                     .format(len(figures), det_name,
                                             fig_local))

            if cfg['GLOBAL']['upload']:
                try:
                    report_fmt = cfg['GLOBAL']['report-fmt']
                    # Remove sensitive information from the config file.
                    del cfg['GLOBAL']
                    # Write the requested cfg.yaml before pushing all figures.
                    with open('{}/report_conf.yaml'.format(
                                                  fig_local), 'w') as outfile:
                        yaml.dump(cfg, outfile, default_flow_style=False)

                    if service_mode == 'prod':
                        # add report_con.yaml in the list of files added to the
                        # new git commit before pushing to remote
                        all_new_files.append('{}/report_conf.yaml'
                                             .format(fig_local))
                        asyncio.ensure_future(push_figures(local_repo,
                                                           all_new_files))
                    # build either html or pdf depending on the running mode
                    # of the report_service and requested report format.
                    asyncio.ensure_future(build_dc_report(local_repo,
                                                          report_fmt))  # noqa
                except Exception as upload_e:
                    logging.error("upload failed: {}".format(upload_e))

                # TODO:delete out-folder
                #try:
                #    asyncio.ensure_future(del_folder(out_folder))
                #except:
                #    logging.error(str(e))

            logging.info('Generating requested plots is finished!')
            logging.info('=======================================')

            return

        try:
            asyncio.ensure_future(do_action(copy.copy(req_cfg),
                                            mode))
        except Exception as e:  # actions that fail are only error logged
            logging.error(str(e))
            break

arg_parser = argparse.ArgumentParser(description='Start the report service')
arg_parser.add_argument('--config-file', type=str,
                        default='./report_conf.yaml',
                        help='config file path with '
                             'reportservice port. '
                             'Default=./report_conf.yaml')
arg_parser.add_argument('--mode', type=str, default="sim", choices=['sim', 'prod', 'local'])
arg_parser.add_argument('--log-file', type=str, default='./report.log',
                        help='The report log file path. Default=./report.log')
arg_parser.add_argument('--logging', type=str, default="INFO",
                        help='logging modes: INFO, DEBUG or ERROR. '
                             'Default=INFO',
                    	choices=['INFO', 'DEBUG', 'ERROR'])

if __name__ == "__main__":
    args = vars(arg_parser.parse_args())
    conf_file = args["config_file"]

    logfile = args["log_file"]
    fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

    logging.basicConfig(filename=logfile, filemode='a+',
                        level=getattr(logging, args['logging']),
                        format='%(levelname)-6s: %(asctime)s %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    mode = args["mode"]
    loop = asyncio.get_event_loop()
    loop.run_until_complete(server_runner(conf_file, mode))
    loop.close()