Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
1 file
+ 377
68
Compare changes
  • Side-by-side
  • Inline
import json
import re
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, date, time, timezone
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion
import h5py
import numpy as np
import pasha as psh
import requests
from oauth2_xfel_client import Oauth2ClientBackend
from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm
global_client = None
class ModuleNameError(KeyError):
def __init__(self, name):
@@ -20,6 +23,92 @@ class ModuleNameError(KeyError):
return f"No module named {self.name!r}"
class APIError(requests.HTTPError):
"""Used when the response includes error details as JSON"""
class CalCatAPIClient:
def __init__(self, base_api_url, oauth_client=None, user_email=""):
if oauth_client is not None:
self.oauth_client = oauth_client
self.session = self.oauth_client.session
else:
# Oauth disabled - used with base_api_url pointing to an
# xfel-oauth-proxy instance
self.oauth_client = None
self.session = requests.Session()
self.user_email = user_email
# Ensure the base URL has a trailing slash
self.base_api_url = base_api_url.rstrip("/") + "/"
def default_headers(self):
return {
"content-type": "application/json",
"Accept": "application/json; version=2",
"X-User-Email": self.user_email,
}
@classmethod
def format_time(cls, dt):
"""Parse different ways to specify time to CalCat."""
if isinstance(dt, datetime):
return dt.astimezone(timezone.utc).isoformat()
elif isinstance(dt, date):
return cls.format_time(datetime.combine(dt, time()))
return dt
def get_request(self, relative_url, params=None, headers=None, **kwargs):
# Base URL may include e.g. '/api/'. This is a prefix for all URLs;
# even if they look like an absolute path.
url = urljoin(self.base_api_url, relative_url.lstrip("/"))
_headers = self.default_headers()
if headers:
_headers.update(headers)
return self.session.get(url, params=params, headers=_headers, **kwargs)
@staticmethod
def _parse_response(resp: requests.Response):
if resp.status_code >= 400:
try:
d = json.loads(resp.content.decode("utf-8"))
except Exception:
resp.raise_for_status()
else:
raise APIError(
f"Error {resp.status_code} from API: "
f"{d.get('info', 'missing details')}"
)
if resp.content == b"":
return None
else:
res = json.loads(resp.content.decode("utf-8"))
pagination_header_fields = [
"X-Total-Pages",
"X-Count-Per-Page",
"X-Current-Page",
"X-Total-Count",
]
pagination_info = {
k[2:].lower().replace("-", "_"): int(resp.headers[k])
for k in pagination_header_fields
if k in resp.headers
}
if pagination_info:
res[".pages"] = pagination_info
return res
def get(self, relative_url, params=None, **kwargs):
resp = self.get_request(relative_url, params, **kwargs)
return self._parse_response(resp)
global_client = None
def get_client():
global global_client
if global_client is None:
@@ -27,22 +116,56 @@ def get_client():
return global_client
def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
def setup_client(
base_url,
client_id,
client_secret,
user_email,
scope="",
session_token=None,
oauth_retries=3,
oauth_timeout=12,
ssl_verify=True,
):
global global_client
global_client = CalibrationClient(
use_oauth2=(client_id is not None),
client_id=client_id,
client_secret=client_secret,
if client_id is not None:
oauth_client = Oauth2ClientBackend(
client_id=client_id,
client_secret=client_secret,
scope=scope,
token_url=f"{base_url}/oauth/token",
session_token=session_token,
max_retries=oauth_retries,
timeout=oauth_timeout,
ssl_verify=ssl_verify,
)
else:
oauth_client = None
global_client = CalCatAPIClient(
f"{base_url}/api/",
oauth_client=oauth_client,
user_email=user_email,
base_api_url=f"{base_url}/api/",
token_url=f"{base_url}/oauth/token",
refresh_url=f"{base_url}/oauth/token",
auth_url=f"{base_url}/oauth/authorize",
scope="",
**kwargs,
)
_default_caldb_root = ...
def _get_default_caldb_root():
global _default_caldb_root
if _default_caldb_root is ...:
onc_path = Path("/common/cal/caldb_store")
maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store")
if onc_path.is_dir():
_default_caldb_root = onc_path
elif maxwell_path.is_dir():
_default_caldb_root = maxwell_path
else:
_default_caldb_root = None
return _default_caldb_root
@dataclass
class SingleConstantVersion:
"""A Calibration Constant Version for 1 detector module"""
@@ -76,46 +199,163 @@ class SingleConstantVersion:
physical_name=ccv["physical_detector_unit"]["physical_name"],
)
def dataset_obj(self, caldb_root=None) -> h5py.Dataset:
if caldb_root is not None:
caldb_root = Path(caldb_root)
else:
caldb_root = _get_default_caldb_root()
f = h5py.File(caldb_root / self.path, "r")
return f[self.dataset]["data"]
def ndarray(self, caldb_root=None):
return self.dataset_obj(caldb_root)[:]
@dataclass
class ModulesConstantVersions:
"""A group of similar CCVs for several modules of one detector"""
constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00'
module_details: List[Dict]
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
) -> "ModulesConstantVersions":
n_specified = sum(
[module_nums is not None, aggregators is not None, qm_names is not None]
)
if n_specified < 1:
raise TypeError("select_modules() requires an argument")
elif n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregators & qm_names"
)
if module_nums is not None:
by_mod_no = {m['module_number']: m for m in self.module_details}
aggregators = [by_mod_no[n]['karabo_da'] for n in module_nums]
elif qm_names is not None:
by_qm = {m['virtual_device_name']: m for m in self.module_details}
aggregators = [by_qm[s]['karabo_da'] for s in qm_names]
elif aggregators is not None:
miss = set(aggregators) - {m['karabo_da'] for m in self.module_details}
if miss:
raise KeyError("Aggregators not found: " + ', '.join(sorted(miss)))
def select_modules(self, aggregators) -> "ModulesConstantVersions":
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
return ModulesConstantVersions(d)
return ModulesConstantVersions(d, self.module_details)
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
@property
def aggregators(self):
return sorted(self.constants)
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
return [m['module_number'] for m in self.module_details
if m['karabo_da'] in self.constants]
@property
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
return [m['virtual_device_name'] for m in self.module_details
if m['karabo_da'] in self.constants]
def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root)
shape = (len(self.constants),) + eg_dset.shape
arr = np.zeros(shape, eg_dset.dtype)
for i, agg in enumerate(self.aggregators):
dset = self.constants[agg].dataset_obj(caldb_root)
dset.read_direct(arr[i])
return arr
def xarray(self, module_naming="da", caldb_root=None):
import xarray
if module_naming == "da":
modules = self.aggregators
elif module_naming == "modno":
modules = self.module_nums
elif module_naming == "qm":
modules = self.qm_names
else:
raise ValueError(f"{module_naming=} (must be 'da', 'modno' or 'qm'")
ndarr = self.ndarray(caldb_root)
# Dimension labels
dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)]
coords = {"module": modules}
name = self.constants[self.aggregators[0]].constant_name
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
@lru_cache()
def detector(identifier, client=None):
client = client or get_client()
res = client.get("detectors", {"identifier": identifier})
if not res:
raise KeyError(f"No detector with identifier {identifier}")
elif len(res) > 1:
raise ValueError(f"Multiple detectors found with identifier {identifier}")
return res[0]
@lru_cache()
def calibration_id(name, client=None):
"""ID for a calibration in CalCat."""
client = client or get_client()
res = client.get("calibrations", {"name": name})
if not res:
raise KeyError(f"No calibration with name {name}")
elif len(res) > 1:
raise ValueError(f"Multiple calibrations found with name {name}")
return res[0]["id"]
@lru_cache()
def calibration_name(cal_id, client=None):
"""Name for a calibration in CalCat."""
client = client or get_client()
res = client.get(f"calibrations/{cal_id}")
return res["name"]
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups, aggregators):
def __init__(self, constant_groups, module_details):
self.constant_groups = {
const_type: ModulesConstantVersions(d)
for const_type, d in constant_groups.items()
}
self.aggregators = aggregators
self.module_details = module_details
@staticmethod
def _format_cond(condition):
"""Encode operating condition to CalCat API format.
Args:
condition (dict): Mapping of parameter DB name to value
Returns:
(dict) Operating condition for use in CalCat API.
"""
return {
"parameters_conditions_attributes": [
{"parameter_name": k, "value": str(v)} for k, v in condition.items()
]
}
@classmethod
def from_condition(
cls,
condition: "ConditionsBase",
detector_name,
modules: Optional[Sequence[str]] = None,
calibrations=None,
client=None,
event_at=None,
@@ -131,17 +371,20 @@ class CalibrationData(Mapping):
if cal_type in calibrations:
cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type)
api = CalCatApi(client or get_client())
client = client or get_client()
detector_id = api.detector(detector_name)['id']
all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
if modules is None:
modules = sorted(all_modules)
else:
modules = sorted(modules)
for m in modules:
if m not in all_modules:
raise ModuleNameError(m)
detector_id = detector(detector_name)["id"]
pdus = client.get(
"physical_detector_units/get_all_by_detector",
{
"detector_id": detector_id,
"pdu_snapshot_at": client.format_time(pdu_snapshot_at),
},
)
module_details = sorted(pdus, key=lambda d: d["karabo_da"])
for mod in module_details:
if mod.get("module_number", -1) < 0:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
d = {}
@@ -149,18 +392,20 @@ class CalibrationData(Mapping):
condition_dict = condition.make_dict(params)
cal_id_map = {
api.calibration_id(calibration): calibration
for calibration in cal_types
calibration_id(calibration): calibration for calibration in cal_types
}
calibration_ids = list(cal_id_map.keys())
query_res = api._closest_ccv_by_time_by_condition(
detector_name,
calibration_ids,
condition_dict,
modules[0] if len(modules) == 1 else "",
event_at,
pdu_snapshot_at or event_at,
query_res = client.get(
"calibration_constant_versions/get_by_detector_conditions",
{
"detector_identifier": detector_name,
"calibration_id": str(calibration_ids),
"karabo_da": "",
"event_at": client.format_time(event_at),
"pdu_snapshot_at": client.format_time(pdu_snapshot_at),
},
data=json.dumps(cls._format_cond(condition_dict)),
)
for ccv in query_res:
@@ -171,10 +416,7 @@ class CalibrationData(Mapping):
ccv
)
res = cls(d, modules)
if modules:
res = res.select_modules(modules)
return res
return cls(d, module_details)
@classmethod
def from_report(
@@ -183,29 +425,23 @@ class CalibrationData(Mapping):
client=None,
):
client = client or get_client()
api = CalCatApi(client)
# Use max page size, hopefully always enough for CCVs from 1 report
params = {"page_size": 500}
if isinstance(report_id_or_path, int):
resp = CalibrationConstantVersion.get_by_report_id(
client, report_id_or_path
)
params["report_id"] = report_id_or_path # Numeric ID
else:
resp = CalibrationConstantVersion.get_by_report_path(
client, report_id_or_path
)
params["report.file_path"] = str(report_id_or_path)
if not resp["success"]:
raise CalCatError(resp)
res = client.get("calibration_constant_versions", params)
d = {}
aggregators = set()
for ccv in resp["data"]:
for ccv in res:
aggr = ccv["physical_detector_unit"]["karabo_da"]
aggregators.add(aggr)
cal_type = api.calibration_name(
ccv["calibration_constant"]["calibration_id"]
)
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
return cls(d, sorted(aggregators))
@@ -219,13 +455,28 @@ class CalibrationData(Mapping):
def __len__(self):
return len(self.constant_groups)
def __repr__(self):
return (f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.module_details)} modules>")
# These properties may include modules for which we have no constants -
# when created with .from_condition(), they represent all modules present in
# the detector (at the specified time).
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
return [m['module_number'] for m in self.module_details]
@property
def aggregator_names(self):
return [m['karabo_da'] for m in self.module_details]
@property
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
return [m['virtual_device_name'] for m in self.module_details]
@property
def pdu_names(self):
return [m['physical_name'] for m in self.module_details]
def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types"""
@@ -248,11 +499,20 @@ class CalibrationData(Mapping):
# TODO: missing for some modules?
return type(self)(d, self.aggregators)
def select_modules(self, aggregators):
return type(self)({
cal_type: mcv.select_modules(aggregators).constants
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
) -> "CalibrationData":
return type(self)(
{
cal_type: mcv.select_modules(
module_nums=module_nums,
aggregators=aggregators,
qm_names=qm_names,
).constants
for (cal_type, mcv) in self.constant_groups.items()
}, sorted(aggregators))
},
sorted(aggregators),
)
def merge(self, *others: "CalibrationData") -> "CalibrationData":
d = {}
@@ -268,6 +528,28 @@ class CalibrationData(Mapping):
return type(self)(d, sorted(aggregators))
def load_all(self, caldb_root=None):
res = {}
const_load_mp = psh.ProcessContext(num_workers=24)
keys = []
for cal_type, mcv in self.constant_groups.items():
res[cal_type] = {}
for module in mcv.aggregators:
dset = mcv.constants[module].dataset_obj(caldb_root)
res[cal_type][module] = const_load_mp.alloc(
shape=dset.shape, dtype=dset.dtype
)
keys.append((cal_type, module))
def _load_constant_dataset(wid, index, key):
cal_type, mod = key
dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
dset.read_direct(res[cal_type][mod])
const_load_mp.map(_load_constant_dataset, keys)
return res
class ConditionsBase:
calibration_types = {} # For subclasses: {calibration: [parameter names]}
@@ -363,3 +645,30 @@ class LPDConditions(ConditionsBase):
"FFMap": _illuminated_parameters,
"BadPixelsFF": _illuminated_parameters,
}
@dataclass
class DSSCConditions(ConditionsBase):
sensor_bias_voltage: float
memory_cells: int
pulse_id_checksum: Optional[float] = None
acquisition_rate: Optional[float] = None
target_gain: Optional[int] = None
encoded_gain: Optional[int] = None
pixels_x: int = 512
pixels_y: int = 128
_params = [
"Sensor Bias Voltage",
"Memory cells",
"Pixels X",
"Pixels Y",
"Pulse id checksum",
"Acquisition rate",
"Target gain",
"Encoded gain",
]
calibration_types = {
"Offset": _params,
"Noise": _params,
}
Loading