"""Monitor calibration jobs in Slurm and send status updates"""
import argparse
import json
import locale
import logging
import os.path
import shlex
import signal
import time
from bisect import insort_left
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from subprocess import run, PIPE

from kafka import KafkaProducer
from kafka.errors import KafkaError

try:
    from .common import notify_ready, file_and_stderr_logs, elapsed_to_seconds
    from .config import webservice as config
    from .messages import MDC, Errors, MigrationError, Success
    from .webservice import init_job_db, init_md_client, time_db_transaction
except ImportError:
    from common import notify_ready, file_and_stderr_logs, elapsed_to_seconds
    from config import webservice as config
    from messages import MDC, Errors, MigrationError, Success
    from webservice import init_job_db, init_md_client, time_db_transaction

log = logging.getLogger(__name__)

STATES_FINISHED = {  # https://slurm.schedmd.com/squeue.html#lbAG
    'BOOT_FAIL',  'CANCELLED', 'COMPLETED',  'DEADLINE', 'FAILED',
    'OUT_OF_MEMORY', 'SPECIAL_EXIT', 'TIMEOUT',
    'NA',  # Unknown (used internally if job ID missing)
}
STATES_FAILED = {'FAILED'}
STATE_ABBREVS = {
    'PENDING': 'PD',
    'RUNNING': 'R',
}


class ExpiringEvents:
    """Track events occuring within a time window.

    Args:
        window (int, optional): Time window in seconds, 600 by default.
    """

    def __init__(self, window=600):
        self.window = window
        self.events = deque()

    def add(self, event=None):
        now = time.monotonic()

        insort_left(self.events, event or now)

        cutoff = now - self.window

        while self.events and self.events[0] < cutoff:
            self.events.popleft()

        return len(self.events)

    def __len__(self):
        return len(self.events)


class NoOpProducer:
    """Fills in for Kafka producer object when setting that up fails"""
    def send(self, topic, value):
        pass


def init_kafka_producer(config):
    try:
        kp = KafkaProducer(
            value_serializer=lambda d: json.dumps(d).encode('utf-8'),
            max_block_ms=2000,  # Don't get stuck trying to send Kafka messages
            **config['kafka']['producer-config'].to_dict()
        )
    except KafkaError:
        log.warning("Problem initialising Kafka producer; "
                        "Kafka notifications will not be sent.", exc_info=True)
        return NoOpProducer()
    else:
        log.info("Connected to Kafka broker (%s) to send notifications",
                 kp.config['bootstrap_servers'])
        return kp


