From 8bcacbe22cc0d3de7f13cd105781c9367135cd33 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Thu, 29 Jul 2021 14:52:20 +0100
Subject: [PATCH] Use afterany dependency for summary Slurm job

---
 src/xfel_calibrate/calibrate.py | 44 +++++++++++++++++++--------------
 1 file changed, 25 insertions(+), 19 deletions(-)

diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py
index b20a40220..051bf518c 100755
--- a/src/xfel_calibrate/calibrate.py
+++ b/src/xfel_calibrate/calibrate.py
@@ -695,12 +695,13 @@ def get_slurm_partition_or_reservation(args) -> List[str]:
     return ['--partition', sprof]
 
 
-def get_launcher_command(args, temp_path, dep_jids=()) -> List[str]:
+def get_launcher_command(args, temp_path, after_ok=(), after_any=()) -> List[str]:
     """
     Return a slurm launcher command
     :param args: Command line arguments
     :param temp_path: Temporary path to run job
-    :param dep_jids: A list of dependent jobs
+    :param after_ok: A list of jobs which must succeed first
+    :param after_any: A list of jobs which must finish first, but may fail
     :return: List of commands and parameters to be used by subprocess
     """
 
@@ -716,11 +717,13 @@ def get_launcher_command(args, temp_path, dep_jids=()) -> List[str]:
 
     launcher_slurm.append("--mem={}G".format(args.get('slurm_mem', '500')))
 
-    if len(dep_jids):
-        launcher_slurm.append(
-            "--dependency=afterok:" + ":".join(str(j) for j in dep_jids)
-        )
-
+    deps = []
+    if after_ok:
+        deps.append("afterok:" + ":".join(str(j) for j in after_ok))
+    if after_any:
+        deps.append("afterany:" + ":".join(str(j) for j in after_ok))
+    if deps:
+        launcher_slurm.append("--dependency=" + ",".join(deps))
     return launcher_slurm
 
 
@@ -743,7 +746,7 @@ def remove_duplications(l) -> list:
 def concurrent_run(
     temp_path: str, nb, nb_path: Path, args: dict, cparm=None, cval=None,
     cluster_cores=8,
-    sequential=False, dep_jids=(),
+    sequential=False, after_ok=(), after_any=(),
     show_title=True, user_venv: Optional[Path] = None,
 ) -> Optional[str]:
     """ Launch a concurrent job on the cluster via Slurm
@@ -779,7 +782,7 @@ def concurrent_run(
     # then run an sbatch job
     cmd = []
     if not sequential:
-        cmd = get_launcher_command(args, temp_path, dep_jids)
+        cmd = get_launcher_command(args, temp_path, after_ok, after_any)
         print(" ".join(cmd))
 
     if user_venv:
@@ -1010,7 +1013,7 @@ def run():
     if user_venv:
         user_venv = Path(user_venv.format(**args))
 
-    joblist = []
+    pre_jobs = []
     cluster_cores = concurrency.get("cluster cores", 8)
     # Check if there are pre-notebooks
     for pre_notebook in pre_notebooks:
@@ -1022,16 +1025,17 @@ def run():
                                cluster_cores=cluster_cores,
                                sequential=sequential, user_venv=user_venv
                               )
-        joblist.append(jobid)
+        pre_jobs.append(jobid)
 
+    main_jobs = []
     if concurrency.get("parameter", None) is None:
         jobid = concurrent_run(run_tmp_path, nb,
                                notebook_path, args,
                                cluster_cores=cluster_cores,
                                sequential=sequential,
-                               dep_jids=joblist, user_venv=user_venv
+                               after_ok=pre_jobs, user_venv=user_venv
                               )
-        joblist.append(jobid)
+        main_jobs.append(jobid)
     else:
         cvar = concurrency["parameter"]
         cvals = args.get(cvar, None)
@@ -1082,7 +1086,6 @@ def run():
         cvtype = get_par_attr(parms, cvar, 'type', list)
         cvals = remove_duplications(cvals)
 
-        jlist = []
         for cnum, cval in enumerate(cvals):
             show_title = cnum == 0
             cval = [cval, ] if not isinstance(cval, list) and cvtype is list else cval
@@ -1092,23 +1095,26 @@ def run():
                                    cluster_cores=cluster_cores,
                                    sequential=sequential,
                                    show_title=show_title,
-                                   dep_jids=joblist,
+                                   after_ok=pre_jobs,
                                   )
-            jlist.append(jobid)
-        joblist.extend(jlist)
+            main_jobs.append(jobid)
 
     # Run dependent notebooks (e.g. summaries after correction)
+    dep_jobs = []
     for i, dep_notebook in enumerate(dep_notebooks):
         dep_notebook_path = Path(PKG_DIR, dep_notebook)
         dep_nb = nbformat.read(dep_notebook_path, as_version=4)
         jobid = concurrent_run(run_tmp_path, dep_nb,
                                dep_notebook_path,
                                args,
-                               dep_jids=joblist,
+                               after_ok=pre_jobs,
+                               after_any=main_jobs,
                                cluster_cores=cluster_cores,
                                sequential=sequential,
                               )
-        joblist.append(jobid)
+        dep_jobs.append(jobid)
+
+    joblist = pre_jobs + main_jobs + dep_jobs
 
     joblist.append(run_finalize(
         fmt_args=fmt_args,
-- 
GitLab