Skip to content
Snippets Groups Projects
Commit 54f8fc83 authored by Thomas Kluyver's avatar Thomas Kluyver
Browse files

Fix select_modules()

parent 9847e518
No related branches found
No related tags found
1 merge request!885Revised CalCat API
...@@ -226,10 +226,11 @@ class ModulesConstantVersions: ...@@ -226,10 +226,11 @@ class ModulesConstantVersions:
module_details: List[Dict] module_details: List[Dict]
def select_modules( def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "ModulesConstantVersions": ) -> "ModulesConstantVersions":
aggs = aggregator_names # Shorter name -> fewer multi-line statements
n_specified = sum( n_specified = sum(
[module_nums is not None, aggregators is not None, qm_names is not None] [module_nums is not None, aggs is not None, qm_names is not None]
) )
if n_specified < 1: if n_specified < 1:
raise TypeError("select_modules() requires an argument") raise TypeError("select_modules() requires an argument")
...@@ -240,22 +241,23 @@ class ModulesConstantVersions: ...@@ -240,22 +241,23 @@ class ModulesConstantVersions:
if module_nums is not None: if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in self.module_details} by_mod_no = {m["module_number"]: m for m in self.module_details}
aggregators = [by_mod_no[n]["karabo_da"] for n in module_nums] aggs = [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None: elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in self.module_details} by_qm = {m["virtual_device_name"]: m for m in self.module_details}
aggregators = [by_qm[s]["karabo_da"] for s in qm_names] aggs = [by_qm[s]["karabo_da"] for s in qm_names]
elif aggregators is not None: elif aggs is not None:
miss = set(aggregators) - {m["karabo_da"] for m in self.module_details} miss = set(aggs) - {m["karabo_da"] for m in self.module_details}
if miss: if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss))) raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
return ModulesConstantVersions(d, self.module_details) mods = [m for m in self.module_details if m["karabo_da"] in d]
return ModulesConstantVersions(d, mods)
# These properties label only the modules we have constants for, which may # These properties label only the modules we have constants for, which may
# be a subset of what's in module_details # be a subset of what's in module_details
@property @property
def aggregators(self): def aggregator_names(self):
return sorted(self.constants) return sorted(self.constants)
@property @property
...@@ -275,10 +277,10 @@ class ModulesConstantVersions: ...@@ -275,10 +277,10 @@ class ModulesConstantVersions:
] ]
def ndarray(self, caldb_root=None): def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root) eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype) arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregators): for i, agg in enumerate(self.aggregator_names):
dset = self.constants[agg].dataset_obj(caldb_root) dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i]) dset.read_direct(arr[i])
return arr return arr
...@@ -287,7 +289,7 @@ class ModulesConstantVersions: ...@@ -287,7 +289,7 @@ class ModulesConstantVersions:
import xarray import xarray
if module_naming == "da": if module_naming == "da":
modules = self.aggregators modules = self.aggregator_names
elif module_naming == "modno": elif module_naming == "modno":
modules = self.module_nums modules = self.module_nums
elif module_naming == "qm": elif module_naming == "qm":
...@@ -300,7 +302,7 @@ class ModulesConstantVersions: ...@@ -300,7 +302,7 @@ class ModulesConstantVersions:
# Dimension labels # Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)] dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
coords = {"module": modules} coords = {"module": modules}
name = self.constants[self.aggregators[0]].constant_name name = self.constants[self.aggregator_names[0]].constant_name
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name) return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
...@@ -340,10 +342,7 @@ class CalibrationData(Mapping): ...@@ -340,10 +342,7 @@ class CalibrationData(Mapping):
"""Collected constants for a given detector""" """Collected constants for a given detector"""
def __init__(self, constant_groups, module_details): def __init__(self, constant_groups, module_details):
self.constant_groups = { self.constant_groups = constant_groups
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
self.module_details = module_details self.module_details = module_details
@staticmethod @staticmethod
...@@ -398,7 +397,7 @@ class CalibrationData(Mapping): ...@@ -398,7 +397,7 @@ class CalibrationData(Mapping):
if mod.get("module_number", -1) < 0: if mod.get("module_number", -1) < 0:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1]) mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
d = {} constant_groups = {}
for params, cal_types in cal_types_by_params_used.items(): for params, cal_types in cal_types_by_params_used.items():
condition_dict = condition.make_dict(params) condition_dict = condition.make_dict(params)
...@@ -424,11 +423,14 @@ class CalibrationData(Mapping): ...@@ -424,11 +423,14 @@ class CalibrationData(Mapping):
aggr = ccv["physical_detector_unit"]["karabo_da"] aggr = ccv["physical_detector_unit"]["karabo_da"]
cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]] cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]]
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response( const_group = constant_groups.setdefault(cal_type, {})
ccv const_group[aggr] = SingleConstantVersion.from_response(ccv)
)
return cls(d, module_details) mcvs = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
@classmethod @classmethod
def from_report( def from_report(
...@@ -447,16 +449,22 @@ class CalibrationData(Mapping): ...@@ -447,16 +449,22 @@ class CalibrationData(Mapping):
res = client.get("calibration_constant_versions", params) res = client.get("calibration_constant_versions", params)
d = {} constant_groups = {}
pdus = [] pdus = []
for ccv in res: for ccv in res:
pdus.append(ccv["physical_detector_unit"]) pdus.append(ccv["physical_detector_unit"])
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"]) cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
aggr = ccv["physical_detector_unit"]["karabo_da"] aggr = ccv["physical_detector_unit"]["karabo_da"]
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) const_group = constant_groups.setdefault(cal_type, {})
const_group[aggr] = SingleConstantVersion.from_response(ccv)
return cls(d, sorted(pdus, key=lambda d: d["karabo_da"])) module_details = sorted(pdus, key=lambda d: d["karabo_da"])
mcvs = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
def __getitem__(self, key) -> ModulesConstantVersions: def __getitem__(self, key) -> ModulesConstantVersions:
return self.constant_groups[key] return self.constant_groups[key]
...@@ -494,7 +502,7 @@ class CalibrationData(Mapping): ...@@ -494,7 +502,7 @@ class CalibrationData(Mapping):
def require_calibrations(self, calibrations): def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types""" """Drop any modules missing the specified constant types"""
mods = set(self.aggregators) mods = set(self.aggregator_names)
for cal_type in calibrations: for cal_type in calibrations:
mods.intersection_update(self[cal_type].constants) mods.intersection_update(self[cal_type].constants)
return self.select_modules(mods) return self.select_modules(mods)
...@@ -514,19 +522,19 @@ class CalibrationData(Mapping): ...@@ -514,19 +522,19 @@ class CalibrationData(Mapping):
return type(self)(d, self.aggregators) return type(self)(d, self.aggregators)
def select_modules( def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "CalibrationData": ) -> "CalibrationData":
return type(self)( mcvs = {
{ cal_type: mcv.select_modules(
cal_type: mcv.select_modules( module_nums=module_nums,
module_nums=module_nums, aggregator_names=aggregator_names,
aggregators=aggregators, qm_names=qm_names,
qm_names=qm_names, )
).constants for (cal_type, mcv) in self.constant_groups.items()
for (cal_type, mcv) in self.constant_groups.items() }
}, aggs = set().union(*[c.aggregator_names for c in mcvs.values()])
sorted(aggregators), module_details = [m for m in self.module_details if m["karabo_da"] in aggs]
) return type(self)(mcvs, module_details)
def merge(self, *others: "CalibrationData") -> "CalibrationData": def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {} d = {}
...@@ -536,9 +544,9 @@ class CalibrationData(Mapping): ...@@ -536,9 +544,9 @@ class CalibrationData(Mapping):
if cal_type in other: if cal_type in other:
d[cal_type].update(other[cal_type].constants) d[cal_type].update(other[cal_type].constants)
aggregators = set(self.aggregators) aggregators = set(self.aggregator_names)
for other in others: for other in others:
aggregators.update(other.aggregators) aggregators.update(other.aggregator_names)
return type(self)(d, sorted(aggregators)) return type(self)(d, sorted(aggregators))
......
...@@ -80,12 +80,12 @@ def test_DSSC_modules_missing(): ...@@ -80,12 +80,12 @@ def test_DSSC_modules_missing():
aggs_q3 = [f"DSSC{m:02}" for m in modnos_q3] aggs_q3 = [f"DSSC{m:02}" for m in modnos_q3]
qm_q3 = [f"Q3M{i}" for i in range(1, 5)] qm_q3 = [f"Q3M{i}" for i in range(1, 5)]
assert offset.select_modules(modnos_q3).module_nums == modnos_q3 assert offset.select_modules(modnos_q3).module_nums == modnos_q3
assert offset.select_modules(aggregators=aggs_q3).module_nums == modnos_q3 assert offset.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3
assert offset.select_modules(qm_names=qm_q3).module_nums == modnos_q3 assert offset.select_modules(qm_names=qm_q3).module_nums == modnos_q3
# test CalibrationData.select_modules() # test CalibrationData.select_modules()
assert dssc_cd.select_modules(modnos_q3).module_nums == modnos_q3 assert dssc_cd.select_modules(modnos_q3).module_nums == modnos_q3
assert dssc_cd.select_modules(aggregators=aggs_q3).module_nums == modnos_q3 assert dssc_cd.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3
assert dssc_cd.select_modules(qm_names=qm_q3).module_nums == modnos_q3 assert dssc_cd.select_modules(qm_names=qm_q3).module_nums == modnos_q3
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment