import argparse
import glob
import os
import sqlite3
from collections import OrderedDict
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, HTTPServer
from subprocess import check_output
from uuid import uuid4

import yaml
from jinja2 import Template
from xfel_calibrate.settings import (free_nodes_cmd, preempt_nodes_cmd,
                                     reservation)


class LimitedSizeDict(OrderedDict):
    def __init__(self, *args, **kwds):
        self.size_limit = kwds.pop("size_limit", None)
        OrderedDict.__init__(self, *args, **kwds)
        self._check_size_limit()

    def __setitem__(self, key, value):
        OrderedDict.__setitem__(self, key, value)
        self._check_size_limit()

    def _check_size_limit(self):
        if self.size_limit is not None:
            while len(self) > self.size_limit:
                self.popitem(last=False)


config = None
pdf_queue = LimitedSizeDict(size_limit=50)


# HTTPRequestHandler class
class RequestHandler(BaseHTTPRequestHandler):
    conf_was_init = False

    def init_config(self):

        global config
        global cal_config

        self.nodes_avail_res_cmd = config["shell-commands"]["nodes-avail-res"]
        self.total_jobs_cmd = config["shell-commands"]["total-jobs"]
        self.upex_jobs_cmd = config["shell-commands"]["upex-jobs"]
        self.upex_prefix = config["shell-commands"]["upex-prefix"]
        self.tail_log_cmd = config["shell-commands"]["tail-log"]
        self.cat_log_cmd = config["shell-commands"]["cat-log"]
        self.mappings = config["mappings"]
        self.run_candidates = config["run-candidates"]

        self.templates = {}
        for template, tfile in config["templates"].items():
            with open(tfile, "r") as tf:
                self.templates[template] = tf.read()
        global pdf_queue
        self.pdf_queue = pdf_queue
        self.conf_was_init = True

    def do_GET(self):

        if not self.conf_was_init:
            self.init_config()
        # Send response status code
        self.send_response(200)
        if "/serve_overview.css" in self.path:
            self.send_header('Content-type', 'text/css')
            self.end_headers()
            for s in self.templates["css"].split("\n"):
                self.wfile.write(s.encode())
            return

        if "pdf?" in self.path:
            puuid = self.path.split("?")[1]
            fpath = self.pdf_queue.get(puuid, None)
            if fpath is None:
                return
            self.send_header('Content-type', 'application/pdf')
            self.end_headers()
            with open(fpath, "rb") as f:
                self.wfile.write(f.read())
            return

        if "process_dark?" in self.path:
            self.send_header('Content-type', 'text/html')
            self.end_headers()
            pars = self.path.split("?")[1].split("&")
            pars = {x.split("=")[0]: x.split("=")[1] for x in pars}
            if pars['instrument'] in ['Nothing', 'none']:
                self.wfile.write(bytes("<br>", "utf8"))
                return
            host = config["server-config"]["host"]
            port = config["server-config"]["port"]
            req_dark = config["scrpts-cnf"]["req-dark"]
            python_v = config["scrpts-cnf"]["python-v"]
            if 'request_dark' in pars:
                par_list = [python_v, req_dark, '--bkg']
                del pars["request_dark"]
                for k, v in pars.items():
                    par_list.append('--{}'.format(str(k).replace("_", "-")))
                    if k == 'detectors':
                        for det in v.split(","):
                            par_list.append('{}'.format(det))
                    elif k == 'proposal':
                        par_list.append('{:06d}'.format(int(v)))
                    else:
                        par_list.append('{}'.format(v))

                par_list = list(filter(None, par_list))
                print('REQUEST DARK: ', par_list)
                timeout = config["server-config"]["dark-timeout"]
                try:
                    msg = check_output(par_list, shell=False,
                                       timeout=timeout).decode('utf8')
                except Exception as e:
                    msg = str(e)

                self.wfile.write(bytes("<br>" + str(msg), "utf8"))
                return

            tmpl = Template(self.templates["checkbox"])
            detectors = list(cal_config['dark'][pars['instrument']].keys())

            det_list = []
            if 'detectors' in pars:
                det_list = pars['detectors'].split(",")

            det_names = []
            for d in detectors:
                if d in det_list:
                    det_names.append(["checked", d])
                else:
                    det_names.append(["", d])

            run_names = []
            run1_det = ["FASTCCD", "EPIX", "DSSC", "PNCCD"]
            run3_det = ["LPD", "AGIPD", "JUNGFRAU"]
            msg = ''
            if any(y in x for x in det_list for y in run1_det):
                run_names = ['run']
            if any(y in x for x in det_list for y in run3_det):
                if run_names == []:
                    run_names = ['run_high', 'run_med', 'run_low']
                else:
                    run_names = []
                    msg = "Incompatible choice"

            message = tmpl.render(detectors=det_names, runs=run_names,
                                  message=msg, host=host, port=port)

            self.wfile.write(bytes(message, "utf8"))
            return

        if "dark?" in self.path:
            # Send headers
            self.send_header('Content-type', 'text/html')
            self.end_headers()
            host = config["server-config"]["host"]
            port = config["server-config"]["port"]
            reports = {}
            for instrument, detectors in cal_config['dark'].items():
                reports[instrument] = {}
                for detector in detectors:
                    det_inset = detector.replace('-', '_')
                    tmpl = '/gpfs/exfel/exp/{}/*/*/usr/dark/*/{}'.format(
                        instrument, det_inset)
                    files = glob.glob(tmpl + '/*pdf')
                    files += glob.glob(tmpl + '/*/*pdf')
                    files.sort(key=os.path.getmtime, reverse=True)
                    file_info = []
                    for i, file in enumerate(files):
                        if (len(file_info) % 2) == 0:
                            bgcolor = 'EEEEEE'
                        else:
                            bgcolor = 'FFFFFF'
                        time = os.stat(file).st_mtime
                        d_time = datetime.fromtimestamp(time).replace(
                            tzinfo=timezone.utc)
                        s_time = d_time.strftime('%y-%m-%d %H:%M')
                        file_info.append([file, s_time, bgcolor])

                    reports[instrument][detector] = file_info

            tmpl = Template(self.templates["dark-overview"])
            message = tmpl.render(reports=reports, host=host, port=port)

            self.wfile.write(bytes(message, "utf8"))
            return

        if "/gpfs" in self.path:
            sendReply = False
            if self.path.endswith(".html"):
                mimetype = 'text/html'
                sendReply = True
            if self.path.endswith(".jpg"):
                mimetype = 'image/jpg'
                sendReply = True
            if self.path.endswith(".gif"):
                mimetype = 'image/gif'
                sendReply = True
            if self.path.endswith(".png"):
                mimetype = 'image/png'
                sendReply = True
            if self.path.endswith(".pdf"):
                mimetype = 'application/pdf'
                sendReply = True
            if self.path.endswith(".js"):
                mimetype = 'application/javascript'
                sendReply = True
            if self.path.endswith(".css"):
                mimetype = 'text/css'
                sendReply = True

            if sendReply == True and os.path.isfile(self.path):
                with open(self.path, "rb") as f:
                    self.send_header('Content-type', mimetype)
                    self.end_headers()
                    self.wfile.write(f.read())
            return


        # Send headers
        self.send_header('Content-type', 'text/html')
        self.end_headers()

        # Send message back to client
        # general maxwell stats
        free = int(check_output(free_nodes_cmd, shell=True).decode('utf8'))
        preempt = int(check_output(
            preempt_nodes_cmd, shell=True).decode('utf8'))
        nodes_avail_general = free + preempt
        ret = check_output(self.nodes_avail_res_cmd.format(reservation),
                           shell=True).decode('utf8')
        ures = ret.split("/")
        if len(ures) == 2:
            used, reserved = ures
        else:
            used = 0
            reserved = 0
        nodes_avail_gr = "{}/{}".format(int(reserved) - int(used),
                                        reserved)
        total_jobs_running = check_output(self.total_jobs_cmd, shell=True)
        total_jobs_running = total_jobs_running.decode('utf8').split("/")[0]

        upex_res = [r for r in
                    check_output(self.upex_jobs_cmd,
                                 shell=True).decode('utf8').split()
                    if self.upex_prefix in r]

        upex_reservations = {}
        for res in upex_res:
            nodes_avail_res = check_output(
                self.nodes_avail_res_cmd.format(res),
                shell=True).decode('utf8')
            used, reserved = nodes_avail_res.split("/")
            nodes_avail_res = "{}/{}".format(int(reserved) - int(used),
                                             reserved)
            upex_reservations[res] = nodes_avail_res

        recommendation = "DON'T SUBMIT TO RESERVATION"
        if nodes_avail_general < int(reserved) - int(used):
            "CONSIDER SUBMITTING TO RESERVATION"

        tmpl = Template(self.templates["maxwell-status"])
        maxwell_status_r = tmpl.render(nodes_avail_general=nodes_avail_general,
                                       nodes_avail_general_res=nodes_avail_gr,
                                       total_jobs_running=total_jobs_running,
                                       upex_reservations=upex_reservations,
                                       recommendation=recommendation)

        last_n_lines = check_output(self.tail_log_cmd,
                                    shell=True).decode('utf8').split("\n")
        last_n_lines = [l for l in last_n_lines
                        if ("Response error from MDC" not in l
                            and "DEBUG" not in l)]
        tmpl = Template(self.templates["log-output"])
        log_output_r = tmpl.render(logout="<br>".join(last_n_lines[::-1]))

        last_n_lines = check_output(self.cat_log_cmd,
                                    shell=True).decode('utf8').split("\n")[
                       ::-1]

        def get_run_info(l, key):
            """
            Parse a line and extract information
            :param l: Line to parse
            :param key: A key work: DARK or CORRECT
            :return: Detector name, Instrument name, input folder,
            output folder, list of runs, time of request
            """
            if key in l:
                ls = l.split()
                if key not in ls:
                    return None
                dclass = ls[ls.index(key) - 1]
                in_folder = ls[ls.index("--in-folder") + 1]
                out_folder = ls[ls.index("--out-folder") + 1]

                if "--db-module" in ls:
                    detector = ls[ls.index("--db-module") + 1]
                else:
                    detector = dclass

                if "--instrument" in ls:
                    instrument = ls[ls.index("--instrument") + 1]
                else:
                    if detector == "PNCCD":
                        instrument = "SQS"
                    else:
                        instrument = in_folder.split('/')[4]

                runs = []
                for rc in self.run_candidates:
                    if rc in ls:
                        runs.append(ls[ls.index(rc) + 1])

                requested = "{} {}".format(ls[0], ls[1])

                return [dclass, detector, instrument, in_folder, out_folder,
                        runs, requested]

        last_chars = {}
        last_calib = {}
        host = config["server-config"]["host"]
        port = config["server-config"]["port"]
        for l in last_n_lines:

            info = get_run_info(l, 'DARK')
            if info is not None:
                dclass, detector, instrument, in_folder, out_folder, runs, requested = info  # noqa

                if f"{instrument}-{detector}" in last_chars:
                    continue

                pdfs = glob.glob(f"{out_folder}/*.pdf")
                pdfs += glob.glob(f"{out_folder}/*/*.pdf")
                pdfs.sort()
                pdfs = {p.split("/")[-1]: p for p in pdfs}
                fpdfs = []
                if len(pdfs):
                    # pdfs = ", ".join(pdfs)
                    for pdf, p in pdfs.items():
                        puuid = uuid4().hex
                        self.pdf_queue[puuid] = p
                        fpdfs.append(
                            (pdf, f"http://{host}:{port}/pdf?{puuid}"))
                pdfs = fpdfs
                tsize = 0
                for run in runs:
                    run = int(run)
                    if dclass not in self.mappings:
                        continue
                    for mp in self.mappings[dclass]:
                        for f in glob.glob(
                                f"{in_folder}/r{run:04d}/*{mp}*.h5"):
                            tsize += os.stat(f).st_size
                last_injected = ""
                constant_valid_from = ""
                key = f"{instrument}-{detector}"
                last_chars[key] = {"in_path": in_folder,
                                   "out_path": out_folder,
                                   "runs": runs,
                                   "pdfs": pdfs,
                                   "size": "{:0.1f} GB".format(tsize / 1e9),
                                   "requested": requested,
                                   "device_type": detector,
                                   "last_valid_from": constant_valid_from}

            info = get_run_info(l, 'CORRECT')
            if info is not None:
                _, _, _, in_folder, out_folder, runs, requested = info
                instrument = in_folder.split('/')[4]
                if instrument not in last_calib:
                    last_calib[instrument] = []
                if len(last_calib[instrument]) > config["server-config"]["n-calib"]:   # noqa
                    continue
                proposal = in_folder.split('/')[6]
                pdfs = glob.glob(f"{out_folder}/*.pdf")
                pdfs.sort(key=os.path.getmtime, reverse=True)
                pdfs = {p.split("/")[-1]: p for p in pdfs}

                if [proposal, runs[0]] not in [[x[1], x[2]] for x in
                                               last_calib[instrument]]:
                    last_calib[instrument].append([requested[:-4],
                                                   proposal, runs[0], pdfs])

        tmpl = self.templates["last-characterizations"]
        last_characterizations_r = Template(tmpl).render(char_runs=last_chars,
                                                         host = host,
                                                         port = port)

        tmpl = self.templates["last-correction"]
        last_correction_r = Template(tmpl).render(info=last_calib, host=host,
                                                  port=port)

        conn = sqlite3.connect(config['web-service']['job-db']).cursor()
        conn.execute("SELECT * FROM jobs WHERE status IN ('R', 'PD', 'CG')")
        running_jobs = {}
        for r in conn.fetchall():
            rid, jobid, proposal, run, status, time, det, act = r
            run = int(run)
            key = '{}/r{:04d}/{}/{}'.format(proposal, run, det, act)
            flg = "R"
            if status in ["QUEUE", "PD"]:
                flg = "Q"
            rjobs = running_jobs.get(key, [])
            rjobs.append((flg, '{}-{}'.format(status, time)))
            running_jobs[key] = rjobs

        tmpl = self.templates["running-jobs"]
        running_jobs_r = Template(tmpl).render(running_jobs=running_jobs)

        tmpl = self.templates["request-dark"]
        request_dark_r = Template(tmpl).render(
            instruments=cal_config['dark'].keys())

        tmpl = Template(self.templates["main-doc"])
        message = tmpl.render(maxwell_status=maxwell_status_r,
                              log_output=log_output_r,
                              last_characterizations=last_characterizations_r,
                              last_correction=last_correction_r,
                              request_dark=request_dark_r,
                              running_jobs=running_jobs_r)
        # Write content as utf-8 data
        self.wfile.write(bytes(message, "utf8"))
        return


def run(configfile, port=8008):
    print('reading config file')
    with open(configfile, "r") as cf:
        global config
        config = yaml.load(cf.read(), Loader=yaml.FullLoader)
    with open(config["web-service"]["cal-config"], "r") as cf:
        global cal_config
        cal_config = yaml.load(cf.read(), Loader=yaml.FullLoader)
    print('starting server...')
    sconfig = config["server-config"]
    server_address = (sconfig["host"], sconfig["port"])
    httpd = HTTPServer(server_address, RequestHandler)
    print('running server...')
    httpd.serve_forever()


parser = argparse.ArgumentParser(
    description='Start the overview server')
parser.add_argument('--config', type=str, default="serve_overview.yaml")
if __name__ == "__main__":
    args = vars(parser.parse_args())
    run(args["config"])