From 8eb5c411d5b12bf08073aac346f7b6bac0f1feeb Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Tue, 1 Aug 2023 17:16:33 +0100
Subject: [PATCH] Various additions to multi-module API

---
 src/cal_tools/calcat_interface2.py | 41 +++++++++++++++++++++++++-----
 tests/test_calcat_interface2.py    |  2 +-
 2 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index 9f387caa1..5101cc4d9 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -82,11 +82,11 @@ class ModulesConstantVersions:
 
     @property
     def aggregators(self):
-        return list(self.constants)
+        return sorted(self.constants)
 
     @property
     def module_nums(self):
-        return [int(da[-2:]) for da in self.constants]
+        return [int(da[-2:]) for da in self.aggregators]
 
     @property
     def qm_names(self):
@@ -187,8 +187,8 @@ class CalibrationData(Mapping):
 
         return cls(d)
 
-    def __getitem__(self, key):
-        return self.constant_groups
+    def __getitem__(self, key) -> ModulesConstantVersions:
+        return self.constant_groups[key]
 
     def __iter__(self):
         return iter(self.constant_groups)
@@ -196,6 +196,28 @@ class CalibrationData(Mapping):
     def __len__(self):
         return len(self.constant_groups)
 
+    @property
+    def aggregators(self):
+        names = set()
+        for mcv in self.constant_groups.values():
+            names.update(mcv.aggregators)
+        return sorted(names)
+
+    @property
+    def module_nums(self):
+        return [int(da[-2:]) for da in self.aggregators]
+
+    @property
+    def qm_names(self):
+        return [module_index_to_qm(n) for n in self.module_nums]
+
+    def require_calibrations(self, *calibrations):
+        """Drop any modules missing the specified constant types"""
+        mods = set(self.aggregators)
+        for cal_type in calibrations:
+            mods.intersection_update(self[cal_type].constants)
+        return self.select_modules(mods)
+
     def select_calibrations(self, *calibrations, require_all=True):
         if require_all:
             missing = set(calibrations) - set(self.constant_groups)
@@ -218,9 +240,14 @@ class CalibrationData(Mapping):
             }
         )
 
-    def merge(self, other: "CalibrationData") -> "CalibrationData":
-        d = self.constant_groups.copy()
-        d.update(other.constant_groups)
+    def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData":
+        d = {}
+        for cal_type, mcv in self.constant_groups.items():
+            d[cal_type] = mcv.constants.copy()
+            for other in others:
+                if cal_type in other:
+                    d[cal_type].update(other[cal_type].constants)
+
         return type(self)(d)
 
 
diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py
index f3507d0e7..b372476e7 100644
--- a/tests/test_calcat_interface2.py
+++ b/tests/test_calcat_interface2.py
@@ -35,5 +35,5 @@ def test_AGIPD_CalibrationData_report():
     # Report ID: https://in.xfel.eu/calibration/reports/3757
     agipd_cd = CalibrationData.from_report(3757)
     assert set(agipd_cd) == {'Offset', 'Noise', 'ThresholdsDark', 'BadPixelsDark'}
-    assert agipd_cd['Offset'].aggregators == [f'AGIPD{n:02}' for n in range(16)]
+    assert agipd_cd.aggregators == [f'AGIPD{n:02}' for n in range(16)]
     assert isinstance(agipd_cd['Offset'].constants['AGIPD00'], SingleConstantVersion)
-- 
GitLab