From b3066fb107c77489036bfe992ccea6f9c8073526 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Tue, 2 Jan 2024 12:12:21 +0000
Subject: [PATCH] Create MultiModuleConstant on demand

---
 src/cal_tools/calcat_interface2.py | 59 +++++++++++++++---------------
 1 file changed, 29 insertions(+), 30 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index c0da9dcc8..bb1cf1f7f 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -253,6 +253,7 @@ class MultiModuleConstant:
 
     constants: Dict[str, SingleConstant]  # Keys e.g. 'LPD00'
     module_details: List[Dict]
+    detector_name: str  # e.g. 'HED_DET_AGIPD500K2G'
 
     def select_modules(
         self, module_nums=None, *, aggregator_names=None, qm_names=None
@@ -262,7 +263,7 @@ class MultiModuleConstant:
         )
         d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
         mods = [m for m in self.module_details if m["karabo_da"] in d]
-        return MultiModuleConstant(d, mods)
+        return MultiModuleConstant(d, mods, self.detector_name)
 
     # These properties label only the modules we have constants for, which may
     # be a subset of what's in module_details
@@ -371,6 +372,7 @@ class CalibrationData(Mapping):
     """Collected constants for a given detector"""
 
     def __init__(self, constant_groups, module_details, detector_name):
+        # {calibration: {karabo_da: SingleConstant}}
         self.constant_groups = constant_groups
         self.module_details = module_details
         self.detector_name = detector_name
@@ -456,11 +458,7 @@ class CalibrationData(Mapping):
                 const_group = constant_groups.setdefault(cal_type, {})
                 const_group[aggr] = SingleConstant.from_response(ccv)
 
-        mmcs = {
-            const_type: MultiModuleConstant(d, module_details)
-            for const_type, d in constant_groups.items()
-        }
-        return cls(mmcs, module_details, detector_name)
+        return cls(constant_groups, module_details, detector_name)
 
     @classmethod
     def from_report(
@@ -509,14 +507,12 @@ class CalibrationData(Mapping):
         det_name = detector_id_to_name(det_ids.pop(), client)
 
         module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"])
-        mmcs = {
-            const_type: MultiModuleConstant(d, module_details)
-            for const_type, d in constant_groups.items()
-        }
-        return cls(mmcs, module_details, det_name)
+        return cls(constant_groups, module_details, det_name)
 
     def __getitem__(self, key) -> MultiModuleConstant:
-        return self.constant_groups[key]
+        return MultiModuleConstant(
+            self.constant_groups[key], self.module_details, self.detector_name
+        )
 
     def __iter__(self):
         return iter(self.constant_groups)
@@ -564,25 +560,29 @@ class CalibrationData(Mapping):
         aggs = prepare_selection(
             self.module_details, module_nums, aggregator_names, qm_names
         )
-        mmcs = {
-            cal_type: mmc.select_modules(
-                aggregator_names=set(aggs).intersection(mmc.aggregator_names)
-            )
-            for (cal_type, mmc) in self.constant_groups.items()
-        }
-        aggs = set().union(*[c.aggregator_names for c in mmcs.values()])
-        module_details = [m for m in self.module_details if m["karabo_da"] in aggs]
-        return type(self)(mmcs, module_details, self.detector_name)
+        constant_groups = {}
+        matched_aggregators = set()
+        for cal_type, const_group in self.constant_groups.items():
+            constant_groups[cal_type] = d = {
+                aggr: const for (aggr, const) in const_group.items() if aggr in aggs
+            }
+            matched_aggregators.update(d.keys())
+        module_details = [
+            m for m in self.module_details if m["karabo_da"] in matched_aggregators
+        ]
+        return type(self)(constant_groups, module_details, self.detector_name)
 
     def select_calibrations(self, calibrations) -> "CalibrationData":
-        mmcs = {c: self.constant_groups[c] for c in calibrations}
-        return type(self)(mmcs, self.module_details, self.detector_name)
+        const_groups = {c: self.constant_groups[c] for c in calibrations}
+        return type(self)(const_groups, self.module_details, self.detector_name)
 
     def merge(self, *others: "CalibrationData") -> "CalibrationData":
         det_names = set(cd.detector_name for cd in (self,) + others)
         if len(det_names) > 1:
-            raise Exception("Cannot merge calibration data for different "
-                            "detectors: " + ", ".join(sorted(det_names)))
+            raise Exception(
+                "Cannot merge calibration data for different "
+                "detectors: " + ", ".join(sorted(det_names))
+            )
         det_name = det_names.pop()
 
         cal_types = set(self.constant_groups)
@@ -608,15 +608,14 @@ class CalibrationData(Mapping):
 
         module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"])
 
-        mmcs = {}
+        constant_groups = {}
         for cal_type in cal_types:
-            d = {}
+            d = constant_groups[cal_type] = {}
             for caldata in (self,) + others:
                 if cal_type in caldata:
-                    d.update(caldata[cal_type].constants)
-            mmcs[cal_type] = MultiModuleConstant(d, module_details)
+                    d.update(caldata.constant_groups[cal_type])
 
-        return type(self)(mmcs, module_details, det_name)
+        return type(self)(constant_groups, module_details, det_name)
 
 
 class ConditionsBase:
-- 
GitLab