From d784e47385dba7f21f6acc3dda18bceb996118fd Mon Sep 17 00:00:00 2001 From: Thomas Kluyver <thomas@kluyver.me.uk> Date: Fri, 8 Oct 2021 11:22:08 +0100 Subject: [PATCH] Refactor parsing command line args & loading main notebook --- src/xfel_calibrate/calibrate.py | 106 ++++++++------------------------ src/xfel_calibrate/nb_args.py | 90 ++++++++++++++++++++++----- 2 files changed, 101 insertions(+), 95 deletions(-) diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py index 27403447d..961d39a7e 100755 --- a/src/xfel_calibrate/calibrate.py +++ b/src/xfel_calibrate/calibrate.py @@ -25,14 +25,11 @@ import yaml import cal_tools.tools from .finalize import tex_escape -from .notebooks import notebooks from .nb_args import ( consolize_name, - deconsolize_args, - extend_params, first_markdown_cell, get_notebook_function, - make_extended_parser, + parse_argv_and_load_nb, set_figure_format, ) from .settings import ( @@ -581,52 +578,9 @@ def run(argv=None): if argv is None: argv = sys.argv - parser = make_extended_parser(argv) - args = deconsolize_args(vars(parser.parse_args(argv[1:]))) - detector = args["detector"].upper() - caltype = args["type"].upper() - sequential = args["no_cluster_job"] - - # Pick out any arguments that may prevent reproducibility from - # working, sorted alphabetically and converted back to their - # canonical representation. - not_reproducible_args = sorted( - ('--' + x.replace('_', '-') - for x in ['skip_env_freeze'] - if args[x])) - - # If any of these arguments are set, present a warning. - if not_reproducible_args: - print('WARNING: One or more command line arguments ({}) may prevent ' - 'this specific correction result from being reproducible based ' - 'on its metadata. It may not be possible to restore identical ' - 'output data files when they have been deleted or lost. Please ' - 'ensure that the data retention policy of the chosen storage ' - 'location is sufficient for your ' - 'needs.'.format(', '.join(not_reproducible_args))) - - if not args['not_reproducible']: - # If not explicitly specified that reproducibility may be - # broken, remind the user and exit. - print('To proceed, you can explicitly allow reproducibility to ' - 'be broken by adding --not-reproducible') - sys.exit(1) - - reproducible = False - else: - reproducible = True - - try: - nb_info = notebooks[detector][caltype] - except KeyError: - print("Not one of the known calibrations or detectors") - return 1 - - pre_notebooks = nb_info.get("pre_notebooks", []) - notebook = nb_info["notebook"] - dep_notebooks = nb_info.get("dep_notebooks", []) - concurrency = nb_info.get("concurrency", {'parameter': None}) + args, nb_details = parse_argv_and_load_nb(argv) + concurrency = nb_details.concurrency concurrency_par = args["concurrency_par"] or concurrency['parameter'] if concurrency_par == concurrency['parameter']: # Use the defaults from notebook.py to split the work into several jobs @@ -637,22 +591,14 @@ def run(argv=None): # don't use the associated settings from there. concurrency_defval = concurrency_func = None - - notebook_path = Path(PKG_DIR, notebook) - nb = nbformat.read(notebook_path, as_version=4) - - # extend parameters if needed - ext_func = nb_info.get("extend parms", None) - if ext_func is not None: - extend_params(nb, ext_func, argv) - - parms = extract_parameters(nb, lang='python') + notebook_path = nb_details.path + nb = nb_details.contents title, author = extract_title_author(nb) version = get_pycalib_version() if not title: - title = "{} {} Calibration".format(detector, caltype) + title = f"{nb_details.detector} {nb_details.caltype} Calibration" if not author: author = "anonymous" if not version: @@ -663,20 +609,24 @@ def run(argv=None): run_uuid = f"t{datetime.now().strftime('%y%m%d_%H%M%S')}" # check if concurrency parameter is given and we run concurrently - if not any(p.name == "parameter" for p in parms) and concurrency_par is not None: + if not any(p.name == "parameter" for p in nb_details.default_params) and concurrency_par is not None: msg = f"Notebook cannot be run concurrently: no {concurrency_par} parameter" warnings.warn(msg, RuntimeWarning) # If not explicitly specified, use a new profile for ipcluster - if args.get("cluster_profile") in {None, parser.get_default("cluster_profile")}: - args['cluster_profile'] = "slurm_prof_{}".format(run_uuid) + default_params_by_name = {p.name: p.value for p in nb_details.default_params} + if 'cluster_profile' in default_params_by_name: + if args.get("cluster_profile") == default_params_by_name["cluster_profile"]: + args['cluster_profile'] = "slurm_prof_{}".format(run_uuid) # create a temporary output directory to work in - run_tmp_path = os.path.join(temp_path, f"slurm_out_{detector}_{caltype}_{run_uuid}") + run_tmp_path = os.path.join( + temp_path, f"slurm_out_{nb_details.detector}_{nb_details.caltype}_{run_uuid}" + ) os.makedirs(run_tmp_path) # Write all input parameters to rst file to be included to final report - parms = parameter_values(parms, **args) + parms = parameter_values(nb_details.default_params, **args) make_par_table(parms, run_tmp_path) # And save the invocation of this script itself save_executed_command(run_tmp_path, version, argv) @@ -688,7 +638,7 @@ def run(argv=None): ) # wait on all jobs to run and then finalize the run by creating a report from the notebooks - out_path = Path(default_report_path) / detector.upper() / caltype.upper() / datetime.now().isoformat() + out_path = Path(default_report_path) / nb_details.detector / nb_details.caltype / datetime.now().isoformat() if try_report_to_output: if "out_folder" in args: out_path = Path(args["out_folder"]).absolute() @@ -712,11 +662,9 @@ def run(argv=None): print(f"report_to path contained no path, saving report in '{out_path}'") report_to = out_path / report_to - user_venv = nb_info.get("user", {}).get("venv") - if user_venv: - user_venv = Path(user_venv.format(**args)) - print("Using specified venv:", user_venv) - python_exe = str(user_venv / 'bin' / 'python') + if nb_details.user_venv: + print("Using specified venv:", nb_details.user_venv) + python_exe = str(nb_details.user_venv / 'bin' / 'python') else: python_exe = python_path @@ -731,7 +679,7 @@ def run(argv=None): metadata["pycalibration-version"] = version metadata["report-path"] = f"{report_to}.pdf" if report_to \ else '# REPORT SKIPPED #' - metadata['reproducible'] = reproducible + metadata['reproducible'] = not args['not_reproducible'] metadata["concurrency"] = { 'parameter': concurrency_par, 'default': concurrency_defval, @@ -766,8 +714,7 @@ def run(argv=None): pre_jobs = [] cluster_cores = concurrency.get("cluster cores", 8) # Check if there are pre-notebooks - for pre_notebook in pre_notebooks: - pre_notebook_path = Path(PKG_DIR, pre_notebook) + for pre_notebook_path in nb_details.pre_paths: lead_nb = nbformat.read(pre_notebook_path, as_version=4) pre_jobs.append(prepare_job( run_tmp_path, lead_nb, pre_notebook_path, args, @@ -794,7 +741,7 @@ def run(argv=None): defcval = get_par_attr(parms, concurrency_par, 'value') if defcval is not None: print(f"Concurrency parameter '{concurrency_par}' " - f"is taken from '{notebook}'") + f"is taken from '{notebook_path}'") cvals = defcval if isinstance(defcval, (list, tuple)) else [defcval] if concurrency_func: @@ -834,8 +781,7 @@ def run(argv=None): # Prepare dependent notebooks (e.g. summaries after correction) dep_jobs = [] - for i, dep_notebook in enumerate(dep_notebooks): - dep_notebook_path = Path(PKG_DIR, dep_notebook) + for i, dep_notebook_path in enumerate(nb_details.dep_paths): dep_nb = nbformat.read(dep_notebook_path, as_version=4) dep_jobs.append(prepare_job( run_tmp_path, dep_nb, dep_notebook_path, args, @@ -859,7 +805,7 @@ def run(argv=None): print("Files prepared, not executing now (--prepare-only option).") print("To execute the notebooks, run:") rpt_opts = '' - if user_venv is not None: + if nb_details.user_venv is not None: rpt_opts = f'--python {python_exe}' print(f" python -m xfel_calibrate.repeat {run_tmp_path} {rpt_opts}") return @@ -867,7 +813,7 @@ def run(argv=None): submission_time = datetime.now().strftime('%Y-%m-%dT%H:%M:%S') # Launch the calibration work - if sequential: + if args["no_cluster_job"]: print("Running notebooks directly, not via Slurm...") errors = job_chain.run_direct() joblist = [] @@ -899,7 +845,7 @@ def run(argv=None): fmt_args=fmt_args, temp_path=run_tmp_path, job_list=joblist, - sequential=sequential, + sequential=args["no_cluster_job"], )) if any(j is not None for j in joblist): diff --git a/src/xfel_calibrate/nb_args.py b/src/xfel_calibrate/nb_args.py index 45aa4a9f3..c8b9f7ebf 100644 --- a/src/xfel_calibrate/nb_args.py +++ b/src/xfel_calibrate/nb_args.py @@ -8,9 +8,12 @@ import re import string import sys import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple import nbformat -from nbparameterise import extract_parameters +from nbparameterise import extract_parameters, Parameter from .notebooks import notebooks @@ -148,20 +151,18 @@ def consolize_name(name): return name.replace("_", "-") -def add_args_from_nb(nb, parser, cvar=None, no_required=False): +def add_args_from_nb(parms, parser, cvar=None, no_required=False): """Add argparse arguments for parameters in the first cell of a notebook. Uses nbparameterise to extract the parameter information. Each foo_bar parameter gets a --foo-bar command line option. Boolean parameters get a pair of flags like --abc and --no-abc. - :param nb: NotebookNode object representing a loaded .ipynb file - :param parser: argparse.ArgumentParser instance + :param parms: List of nbparameterise Parameter objects + :param parser: argparse.ArgumentParser instance to modify :param str cvar: Name of the concurrency parameter. :param bool no_required: If True, none of the added options are required. """ - parser.description = make_epilog(nb) - parms = extract_parameters(nb, lang='python') for p in parms: helpstr = ("Default: %(default)s" if not p.comment @@ -333,8 +334,22 @@ def extend_params(nb, extend_func_name, argv): fcc["source"] += "\n" + extension -def make_extended_parser(argv) -> argparse.ArgumentParser: - """Create an ArgumentParser using information from the notebooks""" +@dataclass +class NBDetails: + """Details of a notebook-based workflow to run""" + detector: str # e.g. AGIPD + caltype: str # e.g. CORRECT + path: Path + pre_paths: List[Path] # Notebooks to run before the main notebook + dep_paths: List[Path] # Notebooks to run after the main notebooks + contents: nbformat.NotebookNode + default_params: List[Parameter] + concurrency: Dict[str, Any] # Contents as in notebooks.py + user_venv: Optional[Path] + + +def parse_argv_and_load_nb(argv) -> Tuple[Dict, NBDetails]: + """Parse command-line arguments for xfel-calibrate to run a notebook""" # extend the parser according to user input # the first case is if a detector was given, but no calibration type if len(argv) == 3 and "-h" in argv[2]: @@ -377,9 +392,11 @@ def make_extended_parser(argv) -> argparse.ArgumentParser: nb = nbformat.read(nbpath, as_version=4) msg += make_epilog(nb, caltype=caltype) - return make_initial_parser(epilog=msg) + make_initial_parser(epilog=msg).parse_args(argv[1:]) + sys.exit() # parse_args should already exit for --help elif len(argv) <= 3: - return make_initial_parser() + make_initial_parser().parse_args(argv[1:]) + sys.exit() # parse_args should already exit - not enough args # A detector and type was given. We derive the arguments # from the corresponding notebook @@ -390,6 +407,31 @@ def make_extended_parser(argv) -> argparse.ArgumentParser: print("Not one of the known calibrations or detectors") sys.exit(1) + # Pick out any arguments that may prevent reproducibility from + # working, sorted alphabetically and converted back to their + # canonical representation. + not_reproducible_args = sorted( + ('--' + x.replace('_', '-') + for x in ['skip_env_freeze'] + if args[x])) + + # If any of these arguments are set, present a warning. + if not_reproducible_args: + print('WARNING: One or more command line arguments ({}) may prevent ' + 'this specific correction result from being reproducible based ' + 'on its metadata. It may not be possible to restore identical ' + 'output data files when they have been deleted or lost. Please ' + 'ensure that the data retention policy of the chosen storage ' + 'location is sufficient for your ' + 'needs.'.format(', '.join(not_reproducible_args))) + + if not args['not_reproducible']: + # If not explicitly specified that reproducibility may be + # broken, remind the user and exit. + print('To proceed, you can explicitly allow reproducibility to ' + 'be broken by adding --not-reproducible') + sys.exit(1) + if nb_info["notebook"]: notebook = os.path.join(PKG_DIR, nb_info["notebook"]) else: @@ -413,7 +455,7 @@ def make_extended_parser(argv) -> argparse.ArgumentParser: nb_info["notebook"] = user_notebook_path.format(**vars(user_notebook_args)) notebook = nb_info["notebook"] - cvar = nb_info.get("concurrency", {}).get("parameter", None) + concurrency = nb_info.get("concurrency", {'parameter': None}) nb = nbformat.read(notebook, as_version=4) @@ -422,8 +464,26 @@ def make_extended_parser(argv) -> argparse.ArgumentParser: if ext_func is not None: extend_params(nb, ext_func, argv) - # No extend parms function - add statically defined parameters from the - # first code cell + default_params = extract_parameters(nb, lang='python') + parser = make_initial_parser() - add_args_from_nb(nb, parser, cvar=cvar) - return parser + parser.description = make_epilog(nb) + add_args_from_nb(default_params, parser, cvar=concurrency['parameter']) + + arg_dict = deconsolize_args(vars(parser.parse_args(argv[1:]))) + + user_venv = nb_info.get("user", {}).get("venv") + if user_venv is not None: + user_venv = Path(user_venv.format(**arg_dict)) + + return arg_dict, NBDetails( + detector=args.detector.upper(), + caltype=args.type.upper(), + path=Path(notebook), + pre_paths=[Path(PKG_DIR, p) for p in nb_info.get('pre_notebooks', [])], + dep_paths=[Path(PKG_DIR, p) for p in nb_info.get('dep_notebooks', [])], + contents=nb, + default_params=default_params, + concurrency=concurrency, + user_venv=user_venv, + ) -- GitLab