diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 94969690d237c379c74964858cff0dfd2d7dd9f4..ab8c7762c6f0d0fee26846464b6a1f4e1799ea9c 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -226,10 +226,11 @@ class ModulesConstantVersions: module_details: List[Dict] def select_modules( - self, module_nums=None, *, aggregators=None, qm_names=None + self, module_nums=None, *, aggregator_names=None, qm_names=None ) -> "ModulesConstantVersions": + aggs = aggregator_names # Shorter name -> fewer multi-line statements n_specified = sum( - [module_nums is not None, aggregators is not None, qm_names is not None] + [module_nums is not None, aggs is not None, qm_names is not None] ) if n_specified < 1: raise TypeError("select_modules() requires an argument") @@ -240,22 +241,23 @@ class ModulesConstantVersions: if module_nums is not None: by_mod_no = {m["module_number"]: m for m in self.module_details} - aggregators = [by_mod_no[n]["karabo_da"] for n in module_nums] + aggs = [by_mod_no[n]["karabo_da"] for n in module_nums] elif qm_names is not None: by_qm = {m["virtual_device_name"]: m for m in self.module_details} - aggregators = [by_qm[s]["karabo_da"] for s in qm_names] - elif aggregators is not None: - miss = set(aggregators) - {m["karabo_da"] for m in self.module_details} + aggs = [by_qm[s]["karabo_da"] for s in qm_names] + elif aggs is not None: + miss = set(aggs) - {m["karabo_da"] for m in self.module_details} if miss: raise KeyError("Aggregators not found: " + ", ".join(sorted(miss))) - d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} - return ModulesConstantVersions(d, self.module_details) + d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs} + mods = [m for m in self.module_details if m["karabo_da"] in d] + return ModulesConstantVersions(d, mods) # These properties label only the modules we have constants for, which may # be a subset of what's in module_details @property - def aggregators(self): + def aggregator_names(self): return sorted(self.constants) @property @@ -275,10 +277,10 @@ class ModulesConstantVersions: ] def ndarray(self, caldb_root=None): - eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root) + eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root) shape = (len(self.constants),) + eg_dset.shape arr = np.zeros(shape, eg_dset.dtype) - for i, agg in enumerate(self.aggregators): + for i, agg in enumerate(self.aggregator_names): dset = self.constants[agg].dataset_obj(caldb_root) dset.read_direct(arr[i]) return arr @@ -287,7 +289,7 @@ class ModulesConstantVersions: import xarray if module_naming == "da": - modules = self.aggregators + modules = self.aggregator_names elif module_naming == "modno": modules = self.module_nums elif module_naming == "qm": @@ -300,7 +302,7 @@ class ModulesConstantVersions: # Dimension labels dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)] coords = {"module": modules} - name = self.constants[self.aggregators[0]].constant_name + name = self.constants[self.aggregator_names[0]].constant_name return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name) @@ -340,10 +342,7 @@ class CalibrationData(Mapping): """Collected constants for a given detector""" def __init__(self, constant_groups, module_details): - self.constant_groups = { - const_type: ModulesConstantVersions(d, module_details) - for const_type, d in constant_groups.items() - } + self.constant_groups = constant_groups self.module_details = module_details @staticmethod @@ -398,7 +397,7 @@ class CalibrationData(Mapping): if mod.get("module_number", -1) < 0: mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1]) - d = {} + constant_groups = {} for params, cal_types in cal_types_by_params_used.items(): condition_dict = condition.make_dict(params) @@ -424,11 +423,14 @@ class CalibrationData(Mapping): aggr = ccv["physical_detector_unit"]["karabo_da"] cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]] - d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response( - ccv - ) + const_group = constant_groups.setdefault(cal_type, {}) + const_group[aggr] = SingleConstantVersion.from_response(ccv) - return cls(d, module_details) + mcvs = { + const_type: ModulesConstantVersions(d, module_details) + for const_type, d in constant_groups.items() + } + return cls(mcvs, module_details) @classmethod def from_report( @@ -447,16 +449,22 @@ class CalibrationData(Mapping): res = client.get("calibration_constant_versions", params) - d = {} + constant_groups = {} pdus = [] for ccv in res: pdus.append(ccv["physical_detector_unit"]) cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"]) aggr = ccv["physical_detector_unit"]["karabo_da"] - d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) + const_group = constant_groups.setdefault(cal_type, {}) + const_group[aggr] = SingleConstantVersion.from_response(ccv) - return cls(d, sorted(pdus, key=lambda d: d["karabo_da"])) + module_details = sorted(pdus, key=lambda d: d["karabo_da"]) + mcvs = { + const_type: ModulesConstantVersions(d, module_details) + for const_type, d in constant_groups.items() + } + return cls(mcvs, module_details) def __getitem__(self, key) -> ModulesConstantVersions: return self.constant_groups[key] @@ -494,7 +502,7 @@ class CalibrationData(Mapping): def require_calibrations(self, calibrations): """Drop any modules missing the specified constant types""" - mods = set(self.aggregators) + mods = set(self.aggregator_names) for cal_type in calibrations: mods.intersection_update(self[cal_type].constants) return self.select_modules(mods) @@ -514,19 +522,19 @@ class CalibrationData(Mapping): return type(self)(d, self.aggregators) def select_modules( - self, module_nums=None, *, aggregators=None, qm_names=None + self, module_nums=None, *, aggregator_names=None, qm_names=None ) -> "CalibrationData": - return type(self)( - { - cal_type: mcv.select_modules( - module_nums=module_nums, - aggregators=aggregators, - qm_names=qm_names, - ).constants - for (cal_type, mcv) in self.constant_groups.items() - }, - sorted(aggregators), - ) + mcvs = { + cal_type: mcv.select_modules( + module_nums=module_nums, + aggregator_names=aggregator_names, + qm_names=qm_names, + ) + for (cal_type, mcv) in self.constant_groups.items() + } + aggs = set().union(*[c.aggregator_names for c in mcvs.values()]) + module_details = [m for m in self.module_details if m["karabo_da"] in aggs] + return type(self)(mcvs, module_details) def merge(self, *others: "CalibrationData") -> "CalibrationData": d = {} @@ -536,9 +544,9 @@ class CalibrationData(Mapping): if cal_type in other: d[cal_type].update(other[cal_type].constants) - aggregators = set(self.aggregators) + aggregators = set(self.aggregator_names) for other in others: - aggregators.update(other.aggregators) + aggregators.update(other.aggregator_names) return type(self)(d, sorted(aggregators)) diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py index 6cfc68bac579e6c1c080743d591532acab49ee2c..674b7ba72f9242c773e5692531f7f2c2baf42857 100644 --- a/tests/test_calcat_interface2.py +++ b/tests/test_calcat_interface2.py @@ -80,12 +80,12 @@ def test_DSSC_modules_missing(): aggs_q3 = [f"DSSC{m:02}" for m in modnos_q3] qm_q3 = [f"Q3M{i}" for i in range(1, 5)] assert offset.select_modules(modnos_q3).module_nums == modnos_q3 - assert offset.select_modules(aggregators=aggs_q3).module_nums == modnos_q3 + assert offset.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3 assert offset.select_modules(qm_names=qm_q3).module_nums == modnos_q3 # test CalibrationData.select_modules() assert dssc_cd.select_modules(modnos_q3).module_nums == modnos_q3 - assert dssc_cd.select_modules(aggregators=aggs_q3).module_nums == modnos_q3 + assert dssc_cd.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3 assert dssc_cd.select_modules(qm_names=qm_q3).module_nums == modnos_q3