Skip to content
Snippets Groups Projects
Commit c2fa065a authored by Thomas Kluyver's avatar Thomas Kluyver
Browse files

Store module names (aggregators) explicitly in CalibrationData

parent 3fce0af0
1 merge request!885Revised CalCat API
......@@ -12,6 +12,13 @@ from .tools import module_index_to_qm
global_client = None
class ModuleNameError(KeyError):
def __init__(self, name):
self.name = name
def __str__(self):
return f"No module named {self.name!r}"
def get_client():
global global_client
......@@ -96,11 +103,12 @@ class ModulesConstantVersions:
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups: Dict[str, Dict[str, SingleConstantVersion]]):
def __init__(self, constant_groups, aggregators):
self.constant_groups = {
const_type: ModulesConstantVersions(d)
for const_type, d in constant_groups.items()
}
self.aggregators = aggregators
@classmethod
def from_condition(
......@@ -115,6 +123,8 @@ class CalibrationData(Mapping):
):
if calibrations is None:
calibrations = set(condition.calibration_types)
if pdu_snapshot_at is None:
pdu_snapshot_at = event_at
cal_types_by_params_used = {}
for cal_type, params in condition.calibration_types.items():
......@@ -123,6 +133,16 @@ class CalibrationData(Mapping):
api = CalCatApi(client or get_client())
detector_id = api.detector(detector_name)['id']
all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
if modules is None:
modules = sorted(all_modules)
else:
modules = sorted(modules)
for m in modules:
if m not in all_modules:
raise ModuleNameError(m)
d = {}
for params, cal_types in cal_types_by_params_used.items():
......@@ -151,7 +171,7 @@ class CalibrationData(Mapping):
ccv
)
res = cls(d)
res = cls(d, modules)
if modules:
res = res.select_modules(*modules)
return res
......@@ -178,15 +198,17 @@ class CalibrationData(Mapping):
raise CalCatError(resp)
d = {}
aggregators = set()
for ccv in resp["data"]:
aggr = ccv["physical_detector_unit"]["karabo_da"]
aggregators.add(aggr)
cal_type = api.calibration_name(
ccv["calibration_constant"]["calibration_id"]
)
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
return cls(d)
return cls(d, sorted(aggregators))
def __getitem__(self, key) -> ModulesConstantVersions:
return self.constant_groups[key]
......@@ -197,13 +219,6 @@ class CalibrationData(Mapping):
def __len__(self):
return len(self.constant_groups)
@property
def aggregators(self):
names = set()
for mcv in self.constant_groups.values():
names.update(mcv.aggregators)
return sorted(names)
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment