From 963331af656a9d87d5b4500e476340bcf9a76a6f Mon Sep 17 00:00:00 2001 From: Thomas Kluyver <thomas@kluyver.me.uk> Date: Fri, 4 Aug 2023 17:22:55 +0100 Subject: [PATCH] Fix passing through aggregators when selecting & merging --- src/cal_tools/calcat_interface2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index cdd652387..1e679d2a9 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -246,17 +246,15 @@ class CalibrationData(Mapping): if cal_type in calibrations } # TODO: missing for some modules? - return type(self)(d) + return type(self)(d, self.aggregators) def select_modules(self, *aggregators): - return type(self)( - { + return type(self)({ cal_type: mcv.select_modules(*aggregators).constants for (cal_type, mcv) in self.constant_groups.items() - } - ) + }, sorted(aggregators)) - def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData": + def merge(self, *others: "CalibrationData") -> "CalibrationData": d = {} for cal_type, mcv in self.constant_groups.items(): d[cal_type] = mcv.constants.copy() @@ -264,7 +262,11 @@ class CalibrationData(Mapping): if cal_type in other: d[cal_type].update(other[cal_type].constants) - return type(self)(d) + aggregators = set(self.aggregators) + for other in others: + aggregators.update(other.aggregators) + + return type(self)(d, sorted(aggregators)) class ConditionsBase: -- GitLab