import argparse
import asyncio
import copy
import getpass
import glob
import json
import logging
import os
import sqlite3
import subprocess
import urllib.parse
from datetime import datetime

import yaml
import zmq
import zmq.asyncio
import zmq.auth.thread
from dateutil import parser as timeparser
from git import Repo, InvalidGitRepositoryError
from messages import Errors, Success
from metadata_client.metadata_client import MetadataClient


async def init_job_db(config):
    """ Initialize the sqlite job database

    A new database is created if no pre-existing one is present. A single
    table is created: jobs, which has columns:

        rid, id, proposal, run, flg, status

    :param config: the configuration parsed from the webservice YAML config
    :return: a sqlite3 connection instance to the database
    """
    logging.info("Initializing database")
    conn = sqlite3.connect(config['web-service']['job-db'])
    c = conn.cursor()
    try:
        c.execute("SELECT * FROM jobs")
    except:
        logging.info("Creating initial job database")
        c.execute("CREATE TABLE jobs(rid, id, proposal, run, flg, status)")
    return conn


async def init_md_client(config):
    """ Initialize an MDC client connection

    :param config: the configuration parsed from the webservice YAML config
    :return: an MDC client connection
    """
    mdconf = config['metadata-client']
    client_conn = MetadataClient(client_id=mdconf['user-id'],
                                 client_secret=mdconf['user-secret'],
                                 user_email=mdconf['user-email'],
                                 token_url=mdconf['token-url'],
                                 refresh_url=mdconf['refresh-url'],
                                 auth_url=mdconf['auth-url'],
                                 scope=mdconf['scope'],
                                 base_api_url=mdconf['base-api-url'])
    return client_conn


def init_config_repo(config):
    """ Make sure the configuration repo is present and up-to-data

    :param config: the configuration defined in the `config-repo` section
    """
    os.makedirs(config['local-path'], exist_ok=True)
    # check if it is a repo
    try:
        repo = Repo(config['local-path'])
    except InvalidGitRepositoryError:
        repo = Repo.clone_from(config['url'], config['local-path'])
    try:
        repo.remote().pull()
    except:
        pass
    logging.info("Config repo is initialized")


async def upload_config(socket, config, yaml, instrument, cycle, proposal):
    """ Upload a new configuration YAML

    :param socket: ZMQ socket to send reply on
    :param config: the configuration defined in the `config-repo` section
        of the webservice.yaml configuration.
    :param yaml: the YAML contents to update
    :param instrument: instrument for which the update is for
    :param cycle: the facility cylce the update is for
    :param proposal: the proposal the update is for
    
    The YAML contents will be placed into a file at
    
        {config.local-path}/{cycle}/{proposal}.yaml

    If it exists it is overwritten and then the new version is pushed to
    the configuration git repo.
    """
    repo = Repo(config['local-path'])
    # assure we are on most current version
    repo.remote().pull()
    prop_dir = os.path.join(repo.working_tree_dir, cycle)
    os.makedirs(prop_dir, exist_ok=True)
    with open("{}/{}.yaml".format(prop_dir, proposal), "w") as f:
        f.write(yaml)
    fpath = "{}/{}.yaml".format(prop_dir, proposal)
    repo.index.add([fpath])
    repo.index.commit(
        "Update to proposal {}: {}".format(proposal,
                                           datetime.now().isoformat()))
    repo.remote().push()
    logging.info(Success.UPLOADED_CONFIG.format(cycle, proposal))
    socket.send(Success.UPLOADED_CONFIG.format(cycle, proposal).encode())


def merge(source, destination):
    """
    Deep merge two dictionaries

    :param source: source dictionary to merge into destination
    :param destination: destination dictionary which is being merged in
    :return: the updated destination dictionary

    Taken from:
    https://stackoverflow.com/questions/20656135/
                                       python-deep-merge-dictionary-data
    """
    for key, value in source.items():
        if isinstance(value, dict):
            # get node or create one
            node = destination.setdefault(key, {})
            merge(value, node)
        else:
            destination[key] = value

    return destination