def slurm_status(filter_user=True):
    """ Return the status of slurm jobs by calling squeue

    :param filter_user: set to true to filter ony jobs from current user
    :return: a dictionary indexed by slurm jobid and containing a tuple
             of (status, run time) as values.
    """
    cmd = ["squeue", "--states=all", "--format=%i %T %M %N"]
    if filter_user:
        cmd += ["--me"]
    res = run(cmd, stdout=PIPE, stderr=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().rstrip("\n").split("\n")
        statii = {}
        for r in rlines[1:]:
            try:
                jobid, status, runtime, hostname = r.split(' ')
                jobid = jobid.strip()
                statii[jobid] = status, runtime, hostname
            except ValueError:  # not enough values to unpack in split
                log.warning("Could not parse squeue status line %r", r)
        return statii
    else:
        log.error("Running squeue failed. stdout: %r, stderr: %r",
                  res.stdout.decode(), res.stderr.decode())
        return None


def slurm_job_status(jobid):
    """ Return the status of slurm job

    :param jobid: Slurm job Id
    :return: Slurm state, Elapsed.
    """
    cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"]

    res = run(cmd, stdout=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().split("\n")

        log.debug("Job {} state {}".format(jobid, rlines[2].split()))
        if len(rlines[2].split()) == 3:
            return rlines[2].replace("+", "").split()
    return "NA", "NA", "NA"


class JobsMonitor:
    def __init__(self, config):
        log.info("Starting jobs monitor from %s", Path(__file__).parent)
        self.job_db = init_job_db(config)
        self.mdc = init_md_client(config)
        self.kafka_prod = init_kafka_producer(config)
        self.kafka_topic = config['kafka']['topic']
        self.time_interval = int(config['web-service']['job-update-interval'])
        self.instant_fails_by_host = defauldict(ExpiringEvents)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.job_db.close()
        self.kafka_prod.close(timeout=5)

    def run(self):
        while True:
            try:
                self.do_updates()
            except Exception:
                log.error("Failure to update job DB", exc_info=True)
            time.sleep(self.time_interval)

    def do_updates(self):
        ongoing_jobs_by_exn = self.get_updates_by_exec_id()

        # For executions still running, regroup the jobs by request
        # (by run, for correction requests):
        reqs_still_going = {}
        for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
            if running_jobs_info:
                req_id = self.job_db.execute(
                    "SELECT req_id FROM executions WHERE exec_id = ?", (exec_id,)
                ).fetchone()[0]
                reqs_still_going.setdefault(req_id, []).extend(running_jobs_info)

        # For executions that have finished, send out notifications, and
        # check if the whole request (several executions) has finished.
        reqs_finished = set()
        for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
            if not running_jobs_info:
                req_id = self.process_execution_finished(exec_id)

                if req_id not in reqs_still_going:
                    reqs_finished.add(req_id)

        # Now send updates for all requests which hadn't already finished
        # by the last time this ran:

        for req_id, running_jobs_info in reqs_still_going.items():
            self.process_request_still_going(req_id, running_jobs_info)

        for req_id in reqs_finished:
            self.process_request_finished(req_id)

    def get_updates_by_exec_id(self) -> dict:
        """Get statuses of unfinished jobs, grouped by execution ID

        E.g. {12345: ['R-5:41', 'PD-0:00', ...]}

        Newly completed executions are present with an empty list.
        """
        jobs_to_check = self.job_db.execute(
            "SELECT job_id, exec_id, hostname FROM slurm_jobs WHERE finished = 0"
        ).fetchall()
        if not jobs_to_check:
            log.debug("No unfinished jobs to check")
            return {}

        statii = slurm_status()
        # Check that slurm is giving proper feedback
        if statii is None:
            return {}
        log.debug(f"SLURM info {statii}")

        ongoing_jobs_by_exn = {}
        updates = []
        for r in jobs_to_check:
            log.debug(f"Job in DB before update: %s", tuple(r))
            execn_ongoing_jobs = ongoing_jobs_by_exn.setdefault(r['exec_id'], [])

            if str(r['job_id']) in statii:
                # Jobs which are pending, running, or recently finished (from squeue)
                # Jobs stay for >= 150s after finishing, so we should always see them.
                slstatus, runtime, hostname = statii[str(r['job_id'])]
            else:
                # Fallback: get job info from sacct
                _, runtime, slstatus = slurm_job_status(r['job_id'])
                # We *don't* take hostname from sacct; with some GPFS issues
                # the job may not get recorded, and we don't want to overwrite
                # a hostname we previously got from squeue.
                hostname = r['hostname']

            finished = slstatus in STATES_FINISHED
            if not finished:
                short_state = STATE_ABBREVS.get(slstatus, slstatus)
                execn_ongoing_jobs.append(f"{short_state}-{runtime}")
            elif slstatus in STATES_FAILED and elapsed_to_seconds(runtime) < 2:
                # Specific branch to catch potentially broken nodes.
                num_fails = self.instant_fails_by_host(hostname).add()
                log.warning(f"Job {r['job_id']} failed instantly on "
                            f"{hostname}, {num_fails} on this host within "
                            f"the last 10 minutes")

            updates.append((finished, runtime, slstatus, hostname, r['job_id']))

        with time_db_transaction(self.job_db, 'Update jobs'):
            self.job_db.executemany(
                "UPDATE slurm_jobs SET "
                "finished=?, elapsed=?, status=?, hostname=? WHERE job_id = ?",
                updates
            )

        return ongoing_jobs_by_exn

    def process_request_still_going(self, req_id, running_jobs_info):
        """Send myMdC updates for a request with jobs still running/pending"""
        mymdc_id, action = self.job_db.execute(
            "SELECT mymdc_id, action FROM requests WHERE req_id = ?",
            (req_id,)
        ).fetchone()

        if all(s.startswith('PD-') for s in running_jobs_info):
            # Avoid swamping myMdC with updates for jobs still pending.
            log.debug("No update for %s request with mymdc id %s: jobs pending",
                      action, mymdc_id)
            return

        msg = "\n".join(running_jobs_info)
        log.debug("Update MDC for %s, %s: %s",
                  action, mymdc_id, ', '.join(running_jobs_info)
                  )

        if action == 'CORRECT':
            self.mymdc_update_run(mymdc_id, msg)
        else:  # action == 'DARK'
            self.mymdc_update_dark(mymdc_id, msg)

    def process_execution_finished(self, exec_id):
        """Send notification & record that one execution has finished"""
        statuses = [r[0] for r in self.job_db.execute(
            "SELECT status FROM slurm_jobs WHERE exec_id = ?", (exec_id,)
        ).fetchall()]
        success = set(statuses) == {'COMPLETED'}
        r = self.job_db.execute(
            "SELECT det_type, karabo_id, command, "
            "req_id, proposal, run, action, mymdc_id, timestamp "
            "FROM executions JOIN requests USING (req_id)"
            "WHERE exec_id = ?",
            (exec_id,)
        ).fetchone()
        with time_db_transaction(self.job_db, 'Update execution'):
            self.job_db.execute(
                "UPDATE executions SET success = ? WHERE exec_id = ?",
                (success, exec_id)
            )
        log.info("Execution finished: %s for (p%s, r%s, %s), success=%s",
                 r['action'], r['proposal'], r['run'], r['karabo_id'], success)

        if r['action'] == 'CORRECT':
            try:
                self.kafka_prod.send(self.kafka_topic, {
                    'event': 'correction_complete',
                    'proposal': r['proposal'],
                    'run': str(r['run']),
                    'detector': r['det_type'],
                    'detector_identifier': r['karabo_id'],
                    'success': success,
                })
            except KafkaError:
                log.warning("Error sending Kafka notification",
                            exc_info=True)

            self.record_correction_report(
                r['mymdc_id'], r['command'], r['karabo_id'], success, r['timestamp']
            )

        return r['req_id']

    def process_request_finished(self, req_id):
        """Send Kafka notifications and update myMDC that a request has finished."""
        krb_id_successes = {r[0]: r[1] for r in self.job_db.execute(
            "SELECT karabo_id, success FROM executions WHERE req_id = ?",
            (req_id,)
        ).fetchall()}
        success = (set(krb_id_successes.values()) == {1})

        r = self.job_db.execute(
            "SELECT * FROM requests WHERE req_id = ?", (req_id,)
        ).fetchone()
        log.info(
            "Jobs finished - action: %s, myMdC id: %s, success: %s",
            r['action'], r['mymdc_id'], success,
        )
        if r['action'] == 'CORRECT':
            try:
                self.kafka_prod.send(self.kafka_topic, {
                    'event': 'run_corrections_complete',
                    'proposal': r['proposal'],
                    'run': str(r['run']),
                    'success': success,
                })
            except KafkaError:
                log.warning("Error sending Kafka notification",
                            exc_info=True)

        if success:
            msg = "Calibration jobs succeeded"
        else:
            # List success & failure by karabo_id
            krb_ids_ok = {k for (k, v) in krb_id_successes.items() if v == 1}
            ok = ', '.join(sorted(krb_ids_ok)) if krb_ids_ok else 'none'
            krb_ids_failed = {k for (k, v) in krb_id_successes.items() if v == 0}
            msg = f"Succeeded: {ok}; Failed: {', '.join(sorted(krb_ids_failed))}"

        log.debug("Update MDC for %s, %s: %s", r['action'], r['mymdc_id'], msg)

        if r['action'] == 'CORRECT':
            if success:
                status = 'A'  # Available
            elif set(krb_id_successes.values()) == {0, 1}:
                status = 'AW' # Available with Warning (failed for some detectors)
            else:
                status = 'E'  # Error
            self.mymdc_update_run(r['mymdc_id'], msg, status)
        else:  # r['action'] == 'DARK'
            status = 'F' if success else 'E'  # Finished/Error
            self.mymdc_update_dark(r['mymdc_id'], msg, status)

    def mymdc_update_run(self, run_id, msg, status='IP'):
        """Update correction status in MyMdC"""
        data = {'cal_pipeline_reply': msg, 'flg_cal_data_status': status}
        if status != 'IP':
            data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat()
        response = self.mdc.update_run_api(run_id, data)
        if response.status_code != 200:
            log.error("Failed to update MDC run id %s", run_id)
            log.error(Errors.MDC_RESPONSE.format(response))

    def mymdc_update_dark(self, dark_run_id, msg, status='IP'):
        """Update dark run status in MyMdC"""
        data = {'dark_run': {'flg_status': status,
                             'calcat_feedback': msg}}
        response = self.mdc.update_dark_run_api(dark_run_id, data)

        if response.status_code != 200:
            log.error("Failed to update MDC dark run id %s", dark_run_id)
            log.error(Errors.MDC_RESPONSE.format(response))

    def record_correction_report(
            self, mymdc_run_id, command, karabo_id, success, request_time: str
    ):
        """Add report to MyMdC when a correction execution has finished"""
        args = shlex.split(command)
        try:
            report_path = args[args.index("--report-to") + 1] + ".pdf"
        except (ValueError, IndexError):
            log.error("Couldn't find report path in %r", args)
            return

        if not os.path.isfile(report_path):
            log.error("Jobs finished, but report file %s missing", report_path)
            return

        desc = f"{karabo_id} detector corrections"
        if not success:
            desc += " (errors occurred)"

        log.debug("Adding report file %s to MDC for run ID %s",
                  report_path, mymdc_run_id)

        response = self.mdc.create_report_api({
            "name": os.path.basename(report_path),
            "cal_report_path": os.path.dirname(report_path).rstrip('/') + '/',
            "cal_report_at": request_time,
            "run_id": mymdc_run_id,
            "description": desc,
        })

        if response.status_code >= 400:
            log.error("Failed to add report to MDC for run ID %s: HTTP status %s",
                      mymdc_run_id, response.status_code)


def interrupted(signum, frame):
    raise KeyboardInterrupt


def main(argv=None):
    # Ensure files are opened as UTF-8 by default, regardless of environment.
    locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8'))

    parser = argparse.ArgumentParser(
        description='Start the calibration webservice'
    )
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--log-file', type=Path, default='~/webservice-logs/monitor.log')
    parser.add_argument(
        '--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa
    )
    args = parser.parse_args(argv)

    if args.config_file is not None:
        config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()])

    logging.basicConfig(
        handlers=file_and_stderr_logs(args.log_file),
        level=args.log_level,
    )
    # DEBUG logs from kafka-python are very verbose, so we'll turn them off
    logging.getLogger('kafka').setLevel(logging.INFO)
    # Likewise requests_oauthlib
    logging.getLogger('requests_oauthlib').setLevel(logging.INFO)

    # Treat SIGTERM like SIGINT (Ctrl-C) & do a clean shutdown
    signal.signal(signal.SIGTERM, interrupted)

    with JobsMonitor(config) as jm:
        notify_ready()
        try:
            jm.run()
        except KeyboardInterrupt:
            logging.info("Shutting down on SIGINT/SIGTERM")


if __name__ == "__main__":
    main()