"""Monitor calibration jobs in Slurm and send status updates""" import argparse import json import locale import logging import signal import time from datetime import datetime, timezone from pathlib import Path from subprocess import run, PIPE from kafka import KafkaProducer from kafka.errors import KafkaError try: from .config import webservice as config from .messages import MDC, Errors, MigrationError, Success from .webservice import init_job_db, init_md_client, time_db_transaction except ImportError: from config import webservice as config from messages import MDC, Errors, MigrationError, Success from webservice import init_job_db, init_md_client, time_db_transaction log = logging.getLogger(__name__) STATES_FINISHED = { # https://slurm.schedmd.com/squeue.html#lbAG 'BOOT_FAIL', 'CANCELLED', 'COMPLETED', 'DEADLINE', 'FAILED', 'OUT_OF_MEMORY', 'SPECIAL_EXIT', 'TIMEOUT', } class NoOpProducer: """Fills in for Kafka producer object when setting that up fails""" def send(self, topic, value): pass def init_kafka_producer(config): try: return KafkaProducer( bootstrap_servers=config['kafka']['brokers'], value_serializer=lambda d: json.dumps(d).encode('utf-8'), max_block_ms=2000, # Don't get stuck trying to send Kafka messages ) except KafkaError: log.warning("Problem initialising Kafka producer; " "Kafka notifications will not be sent.", exc_info=True) return NoOpProducer() def slurm_status(filter_user=True): """ Return the status of slurm jobs by calling squeue :param filter_user: set to true to filter ony jobs from current user :return: a dictionary indexed by slurm jobid and containing a tuple of (status, run time) as values. """ cmd = ["squeue", "--states=all"] if filter_user: cmd += ["--me"] res = run(cmd, stdout=PIPE, stderr=PIPE) if res.returncode == 0: rlines = res.stdout.decode().split("\n") statii = {} for r in rlines[1:]: try: jobid, _, _, _, status, runtime, _, _ = r.split() jobid = jobid.strip() statii[jobid] = status, runtime except ValueError: # not enough values to unpack in split pass return statii else: log.error("Running squeue failed. stdout: %r, stderr: %r", res.stdout.decode(), res.stderr.decode()) return None def slurm_job_status(jobid): """ Return the status of slurm job :param jobid: Slurm job Id :return: Slurm state, Elapsed. """ cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"] res = run(cmd, stdout=PIPE) if res.returncode == 0: rlines = res.stdout.decode().split("\n") log.debug("Job {} state {}".format(jobid, rlines[2].split())) if len(rlines[2].split()) == 3: return rlines[2].replace("+", "").split() return "NA", "NA", "NA" class JobsMonitor: def __init__(self, config): log.info("Starting jobs monitor") self.job_db = init_job_db(config) self.mdc = init_md_client(config) self.kafka_prod = init_kafka_producer(config) self.kafka_topic = config['kafka']['topic'] self.time_interval = int(config['web-service']['job-update-interval']) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.job_db.close() self.kafka_prod.close(timeout=5) def run(self): while True: try: self.do_updates() except Exception: log.error("Failure to update job DB", exc_info=True) time.sleep(self.time_interval) def do_updates(self): ongoing_jobs_by_exn = self.get_updates_by_exec_id() # For executions still running, regroup the jobs by request # (by run, for correction requests): reqs_still_going = {} for exec_id, running_jobs_info in ongoing_jobs_by_exn.items(): if running_jobs_info: req_id = self.job_db.execute( "SELECT req_id FROM executions WHERE exec_id = ?", (exec_id,) ).fetchone()[0] reqs_still_going.setdefault(req_id, []).extend(running_jobs_info) # For executions that have finished, send out notifications, and # check if the whole request (several executions) has finished. reqs_finished = set() for exec_id, running_jobs_info in ongoing_jobs_by_exn.items(): if not running_jobs_info: req_id = self.process_execution_finished(exec_id) if req_id not in reqs_still_going: reqs_finished.add(req_id) # Now send updates for all requests which hadn't already finished # by the last time this ran: for req_id, running_jobs_info in reqs_still_going.items(): self.process_request_still_going(req_id, running_jobs_info) for req_id in reqs_finished: self.process_request_finished(req_id) def get_updates_by_exec_id(self) -> dict: """Get statuses of unfinished jobs, grouped by execution ID E.g. {12345: ['R-5:41', 'PD-0:00', ...]} Newly completed executions are present with an empty list. """ jobs_to_check = self.job_db.execute( "SELECT job_id, exec_id FROM slurm_jobs WHERE finished = 0" ).fetchall() if not jobs_to_check: log.debug("No unfinished jobs to check") return {} statii = slurm_status() # Check that slurm is giving proper feedback if statii is None: return {} log.debug(f"SLURM info {statii}") ongoing_jobs_by_exn = {} updates = [] for r in jobs_to_check: log.debug(f"Job in DB before update: %s", tuple(r)) execn_ongoing_jobs = ongoing_jobs_by_exn.setdefault(r['exec_id'], []) if str(r['job_id']) in statii: # statii contains jobs which are still going (from squeue) slstatus, runtime = statii[str(r['job_id'])] execn_ongoing_jobs.append(f"{slstatus}-{runtime}") else: # These jobs have finished (successfully or otherwise) _, runtime, slstatus = slurm_job_status(r['job_id']) finished = slstatus in STATES_FINISHED updates.append((finished, runtime, slstatus, r['job_id'])) with time_db_transaction(self.job_db, 'Update jobs'): self.job_db.executemany( "UPDATE slurm_jobs SET finished=?, elapsed=?, status=? WHERE job_id = ?", updates ) return ongoing_jobs_by_exn def process_request_still_going(self, req_id, running_jobs_info): """Send myMdC updates for a request with jobs still running/pending""" mymdc_id, action = self.job_db.execute( "SELECT mymdc_id, action FROM requests WHERE req_id = ?", (req_id,) ).fetchone() if all(s.startswith('PD-') for s in running_jobs_info): # Avoid swamping myMdC with updates for jobs still pending. log.debug("No update for %s request with mymdc id %s: jobs pending", action, mymdc_id) return msg = "\n".join(running_jobs_info) log.debug("Update MDC for %s, %s: %s", action, mymdc_id, ', '.join(running_jobs_info) ) if action == 'CORRECT': self.mymdc_update_run(mymdc_id, msg) else: # action == 'DARK' self.mymdc_update_dark(mymdc_id, msg) def process_execution_finished(self, exec_id): """Send notification & record that one execution has finished""" statuses = [r[0] for r in self.job_db.execute( "SELECT status FROM slurm_jobs WHERE exec_id = ?", (exec_id,) ).fetchall()] success = set(statuses) == {'COMPLETED'} r = self.job_db.execute( "SELECT det_type, karabo_id, req_id, proposal, run, action " "FROM executions JOIN requests USING (req_id)" "WHERE exec_id = ?", (exec_id,) ).fetchone() with time_db_transaction(self.job_db, 'Update execution'): self.job_db.execute( "UPDATE executions SET success = ? WHERE exec_id = ?", (success, exec_id) ) log.info("Execution finished: %s for (p%s, r%s, %s), success=%s", r['action'], r['proposal'], r['run'], r['karabo_id'], success) if r['action'] == 'CORRECT': try: self.kafka_prod.send(self.kafka_topic, { 'event': 'correction_complete', 'proposal': r['proposal'], 'run': str(r['run']), 'detector': r['det_type'], 'detector_identifier': r['karabo_id'], 'success': success, }) except KafkaError: log.warning("Error sending Kafka notification", exc_info=True) return r['req_id'] def process_request_finished(self, req_id): """Send Kafka notifications and update myMDC that a request has finished.""" krb_id_successes = {r[0]: r[1] for r in self.job_db.execute( "SELECT karabo_id, success FROM executions WHERE req_id = ?", (req_id,) ).fetchall()} success = (set(krb_id_successes.values()) == {1}) r = self.job_db.execute( "SELECT * FROM requests WHERE req_id = ?", (req_id,) ).fetchone() log.info( "Jobs finished - action: %s, myMdC id: %s, success: %s", r['action'], r['mymdc_id'], success, ) if r['action'] == 'CORRECT': try: self.kafka_prod.send(self.kafka_topic, { 'event': 'run_corrections_complete', 'proposal': r['proposal'], 'run': str(r['run']), 'success': success, }) except KafkaError: log.warning("Error sending Kafka notification", exc_info=True) if success: msg = "Calibration jobs succeeded" else: # List success & failure by karabo_id krb_ids_ok = {k for (k, v) in krb_id_successes.items() if v == 1} ok = ', '.join(sorted(krb_ids_ok)) if krb_ids_ok else 'none' krb_ids_failed = {k for (k, v) in krb_id_successes.items() if v == 0} msg = f"Succeeded: {ok}; Failed: {', '.join(sorted(krb_ids_failed))}" log.debug("Update MDC for %s, %s: %s", r['action'], r['mymdc_id'], msg) if r['action'] == 'CORRECT': if success: status = 'A' # Available elif set(krb_id_successes.values()) == {0, 1}: status = 'AW' # Available with Warning (failed for some detectors) else: status = 'E' # Error self.mymdc_update_run(r['mymdc_id'], msg, status) else: # r['action'] == 'DARK' status = 'F' if success else 'E' # Finished/Error self.mymdc_update_dark(r['mymdc_id'], msg, status) def mymdc_update_run(self, run_id, msg, status='IP'): """Update correction status in MyMdC""" data = {'cal_pipeline_reply': msg, 'flg_cal_data_status': status} if status != 'IP': data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat() response = self.mdc.update_run_api(run_id, data) if response.status_code != 200: log.error("Failed to update MDC run id %s", run_id) log.error(Errors.MDC_RESPONSE.format(response)) def mymdc_update_dark(self, dark_run_id, msg, status='IP'): """Update dark run status in MyMdC""" data = {'dark_run': {'flg_status': status, 'calcat_feedback': msg}} response = self.mdc.update_dark_run_api(dark_run_id, data) if response.status_code != 200: log.error("Failed to update MDC dark run id %s", dark_run_id) log.error(Errors.MDC_RESPONSE.format(response)) def interrupted(signum, frame): raise KeyboardInterrupt def main(argv=None): # Ensure files are opened as UTF-8 by default, regardless of environment. locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8')) parser = argparse.ArgumentParser( description='Start the calibration webservice' ) parser.add_argument('--config-file', type=str, default=None) parser.add_argument('--log-file', type=str, default='./monitor.log') parser.add_argument( '--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa ) args = parser.parse_args(argv) if args.config_file is not None: config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()]) fmt = '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] %(message)s' # noqa logging.basicConfig( filename=args.log_file, level=getattr(logging, args.log_level), format=fmt ) # DEBUG logs from kafka-python are very verbose, so we'll turn them off logging.getLogger('kafka').setLevel(logging.INFO) # Treat SIGTERM like SIGINT (Ctrl-C) & do a clean shutdown signal.signal(signal.SIGTERM, interrupted) with JobsMonitor(config) as jm: try: jm.run() except KeyboardInterrupt: logging.info("Shutting down on SIGINT/SIGTERM") if __name__ == "__main__": main()