diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 1e679d2a9a326f322cb12eb2c6a13c004709e232..4b28f8f6c0c96e2163bbe73765858c888bb46bcc 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -83,7 +83,7 @@ class ModulesConstantVersions: constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00' - def select_modules(self, *aggregators) -> "ModulesConstantVersions": + def select_modules(self, aggregators) -> "ModulesConstantVersions": d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} return ModulesConstantVersions(d) @@ -173,7 +173,7 @@ class CalibrationData(Mapping): res = cls(d, modules) if modules: - res = res.select_modules(*modules) + res = res.select_modules(modules) return res @classmethod @@ -227,14 +227,14 @@ class CalibrationData(Mapping): def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] - def require_calibrations(self, *calibrations): + def require_calibrations(self, calibrations): """Drop any modules missing the specified constant types""" mods = set(self.aggregators) for cal_type in calibrations: mods.intersection_update(self[cal_type].constants) return self.select_modules(mods) - def select_calibrations(self, *calibrations, require_all=True): + def select_calibrations(self, calibrations, require_all=True): if require_all: missing = set(calibrations) - set(self.constant_groups) if missing: @@ -248,9 +248,9 @@ class CalibrationData(Mapping): # TODO: missing for some modules? return type(self)(d, self.aggregators) - def select_modules(self, *aggregators): + def select_modules(self, aggregators): return type(self)({ - cal_type: mcv.select_modules(*aggregators).constants + cal_type: mcv.select_modules(aggregators).constants for (cal_type, mcv) in self.constant_groups.items() }, sorted(aggregators))