From c2fa065aad475ba2f33dd0a0aed022ce078116e4 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Fri, 4 Aug 2023 17:18:15 +0100
Subject: [PATCH] Store module names (aggregators) explicitly in
 CalibrationData

---
 src/cal_tools/calcat_interface2.py | 35 +++++++++++++++++++++---------
 1 file changed, 25 insertions(+), 10 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index 456f85517..cdd652387 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -12,6 +12,13 @@ from .tools import module_index_to_qm
 
 global_client = None
 
+class ModuleNameError(KeyError):
+    def __init__(self, name):
+        self.name = name
+
+    def __str__(self):
+        return f"No module named {self.name!r}"
+
 
 def get_client():
     global global_client
@@ -96,11 +103,12 @@ class ModulesConstantVersions:
 class CalibrationData(Mapping):
     """Collected constants for a given detector"""
 
-    def __init__(self, constant_groups: Dict[str, Dict[str, SingleConstantVersion]]):
+    def __init__(self, constant_groups, aggregators):
         self.constant_groups = {
             const_type: ModulesConstantVersions(d)
             for const_type, d in constant_groups.items()
         }
+        self.aggregators = aggregators
 
     @classmethod
     def from_condition(
@@ -115,6 +123,8 @@ class CalibrationData(Mapping):
     ):
         if calibrations is None:
             calibrations = set(condition.calibration_types)
+        if pdu_snapshot_at is None:
+            pdu_snapshot_at = event_at
 
         cal_types_by_params_used = {}
         for cal_type, params in condition.calibration_types.items():
@@ -123,6 +133,16 @@ class CalibrationData(Mapping):
 
         api = CalCatApi(client or get_client())
 
+        detector_id = api.detector(detector_name)['id']
+        all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
+        if modules is None:
+            modules = sorted(all_modules)
+        else:
+            modules = sorted(modules)
+            for m in modules:
+                if m not in all_modules:
+                    raise ModuleNameError(m)
+
         d = {}
 
         for params, cal_types in cal_types_by_params_used.items():
@@ -151,7 +171,7 @@ class CalibrationData(Mapping):
                     ccv
                 )
 
-        res = cls(d)
+        res = cls(d, modules)
         if modules:
             res = res.select_modules(*modules)
         return res
@@ -178,15 +198,17 @@ class CalibrationData(Mapping):
             raise CalCatError(resp)
 
         d = {}
+        aggregators = set()
 
         for ccv in resp["data"]:
             aggr = ccv["physical_detector_unit"]["karabo_da"]
+            aggregators.add(aggr)
             cal_type = api.calibration_name(
                 ccv["calibration_constant"]["calibration_id"]
             )
             d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
 
-        return cls(d)
+        return cls(d, sorted(aggregators))
 
     def __getitem__(self, key) -> ModulesConstantVersions:
         return self.constant_groups[key]
@@ -197,13 +219,6 @@ class CalibrationData(Mapping):
     def __len__(self):
         return len(self.constant_groups)
 
-    @property
-    def aggregators(self):
-        names = set()
-        for mcv in self.constant_groups.values():
-            names.update(mcv.aggregators)
-        return sorted(names)
-
     @property
     def module_nums(self):
         return [int(da[-2:]) for da in self.aggregators]
-- 
GitLab