async def change_config(socket, config, updated_config, instrument, cycle,
                        proposal, apply=False):
    """
    Change the configuration of a proposal

    If no proposal specific configuration yet exists, one is first created
    based on the default configuration of the proposal

    Changes are committed to git.

    :param socket: ZMQ socket to send reply on
    :param config: repo config as given in YAML config file
    :param updated_config: a dictionary containing the updated config
    :param instrument: the instrument to change config for
    :param cycle: the cycle to change config for
    :param proposal: the proposal to change config for
    :param apply: set to True to actually commit a change, otherwise a dry-run
                  is performed
    :return: The updated config to the requesting zmq socket
    """
    # first check if a proposal specific config exists, if not create one
    repo = Repo(config['local-path'])
    repo.remote().pull()
    prop_dir = os.path.join(repo.working_tree_dir, cycle)
    os.makedirs(prop_dir, exist_ok=True)
    fpath = "{}/p{:06d}.yaml".format(prop_dir, int(proposal))
    if not os.path.exists(fpath):
        with open("{}/default.yaml".format(repo.working_tree_dir), "r") as f:
            defconf = yaml.load(f.read())
            subconf = {}
            for action, instruments in defconf.items():
                subconf[action]= {}
                subconf[action][instrument] = instruments[instrument]
            with open(fpath, "w") as wf:
                wf.write(yaml.dump(subconf, default_flow_style=False))
    new_conf = None
    with open(fpath, "r") as rf:
        existing_conf = yaml.load(rf.read())
        new_conf = merge(updated_config, existing_conf)
    if apply:
        with open(fpath, "w") as wf:
            wf.write(yaml.dump(new_conf, default_flow_style=False))
        repo.index.add([fpath])
        repo.index.commit(
            "Update to proposal YAML: {}".format(datetime.now().isoformat()))
        repo.remote().push()
    logging.info(Success.UPLOADED_CONFIG.format(cycle, proposal))
    socket.send(yaml.dump(new_conf, default_flow_style=False).encode())


async def slurm_status(filter_user=True):
    """ Return the status of slurm jobs by calling squeue

    :param filter_user: set to true to filter ony jobs from current user
    :return: a dictionary indexed by slurm jobid and containing a tuple
             of (status, run time) as values.
    """
    cmd = ["squeue"]
    if filter_user:
        cmd += ["-u", getpass.getuser()]
    ret = subprocess.run(cmd, stdout=subprocess.PIPE)
    if ret.returncode == 0:
        rlines = ret.stdout.decode().split("\n")
        statii = {}
        for r in rlines[1:]:
            try:
                jobid, _, _, _, status, runtime, _, _ = r.split()
                jobid = jobid.strip()
                statii[jobid] = status, runtime
            except:
                pass
        return statii
    return None


async def query_rid(conn, socket, rid):
    c = conn.cursor()
    c.execute("SELECT * FROM jobs WHERE rid LIKE '{}'".format(rid))
    combined = {}
    for r in c.fetchall():
        rid, jobid, proposal, run, flg, status = r
        logging.debug(
            "Job {}, proposal {}, run {} has status {}".format(jobid,
                                                               proposal,
                                                               run,
                                                               status))
        cflg, cstatus = combined.get(rid, ([], []))
        cflg.append(flg)
        cstatus.append(status)
        combined[rid] = cflg, cstatus
    flg_order = {"R": 2, "A": 1, "NA": 0}

    msg = ""
    for rid, value in combined.items():
        flgs, statii = value
        flg = max(flgs, key=lambda i: flg_order[i])
        msg += "\n".join(statii)
    if msg == "":
        msg = "DONE"
    socket.send(msg.encode())


