Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
2 files
+ 159
12
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -4,13 +4,15 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import h5py
import numpy as np
import pasha as psh
from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion
from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm
global_client = None
class ModuleNameError(KeyError):
def __init__(self, name):
@@ -20,6 +22,9 @@ class ModuleNameError(KeyError):
return f"No module named {self.name!r}"
global_client = None
def get_client():
global global_client
if global_client is None:
@@ -43,6 +48,24 @@ def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
)
_default_caldb_root = ...
def _get_default_caldb_root():
global _default_caldb_root
if _default_caldb_root is ...:
onc_path = Path("/common/cal/caldb_store")
maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store")
if onc_path.is_dir():
_default_caldb_root = onc_path
elif maxwell_path.is_dir():
_default_caldb_root = maxwell_path
else:
_default_caldb_root = None
return _default_caldb_root
@dataclass
class SingleConstantVersion:
"""A Calibration Constant Version for 1 detector module"""
@@ -76,6 +99,18 @@ class SingleConstantVersion:
physical_name=ccv["physical_detector_unit"]["physical_name"],
)
def dataset_obj(self, caldb_root=None) -> h5py.Dataset:
if caldb_root is not None:
caldb_root = Path(caldb_root)
else:
caldb_root = _get_default_caldb_root()
f = h5py.File(caldb_root / self.path, "r")
return f[self.dataset]["data"]
def ndarray(self, caldb_root=None):
return self.dataset_obj(caldb_root)[:]
@dataclass
class ModulesConstantVersions:
@@ -83,7 +118,7 @@ class ModulesConstantVersions:
constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00'
def select_modules(self, *aggregators) -> "ModulesConstantVersions":
def select_modules(self, aggregators) -> "ModulesConstantVersions":
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
return ModulesConstantVersions(d)
@@ -99,6 +134,36 @@ class ModulesConstantVersions:
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregators):
dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i])
return arr
def xarray(self, module_naming="da", caldb_root=None):
import xarray
if module_naming == "da":
modules = self.aggregators
elif module_naming == "modno":
modules = self.module_nums
elif module_naming == "qm":
modules = self.qm_names
else:
raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'")
ndarr = self.ndarray(caldb_root)
# Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
coords = {"module": modules}
name = self.constants[self.aggregators[0]].constant_name
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
@@ -133,7 +198,7 @@ class CalibrationData(Mapping):
api = CalCatApi(client or get_client())
detector_id = api.detector(detector_name)['id']
detector_id = api.detector(detector_name)["id"]
all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
if modules is None:
modules = sorted(all_modules)
@@ -173,7 +238,7 @@ class CalibrationData(Mapping):
res = cls(d, modules)
if modules:
res = res.select_modules(*modules)
res = res.select_modules(modules)
return res
@classmethod
@@ -219,6 +284,10 @@ class CalibrationData(Mapping):
def __len__(self):
return len(self.constant_groups)
def __repr__(self):
return (f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.aggregators)} modules>")
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
@@ -227,14 +296,14 @@ class CalibrationData(Mapping):
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
def require_calibrations(self, *calibrations):
def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types"""
mods = set(self.aggregators)
for cal_type in calibrations:
mods.intersection_update(self[cal_type].constants)
return self.select_modules(mods)
def select_calibrations(self, *calibrations, require_all=True):
def select_calibrations(self, calibrations, require_all=True):
if require_all:
missing = set(calibrations) - set(self.constant_groups)
if missing:
@@ -246,17 +315,18 @@ class CalibrationData(Mapping):
if cal_type in calibrations
}
# TODO: missing for some modules?
return type(self)(d)
return type(self)(d, self.aggregators)
def select_modules(self, *aggregators):
def select_modules(self, aggregators):
return type(self)(
{
cal_type: mcv.select_modules(*aggregators).constants
cal_type: mcv.select_modules(aggregators).constants
for (cal_type, mcv) in self.constant_groups.items()
}
},
sorted(aggregators),
)
def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData":
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
for cal_type, mcv in self.constant_groups.items():
d[cal_type] = mcv.constants.copy()
@@ -264,7 +334,33 @@ class CalibrationData(Mapping):
if cal_type in other:
d[cal_type].update(other[cal_type].constants)
return type(self)(d)
aggregators = set(self.aggregators)
for other in others:
aggregators.update(other.aggregators)
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
return res
class ConditionsBase:
@@ -361,3 +457,30 @@ class LPDConditions(ConditionsBase):
"FFMap": _illuminated_parameters,
"BadPixelsFF": _illuminated_parameters,
}
@dataclass
class DSSCConditions(ConditionsBase):
sensor_bias_voltage: float
memory_cells: int
pulse_id_checksum: Optional[float] = None
acquisition_rate: Optional[float] = None
target_gain: Optional[int] = None
encoded_gain: Optional[int] = None
pixels_x: int = 512
pixels_y: int = 128
_params = [
"Sensor Bias Voltage",
"Memory cells",
"Pixels X",
"Pixels Y",
"Pulse id checksum",
"Acquisition rate",
"Target gain",
"Encoded gain",
]
calibration_types = {
"Offset": _params,
"Noise": _params,
}
Loading