diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 456f85517473dc05aad59223d9061f0e74642d06..cdd65238747d69ef947c59d488cc5b06953281d2 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -12,6 +12,13 @@ from .tools import module_index_to_qm global_client = None +class ModuleNameError(KeyError): + def __init__(self, name): + self.name = name + + def __str__(self): + return f"No module named {self.name!r}" + def get_client(): global global_client @@ -96,11 +103,12 @@ class ModulesConstantVersions: class CalibrationData(Mapping): """Collected constants for a given detector""" - def __init__(self, constant_groups: Dict[str, Dict[str, SingleConstantVersion]]): + def __init__(self, constant_groups, aggregators): self.constant_groups = { const_type: ModulesConstantVersions(d) for const_type, d in constant_groups.items() } + self.aggregators = aggregators @classmethod def from_condition( @@ -115,6 +123,8 @@ class CalibrationData(Mapping): ): if calibrations is None: calibrations = set(condition.calibration_types) + if pdu_snapshot_at is None: + pdu_snapshot_at = event_at cal_types_by_params_used = {} for cal_type, params in condition.calibration_types.items(): @@ -123,6 +133,16 @@ class CalibrationData(Mapping): api = CalCatApi(client or get_client()) + detector_id = api.detector(detector_name)['id'] + all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at) + if modules is None: + modules = sorted(all_modules) + else: + modules = sorted(modules) + for m in modules: + if m not in all_modules: + raise ModuleNameError(m) + d = {} for params, cal_types in cal_types_by_params_used.items(): @@ -151,7 +171,7 @@ class CalibrationData(Mapping): ccv ) - res = cls(d) + res = cls(d, modules) if modules: res = res.select_modules(*modules) return res @@ -178,15 +198,17 @@ class CalibrationData(Mapping): raise CalCatError(resp) d = {} + aggregators = set() for ccv in resp["data"]: aggr = ccv["physical_detector_unit"]["karabo_da"] + aggregators.add(aggr) cal_type = api.calibration_name( ccv["calibration_constant"]["calibration_id"] ) d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) - return cls(d) + return cls(d, sorted(aggregators)) def __getitem__(self, key) -> ModulesConstantVersions: return self.constant_groups[key] @@ -197,13 +219,6 @@ class CalibrationData(Mapping): def __len__(self): return len(self.constant_groups) - @property - def aggregators(self): - names = set() - for mcv in self.constant_groups.values(): - names.update(mcv.aggregators) - return sorted(names) - @property def module_nums(self): return [int(da[-2:]) for da in self.aggregators]