async def update_job_db(config):
    """ Update the job database and send out updates to MDC

    :param config: configuration parsed from webservice YAML
    """
    logging.info("Starting config db handling")
    conn = await init_job_db(config)
    mdc = await init_md_client(config)
    while True:
        statii = await slurm_status()
        c = conn.cursor()
        c.execute("SELECT * FROM jobs")
        combined = {}
        for r in c.fetchall():

            rid, jobid, proposal, run, flg, status = r
            if jobid in statii:
                slstatus, runtime = statii[jobid]
                query = "UPDATE jobs SET status='{status} - {runtime}' WHERE id LIKE '{jobid}'"  # noqa
                c.execute(query.format(status=slstatus,
                                       runtime=runtime,
                                       jobid=jobid))
            elif not "QUEUED" in status:
                c.execute("DELETE FROM jobs WHERE id LIKE '{jobid}'".format(
                    jobid=jobid))
                cflg, cstatus = combined.get(rid, ([], []))
                cflg.append("A")
                cstatus.append('DONE')
                combined[rid] = cflg, cstatus
            else:
                # check for timed out jobs
                _, start_time = status.split(": ")
                dt = datetime.now() - timeparser.parse(start_time)
                if dt.total_seconds() > config['web-service']['job-timeout']:
                    c.execute(
                        "DELETE FROM jobs WHERE id LIKE '{jobid}'".format(
                            jobid=jobid))
                    cflg, cstatus = combined.get(rid, ([], []))
                    cflg.append("R")
                    cstatus.append('PENDING SLURM SCHEDULING / TIMED OUT?')
                    combined[rid] = cflg, cstatus
        conn.commit()
        c.execute("SELECT * FROM jobs")
        for r in c.fetchall():
            rid, jobid, proposal, run, flg, status = r
            cflg, cstatus = combined.get(rid, ([], []))
            cflg.append(flg)
            cstatus.append(status)
            combined[rid] = cflg, cstatus
        flg_order = {"R": 2, "A": 1, "NA": 0}

        for rid, value in combined.items():
            flgs, statii = value
            flg = max(flgs, key=lambda i: flg_order[i])
            msg = "\n".join(statii)
            response = mdc.update_run_api(rid, {'flg_cal_data_status': flg,
                                                'cal_pipeline_reply': msg})
            if response.status_code != 200:
                logging.error(Errors.MDC_RESPONSE.format(response))
        await asyncio.sleep(int(config['web-service']['job-update-interval']))


async def copy_untouched_files(file_list, out_folder, run):
    """ Copy those files whicha are not touched by calibration to outpot dir

    :param file_list: The list of files to copy
    :param out_folder: The output folder
    :param run: The run which is being handled

    Copying is done via an asyncio subprocess call
    """
    os.makedirs("{}/r{:04d}".format(out_folder, int(run)), exist_ok=True)
    for f in file_list:
        of = f.replace("raw", "proc").replace("RAW", "CORR")
        cmd = ["rsync", "-av", f, of]
        await asyncio.subprocess.create_subprocess_shell(" ".join(cmd))
        logging.info("Copying {} to {}".format(f, of))


async def run_correction(conn, cmd, mode, proposal, run, rid):
    """ Run a correction command
    
    :param cmd: to run, should be a in list for as expected by subprocess.run
    :param mode: "prod" or "sim", in the latter case nothing will be executed
                 but the command will be logged
    :param proposal: proposal the command was issued for
    :param run: run the command was issued for
    :param: rid: run id in the MDC

    Returns a formatted Success or Error message indicating outcome of the
    execution.
    """
    if mode == "prod":
        logging.info(" ".join(cmd))
        ret = subprocess.run(cmd, stdout=subprocess.PIPE)
        if ret.returncode == 0:
            logging.info(Success.START_CORRECTION.format(proposal, run))
            # enter jobs in job db
            c = conn.cursor()
            rstr = ret.stdout.decode()
            query = "INSERT INTO jobs VALUES ('{rid}', '{jobid}', '{proposal}', '{run}', 'R', 'QUEUED: {now}')"  # noqa
            for r in rstr.split("\n"):
                if "Submitted job:" in r:
                    _, jobid = r.split(":")
                    c.execute(query.format(rid=rid, jobid=jobid.strip(),
                                           proposal=proposal, run=run,
                                           now=datetime.now().isoformat()))
            conn.commit()
            logging.debug(" ".join(cmd))
            return Success.START_CORRECTION.format(proposal, run)
        else:
            logging.error(Errors.JOB_LAUNCH_FAILED.format(cmd, ret.returncode))
            return Errors.JOB_LAUNCH_FAILED.format(cmd, ret.returncode)

    else:
        logging.debug(Success.START_CORRECTION_SIM.format(proposal, run))
        logging.debug(cmd)
        return Success.START_CORRECTION_SIM.format(proposal, run)


