From a537f3a37dea66192a0cce6a389ed817e97e9310 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Mon, 22 Jan 2024 10:40:54 +0000
Subject: [PATCH] Allow selecting constant & module directly from
 CalibrationData

---
 src/cal_tools/calcat_interface2.py | 25 +++++++++++++++++--------
 tests/test_calcat_interface2.py    |  7 ++++---
 2 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index 9af50846c..a69ea576d 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -330,16 +330,16 @@ class MultiModuleConstant(Mapping):
         return len(self.constants)
 
     def __getitem__(self, key):
+        if key in (None, ""):
+            raise KeyError(key)
+
         candidate_kdas = set()
         if key in self.constants:  # Karabo DA name, e.g. 'LPD00'
             candidate_kdas.add(key)
 
         for m in self.module_details:
-            if key in (
-                m["module_number"],
-                m["virtual_device_name"],
-                m["physical_name"],
-            ) and m['karabo_da'] in self.constants:
+            names = (m["module_number"], m["virtual_device_name"], m["physical_name"])
+            if key in names and m["karabo_da"] in self.constants:
                 candidate_kdas.add([m["karabo_da"]])
 
         if not candidate_kdas:
@@ -576,9 +576,15 @@ class CalibrationData(Mapping):
         return cls(constant_groups, module_details, det_name)
 
     def __getitem__(self, key) -> MultiModuleConstant:
-        return MultiModuleConstant(
-            self.constant_groups[key], self.module_details, self.detector_name, key
-        )
+        if isinstance(key, str):
+            return MultiModuleConstant(
+                self.constant_groups[key], self.module_details, self.detector_name, key
+            )
+        elif isinstance(key, tuple) and len(key) == 2:
+            cal_type, module = key
+            return self[cal_type][module]
+        else:
+            raise TypeError(f"Key should be string or 2-tuple (got {key!r})")
 
     def __iter__(self):
         return iter(self.constant_groups)
@@ -586,6 +592,9 @@ class CalibrationData(Mapping):
     def __len__(self):
         return len(self.constant_groups)
 
+    def __contains__(self, item):
+        return item in self.constant_groups
+
     def __repr__(self):
         return (
             f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py
index b6972345b..ca84e3a5c 100644
--- a/tests/test_calcat_interface2.py
+++ b/tests/test_calcat_interface2.py
@@ -33,7 +33,8 @@ def test_AGIPD_CalibrationData_metadata():
     assert agipd_cd.detector_name == "MID_DET_AGIPD1M-1"
     assert "Offset" in agipd_cd
     assert set(agipd_cd["Offset"].constants) == {f"AGIPD{m:02}" for m in range(16)}
-    assert isinstance(agipd_cd["Offset"].constants["AGIPD00"], SingleConstant)
+    assert isinstance(agipd_cd["Offset", "AGIPD00"], SingleConstant)
+    assert agipd_cd["Offset", "Q1M2"] == agipd_cd["Offset", "AGIPD01"]
 
 
 @pytest.mark.requires_gpfs
@@ -94,7 +95,7 @@ def test_AGIPD_CalibrationData_metadata_SPB():
     assert agipd_cd["Offset"].qm_names == [
         f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)
     ]
-    assert isinstance(agipd_cd["Offset"].constants["AGIPD00"], SingleConstant)
+    assert isinstance(agipd_cd["Offset", 0], SingleConstant)
 
 
 @pytest.mark.requires_gpfs
@@ -183,4 +184,4 @@ def test_AGIPD_CalibrationData_report():
     assert agipd_cd.detector_name == "SPB_DET_AGIPD1M-1"
     assert set(agipd_cd) == {"Offset", "Noise", "ThresholdsDark", "BadPixelsDark"}
     assert agipd_cd.aggregator_names == [f"AGIPD{n:02}" for n in range(16)]
-    assert isinstance(agipd_cd["Offset"].constants["AGIPD00"], SingleConstant)
+    assert isinstance(agipd_cd["Offset", "AGIPD00"], SingleConstant)
-- 
GitLab