Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
3 files
+ 323
157
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -2,11 +2,12 @@ import json
import re
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, date, time, timezone
from datetime import date, datetime, time, timezone
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
from warnings import warn
import h5py
import numpy as np
@@ -61,6 +62,7 @@ class CalCatAPIClient:
return dt
def get_request(self, relative_url, params=None, headers=None, **kwargs):
"""Make a GET request, return the HTTP response object"""
# Base URL may include e.g. '/api/'. This is a prefix for all URLs;
# even if they look like an absolute path.
url = urljoin(self.base_api_url, relative_url.lstrip("/"))
@@ -85,26 +87,31 @@ class CalCatAPIClient:
if resp.content == b"":
return None
else:
res = json.loads(resp.content.decode("utf-8"))
pagination_header_fields = [
"X-Total-Pages",
"X-Count-Per-Page",
"X-Current-Page",
"X-Total-Count",
]
pagination_info = {
k[2:].lower().replace("-", "_"): int(resp.headers[k])
for k in pagination_header_fields
if k in resp.headers
}
if pagination_info:
res[".pages"] = pagination_info
return res
return json.loads(resp.content.decode("utf-8"))
def get(self, relative_url, params=None, **kwargs):
"""Make a GET request, return response content from JSON"""
resp = self.get_request(relative_url, params, **kwargs)
return self._parse_response(resp)
_pagination_headers = (
"X-Total-Pages",
"X-Count-Per-Page",
"X-Current-Page",
"X-Total-Count",
)
def get_paged(self, relative_url, params=None, **kwargs):
"""Make a GET request, return response content & pagination info"""
resp = self.get_request(relative_url, params, **kwargs)
content = self._parse_response(resp)
pagination_info = {
k[2:].lower().replace("-", "_"): int(resp.headers[k])
for k in self._pagination_headers
if k in resp.headers
}
return content, pagination_info
global_client = None
@@ -212,6 +219,31 @@ class SingleConstantVersion:
return self.dataset_obj(caldb_root)[:]
def prepare_selection(
module_details, module_nums=None, aggregator_names=None, qm_names=None
):
aggs = aggregator_names # Shorter name -> fewer multi-line statements
n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
if n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregator_names & qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in module_details}
return [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in module_details}
return [by_qm[s]["karabo_da"] for s in qm_names]
elif aggs is not None:
miss = set(aggs) - {m["karabo_da"] for m in module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
return aggs
else:
raise TypeError("select_modules() requires an argument")
@dataclass
class ModulesConstantVersions:
"""A group of similar CCVs for several modules of one detector"""
@@ -220,75 +252,75 @@ class ModulesConstantVersions:
module_details: List[Dict]
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "ModulesConstantVersions":
n_specified = sum(
[module_nums is not None, aggregators is not None, qm_names is not None]
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
)
if n_specified < 1:
raise TypeError("select_modules() requires an argument")
elif n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregators & qm_names"
)
if module_nums is not None:
by_mod_no = {m['module_number']: m for m in self.module_details}
aggregators = [by_mod_no[n]['karabo_da'] for n in module_nums]
elif qm_names is not None:
by_qm = {m['virtual_device_name']: m for m in self.module_details}
aggregators = [by_qm[s]['karabo_da'] for s in qm_names]
elif aggregators is not None:
miss = set(aggregators) - {m['karabo_da'] for m in self.module_details}
if miss:
raise KeyError("Aggregators not found: " + ', '.join(sorted(miss)))
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
return ModulesConstantVersions(d, self.module_details)
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
mods = [m for m in self.module_details if m["karabo_da"] in d]
return ModulesConstantVersions(d, mods)
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
@property
def aggregators(self):
def aggregator_names(self):
return sorted(self.constants)
@property
def module_nums(self):
return [m['module_number'] for m in self.module_details
if m['karabo_da'] in self.constants]
return [
m["module_number"]
for m in self.module_details
if m["karabo_da"] in self.constants
]
@property
def qm_names(self):
return [m['virtual_device_name'] for m in self.module_details
if m['karabo_da'] in self.constants]
def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root)
return [
m["virtual_device_name"]
for m in self.module_details
if m["karabo_da"] in self.constants
]
def ndarray(self, caldb_root=None, *, parallel=0):
eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregators):
dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i])
if parallel > 0:
load_ctx = psh.ProcessContext(num_workers=parallel)
else:
load_ctx = psh.SerialContext()
arr = psh.alloc(shape, eg_dset.dtype, fill=0)
def _load_constant_dataset(wid, index, mod):
dset = self.constants[mod].dataset_obj(caldb_root)
dset.read_direct(arr[index])
load_ctx.map(_load_constant_dataset, self.aggregator_names)
return arr
def xarray(self, module_naming="da", caldb_root=None):
def xarray(self, module_naming="modnum", caldb_root=None, *, parallel=0):
import xarray
if module_naming == "da":
modules = self.aggregators
elif module_naming == "modno":
if module_naming == "aggregator":
modules = self.aggregator_names
elif module_naming == "modnum":
modules = self.module_nums
elif module_naming == "qm":
modules = self.qm_names
else:
raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'")
raise ValueError(
f"{module_naming=} (must be 'aggregator', 'modnum' or 'qm'"
)
ndarr = self.ndarray(caldb_root)
ndarr = self.ndarray(caldb_root, parallel=parallel)
# Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
coords = {"module": modules}
name = self.constants[self.aggregators[0]].constant_name
name = self.constants[self.aggregator_names[0]].constant_name
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
@@ -328,10 +360,7 @@ class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups, module_details):
self.constant_groups = {
const_type: ModulesConstantVersions(d)
for const_type, d in constant_groups.items()
}
self.constant_groups = constant_groups
self.module_details = module_details
@staticmethod
@@ -386,7 +415,7 @@ class CalibrationData(Mapping):
if mod.get("module_number", -1) < 0:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
d = {}
constant_groups = {}
for params, cal_types in cal_types_by_params_used.items():
condition_dict = condition.make_dict(params)
@@ -412,11 +441,14 @@ class CalibrationData(Mapping):
aggr = ccv["physical_detector_unit"]["karabo_da"]
cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]]
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(
ccv
)
const_group = constant_groups.setdefault(cal_type, {})
const_group[aggr] = SingleConstantVersion.from_response(ccv)
return cls(d, module_details)
mcvs = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
@classmethod
def from_report(
@@ -435,16 +467,40 @@ class CalibrationData(Mapping):
res = client.get("calibration_constant_versions", params)
d = {}
aggregators = set()
constant_groups = {}
pdus = {} # keyed by karabo_da (e.g. 'AGIPD00')
det_ids = set() # Should only have one detector
for ccv in res:
aggr = ccv["physical_detector_unit"]["karabo_da"]
aggregators.add(aggr)
pdu = ccv["physical_detector_unit"]
# We're only interested in the PDU mapping from the CCV start time
kda = pdu["karabo_da"] = pdu.pop("karabo_da_at_ccv_begin_at")
det_id = pdu["detector_id"] = pdu.pop("detector_id_at_ccv_begin_at")
pdu["virtual_device_name"] = pdu.pop("virtual_device_name_at_ccv_begin_at")
det_ids.add(det_id)
if kda in pdus:
if pdu["physical_name"] != pdus[kda]["physical_name"]:
raise Exception(
f"Mismatched PDU mapping from calibration report: {kda} is both"
f" {pdu['physical_name']} and {pdus[kda]['physical_name']}"
)
else:
pdus[kda] = pdu
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
const_group = constant_groups.setdefault(cal_type, {})
const_group[kda] = SingleConstantVersion.from_response(ccv)
return cls(d, sorted(aggregators))
if len(det_ids) > 1:
raise Exception(f"Found multiple detector IDs in report: {det_ids}")
module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"])
mcvs = {
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
return cls(mcvs, module_details)
def __getitem__(self, key) -> ModulesConstantVersions:
return self.constant_groups[key]
@@ -456,99 +512,92 @@ class CalibrationData(Mapping):
return len(self.constant_groups)
def __repr__(self):
return (f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.module_details)} modules>")
return (
f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.module_details)} modules>"
)
# These properties may include modules for which we have no constants -
# when created with .from_condition(), they represent all modules present in
# the detector (at the specified time).
@property
def module_nums(self):
return [m['module_number'] for m in self.module_details]
return [m["module_number"] for m in self.module_details]
@property
def aggregator_names(self):
return [m['karabo_da'] for m in self.module_details]
return [m["karabo_da"] for m in self.module_details]
@property
def qm_names(self):
return [m['virtual_device_name'] for m in self.module_details]
return [m["virtual_device_name"] for m in self.module_details]
@property
def pdu_names(self):
return [m['physical_name'] for m in self.module_details]
return [m["physical_name"] for m in self.module_details]
def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types"""
mods = set(self.aggregators)
mods = set(self.aggregator_names)
for cal_type in calibrations:
mods.intersection_update(self[cal_type].constants)
return self.select_modules(mods)
def select_calibrations(self, calibrations, require_all=True):
if require_all:
missing = set(calibrations) - set(self.constant_groups)
if missing:
raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}")
d = {
cal_type: mcv.constants
for (cal_type, mcv) in self.constant_groups.items()
if cal_type in calibrations
}
# TODO: missing for some modules?
return type(self)(d, self.aggregators)
return self.select_modules(aggregator_names=mods)
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
self, module_nums=None, *, aggregator_names=None, qm_names=None
) -> "CalibrationData":
return type(self)(
{
cal_type: mcv.select_modules(
module_nums=module_nums,
aggregators=aggregators,
qm_names=qm_names,
).constants
for (cal_type, mcv) in self.constant_groups.items()
},
sorted(aggregators),
# Validate the specified modules against those we know about.
# Each specific constant type may have only a subset of these modules.
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
)
mcvs = {
cal_type: mcv.select_modules(
aggregator_names=set(aggs).intersection(mcv.aggregator_names)
)
for (cal_type, mcv) in self.constant_groups.items()
}
aggs = set().union(*[c.aggregator_names for c in mcvs.values()])
module_details = [m for m in self.module_details if m["karabo_da"] in aggs]
return type(self)(mcvs, module_details)
def select_calibrations(self, calibrations) -> "CalibrationData":
mcvs = {c: self.constant_groups[c] for c in calibrations}
return type(self)(mcvs, self.module_details)
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
for cal_type, mcv in self.constant_groups.items():
d[cal_type] = mcv.constants.copy()
for other in others:
if cal_type in other:
d[cal_type].update(other[cal_type].constants)
aggregators = set(self.aggregators)
cal_types = set(self.constant_groups)
aggregators = set(self.aggregator_names)
pdus_d = {m["karabo_da"]: m for m in self.module_details}
for other in others:
aggregators.update(other.aggregators)
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
return res
cal_types.update(other.constant_groups)
aggregators.update(other.aggregator_names)
for md in other.module_details:
# Warn if constants don't refer to same modules
md_da = md["karabo_da"]
if md_da in pdus_d:
pdu_a = pdus_d[md_da]["physical_name"]
pdu_b = md["physical_name"]
if pdu_a != pdu_b:
warn(
f"Merging constants with different modules for "
f"{md_da}: {pdu_a!r} != {pdu_b!r}",
stacklevel=2,
)
else:
pdus_d[md_da] = md
module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"])
mcvs = {}
for cal_type in cal_types:
d = {}
for caldata in (self,) + others:
if cal_type in caldata:
d.update(caldata[cal_type].constants)
mcvs[cal_type] = ModulesConstantVersions(d, module_details)
return type(self)(mcvs, module_details)
class ConditionsBase:
Loading