import argparse
import glob
import os
import shlex
import sqlite3
from datetime import datetime, timedelta, timezone
from dateutil.parser import parse as parse_datetime
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from shutil import copyfileobj
from subprocess import check_output
from typing import Optional

import yaml
from jinja2 import Template

from xfel_calibrate.settings import free_nodes_cmd, preempt_nodes_cmd

try:
    from .config import serve_overview as config
except:
    from config import serve_overview as config


def elapsed_to_timedelta(text):
    return timedelta(seconds=sum(
        [60**i * int(x) for i, x in enumerate(text.split(':')[::-1])]))

def datetime_to_grafana(dt):
    return int(dt.timestamp() * 1000)


# HTTPRequestHandler class
class RequestHandler(BaseHTTPRequestHandler):
    conf_was_init = False

    def init_config(self):
        self.total_jobs_cmd = config["shell-commands"]["total-jobs"]
        self.tail_log_cmd = config["shell-commands"]["tail-log"]
        self.job_nodes_cmd = config["shell-commands"]["job-nodes"]
        self.run_candidates = config["run-candidates"]

        self.templates = {}
        for template, tfile in config["templates"].items():
            with open(tfile, "r") as tf:
                self.templates[template] = tf.read()

        self.jobs_db = sqlite3.connect(config['web-service']['job-db'])

        self.conf_was_init = True

    def serve_css(self):
        """Serve /serve_overview.css"""
        self.send_response(200)
        self.send_header('Content-type', 'text/css')
        self.end_headers()
        self.wfile.write(self.templates["css"].encode('utf-8'))

    def serve_dark_overview(self):
        # Send headers
        self.send_response(200)
        self.send_header('Content-type', 'text/html')
        self.end_headers()
        host = config["server-config"]["host"]
        port = config["server-config"]["port"]
        reports = {}
        for instrument, detectors in cal_config['dark'].items():
            reports[instrument] = {}
            for detector in detectors:
                tmpl = f'/gpfs/exfel/d/cal/caldb_store/xfel/reports/{instrument}/{detector}/dark/*pdf'
                files = glob.glob(tmpl)
                files.sort(key=os.path.getmtime, reverse=True)
                file_info = []
                for i, file in enumerate(files):
                    if (len(file_info) % 2) == 0:
                        bgcolor = 'EEEEEE'
                    else:
                        bgcolor = 'FFFFFF'
                    time = os.stat(file).st_mtime
                    d_time = datetime.fromtimestamp(time).replace(
                        tzinfo=timezone.utc)
                    s_time = d_time.strftime('%y-%m-%d %H:%M')
                    file_info.append([file, s_time, bgcolor])

                reports[instrument][detector] = file_info

        tmpl = Template(self.templates["dark-overview"])
        message = tmpl.render(reports=reports, host=host, port=port)

        self.wfile.write(bytes(message, "utf8"))

    def serve_file_from_gpfs(self):
        """Serve a file from a path starting with /gpfs"""
        if self.path.endswith(".html"):
            mimetype = 'text/html'
        elif self.path.endswith(".jpg"):
            mimetype = 'image/jpg'
        elif self.path.endswith(".gif"):
            mimetype = 'image/gif'
        elif self.path.endswith(".png"):
            mimetype = 'image/png'
        elif self.path.endswith(".pdf"):
            mimetype = 'application/pdf'
        elif self.path.endswith(".js"):
            mimetype = 'application/javascript'
        elif self.path.endswith(".css"):
            mimetype = 'text/css'
        else:
            return self.send_error(404)

        if os.path.isfile(self.path):
            self.send_response(200)
            self.send_header('Content-Length', str(os.stat(self.path).st_size))
            self.send_header('Content-type', mimetype)
            self.end_headers()
            with open(self.path, "rb") as f:
                copyfileobj(f, self.wfile)
        else:
            self.send_error(404)

    def do_GET(self):

        if not self.conf_was_init:
            self.init_config()

        if "/serve_overview.css" in self.path:
            return self.serve_css()

        if "dark?" in self.path:
            return self.serve_dark_overview()

        if "/gpfs" in self.path:
            return self.serve_file_from_gpfs()

        # Send response status code
        self.send_response(200)

        # Send headers
        self.send_header('Content-type', 'text/html')
        self.end_headers()

        # Send message back to client
        # general maxwell stats
        free = int(check_output(free_nodes_cmd, shell=True).decode('utf8'))
        preempt = int(check_output(
            preempt_nodes_cmd, shell=True).decode('utf8'))
        nodes_avail_general = free + preempt
        total_jobs_running = check_output(self.total_jobs_cmd, shell=True)
        total_jobs_running = total_jobs_running.decode('utf8').split("/")[0]

        tmpl = Template(self.templates["maxwell-status"])
        maxwell_status_r = tmpl.render(nodes_avail_general=nodes_avail_general,
                                       total_jobs_running=total_jobs_running,
                                      )

        last_n_lines = check_output(self.tail_log_cmd,
                                    shell=True).decode('utf8').split("\n")
        last_n_lines = [l for l in last_n_lines
                        if ("Response error from MDC" not in l
                            and "DEBUG" not in l)]
        tmpl = Template(self.templates["log-output"])
        log_output_r = tmpl.render(service="Webservice", lines=last_n_lines)

        last_n_lines_monitor = [l for l in check_output(
                config["shell-commands"]["tail-log-monitor"], shell=True
            ).decode('utf8').split("\n")
            if "DEBUG" not in l
        ]
        log_output_monitor_r = tmpl.render(
            service="Job monitor", lines=last_n_lines_monitor
        )

        host = config["server-config"]["host"]
        port = config["server-config"]["port"]

        tmpl = self.templates["last-characterizations"]
        last_characterizations_r = Template(tmpl).render(
            char_runs=self.get_last_chars(), host=host, port=port
        )

        tmpl = self.templates["last-correction"]
        last_correction_r = Template(tmpl).render(
            info=self.get_last_corrections(), host=host, port=port
        )

        c = self.jobs_db.execute(
            "SELECT job_id, status, elapsed, det_type, proposal, run, action "
            "FROM slurm_jobs INNER JOIN executions USING (exec_id) "
            "INNER JOIN requests USING (req_id) "
            "WHERE finished = 0"
        )

        now = datetime.now()
        running_jobs = {}

        # Mapping of job ID (str) to node for running jobs.
        job_nodes = dict([
            x.split(",") for x in check_output(
                self.job_nodes_cmd, shell=True).decode("utf-8").split("\n")
            if x
        ])

        for job_id, status, elapsed, det, proposal, run, act in c:
            key = f'{proposal}/r{int(run):04d}/{det}/{act}'
            flg = "Q" if status in {"QUEUE", "PENDING"} else "R"
            rjobs = running_jobs.setdefault(key, [])

            rjobs.append((
                flg,
                f'{status[0]}-{elapsed}',
                datetime_to_grafana(now - elapsed_to_timedelta(elapsed)),
                job_nodes.get(str(job_id), '')
            ))

        tmpl = self.templates["running-jobs"]
        running_jobs_r = Template(tmpl).render(
            running_jobs=running_jobs, timestamp_now=datetime_to_grafana(now))

        tmpl = Template(self.templates["main-doc"])
        message = tmpl.render(maxwell_status=maxwell_status_r,
                              log_output=log_output_r,
                              log_output_monitor=log_output_monitor_r,
                              last_characterizations=last_characterizations_r,
                              last_correction=last_correction_r,
                              running_jobs=running_jobs_r)
        # Write content as utf-8 data
        self.wfile.write(bytes(message, "utf8"))
        return

    def parse_calibrate_command(self, cmd):
        args = shlex.split(cmd)

        in_folder = args[args.index("--in-folder") + 1]
        report_to = args[args.index("--report-to") + 1]
        out_folder = args[args.index("--out-folder") + 1]
        detector = args[args.index("--karabo-id") + 1]
        instrument = in_folder.split('/')[4]

        runs = []
        for rc in self.run_candidates:
            if rc in args:
                runs.append(args[args.index(rc) + 1])

        return detector, instrument, in_folder, out_folder, report_to, runs

    def get_last_chars(self):
        cur = self.jobs_db.execute("""
            SELECT command, det_type, karabo_id, timestamp
            FROM executions INNER JOIN requests USING (req_id)
            WHERE action = 'DARK'
            ORDER BY timestamp DESC
            LIMIT 100
        """)

        last_chars = {}
        for command, det_type, karabo_id, timestamp in cur:
            try:
                detector, instrument, in_folder, out_folder, report_to, runs = \
                    self.parse_calibrate_command(command)
            except Exception as e:
                print("Failed parsing xfel-calibrate command", e, flush=True)
                continue

            key = detector if instrument in detector else f"{instrument}-{detector}"  # noqa

            # Check if instrument is in detector name
            # This is not the case for e.g. CALLAB
            key = detector if instrument in detector else f"{instrument}-{detector}"  # noqa
            if key in last_chars:
                continue

            pdfs = [
                (os.path.basename(p), p)
                for p in sorted(glob.glob(f"{report_to}*.pdf"))
            ]
            tsize = 0
            for run in runs:
                run = int(run)
                if 'karabo-da' not in cal_config['data-mapping'].get(detector, {}):
                    continue
                # ToDo calculate tsize based on selected karabo-da
                for mp in cal_config['data-mapping'][detector]['karabo-da']:
                    for f in glob.glob(
                            f"{in_folder}/r{run:04d}/*{mp}*.h5"):
                        tsize += os.stat(f).st_size

            timestamp = parse_datetime(timestamp).strftime('%Y-%m-%d %H:%M:%S')

            last_chars[key] = {"in_path": in_folder,
                               "out_path": out_folder,
                               "runs": runs,
                               "pdfs": pdfs,
                               "size": "{:0.1f} GB".format(tsize / 1e9),
                               "requested": timestamp,
                               "device_type": detector,
                              }

        return last_chars

    def get_last_corrections(self):
        cur = self.jobs_db.execute("""
                    SELECT command, det_type, karabo_id, proposal, timestamp
                    FROM executions INNER JOIN requests USING (req_id)
                    WHERE action = 'CORRECT'
                    ORDER BY timestamp DESC
                    LIMIT 100
                """)

        last_calib = {}
        for command, det_type, karabo_id, proposal, timestamp in cur:
            try:
                detector, instrument, in_folder, out_folder, report_to, runs = \
                    self.parse_calibrate_command(command)
            except Exception as e:
                print("Failed parsing xfel-calibrate command", e, flush=True)
                continue

            inst_records = last_calib.setdefault(instrument, [])
            if len(inst_records) >= config["server-config"]["n-calib"]:  # noqa
                continue

            pdfs = glob.glob(f"{report_to}*.pdf")
            pdfs.sort(key=os.path.getmtime, reverse=True)
            pdfs = {p.split("/")[-1]: p for p in pdfs}

            timestamp = parse_datetime(timestamp).strftime('%Y-%m-%d %H:%M')
            if not any(r[1:3] == (proposal, runs[0]) for r in inst_records):
                inst_records.append((
                    timestamp, proposal, runs[0], pdfs
                ))

        return last_calib


def run(config_file: Optional[str] = None):
    if config_file is not None:
        config.configure(includes_for_dynaconf=[Path(config_file).absolute()])

    with open(config["web-service"]["cal-config"], "r") as cf:
        global cal_config
        cal_config = yaml.load(cf.read(), Loader=yaml.FullLoader)
    print('starting server...', flush=True)
    sconfig = config["server-config"]
    server_address = (sconfig["host"], sconfig["port"])
    httpd = HTTPServer(server_address, RequestHandler)
    print('running server...', flush=True)
    httpd.serve_forever()


parser = argparse.ArgumentParser(
    description='Start the overview server')
parser.add_argument('--config', type=str, default=None,)
if __name__ == "__main__":
    args = vars(parser.parse_args())
    run(args["config"])