diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py index ad73d0aaad7b14226e53603b8b2aeb0250cf86d3..d11e294f949fc6fcaff1e4d17a2f10f6e160db2e 100755 --- a/src/xfel_calibrate/calibrate.py +++ b/src/xfel_calibrate/calibrate.py @@ -78,7 +78,7 @@ def make_initial_parser(**kwargs): ' report') parser.add_argument('--concurrency-par', type=str, - help='Name of cuncurrency parameter.' + help='Name of concurrency parameter.' 'If not given, it is taken from configuration.') parser.add_argument('--priority', type=int, default=2, @@ -899,10 +899,18 @@ def run(): pre_notebooks = nb_info.get("pre_notebooks", []) notebook = nb_info["notebook"] dep_notebooks = nb_info.get("dep_notebooks", []) - concurrency = nb_info.get("concurrency", None) + concurrency = nb_info.get("concurrency", {'parameter': None}) + + concurrency_par = args["concurrency_par"] or concurrency['parameter'] + if concurrency_par == concurrency['parameter']: + # Use the defaults from notebook.py to split the work into several jobs + concurrency_defval = concurrency.get('default concurrency', None) + concurrency_func = concurrency.get('use function', None) + else: + # --concurrency-par specified something different from notebook.py: + # don't use the associated settings from there. + concurrency_defval = concurrency_func = None - if args["concurrency_par"] is not None: - concurrency["parameter"] = args["concurrency_par"] notebook_path = Path(PKG_DIR, notebook) nb = nbformat.read(notebook_path, as_version=4) @@ -928,9 +936,8 @@ def run(): run_uuid = f"t{datetime.now().strftime('%y%m%d_%H%M%S')}" # check if concurrency parameter is given and we run concurrently - if not any(p.name == "parameter" for p in parms) and concurrency["parameter"] is not None: - msg = "Notebook cannot be run concurrently: no {} parameter".format( - concurrency["parameter"]) + if not any(p.name == "parameter" for p in parms) and concurrency_par is not None: + msg = f"Notebook cannot be run concurrently: no {concurrency_par} parameter" warnings.warn(msg, RuntimeWarning) # If not explicitly specified, use a new profile for ipcluster @@ -1012,7 +1019,7 @@ def run(): pre_jobs.append(jobid) main_jobs = [] - if concurrency.get("parameter", None) is None: + if concurrency_par is None: jobid = concurrent_run(run_tmp_path, nb, notebook_path, args, cluster_cores=cluster_cores, @@ -1021,43 +1028,43 @@ def run(): ) main_jobs.append(jobid) else: - cvar = concurrency["parameter"] - cvals = args.get(cvar, None) + cvals = args.get(concurrency_par, None) - con_func = concurrency.get("use function", None) # Consider [-1] as None - if cvals is None or cvals == [-1]: - defcval = concurrency.get("default concurrency", None) - if defcval is not None: - print(f"Concurrency parameter '{cvar}' " - f"is taken from notebooks.py") - cvals = defcval if isinstance(defcval, (list, tuple)) else range(defcval) + if (cvals is None or cvals == [-1]) and concurrency_defval is not None: + print(f"Concurrency parameter '{concurrency_par}' " + f"is taken from notebooks.py") + cvals = concurrency_defval if isinstance(concurrency_defval, (list, tuple)) else range(concurrency_defval) + if cvals is None: - defcval = get_par_attr(parms, cvar, 'value') + defcval = get_par_attr(parms, concurrency_par, 'value') if defcval is not None: - print(f"Concurrency parameter '{cvar}' " + print(f"Concurrency parameter '{concurrency_par}' " f"is taken from '{notebook}'") cvals = defcval if isinstance(defcval, (list, tuple)) else [defcval] - if con_func: - func = get_notebook_function(nb, con_func) + + if concurrency_func: + func = get_notebook_function(nb, concurrency_func) if func is None: - warnings.warn(f"Didn't find concurrency function {con_func} in notebook", - RuntimeWarning) + warnings.warn( + f"Didn't find concurrency function {concurrency_func} in notebook", + RuntimeWarning + ) else: df = {} exec(func, df) - f = df[con_func] + f = df[concurrency_func] import inspect sig = inspect.signature(f) if cvals: # in case default needs to be used for function call - args[cvar] = cvals + args[concurrency_par] = cvals callargs = [args[arg] for arg in sig.parameters] cvals = f(*callargs) print(f"Split concurrency into {cvals}") # get expected type - cvtype = get_par_attr(parms, cvar, 'type', list) + cvtype = get_par_attr(parms, concurrency_par, 'type', list) cvals = remove_duplications(cvals) for cnum, cval in enumerate(cvals): @@ -1065,7 +1072,7 @@ def run(): cval = [cval, ] if not isinstance(cval, list) and cvtype is list else cval jobid = concurrent_run(run_tmp_path, nb, notebook_path, args, - cvar, cval, + concurrency_par, cval, cluster_cores=cluster_cores, sequential=sequential, show_title=show_title,