From f5cc4a4123ad808b5ec4a214f35ea9afa52de6aa Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Tue, 2 Apr 2024 13:20:29 +0200
Subject: [PATCH] add tests for decide_run_parameters

---
 tests/test_webservice.py | 29 ++++++++++++++++++++++++-----
 webservice/webservice.py | 11 ++++-------
 2 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/tests/test_webservice.py b/tests/test_webservice.py
index 692e0ad12..feae3c386 100644
--- a/tests/test_webservice.py
+++ b/tests/test_webservice.py
@@ -1,25 +1,27 @@
-from collections import namedtuple
+import datetime as dt
 import logging
 import sys
-import datetime as dt
+from collections import namedtuple
 from pathlib import Path
 from unittest import mock
 
 import pytest
 
+from webservice.webservice import decide_run_parameters
+
 sys.path.insert(0, Path(__file__).parent / 'webservice')
 import webservice  # noqa: import not at top of file
 from webservice.messages import MigrationError  # noqa: import not at top
 from webservice.webservice import (  # noqa: import not at top of file
     check_files,
     check_run_type_skip,
+    config,
+    get_slurm_nice,
+    get_slurm_partition,
     merge,
     parse_config,
     run_action,
     wait_on_transfer,
-    get_slurm_partition,
-    get_slurm_nice,
-    config,
 )
 
 VALID_BEAMTIME = {
@@ -444,3 +446,20 @@ async def test_skip_runs_exception(return_value, status_code, caplog):
     assert "run information does not contain expected key" in caplog.text
     # And `False` should be returned
     assert ret == False
+
+
+def test_decide_run_parameters_single_run():
+    assert decide_run_parameters([123], 1) == {'run': 123}
+
+
+def test_decide_run_parameters_three_runs():
+    expected = {'run-high': 1, 'run-med': 2, 'run-low': 3}
+    assert decide_run_parameters([1, 2, 3], 3) == expected
+
+
+def test_decide_run_parameters_mismatched_runs():
+    with pytest.raises(ValueError):
+        decide_run_parameters([1, 2, 3, 4], 3)
+
+    with pytest.raises(ValueError):
+        decide_run_parameters([1, 2], 3)
diff --git a/webservice/webservice.py b/webservice/webservice.py
index 9cd06b81c..bd263c4c5 100644
--- a/webservice/webservice.py
+++ b/webservice/webservice.py
@@ -404,13 +404,10 @@ def decide_run_parameters(runs, expected_number_of_runs):
         dict: Mapping for run notebook parameters names and values.
     """
 
-    runs_dict = {}
-    if expected_number_of_runs == 3:
-        if len(runs) == 1:
-            return {'run-high': runs[0], 'run-med': '0', 'run-low': '0'}
-        else:  # len(runs) == 3
-            return {
-                'run-high': runs[0], 'run-med': runs[1], 'run-low': runs[2]}
+    if expected_number_of_runs == 3 and len(runs) == 1:
+        return {'run-high': runs[0], 'run-med': '0', 'run-low': '0'}
+    elif expected_number_of_runs == 3 and len(runs) == 3:
+        return {'run-high': runs[0], 'run-med': runs[1], 'run-low': runs[2]}
     elif expected_number_of_runs == 2 and len(runs) == 2:
         return {'run-high': runs[0], 'run-low': runs[1]}
     elif len(runs) == 1:  # single run operation modes.
-- 
GitLab