diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index ceb9751b706da0f8e458a3d728fbce4e77440f9d..af6d76acd97bc663a10f00a524cb7f98d4efe605 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Dict, Optional, Sequence, Union import h5py +import numpy as np import pasha as psh from calibration_client import CalibrationClient from calibration_client.modules import CalibrationConstantVersion @@ -98,7 +99,7 @@ class SingleConstantVersion: physical_name=ccv["physical_detector_unit"]["physical_name"], ) - def dataset_obj(self, caldb_root=None): + def dataset_obj(self, caldb_root=None) -> h5py.Dataset: if caldb_root is not None: caldb_root = Path(caldb_root) else: @@ -133,6 +134,36 @@ class ModulesConstantVersions: def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] + def ndarray(self, caldb_root=None): + eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root) + shape = (len(self.constants),) + eg_dset.shape + arr = np.zeros(shape, eg_dset.dtype) + for i, agg in enumerate(self.aggregators): + dset = self.constants[agg].dataset_obj(caldb_root) + dset.read_direct(arr[i]) + return arr + + def xarray(self, module_naming="da", caldb_root=None): + import xarray + + if module_naming == "da": + modules = self.aggregators + elif module_naming == "modno": + modules = self.module_nums + elif module_naming == "qm": + modules = self.qm_names + else: + raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'") + + ndarr = self.ndarray(caldb_root) + + # Dimension labels + dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)] + coords = {"module": modules} + name = self.constants[self.aggregators[0]].constant_name + + return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name) + class CalibrationData(Mapping): """Collected constants for a given detector"""