Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
2 files
+ 318
141
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -2,14 +2,14 @@ import json
import re
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, date, time, timezone
from datetime import date, datetime, time, timezone
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
from warnings import warn
import h5py
import numpy as np
import pasha as psh
import requests
from oauth2_xfel_client import Oauth2ClientBackend
@@ -173,8 +173,11 @@ def _get_default_caldb_root():
@dataclass
class SingleConstantVersion:
"""A Calibration Constant Version for 1 detector module"""
class SingleConstant:
"""A calibration constant for one detector module
CalCat calls this a calibration constant version (CCV).
"""
id: int
version_name: str
@@ -189,7 +192,7 @@ class SingleConstantVersion:
physical_name: str # PDU name
@classmethod
def from_response(cls, ccv: dict) -> "SingleConstantVersion":
def from_response(cls, ccv: dict) -> "SingleConstant":
const = ccv["calibration_constant"]
return cls(
id=ccv["id"],
@@ -218,44 +221,53 @@ class SingleConstantVersion:
return self.dataset_obj(caldb_root)[:]
def prepare_selection(
module_details, module_nums=None, aggregator_names=None, qm_names=None
):
aggs = aggregator_names # Shorter name -> fewer multi-line statements
n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
if n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregator_names "
"& qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in module_details}
return [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in module_details}
return [by_qm[s]["karabo_da"] for s in qm_names]
elif aggs is not None:
miss = set(aggs) - {m["karabo_da"] for m in module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
return aggs
else:
raise TypeError("select_modules() requires an argument")
@dataclass
class ModulesConstantVersions:
"""A group of similar CCVs for several modules of one detector"""
class MultiModuleConstant:
"""A group of similar constants for several modules of one detector"""
constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00'
constants: Dict[str, SingleConstant] # Keys e.g. 'LPD00'
module_details: List[Dict]
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
) -> "ModulesConstantVersions":
n_specified = sum(
[module_nums is not None, aggregators is not None, qm_names is not None]
self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "MultiModuleConstant":
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
)
if n_specified < 1:
raise TypeError("select_modules() requires an argument")
elif n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregators & qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in self.module_details}
aggregators = [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in self.module_details}
aggregators = [by_qm[s]["karabo_da"] for s in qm_names]
elif aggregators is not None:
miss = set(aggregators) - {m["karabo_da"] for m in self.module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
return ModulesConstantVersions(d, self.module_details)
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
mods = [m for m in self.module_details if m["karabo_da"] in d]
return MultiModuleConstant(d, mods)
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
@property
def aggregators(self):
def aggregator_names(self):
return sorted(self.constants)
@property
@@ -274,39 +286,50 @@ class ModulesConstantVersions:
if m["karabo_da"] in self.constants
]
def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root)
def ndarray(self, caldb_root=None, *, parallel=0):
eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregators):
dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i])
if parallel > 0:
load_ctx = psh.ProcessContext(num_workers=parallel)
else:
load_ctx = psh.SerialContext()
arr = psh.alloc(shape, eg_dset.dtype, fill=0)
def _load_constant_dataset(wid, index, mod):
dset = self.constants[mod].dataset_obj(caldb_root)
dset.read_direct(arr[index])
load_ctx.map(_load_constant_dataset, self.aggregator_names)
return arr
def xarray(self, module_naming="da", caldb_root=None):
def xarray(self, module_naming="modnum", caldb_root=None, *, parallel=0):
import xarray
if module_naming == "da":
modules = self.aggregators
elif module_naming == "modno":
if module_naming == "aggregator":
modules = self.aggregator_names
elif module_naming == "modnum":
modules = self.module_nums
elif module_naming == "qm":
modules = self.qm_names
else:
raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'")
raise ValueError(
f"{module_naming=} (must be 'aggregator', 'modnum' or 'qm'"
)
ndarr = self.ndarray(caldb_root)
ndarr = self.ndarray(caldb_root, parallel=parallel)
# Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
coords = {"module": modules}
name = self.constants[self.aggregators[0]].constant_name
name = self.constants[self.aggregator_names[0]].constant_name
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
@lru_cache()
def detector(identifier, client=None):
def detector_by_name(identifier, client=None):
client = client or get_client()
res = client.get("detectors", {"identifier": identifier})
if not res:
@@ -316,6 +339,14 @@ def detector(identifier, client=None):
return res[0]
@lru_cache()
def detector_id_to_name(det_id: int, client=None):
"""Convert a numeric detector ID to a name like 'FXE_DET_LPD1M-1'"""
client = client or get_client()
res = client.get(f"detectors/{det_id}")
return res["identifier"] # "name" & "karabo_name" appear to be equivalent
@lru_cache()
def calibration_id(name, client=None):
"""ID for a calibration in CalCat."""
@@ -339,12 +370,10 @@ def calibration_name(cal_id, client=None):
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups, module_details):
self.constant_groups = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
def __init__(self, constant_groups, module_details, detector_name):
self.constant_groups = constant_groups
self.module_details = module_details
self.detector_name = detector_name
@staticmethod
def _format_cond(condition):
@@ -385,7 +414,7 @@ class CalibrationData(Mapping):
client = client or get_client()
detector_id = detector(detector_name)["id"]
detector_id = detector_by_name(detector_name, client)["id"]
pdus = client.get(
"physical_detector_units/get_all_by_detector",
{
@@ -398,7 +427,7 @@ class CalibrationData(Mapping):
if mod.get("module_number", -1) < 0:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
d = {}
constant_groups = {}
for params, cal_types in cal_types_by_params_used.items():
condition_dict = condition.make_dict(params)
@@ -424,11 +453,14 @@ class CalibrationData(Mapping):
aggr = ccv["physical_detector_unit"]["karabo_da"]
cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]]
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(
ccv
)
const_group = constant_groups.setdefault(cal_type, {})
const_group[aggr] = SingleConstant.from_response(ccv)
return cls(d, module_details)
mmcs = {
const_type: MultiModuleConstant(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mmcs, module_details, detector_name)
@classmethod
def from_report(
@@ -447,18 +479,43 @@ class CalibrationData(Mapping):
res = client.get("calibration_constant_versions", params)
d = {}
aggregators = set()
constant_groups = {}
pdus = {} # keyed by karabo_da (e.g. 'AGIPD00')
det_ids = set() # Should only have one detector
for ccv in res:
aggr = ccv["physical_detector_unit"]["karabo_da"]
aggregators.add(aggr)
pdu = ccv["physical_detector_unit"]
# We're only interested in the PDU mapping from the CCV start time
kda = pdu["karabo_da"] = pdu.pop("karabo_da_at_ccv_begin_at")
det_id = pdu["detector_id"] = pdu.pop("detector_id_at_ccv_begin_at")
pdu["virtual_device_name"] = pdu.pop("virtual_device_name_at_ccv_begin_at")
det_ids.add(det_id)
if kda in pdus:
if pdu["physical_name"] != pdus[kda]["physical_name"]:
raise Exception(
f"Mismatched PDU mapping from calibration report: {kda} is both"
f" {pdu['physical_name']} and {pdus[kda]['physical_name']}"
)
else:
pdus[kda] = pdu
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
const_group = constant_groups.setdefault(cal_type, {})
const_group[kda] = SingleConstant.from_response(ccv)
return cls(d, sorted(aggregators))
if len(det_ids) > 1:
raise Exception(f"Found multiple detector IDs in report: {det_ids}")
det_name = detector_id_to_name(det_ids.pop(), client)
def __getitem__(self, key) -> ModulesConstantVersions:
module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"])
mmcs = {
const_type: MultiModuleConstant(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mmcs, module_details, det_name)
def __getitem__(self, key) -> MultiModuleConstant:
return self.constant_groups[key]
def __iter__(self):
@@ -470,7 +527,7 @@ class CalibrationData(Mapping):
def __repr__(self):
return (
f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.module_details)} modules>"
f"constants for {len(self.module_details)} modules of {self.detector_name}>"
)
# These properties may include modules for which we have no constants -
@@ -494,75 +551,72 @@ class CalibrationData(Mapping):
def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types"""
mods = set(self.aggregators)
mods = set(self.aggregator_names)
for cal_type in calibrations:
mods.intersection_update(self[cal_type].constants)
return self.select_modules(mods)
def select_calibrations(self, calibrations, require_all=True):
if require_all:
missing = set(calibrations) - set(self.constant_groups)
if missing:
raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}")
d = {
cal_type: mcv.constants
for (cal_type, mcv) in self.constant_groups.items()
if cal_type in calibrations
}
# TODO: missing for some modules?
return type(self)(d, self.aggregators)
return self.select_modules(aggregator_names=mods)
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "CalibrationData":
return type(self)(
{
cal_type: mcv.select_modules(
module_nums=module_nums,
aggregators=aggregators,
qm_names=qm_names,
).constants
for (cal_type, mcv) in self.constant_groups.items()
},
sorted(aggregators),
# Validate the specified modules against those we know about.
# Each specific constant type may have only a subset of these modules.
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
)
mmcs = {
cal_type: mmc.select_modules(
aggregator_names=set(aggs).intersection(mmc.aggregator_names)
)
for (cal_type, mmc) in self.constant_groups.items()
}
aggs = set().union(*[c.aggregator_names for c in mmcs.values()])
module_details = [m for m in self.module_details if m["karabo_da"] in aggs]
return type(self)(mmcs, module_details, self.detector_name)
def select_calibrations(self, calibrations) -> "CalibrationData":
mmcs = {c: self.constant_groups[c] for c in calibrations}
return type(self)(mmcs, self.module_details, self.detector_name)
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
for cal_type, mcv in self.constant_groups.items():
d[cal_type] = mcv.constants.copy()
for other in others:
if cal_type in other:
d[cal_type].update(other[cal_type].constants)
aggregators = set(self.aggregators)
det_names = set(cd.detector_name for cd in (self,) + others)
if len(det_names) > 1:
raise Exception("Cannot merge calibration data for different "
"detectors: " + ", ".join(sorted(det_names)))
det_name = det_names.pop()
cal_types = set(self.constant_groups)
aggregators = set(self.aggregator_names)
pdus_d = {m["karabo_da"]: m for m in self.module_details}
for other in others:
aggregators.update(other.aggregators)
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
return res
cal_types.update(other.constant_groups)
aggregators.update(other.aggregator_names)
for md in other.module_details:
# Warn if constants don't refer to same modules
md_da = md["karabo_da"]
if md_da in pdus_d:
pdu_a = pdus_d[md_da]["physical_name"]
pdu_b = md["physical_name"]
if pdu_a != pdu_b:
warn(
f"Merging constants with different modules for "
f"{md_da}: {pdu_a!r} != {pdu_b!r}",
stacklevel=2,
)
else:
pdus_d[md_da] = md
module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"])
mmcs = {}
for cal_type in cal_types:
d = {}
for caldata in (self,) + others:
if cal_type in caldata:
d.update(caldata[cal_type].constants)
mmcs[cal_type] = MultiModuleConstant(d, module_details)
return type(self)(mmcs, module_details, det_name)
class ConditionsBase:
Loading