Skip to content
Snippets Groups Projects
Commit a75e9ee1 authored by Thomas Kluyver's avatar Thomas Kluyver
Browse files

Fix CalibrationData.require_constants()

parent 54f8fc83
No related branches found
No related tags found
1 merge request!885Revised CalCat API
...@@ -218,6 +218,31 @@ class SingleConstantVersion: ...@@ -218,6 +218,31 @@ class SingleConstantVersion:
return self.dataset_obj(caldb_root)[:] return self.dataset_obj(caldb_root)[:]
def prepare_selection(
module_details, module_nums=None, aggregator_names=None, qm_names=None
):
aggs = aggregator_names # Shorter name -> fewer multi-line statements
n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
if n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregator_names & qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in module_details}
return [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in module_details}
return [by_qm[s]["karabo_da"] for s in qm_names]
elif aggs is not None:
miss = set(aggs) - {m["karabo_da"] for m in module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
return aggs
else:
raise TypeError("select_modules() requires an argument")
@dataclass @dataclass
class ModulesConstantVersions: class ModulesConstantVersions:
"""A group of similar CCVs for several modules of one detector""" """A group of similar CCVs for several modules of one detector"""
...@@ -228,28 +253,9 @@ class ModulesConstantVersions: ...@@ -228,28 +253,9 @@ class ModulesConstantVersions:
def select_modules( def select_modules(
self, module_nums=None, *, aggregator_names=None, qm_names=None self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "ModulesConstantVersions": ) -> "ModulesConstantVersions":
aggs = aggregator_names # Shorter name -> fewer multi-line statements aggs = prepare_selection(
n_specified = sum( self.module_details, module_nums, aggregator_names, qm_names
[module_nums is not None, aggs is not None, qm_names is not None]
) )
if n_specified < 1:
raise TypeError("select_modules() requires an argument")
elif n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregators & qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in self.module_details}
aggs = [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in self.module_details}
aggs = [by_qm[s]["karabo_da"] for s in qm_names]
elif aggs is not None:
miss = set(aggs) - {m["karabo_da"] for m in self.module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs} d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
mods = [m for m in self.module_details if m["karabo_da"] in d] mods = [m for m in self.module_details if m["karabo_da"] in d]
return ModulesConstantVersions(d, mods) return ModulesConstantVersions(d, mods)
...@@ -505,7 +511,7 @@ class CalibrationData(Mapping): ...@@ -505,7 +511,7 @@ class CalibrationData(Mapping):
mods = set(self.aggregator_names) mods = set(self.aggregator_names)
for cal_type in calibrations: for cal_type in calibrations:
mods.intersection_update(self[cal_type].constants) mods.intersection_update(self[cal_type].constants)
return self.select_modules(mods) return self.select_modules(aggregator_names=mods)
def select_calibrations(self, calibrations, require_all=True): def select_calibrations(self, calibrations, require_all=True):
if require_all: if require_all:
...@@ -524,11 +530,14 @@ class CalibrationData(Mapping): ...@@ -524,11 +530,14 @@ class CalibrationData(Mapping):
def select_modules( def select_modules(
self, module_nums=None, *, aggregator_names=None, qm_names=None self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "CalibrationData": ) -> "CalibrationData":
# Validate the specified modules against those we know about.
# Each specific constant type may have only a subset of these modules.
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
)
mcvs = { mcvs = {
cal_type: mcv.select_modules( cal_type: mcv.select_modules(
module_nums=module_nums, aggregator_names=set(aggs).intersection(mcv.aggregator_names)
aggregator_names=aggregator_names,
qm_names=qm_names,
) )
for (cal_type, mcv) in self.constant_groups.items() for (cal_type, mcv) in self.constant_groups.items()
} }
......
...@@ -103,7 +103,11 @@ def test_LPD_constant_missing(): ...@@ -103,7 +103,11 @@ def test_LPD_constant_missing():
assert lpd_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)] assert lpd_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)]
# When we look at a specific constant, module LPD05 is missing # When we look at a specific constant, module LPD05 is missing
assert lpd_cd["Offset"].module_nums == list(range(0, 5)) + list(range(6, 16)) modnos_w_constant = list(range(0, 5)) + list(range(6, 16))
assert lpd_cd["Offset"].module_nums == modnos_w_constant
# Test CalibrationData.require_constant()
assert lpd_cd.require_calibrations(["Offset"]).module_nums == modnos_w_constant
@pytest.mark.xfail @pytest.mark.xfail
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment