From ac8dab49652d316dc16d7db0ee19800f8b51a42b Mon Sep 17 00:00:00 2001 From: Thomas Kluyver <thomas@kluyver.me.uk> Date: Thu, 24 Aug 2023 14:03:35 +0100 Subject: [PATCH] Add methods for loading constants --- src/cal_tools/calcat_interface2.py | 66 ++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 4b28f8f6c..5b0b827e4 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -4,13 +4,14 @@ from datetime import datetime from pathlib import Path from typing import Dict, Optional, Sequence, Union +import h5py +import pasha as psh from calibration_client import CalibrationClient from calibration_client.modules import CalibrationConstantVersion from .calcat_interface import CalCatApi, CalCatError from .tools import module_index_to_qm -global_client = None class ModuleNameError(KeyError): def __init__(self, name): @@ -20,6 +21,9 @@ class ModuleNameError(KeyError): return f"No module named {self.name!r}" +global_client = None + + def get_client(): global global_client if global_client is None: @@ -43,6 +47,24 @@ def setup_client(base_url, client_id, client_secret, user_email, **kwargs): ) +_default_caldb_root = ... + + +def _get_default_caldb_root(): + global _default_caldb_root + if _default_caldb_root is ...: + onc_path = Path("/common/cal/caldb_store") + maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store") + if onc_path.is_dir(): + _default_caldb_root = onc_path + elif maxwell_path.is_dir(): + _default_caldb_root = maxwell_path + else: + _default_caldb_root = None + + return _default_caldb_root + + @dataclass class SingleConstantVersion: """A Calibration Constant Version for 1 detector module""" @@ -76,6 +98,18 @@ class SingleConstantVersion: physical_name=ccv["physical_detector_unit"]["physical_name"], ) + def dataset_obj(self, caldb_root=None): + if caldb_root is not None: + caldb_root = Path(caldb_root) + else: + caldb_root = _get_default_caldb_root() + + f = h5py.File(caldb_root / self.path, "r") + return f[self.dataset]["data"] + + def ndarray(self, caldb_root=None): + return self.dataset_obj(caldb_root)[:] + @dataclass class ModulesConstantVersions: @@ -133,7 +167,7 @@ class CalibrationData(Mapping): api = CalCatApi(client or get_client()) - detector_id = api.detector(detector_name)['id'] + detector_id = api.detector(detector_name)["id"] all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at) if modules is None: modules = sorted(all_modules) @@ -249,10 +283,13 @@ class CalibrationData(Mapping): return type(self)(d, self.aggregators) def select_modules(self, aggregators): - return type(self)({ + return type(self)( + { cal_type: mcv.select_modules(aggregators).constants for (cal_type, mcv) in self.constant_groups.items() - }, sorted(aggregators)) + }, + sorted(aggregators), + ) def merge(self, *others: "CalibrationData") -> "CalibrationData": d = {} @@ -268,6 +305,27 @@ class CalibrationData(Mapping): return type(self)(d, sorted(aggregators)) + def load_all(self, caldb_root=None): + res = {} + + const_load_mp = psh.ProcessContext(num_workers=24) + keys = [] + for cal_type, mcv in self.constant_groups.items(): + res[cal_type] = {} + for module in mcv.aggregators: + dset = mcv.constants[module].dataset_obj(caldb_root) + res[cal_type][module] = const_load_mp.alloc( + shape=dset.shape, dtype=dset.dtype + ) + keys.append((cal_type, module)) + + def _load_constant_dataset(wid, index, key): + cal_type, mod = key + dset = self[cal_type].constants[mod].dataset_obj(caldb_root) + dset.read_direct(res[cal_type][mod]) + + const_load_mp.map(_load_constant_dataset, keys) + class ConditionsBase: calibration_types = {} # For subclasses: {calibration: [parameter names]} -- GitLab