Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Monitor calibration jobs in Slurm and send status updates"""
import argparse
import json
import locale
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from subprocess import run, PIPE
from kafka import KafkaProducer
from kafka.errors import KafkaError
try:
from .config import webservice as config
from .messages import MDC, Errors, MigrationError, Success
from .webservice import init_job_db, init_md_client
except ImportError:
from config import webservice as config
from messages import MDC, Errors, MigrationError, Success
from webservice import init_job_db, init_md_client
log = logging.getLogger(__name__)
class NoOpProducer:
"""Fills in for Kafka producer object when setting that up fails"""
def send(self, topic, value):
pass
def init_kafka_producer(config):
try:
return KafkaProducer(
bootstrap_servers=config['kafka']['brokers'],
value_serializer=lambda d: json.dumps(d).encode('utf-8'),
max_block_ms=2000, # Don't get stuck trying to send Kafka messages
)
except KafkaError:
log.warning("Problem initialising Kafka producer; "
"Kafka notifications will not be sent.", exc_info=True)
return NoOpProducer()
def slurm_status(filter_user=True):
""" Return the status of slurm jobs by calling squeue
:param filter_user: set to true to filter ony jobs from current user
:return: a dictionary indexed by slurm jobid and containing a tuple
of (status, run time) as values.
"""
cmd = ["squeue"]
if filter_user:
res = run(cmd, stdout=PIPE)
if res.returncode == 0:
rlines = res.stdout.decode().split("\n")
statii = {}
for r in rlines[1:]:
try:
jobid, _, _, _, status, runtime, _, _ = r.split()
jobid = jobid.strip()
statii[jobid] = status, runtime
except ValueError: # not enough values to unpack in split
pass
return statii
def slurm_job_status(jobid):
""" Return the status of slurm job
:param jobid: Slurm job Id
:return: Slurm state, Elapsed.
"""
cmd = ["sacct", "-j", str(jobid), "--format=JobID,Elapsed,state"]
res = run(cmd, stdout=PIPE)
if res.returncode == 0:
rlines = res.stdout.decode().split("\n")
log.debug("Job {} state {}".format(jobid, rlines[2].split()))
if len(rlines[2].split()) == 3:
return rlines[2].replace("+", "").split()
return "NA", "NA", "NA"
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class JobsMonitor:
def __init__(self, config):
log.info("Starting jobs monitor")
self.job_db = init_job_db(config)
self.mdc = init_md_client(config)
self.kafka_prod = init_kafka_producer(config)
self.kafka_topic = config['kafka']['topic']
self.time_interval = int(config['web-service']['job-update-interval'])
def run(self):
while True:
try:
self.do_updates()
except Exception:
log.error("Failure to update job DB", exc_info=True)
time.sleep(self.time_interval)
def do_updates(self):
ongoing_jobs_by_exn = self.get_updates_by_exec_id()
# ^ dict grouping statuses of unfinished jobs by execution ID:
# {12345: ['R-5:41', 'PD-0:00', ...]}
# Newly completed executions are present with an empty list.
# For executions still running, regroup the jobs by request
# (by run, for correction requests):
reqs_still_going = {}
for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
if running_jobs_info:
req_id = self.job_db.execute(
"SELECT req_id FROM executions WHERE exec_id = ?", (exec_id,)
).fetchone().req_id
reqs_still_going.setdefault(req_id, []).extend(running_jobs_info)
# For executions that have finished, send out notifications, and
# check if the whole request (several executions) has finished.
reqs_finished = set()
for exec_id, running_jobs_info in ongoing_jobs_by_exn.items():
if not running_jobs_info:
req_id = self.process_execution_finished(exec_id)
if req_id not in reqs_still_going:
reqs_finished.add(req_id)
# Now send updates for all requests which hadn't already finished
# by the last time this ran:
for req_id, running_jobs_info in reqs_still_going.items():
self.process_request_still_going(req_id, running_jobs_info)
for req_id in reqs_finished:
self.process_request_finished(req_id)
def get_updates_by_exec_id(self):
c = self.job_db.cursor()
c.execute("SELECT job_id, exec_id FROM slurm_jobs WHERE finished = 0")
statii = slurm_status()
# Check that slurm is giving proper feedback
if statii is None:
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
return {}
log.debug(f"SLURM info {statii}")
ongoing_jobs_by_exn = {}
for r in c.fetchall():
log.debug(f"DB info {r}")
execn_ongoing_jobs = ongoing_jobs_by_exn.setdefault(r.exec_id, [])
if r.job_id in statii:
# statii contains jobs which are still going (from squeue)
slstatus, runtime = statii[r.job_id]
finished = False
execn_ongoing_jobs.append(f"{slstatus}-{runtime}")
else:
# These jobs have finished (successfully or otherwise)
_, runtime, slstatus = slurm_job_status(r.job_id)
finished = True
c.execute(
"UPDATE slurm_jobs SET finished=?, elapsed=?, status=? WHERE job_id = ?",
(finished, runtime, slstatus, r.job_id)
)
self.job_db.commit()
return ongoing_jobs_by_exn
def process_request_still_going(self, req_id, running_jobs_info):
"""Send myMdC updates for a request with jobs still running/pending"""
mymdc_id, action = self.job_db.execute(
"SELECT mymdc_id, action FROM requests WHERE req_id = ?",
(req_id,)
).fetchone()
if all(s.startswith('PD-') for s in running_jobs_info):
# Avoid swamping myMdC with updates for jobs still pending.
log.debug("No update for %s request with mymdc id %s: jobs pending",
action, mymdc_id)
return
msg = "\n".join(running_jobs_info)
log.debug("Update MDC for %s, %s: %s",
action, mymdc_id, ', '.join(running_jobs_info)
)
if action == 'CORRECT':
self.mymdc_update_run(mymdc_id, msg)
else: # action == 'DARK'
self.mymdc_update_dark(mymdc_id, msg)
def process_execution_finished(self, exec_id):
"""Send notification that one execution has finished"""
statuses = [r.status for r in self.job_db.execute(
"SELECT status FROM slurm_jobs WHERE exec_id = ?", (exec_id,)
).fetchall()]
success = set(statuses) == {'COMPLETED'}
r = self.job_db.execute(
"SELECT det_type, karabo_id, req_id, proposal, run, action "
"FROM executions JOIN requests USING (req_id)"
"WHERE exec_id = ?",
(exec_id,)
).fetchone()
with self.job_db:
self.job_db.execute(
"UPDATE executions SET success = ? WHERE exec_id = ?",
(success, exec_id)
)
log.info("Execution finished: %s for (p%s, r%s, %s), success=%s",
r.action, r.proposal, r.run, r.karabo_id, success)
if r.action == 'CORRECT':
try:
self.kafka_prod.send(self.kafka_topic, {
'event': 'correction_complete',
'proposal': r.proposal,
'run': r.run,
'detector': r.det_type,
'karabo_id': r.karabo_id,
'success': success,
})
except KafkaError:
log.warning("Error sending Kafka notification",
exc_info=True)
return r.req_id
def process_request_finished(self, req_id):
"""Send notifications that a request has finished"""
exec_successes = {r.success for r in self.job_db.execute(
"SELECT success FROM executions WHERE req_id = ?", (req_id,)
).fetchall()}
success = (exec_successes == {1})
r = self.job_db.execute(
"SELECT * FROM requests WHERE req_id = ?", (req_id,)
).fetchone()
log.info(
"Jobs finished - action: %s, myMdC id: %s, success: %s",
r.action, r.mymdc_id, success,
)
if r.action == 'CORRECT':
try:
self.kafka_prod.send(self.kafka_topic, {
'event': 'run_corrections_complete',
'proposal': r.proposal,
'run': r.run,
'success': success,
})
except KafkaError:
log.warning("Error sending Kafka notification",
exc_info=True)
if success:
msg = "Calibration jobs succeeded"
else:
# Count failed jobs
job_statuses = [r.status for r in self.job_db.execute(
"SELECT slurm_jobs.status FROM slurm_jobs "
"INNER JOIN executions USING (exec_id) "
"INNER JOIN requests USING (req_id) "
"WHERE req_id = ?", (req_id,)
).fetchall()]
n_failed = sum(s != 'COMPLETED' for s in job_statuses)
msg = f"{n_failed}/{len(job_statuses)} calibration jobs failed"
log.debug("Update MDC for %s, %s: %s", r.action, r.mymdc_id, msg)
if r.action == 'CORRECT':
status = 'A' if success else 'NA' # Not-/Available
self.mymdc_update_run(r.mymdc_id, msg, status)
else: # r.action == 'DARK'
status = 'F' if success else 'E' # Finished/Error
self.mymdc_update_dark(r.mymdc_id, msg, status)
def mymdc_update_run(self, run_id, msg, status='R'):
data = {'flg_cal_data_status': status,
'cal_pipeline_reply': msg}
if status != 'R':
data['cal_last_end_at'] = datetime.now(tz=timezone.utc).isoformat()
response = self.mdc.update_run_api(run_id, data)
if response.status_code != 200:
log.error("Failed to update MDC run id %s", run_id)
log.error(Errors.MDC_RESPONSE.format(response))
def mymdc_update_dark(self, dark_run_id, msg, status='IP'):
data = {'dark_run': {'flg_status': status,
'calcat_feedback': msg}}
response = self.mdc.update_dark_run_api(dark_run_id, data)
if response.status_code != 200:
log.error("Failed to update MDC dark run id %s", dark_run_id)
log.error(Errors.MDC_RESPONSE.format(response))
def main(argv=None):
# Ensure files are opened as UTF-8 by default, regardless of environment.
locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8'))
parser = argparse.ArgumentParser(
description='Start the calibration webservice'
)
parser.add_argument('--config-file', type=str, default=None)
parser.add_argument('--log-file', type=str, default='./monitor.log')
parser.add_argument(
'--log-level', type=str, default="INFO", choices=['INFO', 'DEBUG', 'ERROR'] # noqa
)
args = parser.parse_args(argv)
if args.config_file is not None:
config.configure(includes_for_dynaconf=[Path(args.config_file).absolute()])
fmt = '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] %(message)s' # noqa
logging.basicConfig(
filename=args.log_file,
level=getattr(logging, args.log_level),
format=fmt
)
JobsMonitor(config).run()
if __name__ == "__main__":
main()