From a75e9ee1f61be187ce12482bcb8f0e5677e29cc8 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Tue, 19 Dec 2023 16:44:10 +0000
Subject: [PATCH] Fix CalibrationData.require_constants()

---
 src/cal_tools/calcat_interface2.py | 59 +++++++++++++++++-------------
 tests/test_calcat_interface2.py    |  6 ++-
 2 files changed, 39 insertions(+), 26 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index ab8c7762c..18bc5a925 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -218,6 +218,31 @@ class SingleConstantVersion:
         return self.dataset_obj(caldb_root)[:]
 
 
+def prepare_selection(
+    module_details, module_nums=None, aggregator_names=None, qm_names=None
+):
+    aggs = aggregator_names  # Shorter name -> fewer multi-line statements
+    n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
+    if n_specified > 1:
+        raise TypeError(
+            "select_modules() accepts only one of module_nums, aggregator_names & qm_names"
+        )
+
+    if module_nums is not None:
+        by_mod_no = {m["module_number"]: m for m in module_details}
+        return [by_mod_no[n]["karabo_da"] for n in module_nums]
+    elif qm_names is not None:
+        by_qm = {m["virtual_device_name"]: m for m in module_details}
+        return [by_qm[s]["karabo_da"] for s in qm_names]
+    elif aggs is not None:
+        miss = set(aggs) - {m["karabo_da"] for m in module_details}
+        if miss:
+            raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
+        return aggs
+    else:
+        raise TypeError("select_modules() requires an argument")
+
+
 @dataclass
 class ModulesConstantVersions:
     """A group of similar CCVs for several modules of one detector"""
@@ -228,28 +253,9 @@ class ModulesConstantVersions:
     def select_modules(
         self, module_nums=None, *, aggregator_names=None, qm_names=None
     ) -> "ModulesConstantVersions":
-        aggs = aggregator_names  # Shorter name -> fewer multi-line statements
-        n_specified = sum(
-            [module_nums is not None, aggs is not None, qm_names is not None]
+        aggs = prepare_selection(
+            self.module_details, module_nums, aggregator_names, qm_names
         )
-        if n_specified < 1:
-            raise TypeError("select_modules() requires an argument")
-        elif n_specified > 1:
-            raise TypeError(
-                "select_modules() accepts only one of module_nums, aggregators & qm_names"
-            )
-
-        if module_nums is not None:
-            by_mod_no = {m["module_number"]: m for m in self.module_details}
-            aggs = [by_mod_no[n]["karabo_da"] for n in module_nums]
-        elif qm_names is not None:
-            by_qm = {m["virtual_device_name"]: m for m in self.module_details}
-            aggs = [by_qm[s]["karabo_da"] for s in qm_names]
-        elif aggs is not None:
-            miss = set(aggs) - {m["karabo_da"] for m in self.module_details}
-            if miss:
-                raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
-
         d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
         mods = [m for m in self.module_details if m["karabo_da"] in d]
         return ModulesConstantVersions(d, mods)
@@ -505,7 +511,7 @@ class CalibrationData(Mapping):
         mods = set(self.aggregator_names)
         for cal_type in calibrations:
             mods.intersection_update(self[cal_type].constants)
-        return self.select_modules(mods)
+        return self.select_modules(aggregator_names=mods)
 
     def select_calibrations(self, calibrations, require_all=True):
         if require_all:
@@ -524,11 +530,14 @@ class CalibrationData(Mapping):
     def select_modules(
         self, module_nums=None, *, aggregator_names=None, qm_names=None
     ) -> "CalibrationData":
+        # Validate the specified modules against those we know about.
+        # Each specific constant type may have only a subset of these modules.
+        aggs = prepare_selection(
+            self.module_details, module_nums, aggregator_names, qm_names
+        )
         mcvs = {
             cal_type: mcv.select_modules(
-                module_nums=module_nums,
-                aggregator_names=aggregator_names,
-                qm_names=qm_names,
+                aggregator_names=set(aggs).intersection(mcv.aggregator_names)
             )
             for (cal_type, mcv) in self.constant_groups.items()
         }
diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py
index 674b7ba72..d4a08eeeb 100644
--- a/tests/test_calcat_interface2.py
+++ b/tests/test_calcat_interface2.py
@@ -103,7 +103,11 @@ def test_LPD_constant_missing():
     assert lpd_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)]
 
     # When we look at a specific constant, module LPD05 is missing
-    assert lpd_cd["Offset"].module_nums == list(range(0, 5)) + list(range(6, 16))
+    modnos_w_constant = list(range(0, 5)) + list(range(6, 16))
+    assert lpd_cd["Offset"].module_nums == modnos_w_constant
+
+    # Test CalibrationData.require_constant()
+    assert lpd_cd.require_calibrations(["Offset"]).module_nums == modnos_w_constant
 
 
 @pytest.mark.xfail
-- 
GitLab