Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
2 files
+ 167
68
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -7,9 +7,9 @@ from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
from warnings import warn
import h5py
import numpy as np
import pasha as psh
import requests
from oauth2_xfel_client import Oauth2ClientBackend
@@ -225,7 +225,8 @@ def prepare_selection(
n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
if n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregator_names & qm_names"
"select_modules() accepts only one of module_nums, aggregator_names "
"& qm_names"
)
if module_nums is not None:
@@ -282,28 +283,39 @@ class ModulesConstantVersions:
if m["karabo_da"] in self.constants
]
def ndarray(self, caldb_root=None):
def ndarray(self, caldb_root=None, *, parallel=0):
eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregator_names):
dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i])
if parallel > 0:
load_ctx = psh.ProcessContext(num_workers=parallel)
else:
load_ctx = psh.SerialContext()
arr = psh.alloc(shape, eg_dset.dtype, fill=0)
def _load_constant_dataset(wid, index, mod):
dset = self.constants[mod].dataset_obj(caldb_root)
dset.read_direct(arr[index])
load_ctx.map(_load_constant_dataset, self.aggregator_names)
return arr
def xarray(self, module_naming="da", caldb_root=None):
def xarray(self, module_naming="modnum", caldb_root=None, *, parallel=0):
import xarray
if module_naming == "da":
if module_naming == "aggregator":
modules = self.aggregator_names
elif module_naming == "modno":
elif module_naming == "modnum":
modules = self.module_nums
elif module_naming == "qm":
modules = self.qm_names
else:
raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'")
raise ValueError(
f"{module_naming=} (must be 'aggregator', 'modnum' or 'qm'"
)
ndarr = self.ndarray(caldb_root)
ndarr = self.ndarray(caldb_root, parallel=parallel)
# Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
@@ -314,7 +326,7 @@ class ModulesConstantVersions:
@lru_cache()
def detector(identifier, client=None):
def detector_by_name(identifier, client=None):
client = client or get_client()
res = client.get("detectors", {"identifier": identifier})
if not res:
@@ -324,6 +336,14 @@ def detector(identifier, client=None):
return res[0]
@lru_cache()
def detector_id_to_name(det_id: int, client=None):
"""Convert a numeric detector ID to a name like 'FXE_DET_LPD1M-1'"""
client = client or get_client()
res = client.get(f"detectors/{det_id}")
return res["identifier"] # "name" & "karabo_name" appear to be equivalent
@lru_cache()
def calibration_id(name, client=None):
"""ID for a calibration in CalCat."""
@@ -347,9 +367,10 @@ def calibration_name(cal_id, client=None):
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups, module_details):
def __init__(self, constant_groups, module_details, detector_name):
self.constant_groups = constant_groups
self.module_details = module_details
self.detector_name = detector_name
@staticmethod
def _format_cond(condition):
@@ -390,7 +411,7 @@ class CalibrationData(Mapping):
client = client or get_client()
detector_id = detector(detector_name)["id"]
detector_id = detector_by_name(detector_name, client)["id"]
pdus = client.get(
"physical_detector_units/get_all_by_detector",
{
@@ -436,7 +457,7 @@ class CalibrationData(Mapping):
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
return cls(mcvs, module_details, detector_name)
@classmethod
def from_report(
@@ -456,21 +477,40 @@ class CalibrationData(Mapping):
res = client.get("calibration_constant_versions", params)
constant_groups = {}
pdus = []
pdus = {} # keyed by karabo_da (e.g. 'AGIPD00')
det_ids = set() # Should only have one detector
for ccv in res:
pdus.append(ccv["physical_detector_unit"])
pdu = ccv["physical_detector_unit"]
# We're only interested in the PDU mapping from the CCV start time
kda = pdu["karabo_da"] = pdu.pop("karabo_da_at_ccv_begin_at")
det_id = pdu["detector_id"] = pdu.pop("detector_id_at_ccv_begin_at")
pdu["virtual_device_name"] = pdu.pop("virtual_device_name_at_ccv_begin_at")
det_ids.add(det_id)
if kda in pdus:
if pdu["physical_name"] != pdus[kda]["physical_name"]:
raise Exception(
f"Mismatched PDU mapping from calibration report: {kda} is both"
f" {pdu['physical_name']} and {pdus[kda]['physical_name']}"
)
else:
pdus[kda] = pdu
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
aggr = ccv["physical_detector_unit"]["karabo_da"]
const_group = constant_groups.setdefault(cal_type, {})
const_group[aggr] = SingleConstantVersion.from_response(ccv)
const_group[kda] = SingleConstantVersion.from_response(ccv)
module_details = sorted(pdus, key=lambda d: d["karabo_da"])
if len(det_ids) > 1:
raise Exception(f"Found multiple detector IDs in report: {det_ids}")
det_name = detector_id_to_name(det_ids.pop(), client)
module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"])
mcvs = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
return cls(mcvs, module_details, det_name)
def __getitem__(self, key) -> ModulesConstantVersions:
return self.constant_groups[key]
@@ -513,20 +553,6 @@ class CalibrationData(Mapping):
mods.intersection_update(self[cal_type].constants)
return self.select_modules(aggregator_names=mods)
def select_calibrations(self, calibrations, require_all=True):
if require_all:
missing = set(calibrations) - set(self.constant_groups)
if missing:
raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}")
d = {
cal_type: mcv.constants
for (cal_type, mcv) in self.constant_groups.items()
if cal_type in calibrations
}
# TODO: missing for some modules?
return type(self)(d, self.aggregators)
def select_modules(
self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "CalibrationData":
@@ -545,41 +571,49 @@ class CalibrationData(Mapping):
module_details = [m for m in self.module_details if m["karabo_da"] in aggs]
return type(self)(mcvs, module_details)
def select_calibrations(self, calibrations) -> "CalibrationData":
mcvs = {c: self.constant_groups[c] for c in calibrations}
return type(self)(mcvs, self.module_details)
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
for cal_type, mcv in self.constant_groups.items():
d[cal_type] = mcv.constants.copy()
for other in others:
if cal_type in other:
d[cal_type].update(other[cal_type].constants)
det_names = set(cd.detector_name for cd in (self,) + others)
if len(det_names) > 1:
raise Exception("Cannot merge calibration data for different "
f"detectors: " + ", ".join(sorted(det_names)))
det_name = det_names.pop()
cal_types = set(self.constant_groups)
aggregators = set(self.aggregator_names)
pdus_d = {m["karabo_da"]: m for m in self.module_details}
for other in others:
cal_types.update(other.constant_groups)
aggregators.update(other.aggregator_names)
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
return res
for md in other.module_details:
# Warn if constants don't refer to same modules
md_da = md["karabo_da"]
if md_da in pdus_d:
pdu_a = pdus_d[md_da]["physical_name"]
pdu_b = md["physical_name"]
if pdu_a != pdu_b:
warn(
f"Merging constants with different modules for "
f"{md_da}: {pdu_a!r} != {pdu_b!r}",
stacklevel=2,
)
else:
pdus_d[md_da] = md
module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"])
mcvs = {}
for cal_type in cal_types:
d = {}
for caldata in (self,) + others:
if cal_type in caldata:
d.update(caldata[cal_type].constants)
mcvs[cal_type] = ModulesConstantVersions(d, module_details)
return type(self)(mcvs, module_details, det_name)
class ConditionsBase:
Loading