From a75e9ee1f61be187ce12482bcb8f0e5677e29cc8 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver <thomas@kluyver.me.uk> Date: Tue, 19 Dec 2023 16:44:10 +0000 Subject: [PATCH] Fix CalibrationData.require_constants() --- src/cal_tools/calcat_interface2.py | 59 +++++++++++++++++------------- tests/test_calcat_interface2.py | 6 ++- 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index ab8c7762c..18bc5a925 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -218,6 +218,31 @@ class SingleConstantVersion: return self.dataset_obj(caldb_root)[:] +def prepare_selection( + module_details, module_nums=None, aggregator_names=None, qm_names=None +): + aggs = aggregator_names # Shorter name -> fewer multi-line statements + n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None]) + if n_specified > 1: + raise TypeError( + "select_modules() accepts only one of module_nums, aggregator_names & qm_names" + ) + + if module_nums is not None: + by_mod_no = {m["module_number"]: m for m in module_details} + return [by_mod_no[n]["karabo_da"] for n in module_nums] + elif qm_names is not None: + by_qm = {m["virtual_device_name"]: m for m in module_details} + return [by_qm[s]["karabo_da"] for s in qm_names] + elif aggs is not None: + miss = set(aggs) - {m["karabo_da"] for m in module_details} + if miss: + raise KeyError("Aggregators not found: " + ", ".join(sorted(miss))) + return aggs + else: + raise TypeError("select_modules() requires an argument") + + @dataclass class ModulesConstantVersions: """A group of similar CCVs for several modules of one detector""" @@ -228,28 +253,9 @@ class ModulesConstantVersions: def select_modules( self, module_nums=None, *, aggregator_names=None, qm_names=None ) -> "ModulesConstantVersions": - aggs = aggregator_names # Shorter name -> fewer multi-line statements - n_specified = sum( - [module_nums is not None, aggs is not None, qm_names is not None] + aggs = prepare_selection( + self.module_details, module_nums, aggregator_names, qm_names ) - if n_specified < 1: - raise TypeError("select_modules() requires an argument") - elif n_specified > 1: - raise TypeError( - "select_modules() accepts only one of module_nums, aggregators & qm_names" - ) - - if module_nums is not None: - by_mod_no = {m["module_number"]: m for m in self.module_details} - aggs = [by_mod_no[n]["karabo_da"] for n in module_nums] - elif qm_names is not None: - by_qm = {m["virtual_device_name"]: m for m in self.module_details} - aggs = [by_qm[s]["karabo_da"] for s in qm_names] - elif aggs is not None: - miss = set(aggs) - {m["karabo_da"] for m in self.module_details} - if miss: - raise KeyError("Aggregators not found: " + ", ".join(sorted(miss))) - d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs} mods = [m for m in self.module_details if m["karabo_da"] in d] return ModulesConstantVersions(d, mods) @@ -505,7 +511,7 @@ class CalibrationData(Mapping): mods = set(self.aggregator_names) for cal_type in calibrations: mods.intersection_update(self[cal_type].constants) - return self.select_modules(mods) + return self.select_modules(aggregator_names=mods) def select_calibrations(self, calibrations, require_all=True): if require_all: @@ -524,11 +530,14 @@ class CalibrationData(Mapping): def select_modules( self, module_nums=None, *, aggregator_names=None, qm_names=None ) -> "CalibrationData": + # Validate the specified modules against those we know about. + # Each specific constant type may have only a subset of these modules. + aggs = prepare_selection( + self.module_details, module_nums, aggregator_names, qm_names + ) mcvs = { cal_type: mcv.select_modules( - module_nums=module_nums, - aggregator_names=aggregator_names, - qm_names=qm_names, + aggregator_names=set(aggs).intersection(mcv.aggregator_names) ) for (cal_type, mcv) in self.constant_groups.items() } diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py index 674b7ba72..d4a08eeeb 100644 --- a/tests/test_calcat_interface2.py +++ b/tests/test_calcat_interface2.py @@ -103,7 +103,11 @@ def test_LPD_constant_missing(): assert lpd_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)] # When we look at a specific constant, module LPD05 is missing - assert lpd_cd["Offset"].module_nums == list(range(0, 5)) + list(range(6, 16)) + modnos_w_constant = list(range(0, 5)) + list(range(6, 16)) + assert lpd_cd["Offset"].module_nums == modnos_w_constant + + # Test CalibrationData.require_constant() + assert lpd_cd.require_calibrations(["Offset"]).module_nums == modnos_w_constant @pytest.mark.xfail -- GitLab