Skip to content
Snippets Groups Projects
Commit ac8dab49 authored by Thomas Kluyver's avatar Thomas Kluyver
Browse files

Add methods for loading constants

parent afab4fc3
No related branches found
No related tags found
1 merge request!885Revised CalCat API
......@@ -4,13 +4,14 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import h5py
import pasha as psh
from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion
from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm
global_client = None
class ModuleNameError(KeyError):
def __init__(self, name):
......@@ -20,6 +21,9 @@ class ModuleNameError(KeyError):
return f"No module named {self.name!r}"
global_client = None
def get_client():
global global_client
if global_client is None:
......@@ -43,6 +47,24 @@ def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
)
_default_caldb_root = ...
def _get_default_caldb_root():
global _default_caldb_root
if _default_caldb_root is ...:
onc_path = Path("/common/cal/caldb_store")
maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store")
if onc_path.is_dir():
_default_caldb_root = onc_path
elif maxwell_path.is_dir():
_default_caldb_root = maxwell_path
else:
_default_caldb_root = None
return _default_caldb_root
@dataclass
class SingleConstantVersion:
"""A Calibration Constant Version for 1 detector module"""
......@@ -76,6 +98,18 @@ class SingleConstantVersion:
physical_name=ccv["physical_detector_unit"]["physical_name"],
)
def dataset_obj(self, caldb_root=None):
if caldb_root is not None:
caldb_root = Path(caldb_root)
else:
caldb_root = _get_default_caldb_root()
f = h5py.File(caldb_root / self.path, "r")
return f[self.dataset]["data"]
def ndarray(self, caldb_root=None):
return self.dataset_obj(caldb_root)[:]
@dataclass
class ModulesConstantVersions:
......@@ -133,7 +167,7 @@ class CalibrationData(Mapping):
api = CalCatApi(client or get_client())
detector_id = api.detector(detector_name)['id']
detector_id = api.detector(detector_name)["id"]
all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
if modules is None:
modules = sorted(all_modules)
......@@ -249,10 +283,13 @@ class CalibrationData(Mapping):
return type(self)(d, self.aggregators)
def select_modules(self, aggregators):
return type(self)({
return type(self)(
{
cal_type: mcv.select_modules(aggregators).constants
for (cal_type, mcv) in self.constant_groups.items()
}, sorted(aggregators))
},
sorted(aggregators),
)
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
......@@ -268,6 +305,27 @@ class CalibrationData(Mapping):
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
class ConditionsBase:
calibration_types = {} # For subclasses: {calibration: [parameter names]}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment