"""Monitor calibration jobs in Slurm and send status updates""" import argparse import json import locale import logging import os.path import shlex import signal import time from bisect import insort_left from collections import deque from datetime import datetime, timezone from pathlib import Path from subprocess import run, PIPE from kafka import KafkaProducer from kafka.errors import KafkaError try: from .common import notify_ready, file_and_stderr_logs, elapsed_to_seconds from .config import webservice as config from .messages import MDC, Errors, MigrationError, Success from .webservice import init_job_db, init_md_client, time_db_transaction except ImportError: from common import notify_ready, file_and_stderr_logs, elapsed_to_seconds from config import webservice as config from messages import MDC, Errors, MigrationError, Success from webservice import init_job_db, init_md_client, time_db_transaction log = logging.getLogger(__name__) STATES_FINISHED = { # https://slurm.schedmd.com/squeue.html#lbAG 'BOOT_FAIL', 'CANCELLED', 'COMPLETED', 'DEADLINE', 'FAILED', 'OUT_OF_MEMORY', 'SPECIAL_EXIT', 'TIMEOUT', 'NA', # Unknown (used internally if job ID missing) } STATES_FAILED = {'FAILED'} STATE_ABBREVS = { 'PENDING': 'PD', 'RUNNING': 'R', } class ExpiringEvents: """Track events occuring within a time window. Args: window (int, optional): Time window in seconds, 600 by default. """ def __init__(self, window=600): self.window = window self.events = deque() def add(self, event=None): now = time.monotonic() insort_left(self.events, event or now) cutoff = now - self.window while self.events and self.events[0] < cutoff: self.events.popleft() return len(self.events) def __len__(self): return len(self.events) class NoOpProducer: """Fills in for Kafka producer object when setting that up fails""" def send(self, topic, value): pass def init_kafka_producer(config): try: kp = KafkaProducer( value_serializer=lambda d: json.dumps(d).encode('utf-8'), max_block_ms=2000, # Don't get stuck trying to send Kafka messages **config['kafka']['producer-config'].to_dict() ) except KafkaError: log.warning("Problem initialising Kafka producer; " "Kafka notifications will not be sent.", exc_info=True) return NoOpProducer() else: log.info("Connected to Kafka broker (%s) to send notifications", kp.config['bootstrap_servers']) return kp def slurm_status(filter_user=True): """ Return the status of slurm jobs by calling squeue :param filter_user: set to true to filter ony jobs from current user :return: a dictionary indexed by slurm jobid and containing a tuple of (status, run time) as values. """ cmd = ["squeue", "--states=all", "--format=%i %T %M %N"] if filter_user: cmd += ["--me"] res = run(cmd, stdout=PIPE, stderr=PIPE) if res.returncode == 0: rlines = res.stdout.decode().rstrip("\n").split("\n") statii = {} for r in rlines[1:]: try: jobid, status, runtime, hostname = r.split(' ') jobid = jobid.strip() statii[jobid] = status, runtime, hostname except ValueError: # not enough values to unpack in split log.warning("Could not parse squeue status line %r", r) return statii else: log.error("Running squeue failed. stdout: %r, stderr: %r", res.stdout.decode(), res.stderr.decode()) return None def slurm_job_status(jobid): """ Return the status of slurm job :param jobid: Slurm job Id :return: Slurm state, Elapsed. """ cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"] res = run(cmd, stdout=PIPE) if res.returncode == 0: rlines = res.stdout.decode().split("\n") log.debug("Job {} state {}".format(jobid, rlines[2].split())) if len(rlines[2].split()) == 3: return rlines[2].replace("+", "").split() return "NA", "NA", "NA" class JobsMonitor: def __init__(self, config): log.info("Starting jobs monitor from %s", Path(__file__).parent) self.job_db = init_job_db(config) self.mdc = init_md_client(config) self.kafka_prod = init_kafka_producer(config) self.kafka_topic = config['kafka']['topic'] self.time_interval = int(config['web-service']['job-update-interval']) self.instant_fails_by_host = defauldict(ExpiringEvents) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.job_db.close() self.kafka_prod.close(timeout=5) def run(self): while True: try: self.do_updates() except Exception: log.error("Failure to update job DB", exc_info=True) time.sleep(self.time_interval) def do_updates(self): ongoing_jobs_by_exn = self.get_updates_by_exec_id() # For executions still running, regroup the jobs by request # (by run, for correction requests): reqs_still_going = {} for exec_id, running_jobs_info in ongoing_jobs_by_exn.items(): if running_jobs_info: req_id = self.job_db.execute( "SELECT req_id FROM executions WHERE exec_id = ?", (exec_id,) ).fetchone()[0] reqs_still_going.setdefault(req_id, []).extend(running_jobs_info) # For executions that have finished, send out notifications, and # check if the whole request (several executions) has finished. reqs_finished = set() for exec_id, running_jobs_info in ongoing_jobs_by_exn.items(): if not running_jobs_info: req_id = self.process_execution_finished(exec_id) if req_id not in reqs_still_going: reqs_finished.add(req_id) # Now send updates for all requests which hadn't already finished # by the last time this ran: for req_id, running_jobs_info in reqs_still_going.items(): self.process_request_still_going(req_id, running_jobs_info) for req_id in reqs_finished: self.process_request_finished(req_id) def get_updates_by_exec_id(self) -> dict: """Get statuses of unfinished jobs, grouped by execution ID E.g. {12345: ['R-5:41', 'PD-0:00', ...]} Newly completed executions are present with an empty list. """ jobs_to_check = self.job_db.execute( "SELECT job_id, exec_id, hostname FROM slurm_jobs WHERE finished = 0" ).fetchall() if not jobs_to_check: log.debug("No unfinished jobs to check") return {} statii = slurm_status() # Check that slurm is giving proper feedback if statii is None: return {} log.debug(f"SLURM info {statii}") ongoing_jobs_by_exn = {} updates = [] for r in jobs_to_check: log.debug(f"Job in DB before update: %s", tuple(r)) execn_ongoing_jobs = ongoing_jobs_by_exn.setdefault(r['exec_id'], []) if str(r['job_id']) in statii: # Jobs which are pending, running, or recently finished (from squeue) # Jobs stay for >= 150s after finishing, so we should always see them. slstatus, runtime, hostname = statii[str(r['job_id'])] else: # Fallback: get job info from sacct _, runtime, slstatus = slurm_job_status(r['job_id']) # We *don't* take hostname from sacct; with some GPFS issues # the job may not get recorded, and we don't want to overwrite # a hostname we previously got from squeue. hostname = r['hostname'] finished = slstatus in STATES_FINISHED if not finished: short_state = STATE_ABBREVS.get(slstatus, slstatus) execn_ongoing_jobs.append(f"{short_state}-{runtime}") elif slstatus in STATES_FAILED and elapsed_to_seconds(runtime) < 2: # Specific branch to catch potentially broken nodes. num_fails = self.instant_fails_by_host(hostname).add() log.warning(f"Job {r['job_id']} failed instantly on " f"{hostname}, {num_fails} on this host within " f"the last 10 minutes") updates.append((finished, runtime, slstatus, hostname, r['job_id'])) with time_db_transaction(self.job_db, 'Update jobs'): self.job_db.executemany( "UPDATE slurm_jobs SET " "finished=?, elapsed=?, status=?, hostname=? WHERE job_id = ?", updates ) return ongoing_jobs_by_exn def process_request_still_going(self, req_id, running_jobs_info): """Send myMdC updates for a request with jobs still running/pending""" mymdc_id, action = self.job_db.execute( "SELECT mymdc_id, action FROM requests WHERE req_id = ?", (req_id,) ).fetchone() if all(s.startswith('PD-') for s in running_jobs_info): # Avoid swamping myMdC with updates for jobs still pending. log.debug("No update for %s request with mymdc id %s: jobs pending", action, mymdc_id) return msg = "\n".join(running_jobs_info) log.debug("Update MDC for %s, %s: %s", action, mymdc_id, ', '.join(running_jobs_info) ) if action == 'CORRECT': self.mymdc_update_run(mymdc_id, msg) else: # action == 'DARK' self.mymdc_update_dark(mymdc_id, msg) def process_execution_finished(self, exec_id): """Send notification & record that one execution has finished""" statuses = [r[0] for r in self.job_db.execute( "SELECT status FROM slurm_jobs WHERE exec_id = ?", (exec_id,) ).fetchall()] success = set(statuses) == {'COMPLETED'} r = self.job_db.execute( "SELECT det_type, karabo_id, command, " "req_id, proposal, run, action, mymdc_id, timestamp " "FROM executions JOIN requests USING (req_id)" "WHERE exec_id = ?", (exec_id,) ).fetchone() with time_db_transaction(self.job_db, 'Update execution'): self.job_db.execute( "UPDATE executions SET success = ? WHERE exec_id = ?", (success, exec_id) ) log.info("Execution finished: %s for (p%s, r%s, %s), success=%s", r['action'], r['proposal'], r['run'], r['karabo_id'], success) if r['action'] == 'CORRECT': try: self.kafka_prod.send(self.kafka_topic, { 'event': 'correction_complete', 'proposal': r['proposal'], 'run': str(r['run']), 'detector': r['det_type'], 'detector_identifier': r['karabo_id'], 'success': success, }) except KafkaError: log.warning("Error sending Kafka notification", exc_info=True) self.record_correction_report( r['mymdc_id'], r['command'], r['karabo_id'], success, r['timestamp'] ) return r['req_id'] def process_request_finished(self, req_id): """Send Kafka notifications and update myMDC that a request has finished.""" krb_id_successes = {r[0]: r[1] for r in self.job_db.execute( "SELECT karabo_id, success FROM executions WHERE req_id = ?", (req_id,) ).fetchall()} success = (set(krb_id_successes.values()) == {1}) r = self.job_db.execute( "SELECT * FROM requests WHERE req_id = ?", (req_id,) ).fetchone() log.info( "Jobs finished - action: %s, myMdC id: %s, success: %s", r['action'], r['mymdc_id'], success, ) if r['action'] == 'CORRECT': try: self.kafka_prod.send(self.kafka_topic, { 'event': 'run_corrections_complete', 'proposal': r['proposal'], 'run': str(r['run']), 'success': success, }) except KafkaError: log.warning("Error sending Kafka notification", exc_info=True) if success: msg = "Calibration jobs succeeded" else: # List success & failure by karabo_id krb_ids_ok = {k for (k, v) in krb_id_successes.items() if v == 1} ok = ', '.join(sorted(krb_ids_ok)) if krb_ids_ok else 'none' krb_ids_failed = {k for (k, v) in krb_id_successes.items() if v == 0} msg = f"Succeeded: {ok}; Failed: {', '.join(sorted(krb_ids_failed))}" log.debug("Update MDC for %s, %s: %s", r['action'], r['mymdc_id'], msg) if r['action'] == 'CORRECT': if success: status = 'A' # Available elif set(krb_id_successes.values()) == {0, 1}: status = 'AW' # Available with Warning (failed for some detectors) else: status = 'E' # Error self.mymdc_update_run(r['mymdc_id'], msg, status) else: # r['action'] == 'DARK' status = 'F' if success else 'E' # Finished/Error self.mymdc_update_dark(r['mymdc_id'], msg, status) def mymdc_update_run(self, run_id, msg, status='IP'): """Update correction status in MyMdC""" data = {'cal_pipeline_reply': msg, 'flg_cal_data_status': status} if status != 'IP': data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat() response = self.mdc.update_run_api(run_id, data) if response.status_code != 200: log.error("Failed to update MDC run id %s", run_id) log.error(Errors.MDC_RESPONSE.format(response)) def mymdc_update_dark(self, dark_run_id, msg, status='IP'): """Update dark run status in MyMdC""" data = {'dark_run': {'flg_status': status, 'calcat_feedback': msg}} response = self.mdc.update_dark_run_api(dark_run_id, data) if response.status_code != 200: log.error("Failed to update MDC dark run id %s", dark_run_id) log.error(Errors.MDC_RESPONSE.format(response)) def record_correction_report( self, mymdc_run_id, command, karabo_id, success, request_time: str ): """Add report to MyMdC when a correction execution has finished""" args = shlex.split(command) try: report_path = args[args.index("--report-to") + 1] + ".pdf" except (ValueError, IndexError): log.error("Couldn't find report path in %r", args) return if not os.path.isfile(report_path): log.error("Jobs finished, but report file %s missing", report_path) return desc = f"{karabo_id} detector corrections" if not success: desc += " (errors occurred)" log.debug("Adding report file %s to MDC for run ID %s", report_path, mymdc_run_id) response = self.mdc.create_report_api({ "name": os.path.basename(report_path), "cal_report_path": os.path.dirname(report_path).rstrip('/') + '/', "cal_report_at": request_time, "run_id": mymdc_run_id, "description": desc, }) if response.status_code >= 400: log.error("Failed to add report to MDC for run ID %s: HTTP status %s", mymdc_run_id, response.status_code) def interrupted(signum, frame): raise KeyboardInterrupt def main(argv=None): # Ensure files are opened as UTF-8 by default, regardless of environment. locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8')) parser = argparse.ArgumentParser( description='Start the calibration webservice' ) parser.add_argument('--config-file', type=str, default=None) parser.add_argument('--log-file', type=Path, default='~/webservice-logs/monitor.log') parser.add_argument( '--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa ) args = parser.parse_args(argv) if args.config_file is not None: config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()]) logging.basicConfig( handlers=file_and_stderr_logs(args.log_file), level=args.log_level, ) # DEBUG logs from kafka-python are very verbose, so we'll turn them off logging.getLogger('kafka').setLevel(logging.INFO) # Likewise requests_oauthlib logging.getLogger('requests_oauthlib').setLevel(logging.INFO) # Treat SIGTERM like SIGINT (Ctrl-C) & do a clean shutdown signal.signal(signal.SIGTERM, interrupted) with JobsMonitor(config) as jm: notify_ready() try: jm.run() except KeyboardInterrupt: logging.info("Shutting down on SIGINT/SIGTERM") if __name__ == "__main__": main()