From ac8dab49652d316dc16d7db0ee19800f8b51a42b Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Thu, 24 Aug 2023 14:03:35 +0100
Subject: [PATCH] Add methods for loading constants

---
 src/cal_tools/calcat_interface2.py | 66 ++++++++++++++++++++++++++++--
 1 file changed, 62 insertions(+), 4 deletions(-)

diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py
index 4b28f8f6c..5b0b827e4 100644
--- a/src/cal_tools/calcat_interface2.py
+++ b/src/cal_tools/calcat_interface2.py
@@ -4,13 +4,14 @@ from datetime import datetime
 from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 
+import h5py
+import pasha as psh
 from calibration_client import CalibrationClient
 from calibration_client.modules import CalibrationConstantVersion
 
 from .calcat_interface import CalCatApi, CalCatError
 from .tools import module_index_to_qm
 
-global_client = None
 
 class ModuleNameError(KeyError):
     def __init__(self, name):
@@ -20,6 +21,9 @@ class ModuleNameError(KeyError):
         return f"No module named {self.name!r}"
 
 
+global_client = None
+
+
 def get_client():
     global global_client
     if global_client is None:
@@ -43,6 +47,24 @@ def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
     )
 
 
+_default_caldb_root = ...
+
+
+def _get_default_caldb_root():
+    global _default_caldb_root
+    if _default_caldb_root is ...:
+        onc_path = Path("/common/cal/caldb_store")
+        maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store")
+        if onc_path.is_dir():
+            _default_caldb_root = onc_path
+        elif maxwell_path.is_dir():
+            _default_caldb_root = maxwell_path
+        else:
+            _default_caldb_root = None
+
+    return _default_caldb_root
+
+
 @dataclass
 class SingleConstantVersion:
     """A Calibration Constant Version for 1 detector module"""
@@ -76,6 +98,18 @@ class SingleConstantVersion:
             physical_name=ccv["physical_detector_unit"]["physical_name"],
         )
 
+    def dataset_obj(self, caldb_root=None):
+        if caldb_root is not None:
+            caldb_root = Path(caldb_root)
+        else:
+            caldb_root = _get_default_caldb_root()
+
+        f = h5py.File(caldb_root / self.path, "r")
+        return f[self.dataset]["data"]
+
+    def ndarray(self, caldb_root=None):
+        return self.dataset_obj(caldb_root)[:]
+
 
 @dataclass
 class ModulesConstantVersions:
@@ -133,7 +167,7 @@ class CalibrationData(Mapping):
 
         api = CalCatApi(client or get_client())
 
-        detector_id = api.detector(detector_name)['id']
+        detector_id = api.detector(detector_name)["id"]
         all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
         if modules is None:
             modules = sorted(all_modules)
@@ -249,10 +283,13 @@ class CalibrationData(Mapping):
         return type(self)(d, self.aggregators)
 
     def select_modules(self, aggregators):
-        return type(self)({
+        return type(self)(
+            {
                 cal_type: mcv.select_modules(aggregators).constants
                 for (cal_type, mcv) in self.constant_groups.items()
-        }, sorted(aggregators))
+            },
+            sorted(aggregators),
+        )
 
     def merge(self, *others: "CalibrationData") -> "CalibrationData":
         d = {}
@@ -268,6 +305,27 @@ class CalibrationData(Mapping):
 
         return type(self)(d, sorted(aggregators))
 
+    def load_all(self, caldb_root=None):
+        res = {}
+
+        const_load_mp = psh.ProcessContext(num_workers=24)
+        keys = []
+        for cal_type, mcv in self.constant_groups.items():
+            res[cal_type] = {}
+            for module in mcv.aggregators:
+                dset = mcv.constants[module].dataset_obj(caldb_root)
+                res[cal_type][module] = const_load_mp.alloc(
+                    shape=dset.shape, dtype=dset.dtype
+                )
+                keys.append((cal_type, module))
+
+        def _load_constant_dataset(wid, index, key):
+            cal_type, mod = key
+            dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
+            dset.read_direct(res[cal_type][mod])
+
+        const_load_mp.map(_load_constant_dataset, keys)
+
 
 class ConditionsBase:
     calibration_types = {}  # For subclasses: {calibration: [parameter names]}
-- 
GitLab