diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 18bc5a9251a4313261760732d92dfb171f912946..084f194b0951754854f805f4746b25d7c3d3d4da 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -282,28 +282,39 @@ class ModulesConstantVersions: if m["karabo_da"] in self.constants ] - def ndarray(self, caldb_root=None): + def ndarray(self, caldb_root=None, *, parallel=0): eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root) shape = (len(self.constants),) + eg_dset.shape - arr = np.zeros(shape, eg_dset.dtype) - for i, agg in enumerate(self.aggregator_names): - dset = self.constants[agg].dataset_obj(caldb_root) - dset.read_direct(arr[i]) + + if parallel > 0: + load_ctx = psh.ProcessContext(num_workers=parallel) + else: + load_ctx = psh.SerialContext() + + arr = psh.alloc(shape, eg_dset.dtype, fill=0) + + def _load_constant_dataset(wid, index, mod): + dset = self.constants[mod].dataset_obj(caldb_root) + dset.read_direct(arr[index]) + + load_ctx.map(_load_constant_dataset, self.aggregator_names) return arr - def xarray(self, module_naming="da", caldb_root=None): + def xarray(self, module_naming="modnum", caldb_root=None, *, parallel=0): import xarray - if module_naming == "da": + if module_naming == "aggregator": modules = self.aggregator_names - elif module_naming == "modno": + elif module_naming == "modnum": modules = self.module_nums elif module_naming == "qm": modules = self.qm_names else: - raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'") + raise ValueError( + f"{module_naming=} (must be 'aggregator', 'modnum' or 'qm'" + ) - ndarr = self.ndarray(caldb_root) + ndarr = self.ndarray(caldb_root, parallel=parallel) # Dimension labels dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)] diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py index d4a08eeeb9fcc71e6a1e28f8d7567bc2c611b435..2e0fcf15f74df856e86d18168a550d09bfc3c8d3 100644 --- a/tests/test_calcat_interface2.py +++ b/tests/test_calcat_interface2.py @@ -1,4 +1,6 @@ +import numpy as np import pytest +import xarray as xr from cal_tools.calcat_interface2 import ( CalibrationData, @@ -59,6 +61,33 @@ def test_AGIPD_CalibrationData_metadata_SPB(): assert isinstance(agipd_cd["Offset"].constants["AGIPD00"], SingleConstantVersion) +@pytest.mark.requires_gpfs +def test_AGIPD_load_data(): + cond = AGIPDConditions( + sensor_bias_voltage=300, + memory_cells=352, + acquisition_rate=1.1, + integration_time=12, + source_energy=9.2, + gain_mode=0, + gain_setting=0, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "SPB_DET_AGIPD1M-1", + event_at="2020-01-07 13:26:48.00", + ) + arr = agipd_cd["Offset"].select_modules(list(range(4))).xarray() + assert arr.shape == (4, 128, 512, 352, 3) + assert arr.dims[0] == 'module' + np.testing.assert_array_equal(arr.coords['module'], np.arange(0, 4)) + assert arr.dtype == np.float64 + + # Load parallel + arr_p = agipd_cd["Offset"].select_modules(list(range(4))).xarray(parallel=4) + xr.testing.assert_identical(arr_p, arr) + + @pytest.mark.requires_gpfs def test_DSSC_modules_missing(): dssc_cd = CalibrationData.from_condition(