From 963331af656a9d87d5b4500e476340bcf9a76a6f Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Fri, 4 Aug 2023 17:22:55 +0100
Subject: [PATCH] Fix passing through aggregators when selecting & merging

---
 src/cal_tools/calcat_interface2.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index cdd652387..1e679d2a9 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -246,17 +246,15 @@ class CalibrationData(Mapping):
             if cal_type in calibrations
         }
         # TODO: missing for some modules?
-        return type(self)(d)
+        return type(self)(d, self.aggregators)
 
     def select_modules(self, *aggregators):
-        return type(self)(
-            {
+        return type(self)({
                 cal_type: mcv.select_modules(*aggregators).constants
                 for (cal_type, mcv) in self.constant_groups.items()
-            }
-        )
+        }, sorted(aggregators))
 
-    def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData":
+    def merge(self, *others: "CalibrationData") -> "CalibrationData":
         d = {}
         for cal_type, mcv in self.constant_groups.items():
             d[cal_type] = mcv.constants.copy()
@@ -264,7 +262,11 @@ class CalibrationData(Mapping):
                 if cal_type in other:
                     d[cal_type].update(other[cal_type].constants)
 
-        return type(self)(d)
+        aggregators = set(self.aggregators)
+        for other in others:
+            aggregators.update(other.aggregators)
+
+        return type(self)(d, sorted(aggregators))
 
 
 class ConditionsBase:
-- 
GitLab