From 536bfed688d6c2e0f65df4a0deb935a852501d4c Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Thu, 7 Oct 2021 17:21:53 +0100
Subject: [PATCH] Accept argv as input to various functions

---
 src/xfel_calibrate/calibrate.py | 18 +++++++++++-------
 src/xfel_calibrate/nb_args.py   | 19 +++++++++----------
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py
index 8a36c93c4..27403447d 100755
--- a/src/xfel_calibrate/calibrate.py
+++ b/src/xfel_calibrate/calibrate.py
@@ -5,6 +5,7 @@ import locale
 import math
 import os
 import re
+import shlex
 import shutil
 import stat
 import sys
@@ -260,7 +261,7 @@ def run_finalize(fmt_args, temp_path, job_list, sequential=False):
     return jobid
 
 
-def save_executed_command(run_tmp_path, version):
+def save_executed_command(run_tmp_path, version, argv):
     """
     Create a file with string used to execute `xfel_calibrate`
 
@@ -271,7 +272,7 @@ def save_executed_command(run_tmp_path, version):
     f_name = os.path.join(run_tmp_path, "run_calibrate.sh")
     with open(f_name, "w") as finfile:
         finfile.write(f'# pycalibration version: {version}\n')
-        finfile.write(' '.join(sys.argv))
+        finfile.write(shlex.join(argv))
 
 
 class SlurmOptions:
@@ -572,13 +573,16 @@ def make_par_table(parms, run_tmp_path: str):
         finfile.write(textwrap.dedent(tmpl.render(p=col_type, lines=l_parms)))
 
 
-def run():
+def run(argv=None):
     """ Run a calibration task with parser arguments """
     # Ensure files are opened as UTF-8 by default, regardless of environment.
     locale.setlocale(locale.LC_CTYPE, ('en_US', 'UTF-8'))
 
-    parser = make_extended_parser()
-    args = deconsolize_args(vars(parser.parse_args()))
+    if argv is None:
+        argv = sys.argv
+
+    parser = make_extended_parser(argv)
+    args = deconsolize_args(vars(parser.parse_args(argv[1:])))
     detector = args["detector"].upper()
     caltype = args["type"].upper()
     sequential = args["no_cluster_job"]
@@ -640,7 +644,7 @@ def run():
     # extend parameters if needed
     ext_func = nb_info.get("extend parms", None)
     if ext_func is not None:
-        extend_params(nb, ext_func)
+        extend_params(nb, ext_func, argv)
 
     parms = extract_parameters(nb, lang='python')
 
@@ -675,7 +679,7 @@ def run():
     parms = parameter_values(parms, **args)
     make_par_table(parms, run_tmp_path)
     # And save the invocation of this script itself
-    save_executed_command(run_tmp_path, version)
+    save_executed_command(run_tmp_path, version, argv)
 
     # Copy the bash script which will be used to run notebooks
     shutil.copy2(
diff --git a/src/xfel_calibrate/nb_args.py b/src/xfel_calibrate/nb_args.py
index c32eaa152..45aa4a9f3 100644
--- a/src/xfel_calibrate/nb_args.py
+++ b/src/xfel_calibrate/nb_args.py
@@ -305,7 +305,7 @@ def deconsolize_args(args):
     return {k.replace("-", "_"): v for k, v in args.items()}
 
 
-def extend_params(nb, extend_func_name):
+def extend_params(nb, extend_func_name, argv):
     """Add parameters in the first code cell by calling a function in the notebook
     """
     func = get_notebook_function(nb, extend_func_name)
@@ -320,7 +320,7 @@ def extend_params(nb, extend_func_name):
     # Make a temporary parser that won't exit if it sees -h or --help
     pre_parser = make_initial_parser(add_help=False)
     add_args_from_nb(nb, pre_parser, no_required=True)
-    known, _ = pre_parser.parse_known_args()
+    known, _ = pre_parser.parse_known_args(argv[1:])
     args = deconsolize_args(vars(known))
 
     df = {}
@@ -333,13 +333,12 @@ def extend_params(nb, extend_func_name):
     fcc["source"] += "\n" + extension
 
 
-def make_extended_parser() -> argparse.ArgumentParser:
+def make_extended_parser(argv) -> argparse.ArgumentParser:
     """Create an ArgumentParser using information from the notebooks"""
-
     # extend the parser according to user input
     # the first case is if a detector was given, but no calibration type
-    if len(sys.argv) == 3 and "-h" in sys.argv[2]:
-        detector = sys.argv[1].upper()
+    if len(argv) == 3 and "-h" in argv[2]:
+        detector = argv[1].upper()
         try:
             det_notebooks = notebooks[detector]
         except KeyError:
@@ -379,12 +378,12 @@ def make_extended_parser() -> argparse.ArgumentParser:
                 msg += make_epilog(nb, caltype=caltype)
 
         return make_initial_parser(epilog=msg)
-    elif len(sys.argv) <= 3:
+    elif len(argv) <= 3:
         return make_initial_parser()
 
     # A detector and type was given. We derive the arguments
     # from the corresponding notebook
-    args, _ = make_initial_parser(add_help=False).parse_known_args()
+    args, _ = make_initial_parser(add_help=False).parse_known_args(argv[1:])
     try:
         nb_info = notebooks[args.detector.upper()][args.type.upper()]
     except KeyError:
@@ -409,7 +408,7 @@ def make_extended_parser() -> argparse.ArgumentParser:
         for var in user_notebook_variables:
             user_notebook_parser.add_argument(f"--{var}")
 
-        user_notebook_args, _ = user_notebook_parser.parse_known_args()
+        user_notebook_args, _ = user_notebook_parser.parse_known_args(argv[1:])
 
         nb_info["notebook"] = user_notebook_path.format(**vars(user_notebook_args))
         notebook = nb_info["notebook"]
@@ -421,7 +420,7 @@ def make_extended_parser() -> argparse.ArgumentParser:
     # extend parameters if needed
     ext_func = nb_info.get("extend parms", None)
     if ext_func is not None:
-        extend_params(nb, ext_func)
+        extend_params(nb, ext_func, argv)
 
     # No extend parms function - add statically defined parameters from the
     # first code cell
-- 
GitLab