Skip to content
Snippets Groups Projects
job_monitor.py 20.6 KiB
Newer Older
"""Monitor calibration jobs in Slurm and send status updates"""
import argparse
import json
import locale
import logging
import signal
import time
from datetime import datetime, timezone
from pathlib import Path
from subprocess import run, PIPE

from kafka import KafkaProducer
from kafka.errors import KafkaError

try:
    from .common import notify_ready
    from .config import webservice as config
    from .messages import MDC, Errors, MigrationError, Success
    from .webservice import init_job_db, init_md_client, time_db_transaction
except ImportError:
    from common import notify_ready
    from config import webservice as config
    from messages import MDC, Errors, MigrationError, Success
    from webservice import init_job_db, init_md_client, time_db_transaction

log = logging.getLogger(__name__)

STATES_FINISHED = {  # https://slurm.schedmd.com/squeue.html#lbAG
    'BOOT_FAIL',  'CANCELLED', 'COMPLETED',  'DEADLINE', 'FAILED',
    'OUT_OF_MEMORY', 'SPECIAL_EXIT', 'TIMEOUT',
STATE_ABBREVS = {
    'PENDING': 'PD',
    'RUNNING': 'R',
}

class NoOpProducer:
    """Fills in for Kafka producer object when setting that up fails"""
    def send(self, topic, value):
        pass


def init_kafka_producer(config):
    try:
            value_serializer=lambda d: json.dumps(d).encode('utf-8'),
            max_block_ms=2000,  # Don't get stuck trying to send Kafka messages
            **config['kafka']['producer-config'].to_dict()
        log.warning("Problem initialising Kafka producer; "
                        "Kafka notifications will not be sent.", exc_info=True)
        return NoOpProducer()
    else:
        log.info("Connected to Kafka broker (%s) to send notifications",
                 kp.config['bootstrap_servers'])
        return kp


def slurm_status(filter_user=True):
    """ Return the status of slurm jobs by calling squeue

    :param filter_user: set to true to filter ony jobs from current user
    :return: a dictionary indexed by slurm jobid and containing a tuple
             of (status, run time) as values.
    """
    cmd = ["squeue", "--states=all", "--format=%i %T %M"]
Thomas Kluyver's avatar
Thomas Kluyver committed
        cmd += ["--me"]
    res = run(cmd, stdout=PIPE, stderr=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().split("\n")
        statii = {}
        for r in rlines[1:]:
            try:
                jobid, status, runtime = r.split()
                jobid = jobid.strip()
                statii[jobid] = status, runtime
            except ValueError:  # not enough values to unpack in split
                pass
        return statii
    else:
        log.error("Running squeue failed. stdout: %r, stderr: %r",
                  res.stdout.decode(), res.stderr.decode())
        return None


def slurm_job_status(jobid):
    """ Return the status of slurm job

    :param jobid: Slurm job Id
    :return: Slurm state, Elapsed.
    """
    cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"]

    res = run(cmd, stdout=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().split("\n")

        log.debug("Job {} state {}".format(jobid, rlines[2].split()))
        if len(rlines[2].split()) == 3:
            return rlines[2].replace("+", "").split()
    return "NA", "NA", "NA"


def parse_log_file(file_path):
    results = []
    with open(file_path, 'r') as file:
        for line in file:
            try:
                log_entry = json.loads(line.strip())
                error_message = log_entry.get('message', '')
                error_class = log_entry.get('class', '')
                results.append((error_message, error_class))
            except json.JSONDecodeError:
                log.error(f"Skipping invalid JSON: {line.strip()}")
    return results


def get_report_dir(command):
    args = shlex.split(command)
    try:
        return args[args.index("--report-to") + 1]
    except (ValueError, IndexError):
        log.error("Couldn't find report directory in %r", args)
        return


def process_log_file(job_id, karabo_id, report_dir, karabo_id_log, file):
    if file.exists():
        with open(file, 'r') as f:
            for line in f:
                try:
                    json_line = json.loads(line)
                    if "default" not in json_line['class'].lower():
                        message = json_line['message']
                        if message:
                            karabo_id_log.setdefault(
                                message, []).append(job_id)
                except json.JSONDecodeError:
                    log.error(
                        f"Invalid JSON in errors file {file}:"
                        f" {line.strip()}")
    return karabo_id_log


def compress_job_ids(job_ids):
    """Compress list of job IDs to a shorter representation.

    Args:
        job_ids (list): List of job IDs

    Returns:
        str: Compressed representation like "16 jobs (11498126-11498141)"
        or "2 jobs (11498142, 11498143)" for non-sequential IDs
    """
    if not job_ids:
        return "0 jobs"

    # Convert to integers and sort
    ids = sorted(int(id) for id in job_ids)

    # Check if they're sequential
    if len(ids) > 2 and ids[-1] - ids[0] + 1 == len(ids):
        return f"{len(ids)} jobs ({ids[0]}-{ids[-1]})"

    if len(ids) > 4:
        return f"{len(ids)} jobs (e.g. {ids[0]}, {ids[1]}...)"

    return f"{len(ids)} jobs ({', '.join(str(id) for id in ids)})"


def format_log_message(errors):
    """Format log messages with compressed job IDs."""
    formatted = {}
    for karabo_id, messages in errors.items():
        formatted[karabo_id] = {
            msg: compress_job_ids(job_ids)
            for msg, job_ids in messages.items()
        }
    return formatted


class JobsMonitor:
    def __init__(self, config):
        log.info("Starting jobs monitor")
        self.job_db = init_job_db(config)
        self.mdc = init_md_client(config)
        self.kafka_prod = init_kafka_producer(config)
        self.kafka_topic = config['kafka']['topic']
        self.time_interval = int(config['web-service']['job-update-interval'])

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.job_db.close()
        self.kafka_prod.close(timeout=5)

    def run(self):
        while True:
            try:
                self.do_updates()
            except Exception:
                log.error("Failure to update job DB", exc_info=True)
            time.sleep(self.time_interval)

    def do_updates(self):
        ongoing_jobs_by_exn = self.get_updates_by_exec_id()

        # For executions still running, regroup the jobs by request
        # (by run, for correction requests):
        reqs_still_going = {}
        for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
            if running_jobs_info:
                req_id = self.job_db.execute(
                    "SELECT req_id FROM executions WHERE exec_id = ?", (exec_id,)
                reqs_still_going.setdefault(req_id, []).extend(running_jobs_info)

        # For executions that have finished, send out notifications, and
        # check if the whole request (several executions) has finished.
        reqs_finished = set()
        for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
            if not running_jobs_info:
                req_id = self.process_execution_finished(exec_id)

                if req_id not in reqs_still_going:
                    reqs_finished.add(req_id)

        # Now send updates for all requests which hadn't already finished
        # by the last time this ran:

        for req_id, running_jobs_info in reqs_still_going.items():
            self.process_request_still_going(req_id, running_jobs_info)

        for req_id in reqs_finished:
            self.process_request_finished(req_id)

    def get_updates_by_exec_id(self) -> dict:
        """Get statuses of unfinished jobs, grouped by execution ID

        E.g. {12345: ['R-5:41', 'PD-0:00', ...]}

        Newly completed executions are present with an empty list.
        """
        jobs_to_check = self.job_db.execute(
            "SELECT job_id, exec_id FROM slurm_jobs WHERE finished = 0"
        ).fetchall()
        if not jobs_to_check:
            log.debug("No unfinished jobs to check")
            return {}

        statii = slurm_status()
        # Check that slurm is giving proper feedback
        if statii is None:
            return {}
        log.debug(f"SLURM info {statii}")

        ongoing_jobs_by_exn = {}
        updates = []
        for r in jobs_to_check:
            log.debug(f"Job in DB before update: %s", tuple(r))
            execn_ongoing_jobs = ongoing_jobs_by_exn.setdefault(r['exec_id'], [])
            if str(r['job_id']) in statii:
                # statii contains jobs which are still going (from squeue)
                slstatus, runtime = statii[str(r['job_id'])]
            else:
                # These jobs have finished (successfully or otherwise)
                _, runtime, slstatus = slurm_job_status(r['job_id'])
            if not finished:
                short_state = STATE_ABBREVS.get(slstatus, slstatus)
                execn_ongoing_jobs.append(f"{short_state}-{runtime}")
            updates.append((finished, runtime, slstatus, r['job_id']))

        with time_db_transaction(self.job_db, 'Update jobs'):
            self.job_db.executemany(
                "UPDATE slurm_jobs SET finished=?, elapsed=?, status=? WHERE job_id = ?",
            )

        return ongoing_jobs_by_exn

    def process_request_still_going(self, req_id, running_jobs_info):
        """Send myMdC updates for a request with jobs still running/pending"""
        mymdc_id, action = self.job_db.execute(
            "SELECT mymdc_id, action FROM requests WHERE req_id = ?",
            (req_id,)
        ).fetchone()

        if all(s.startswith('PD-') for s in running_jobs_info):
            # Avoid swamping myMdC with updates for jobs still pending.
            log.debug("No update for %s request with mymdc id %s: jobs pending",
                      action, mymdc_id)
            return

        msg = "\n".join(running_jobs_info)
        log.debug("Update MDC for %s, %s: %s",
                  action, mymdc_id, ', '.join(running_jobs_info)
                  )

        if action == 'CORRECT':
            self.mymdc_update_run(mymdc_id, msg)
        else:  # action == 'DARK'
            self.mymdc_update_dark(mymdc_id, msg)

    def process_execution_finished(self, exec_id):
        """Send notification & record that one execution has finished"""
        statuses = [r[0] for r in self.job_db.execute(
            "SELECT status FROM slurm_jobs WHERE exec_id = ?", (exec_id,)
        ).fetchall()]
        success = set(statuses) == {'COMPLETED'}
        r = self.job_db.execute(
            "SELECT det_type, karabo_id, command, "
            "req_id, proposal, run, action, mymdc_id, timestamp "
            "FROM executions JOIN requests USING (req_id)"
            "WHERE exec_id = ?",
            (exec_id,)
        ).fetchone()
        with time_db_transaction(self.job_db, 'Update execution'):
            self.job_db.execute(
                "UPDATE executions SET success = ? WHERE exec_id = ?",
                (success, exec_id)
            )
        log.info("Execution finished: %s for (p%s, r%s, %s), success=%s",
                 r['action'], r['proposal'], r['run'], r['karabo_id'], success)
        if r['action'] == 'CORRECT':
            try:
                self.kafka_prod.send(self.kafka_topic, {
                    'event': 'correction_complete',
                    'proposal': r['proposal'],
                    'detector': r['det_type'],
                    'detector_identifier': r['karabo_id'],
                    'success': success,
                })
            except KafkaError:
                log.warning("Error sending Kafka notification",
                            exc_info=True)

            self.record_correction_report(
                r['mymdc_id'], r['command'], r['karabo_id'], success, r['timestamp']
        return r['req_id']

    def process_request_finished(self, req_id):
Karim Ahmed's avatar
Karim Ahmed committed
        """Send Kafka notifications and update myMDC that a request has finished."""
        execs = self.job_db.execute(
            "SELECT karabo_id, success, command FROM executions WHERE req_id = ?",
            (req_id,)).fetchall()

        # both dicts will be structured as {karabo_id: {message: [Job IDs]}}
        warnings = {}
        errors = {}
        krb_ids_failed = []
        krb_ids_success = []

        for karabo_id, exec_success, command in execs:

            # Get status for all jobs in this execution
            job_statuses = self.job_db.execute(
                "SELECT job_id, status FROM slurm_jobs WHERE exec_id IN "
                "(SELECT exec_id FROM executions WHERE karabo_id = ? AND req_id = ?)",
                (karabo_id, req_id)
            ).fetchall()

            # Look at logs and check if there are warnings/errors
            report_dir = get_report_dir(command)
            if not report_dir:
                log.error(f"Could not get report directory from command: {command}")
                continue
            report_dir = Path(report_dir)
            if not report_dir.exists():
                log.error(f"Report directory does not exist: {report_dir}")
                continue

            if exec_success:
                krb_ids_success.append(karabo_id)
            else:
                krb_ids_failed.append(karabo_id)

            for job_id, status in job_statuses:
                if status == "COMPLETED":
                    continue

                if not exec_success:  # no errors expected for successful execution.
                    karabo_id_err = errors.setdefault(karabo_id, {})

                    if status == "FAILED":  # process error logs
                        error_file = report_dir / f"errors_{job_id}.log"
                        process_log_file(
                            job_id, karabo_id, report_dir, karabo_id_err, error_file)
                        if len(karabo_id_err) == 0:
                            log.warning(f"Job {job_id} failed but no error log/messages found.")
                            karabo_id_err.setdefault(
                                "Job failed but no error logs found", []).append(job_id)
                    else: # Job unsucessefull with a status other than `FAILED`
                        karabo_id_err.setdefault(
                            f"SLURM job terminated with status: {status}", []).append(job_id)

                # Process warning logs
                warning_file = report_dir / f"warnings_{job_id}.log"
                process_log_file(
                            job_id, karabo_id, report_dir, karabo_id_err, warning_file)

        success = (not krb_ids_failed)

        r = self.job_db.execute(
            "SELECT * FROM requests WHERE req_id = ?", (req_id,)
        ).fetchone()
        log.info(
            "Jobs finished - action: %s, myMdC id: %s, success: %s",
            r['action'], r['mymdc_id'], success,
        if r['action'] == 'CORRECT':
            try:
                self.kafka_prod.send(self.kafka_topic, {
                    'event': 'run_corrections_complete',
                    'proposal': r['proposal'],
                    'success': success,
                })
            except KafkaError:
                log.warning("Error sending Kafka notification",
                            exc_info=True)

        if success:
            if warnings:
                msg = f"Calibration jobs succeeded with warnings: {json.dumps(format_log_message(warnings), indent=4)}"
            else:
                msg = "Calibration jobs succeeded"
            # List success & failure by karabo_id
            ok = ', '.join(sorted(krb_ids_success)) if krb_ids_success else 'none'
            msg = (
                f"Succeeded: {ok}; Failed: {', '.join(sorted(krb_ids_failed))} :"
                f" {json.dumps(format_log_message(errors), indent=4)}"
            )
        log.debug("Update MDC for %s, %s: %s", r['action'], r['mymdc_id'], msg)
        if r['action'] == 'CORRECT':
            if success:
                status = 'A'  # Available
            elif (krb_ids_success and krb_ids_failed) or warnings:
                # TODO: do we keep only an error when one detector fail?
                status = 'AW' # Available with Warning (failed for some detectors)
            else:
                status = 'E'  # Error
            self.mymdc_update_run(r['mymdc_id'], msg, status)
        else:  # r['action'] == 'DARK'
            # TODO: add warning when available at myMDC.
            status = 'F' if success else 'E'  # Finished/Error
            self.mymdc_update_dark(r['mymdc_id'], msg, status)
    def mymdc_update_run(self, run_id, msg, status='IP'):
        """Update correction status in MyMdC"""
        data = {'cal_pipeline_reply': msg, 'flg_cal_data_status': status}
        if status != 'IP':
            data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat()
        response = self.mdc.update_run_api(run_id, data)
        if response.status_code != 200:
            log.error("Failed to update MDC run id %s", run_id)
            log.error(Errors.MDC_RESPONSE.format(response))

    def mymdc_update_dark(self, dark_run_id, msg, status='IP'):
        """Update dark run status in MyMdC"""
        data = {'dark_run': {'flg_status': status,
                             'calcat_feedback': msg}}
        response = self.mdc.update_dark_run_api(dark_run_id, data)

        if response.status_code != 200:
            log.error("Failed to update MDC dark run id %s", dark_run_id)
            log.error(Errors.MDC_RESPONSE.format(response))
    def record_correction_report(
            self, mymdc_run_id, command, karabo_id, success, request_time: str
    ):
        """Add report to MyMdC when a correction execution has finished"""
        report_path = get_report_dir(command) + ".pdf"
        if not os.path.isfile(report_path):
            log.error("Jobs finished, but report file %s missing", report_path)
            return

        desc = f"{karabo_id} detector corrections"
        if not success:
            desc += " (errors occurred)"

        log.debug("Adding report file %s to MDC for run ID %s",
                  report_path, mymdc_run_id)

        response = self.mdc.create_report_api({
            "name": os.path.basename(report_path),
            "cal_report_path": os.path.dirname(report_path).rstrip('/') + '/',
            "cal_report_at": request_time,
            "run_id": mymdc_run_id,
            "description": desc,
        })

        if response.status_code >= 400:
            log.error("Failed to add report to MDC for run ID %s: HTTP status %s",
                      mymdc_run_id, response.status_code)


def interrupted(signum, frame):
    raise KeyboardInterrupt

def main(argv=None):
    # Ensure files are opened as UTF-8 by default, regardless of environment.
    locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8'))

    parser = argparse.ArgumentParser(
        description='Start the calibration webservice'
    )
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--log-file', type=str, default='./monitor.log')
    parser.add_argument(
        '--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa
    )
    args = parser.parse_args(argv)

    if args.config_file is not None:
        config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()])

    fmt = '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] %(message)s' # noqa
    logging.basicConfig(
        filename=args.log_file,
        level=getattr(logging, args.log_level),
        format=fmt
    )
    # Also log to the journal (via stderr), which keeps its own timestamps
    streamhandler = logging.StreamHandler()
    streamhandler.setFormatter(logging.Formatter(
        '%(name)s - %(levelname)s - [%(filename)s:%(lineno)d] %(message)s'
    ))
    logging.getLogger().addHandler(streamhandler)
    # DEBUG logs from kafka-python are very verbose, so we'll turn them off
    logging.getLogger('kafka').setLevel(logging.INFO)
    # Likewise requests_oauthlib
    logging.getLogger('requests_oauthlib').setLevel(logging.INFO)

    # Treat SIGTERM like SIGINT (Ctrl-C) & do a clean shutdown
    signal.signal(signal.SIGTERM, interrupted)

    with JobsMonitor(config) as jm:
        try:
            jm.run()
        except KeyboardInterrupt:
            logging.info("Shutting down on SIGINT/SIGTERM")


if __name__ == "__main__":
    main()