diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 9f387caa164d58c8822a9ac32f9b35468563e236..5101cc4d93588593b6c3e10f7487a757c71a4bf7 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -82,11 +82,11 @@ class ModulesConstantVersions: @property def aggregators(self): - return list(self.constants) + return sorted(self.constants) @property def module_nums(self): - return [int(da[-2:]) for da in self.constants] + return [int(da[-2:]) for da in self.aggregators] @property def qm_names(self): @@ -187,8 +187,8 @@ class CalibrationData(Mapping): return cls(d) - def __getitem__(self, key): - return self.constant_groups + def __getitem__(self, key) -> ModulesConstantVersions: + return self.constant_groups[key] def __iter__(self): return iter(self.constant_groups) @@ -196,6 +196,28 @@ class CalibrationData(Mapping): def __len__(self): return len(self.constant_groups) + @property + def aggregators(self): + names = set() + for mcv in self.constant_groups.values(): + names.update(mcv.aggregators) + return sorted(names) + + @property + def module_nums(self): + return [int(da[-2:]) for da in self.aggregators] + + @property + def qm_names(self): + return [module_index_to_qm(n) for n in self.module_nums] + + def require_calibrations(self, *calibrations): + """Drop any modules missing the specified constant types""" + mods = set(self.aggregators) + for cal_type in calibrations: + mods.intersection_update(self[cal_type].constants) + return self.select_modules(mods) + def select_calibrations(self, *calibrations, require_all=True): if require_all: missing = set(calibrations) - set(self.constant_groups) @@ -218,9 +240,14 @@ class CalibrationData(Mapping): } ) - def merge(self, other: "CalibrationData") -> "CalibrationData": - d = self.constant_groups.copy() - d.update(other.constant_groups) + def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData": + d = {} + for cal_type, mcv in self.constant_groups.items(): + d[cal_type] = mcv.constants.copy() + for other in others: + if cal_type in other: + d[cal_type].update(other[cal_type].constants) + return type(self)(d) diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py index f3507d0e71226b5f9e76f38f14eb85ad9b4ac157..b372476e7ba302fcb29ebafc0f56dd332bf809dc 100644 --- a/tests/test_calcat_interface2.py +++ b/tests/test_calcat_interface2.py @@ -35,5 +35,5 @@ def test_AGIPD_CalibrationData_report(): # Report ID: https://in.xfel.eu/calibration/reports/3757 agipd_cd = CalibrationData.from_report(3757) assert set(agipd_cd) == {'Offset', 'Noise', 'ThresholdsDark', 'BadPixelsDark'} - assert agipd_cd['Offset'].aggregators == [f'AGIPD{n:02}' for n in range(16)] + assert agipd_cd.aggregators == [f'AGIPD{n:02}' for n in range(16)] assert isinstance(agipd_cd['Offset'].constants['AGIPD00'], SingleConstantVersion)