async def server_runner(config, mode):
    """ The main server loop
    
    The main server loop handles remote requests via a ZMQ interface.
    
    Requests are the form of ZMQ.REQuest and have the format

        command, *parms

    where *parms is a string-encoded python list as defined by the
    commands. The following commands are currently understood:

    - correct, with parmeters sase, instrument, cycle, proposal, runnr

       where

       :param rid: is the runid within the MDC database
       :param sase: is the sase beamline
       :param instrument: is the instrument
       :param cycle: is the facility cycle
       :param proposal: is the proposal id
       :param runnr: is the run number in integer form, e.g. without leading
                    "r"

       This will trigger a correction process to be launched for that run in
       the given cycle and proposal.

    - upload-yaml, with parameters sase, instrument, cycle, proposal, yaml

       where

       :param sase: is the sase beamline
       :param instrument: is the instrument
       :param cycle: is the facility cycle
       :param proposal: is the proposal id
       :param yaml: is url-encoded (quotes and spaces) representation of
                    new YAML file

       This will create or replace the existing YAML configuration for the
       proposal and cycle with the newly sent one, and then push it to the git
       configuration repo.

    """

    init_config_repo(config['config-repo'])
    job_db = await init_job_db(config)
    mdc = await init_md_client(config)

    context = zmq.asyncio.Context()
    auth = zmq.auth.thread.ThreadAuthenticator(context)
    if mode == "prod-auth":
        auth.start()
        auth.allow(config['web-service']['allowed-ips'])

    socket = context.socket(zmq.REP)
    socket.zap_domain = b'global'
    socket.bind("{}:{}".format(config['web-service']['bind-to'],
                               config['web-service']['port']))

    while True:
        response = await socket.recv_multipart()
        if isinstance(response, list) and len(response) == 1:
            try:  # protect against unparseable requests
                response = eval(response[0])
            except Exception as e:
                logging.error(str(e))
                socket.send(Errors.REQUEST_FAILED.encode())
                continue

        if len(
                response) < 2:  # catch parseable but malformed requests
            logging.error(Errors.REQUEST_MALFORMED.format(response))
            socket.send(Errors.REQUEST_MALFORMED.format(response).encode())
            continue

        action, payload = response[0], response[1:]

        if action not in ["correct", 'dark', 'query-rid',
                          'upload-yaml', 'update_conf']:  # only handle known actions
            logging.warn(Errors.UNKNOWN_ACTION.format(action))
            socket.send(Errors.UNKNOWN_ACTION.format(action).encode())
            continue

        if action == "query-rid":
            rid = payload[0]
            await query_rid(job_db, socket, rid)
            continue

        async def do_action(action, payload):
            in_folder = None
            out_folder = None
            run_mapping = {}
            priority = None
            req_res = None

            if action in ['update_conf']:
                try:
                    sase, instrument, cycle, proposal, config_yaml, apply = payload  # noqa
                    updated_config = json.loads(config_yaml)
                    await change_config(socket, config['config-repo'],
                                        updated_config, instrument, cycle,
                                        proposal, apply.upper()=="TRUE")
                except Exception as e:
                    e = str(e)
                    logging.error(f"Failure applying config for {proposal}:" +
                                  f" {e}: {updated_config}")

            if action in ['dark', 'correct']:
                wait_runs = []

                if action == 'correct':
                    rid, sase, instrument, cycle, proposal, runnr, priority = payload
                    runnr = runnr.replace("r", "")
                    wait_runs = [runnr]
                if action == 'dark':
                    rid, sase, instrument, cycle, proposal = payload[:5]
                    runs = payload[5:]  # can be many
                    for i, run in enumerate(runs):
                        erun = eval(run)
                        if isinstance(erun, (list, tuple)):
                            typ, runnr = erun
                            if typ == "reservation":
                                req_res = runnr
                                continue

                            runnr = runnr.replace("r", "")
                            run_mapping[typ] = runnr
                            wait_runs.append(runnr)
                        else:
                            run_mapping['no_mapping_{}'.format(i)] = erun
                            wait_runs.append(erun)
                proposal = proposal.replace("p", "")
                proposal = "{:06d}".format(int(proposal))
                specific_conf_file = "{}/{}/{}.yaml".format(
                    config['config-repo']['local-path'], cycle, proposal)
                if os.path.exists(specific_conf_file):
                    with open(specific_conf_file, "r") as f:
                        pconf = yaml.load(f.read())[action]
                else:
                    print("Using default file, as {} does not exist".format(
                        specific_conf_file))
                    default_file = "{}/default.yaml".format(
                        config['config-repo']['local-path'])
                    with open(default_file, "r") as f:
                        pconf = yaml.load(f.read())[action]
                if instrument not in pconf:
                    socket.send(Errors.NOT_CONFIGURED.encode())
                    return

                in_folder = config[action]['in-folder'].format(
                    instrument=instrument, cycle=cycle, proposal=proposal)

                msg = "Queued proposal {}, run {} for offline calibration, priority: {}".format(
                    proposal, ", ".join(wait_runs), priority)
                socket.send(msg.encode())
                logging.debug(msg)

                all_transfers = []
                for runnr in wait_runs:
                    rpath = "{}/r{:04d}/".format(in_folder, int(runnr))

                    async def wait_on_transfer():
                        if 'pnfs' in os.path.realpath(
                                rpath):  # dcache files are assumed migrated
                            return True
                        rstr = None
                        ret = None
                        max_tries = 300  # 3000s
                        tries = 0
                        while not os.path.exists(rpath):
                            await asyncio.sleep(10)
                        # await asyncio.sleep(1)
                        while rstr is None or 'status="online"' in rstr or 'status="Online"' in rstr or ret.returncode != 0:  # noqa
                            await asyncio.sleep(10)
                            ret = subprocess.run(
                                ["getfattr", "-n", "user.status", rpath],
                                stdout=subprocess.PIPE)
                            rstr = ret.stdout.decode()
                            print(rstr)
                            if tries > max_tries:
                                return False
                            tries += 1

                        return ret.returncode == 0

                    transfer_complete = await wait_on_transfer()
                    print("Transfer complete: ", transfer_complete)
                    all_transfers.append(transfer_complete)
                    if not transfer_complete:
                        logging.error(
                            Errors.TRANSFER_EVAL_FAILED.format(proposal,
                                                               runnr))
                        msg = "Timeout waiting for migration. Contact det-support@xfel.eu"
                        response = mdc.update_run_api(rid, {
                            'flg_cal_data_status': 'NA',
                            'cal_pipeline_reply': msg})
                        if response.status_code != 200:
                            logging.error(Errors.MDC_RESPONSE.format(response))

                print("All transfers", all(all_transfers))
                if not all(all_transfers):
                    logging.error(Errors.TRANSFER_EVAL_FAILED.format(proposal,
                                                                     ",".join(
                                                                         wait_runs)))
                    return

            print("Now doing: {}".format(action))
            if action == 'dark':
                print("Running dark cal")
                status = []
                detectors = {}
                out_folder = config[action]['out-folder'].format(
                    instrument=instrument, cycle=cycle, proposal=proposal,
                    runs="_".join(wait_runs))
                for runnr in wait_runs:
                    rpath = "{}/r{:04d}/".format(in_folder, int(runnr))
                    if len(detectors) == 0:
                        logging.warn(Errors.NOTHING_TO_DO.format(rpath))
                        msg = "Nothing to characterize for these runs"
                        response = mdc.update_run_api(rid, {
                            'flg_cal_data_status': 'NA',
                            'cal_pipeline_reply': msg})
                        if response.status_code != 200:
                            logging.error(Errors.MDC_RESPONSE.format(response))

                    for detector, dconfig in pconf[instrument].items():
                        # check if we find files according to mapping in raw run folder
                        fl = glob.glob(
                            "{}/RAW-*{}*.h5".format(rpath, dconfig["inset"]))
                        if len(fl):
                            thisconf = copy.copy(dconfig)
                            thisconf["in-folder"] = in_folder
                            thisconf["out-folder"] = out_folder

                            del thisconf[
                                "inset"]  # don't need this for xfel-calibrate
                            detectors[detector] = thisconf
                print("Detectors:", detectors)
                for detector, dconfig in detectors.items():
                    if "-" in detector:
                        detector, _ = detector.split("-")
                    priority = '1'

                    if detector.upper() in ["JUNGFRAU", "FASTCCD", "PNCCD",
                                            "EPIX", "EPIX10K"]:
                        priority = '0'
                    cmd = ["python", "-m", "xfel_calibrate.calibrate",
                           detector, "DARK", '--priority', priority]

                    if req_res:
                        cmd += ['--reservation', req_res]

                    run_config = []
                    for typ, run in run_mapping.items():
                        if "no_mapping" in typ:
                            run_config.append(run)
                        else:
                            dconfig[typ] = run
                    if len(run_config):
                        dconfig["runs"] = ",".join(run_config)

                    for key, value in dconfig.items():
                        if not isinstance(value, bool):
                            cmd += ["--{}".format(key), str(value)]
                        else:
                            cmd += ["--{}".format(key)]
                    ret = await run_correction(job_db, cmd, mode, proposal,
                                               wait_runs[0], rid)
                    status.append(ret)

            if action == 'correct':
                runnr = wait_runs[0]
                rpath = "{}/r{:04d}/".format(in_folder, int(runnr))

                out_folder = config[action]['out-folder'].format(
                    instrument=instrument, cycle=cycle, proposal=proposal)
                corr_file_list = set()
                copy_file_list = set(glob.glob("{}/*.h5".format(rpath)))
                detectors = {}
                for detector, dconfig in pconf[instrument].items():
                    # check if we find files according to mapping in raw run folder
                    fl = glob.glob(
                        "{}/RAW-*{}*.h5".format(rpath, dconfig["inset"]))
                    if len(fl):
                        corr_file_list = corr_file_list.union(set(fl))
                        thisconf = copy.copy(dconfig)
                        thisconf["in-folder"] = in_folder
                        thisconf["out-folder"] = out_folder
                        thisconf["run"] = runnr
                        del thisconf[
                            "inset"]  # don't need this for xfel-calibrate
                        detectors[detector] = thisconf
                copy_file_list = copy_file_list.difference(corr_file_list)
                print(detectors)
                asyncio.ensure_future(
                    copy_untouched_files(copy_file_list, out_folder, runnr))
                if len(detectors) == 0:
                    logging.warn(Errors.NOTHING_TO_DO.format(rpath))
                    msg = "Nothing to calibrate for this run, copied raw data only"
                    response = mdc.update_run_api(rid,
                                                  {'flg_cal_data_status': 'NA',
                                                   'cal_pipeline_reply': msg})
                    if response.status_code != 200:
                        logging.error(Errors.MDC_RESPONSE.format(response))
                    return
                status = []
                for detector, dconfig in detectors.items():
                    if "-" in detector:
                        detector, _ = detector.split("-")
                    cmd = ["python", "-m", "xfel_calibrate.calibrate",
                           detector, "CORRECT"]
                    for key, value in dconfig.items():
                        if not isinstance(value, bool):
                            cmd += ["--{}".format(key), str(value)]
                        else:
                            cmd += ["--{}".format(key)]
                    if priority:
                        cmd += ["--priority", str(priority)]
                    ret = await run_correction(job_db, cmd, mode, proposal,
                                               runnr, rid)
                    status.append(ret)

            if action == 'upload-yaml':
                sase, instrument, cycle, proposal, this_yaml = payload
                this_yaml = urllib.parse.unquote_plus(this_yaml)
                await upload_config(socket, config['config-repo'], this_yaml,
                                    instrument, cycle, proposal)

        try:
            asyncio.ensure_future(
                do_action(copy.copy(action), copy.copy(payload)))
        except Exception as e:  # actions that fail are only error logged
            logging.error(str(e))


parser = argparse.ArgumentParser(
    description='Start the calibration webservice')
parser.add_argument('--config-file', type=str, default='./webservice.yaml')
parser.add_argument('--log-file', type=str, default='./web.log')
parser.add_argument('--mode', type=str, default="sim", choices=['sim', 'prod'])
parser.add_argument('--logging', type=str, default="INFO",
                    choices=['INFO', 'DEBUG', 'ERROR'])

if __name__ == "__main__":
    args = vars(parser.parse_args())
    conf_file = args["config_file"]
    with open(conf_file, "r") as f:
        config = yaml.load(f.read())
    mode = args["mode"]
    logfile = args["log_file"]
    fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(filename=logfile,
                        level=getattr(logging, args['logging']),
                        format=fmt)
    loop = asyncio.get_event_loop()
    loop.create_task(update_job_db(config))
    loop.run_until_complete(server_runner(config, mode))
    loop.close()