diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index a2af181460c37e2210cd706ec83cba61d2b23afb..be9ff416c27700db859fba880a04badbf482fb06 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -7,6 +7,7 @@ from functools import lru_cache from pathlib import Path from typing import Dict, List, Optional, Union from urllib.parse import urljoin +from warnings import warn import h5py import numpy as np @@ -542,19 +543,43 @@ class CalibrationData(Mapping): module_details = [m for m in self.module_details if m["karabo_da"] in aggs] return type(self)(mcvs, module_details) - def merge(self, *others: "CalibrationData") -> "CalibrationData": - d = {} - for cal_type, mcv in self.constant_groups.items(): - d[cal_type] = mcv.constants.copy() - for other in others: - if cal_type in other: - d[cal_type].update(other[cal_type].constants) + def select_calibrations(self, calibrations) -> "CalibrationData": + mcvs = {c: self.constant_groups[c] for c in calibrations} + return type(self)(mcvs, self.module_details) + def merge(self, *others: "CalibrationData") -> "CalibrationData": + cal_types = set(self.constant_groups) aggregators = set(self.aggregator_names) + pdus_d = {m["karabo_da"]: m for m in self.module_details} for other in others: + cal_types.update(other.constant_groups) aggregators.update(other.aggregator_names) + for md in other.module_details: + # Warn if constants don't refer to same modules + md_da = md["karabo_da"] + if md_da in pdus_d: + pdu_a = pdus_d[md_da]["physical_name"] + pdu_b = md["physical_name"] + if pdu_a != pdu_b: + warn( + f"Merging constants with different modules for " + f"{md_da}: {pdu_a!r} != {pdu_b!r}", + stacklevel=2, + ) + else: + pdus_d[md_da] = md + + module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"]) + + mcvs = {} + for cal_type in cal_types: + d = {} + for caldata in (self,) + others: + if cal_type in caldata: + d.update(caldata[cal_type].constants) + mcvs[cal_type] = ModulesConstantVersions(d, module_details) - return type(self)(d, sorted(aggregators)) + return type(self)(mcvs, module_details) class ConditionsBase: diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py index 2e0fcf15f74df856e86d18168a550d09bfc3c8d3..43b0a000dea4746adbd6d27cf8b93de79d261306 100644 --- a/tests/test_calcat_interface2.py +++ b/tests/test_calcat_interface2.py @@ -35,6 +35,39 @@ def test_AGIPD_CalibrationData_metadata(): assert isinstance(agipd_cd["Offset"].constants["AGIPD00"], SingleConstantVersion) +@pytest.mark.requires_gpfs +def test_AGIPD_merge(): + cond = AGIPDConditions( + # From: https://in.xfel.eu/calibration/calibration_constants/5754#condition + sensor_bias_voltage=300, # V + memory_cells=352, + acquisition_rate=2.2, # MHz + gain_mode=0, + gain_setting=1, + integration_time=12, + source_energy=9.2, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "MID_DET_AGIPD1M-1", + event_at="2022-09-01 13:26:48.00", + calibrations=["Offset", "SlopesFF"], + ) + + modnos_q1 = list(range(0, 4)) + modnos_q4 = list(range(12, 16)) + merged = agipd_cd.select_modules(modnos_q1).merge(agipd_cd.select_modules(modnos_q4)) + assert merged.module_nums == modnos_q1 + modnos_q4 + + offset_only = agipd_cd.select_calibrations(["Offset"]) + slopes_only = agipd_cd.select_calibrations(["SlopesFF"]) + assert set(offset_only) == {"Offset"} + assert set(slopes_only) == {"SlopesFF"} + merged_cals = offset_only.merge(slopes_only) + assert set(merged_cals) == {"Offset", "SlopesFF"} + assert merged_cals.module_nums == list(range(16)) + + @pytest.mark.requires_gpfs def test_AGIPD_CalibrationData_metadata_SPB(): """Test CalibrationData with AGIPD condition""" @@ -79,8 +112,8 @@ def test_AGIPD_load_data(): ) arr = agipd_cd["Offset"].select_modules(list(range(4))).xarray() assert arr.shape == (4, 128, 512, 352, 3) - assert arr.dims[0] == 'module' - np.testing.assert_array_equal(arr.coords['module'], np.arange(0, 4)) + assert arr.dims[0] == "module" + np.testing.assert_array_equal(arr.coords["module"], np.arange(0, 4)) assert arr.dtype == np.float64 # Load parallel