Skip to content
Snippets Groups Projects
job_monitor.py 8.01 KiB
Newer Older
"""Monitor calibration jobs in Slurm and send status updates"""
import argparse
import json
import locale
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from subprocess import run, PIPE

from kafka import KafkaProducer
from kafka.errors import KafkaError

try:
    from .config import webservice as config
    from .messages import MDC, Errors, MigrationError, Success
    from .webservice import init_job_db, init_md_client
except ImportError:
    from config import webservice as config
    from messages import MDC, Errors, MigrationError, Success
    from webservice import init_job_db, init_md_client

log = logging.getLogger(__name__)


class NoOpProducer:
    """Fills in for Kafka producer object when setting that up fails"""
    def send(self, topic, value):
        pass


def init_kafka_producer(config):
    try:
        return KafkaProducer(
            bootstrap_servers=config['kafka']['brokers'],
            value_serializer=lambda d: json.dumps(d).encode('utf-8'),
            max_block_ms=2000,  # Don't get stuck trying to send Kafka messages
        )
    except KafkaError:
        log.warning("Problem initialising Kafka producer; "
                        "Kafka notifications will not be sent.", exc_info=True)
        return NoOpProducer()


def slurm_status(filter_user=True):
    """ Return the status of slurm jobs by calling squeue

    :param filter_user: set to true to filter ony jobs from current user
    :return: a dictionary indexed by slurm jobid and containing a tuple
             of (status, run time) as values.
    """
    cmd = ["squeue"]
    if filter_user:
Thomas Kluyver's avatar
Thomas Kluyver committed
        cmd += ["--me"]
    res = run(cmd, stdout=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().split("\n")
        statii = {}
        for r in rlines[1:]:
            try:
                jobid, _, _, _, status, runtime, _, _ = r.split()
                jobid = jobid.strip()
                statii[jobid] = status, runtime
            except ValueError:  # not enough values to unpack in split
                pass
        return statii


def slurm_job_status(jobid):
    """ Return the status of slurm job

    :param jobid: Slurm job Id
    :return: Slurm state, Elapsed.
    """
    cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"]

    res = run(cmd, stdout=PIPE)
    if res.returncode == 0:
        rlines = res.stdout.decode().split("\n")

        log.debug("Job {} state {}".format(jobid, rlines[2].split()))
        if len(rlines[2].split()) == 3:
            return rlines[2].replace("+", "").split()
    return "NA", "NA", "NA"


def update_job_db(config):
    """ Update the job database and send out updates to MDC

    :param config: configuration parsed from webservice YAML
    """
    log.info("Starting jobs monitor")
    conn = init_job_db(config)
    mdc = init_md_client(config)
    kafka_prod = init_kafka_producer(config)
    kafka_topic = config['kafka']['topic']
    time_interval = int(config['web-service']['job-update-interval'])

    while True:
        statii = slurm_status()
        # Check that slurm is giving proper feedback
        if statii is None:
            time.sleep(time_interval)
            continue
        try:
            c = conn.cursor()
            c.execute("SELECT * FROM jobs WHERE status IN ('R', 'PD', 'CG') ")
            combined = {}
            log.debug(f"SLURM info {statii}")

            for r in c.fetchall():
                rid, jobid, proposal, run, status, _time, det, action = r
                log.debug(f"DB info {r}")

                cflg, cstatus, *_ = combined.setdefault((rid, action), (
                    [], [], proposal, run, det
                ))
                if jobid in statii:
                    slstatus, runtime = statii[jobid]
                    query = "UPDATE jobs SET status=?, time=? WHERE jobid LIKE ?"
                    c.execute(query, (slstatus, runtime, jobid))

                    cflg.append('R')
                    cstatus.append(f"{slstatus}-{runtime}")
                else:
                    _, sltime, slstatus = slurm_job_status(jobid)
                    query = "UPDATE jobs SET status=? WHERE jobid LIKE ?"
                    c.execute(query, (slstatus, jobid))

                    if slstatus == 'COMPLETED':
                        cflg.append("A")
                    else:
                        cflg.append("NA")
                    cstatus.append(slstatus)
            conn.commit()

            flg_order = {"R": 2, "A": 1, "NA": 0}
            dark_flags = {'NA': 'E', 'R': 'IP', 'A': 'F'}
            for rid, action in combined:
                if int(rid) == 0:  # this job was not submitted from MyMDC
                    continue
                flgs, statii, proposal, run, det = combined[rid, action]
                # sort by least done status
                flg = max(flgs, key=lambda i: flg_order[i])
                if flg != 'R':
                    log.info(
                        "Jobs finished - action: %s, run id: %s, status: %s",
                        action, rid, flg,
                    )
                    if action == 'CORRECT':
                        try:
                            kafka_prod.send(kafka_topic, {
                                'event': 'correction_complete',
                                'proposal': proposal,
                                'run': run,
                                'detector': det,
                                'success': (flg == 'A'),  # A for Available
                            })
                        except KafkaError:
                            log.warning("Error sending Kafka notification",
                                            exc_info=True)

                if all(s.startswith('PD-') for s in statii):
                    # Avoid swamping myMdC with updates for jobs still pending.
                    log.debug(
                        "No update for action %s, rid %s: jobs pending",
                        action, rid
                    )
                    continue

                msg = "\n".join(statii)
                msg_debug = f"Update MDC {rid}, {msg}"
                log.debug(msg_debug.replace('\n', ', '))

                if action == 'CORRECT':
                    data = {'flg_cal_data_status': flg,
                            'cal_pipeline_reply': msg}
                    if flg != 'R':
                        data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat()
                    response = mdc.update_run_api(rid, data)

                else:  # action == 'DARK' but it's dark_request
                    data = {'dark_run': {'flg_status': dark_flags[flg],
                                         'calcat_feedback': msg}}
                    response = mdc.update_dark_run_api(rid, data)

                if response.status_code != 200:
                    log.error("Failed to update MDC for action %s, rid %s",
                                  action, rid)
                    log.error(Errors.MDC_RESPONSE.format(response))
        except Exception:
            log.error("Failure to update job DB", exc_info=True)

        time.sleep(time_interval)



def main(argv=None):
    # Ensure files are opened as UTF-8 by default, regardless of environment.
    locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8'))

    parser = argparse.ArgumentParser(
        description='Start the calibration webservice'
    )
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--log-file', type=str, default='./monitor.log')
    parser.add_argument(
        '--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa
    )
    args = parser.parse_args(argv)

    if args.config_file is not None:
        config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()])

    fmt = '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] %(message)s' # noqa
    logging.basicConfig(
        filename=args.log_file,
        level=getattr(logging, args.log_level),
        format=fmt
    )
    update_job_db(config)


if __name__ == "__main__":
    main()