Skip to content
Snippets Groups Projects

Revised CalCat API

Merged Thomas Kluyver requested to merge calcat-api-2 into master
Compare and Show latest version
3 files
+ 314
88
Compare changes
  • Side-by-side
  • Inline
Files
3
import json
import re
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from datetime import date, datetime, time, timezone
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
import h5py
import numpy as np
import pasha as psh
from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion
from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm
import requests
from oauth2_xfel_client import Oauth2ClientBackend
class ModuleNameError(KeyError):
@@ -22,6 +23,95 @@ class ModuleNameError(KeyError):
return f"No module named {self.name!r}"
class APIError(requests.HTTPError):
"""Used when the response includes error details as JSON"""
class CalCatAPIClient:
def __init__(self, base_api_url, oauth_client=None, user_email=""):
if oauth_client is not None:
self.oauth_client = oauth_client
self.session = self.oauth_client.session
else:
# Oauth disabled - used with base_api_url pointing to an
# xfel-oauth-proxy instance
self.oauth_client = None
self.session = requests.Session()
self.user_email = user_email
# Ensure the base URL has a trailing slash
self.base_api_url = base_api_url.rstrip("/") + "/"
def default_headers(self):
return {
"content-type": "application/json",
"Accept": "application/json; version=2",
"X-User-Email": self.user_email,
}
@classmethod
def format_time(cls, dt):
"""Parse different ways to specify time to CalCat."""
if isinstance(dt, datetime):
return dt.astimezone(timezone.utc).isoformat()
elif isinstance(dt, date):
return cls.format_time(datetime.combine(dt, time()))
return dt
def get_request(self, relative_url, params=None, headers=None, **kwargs):
"""Make a GET request, return the HTTP response object"""
# Base URL may include e.g. '/api/'. This is a prefix for all URLs;
# even if they look like an absolute path.
url = urljoin(self.base_api_url, relative_url.lstrip("/"))
_headers = self.default_headers()
if headers:
_headers.update(headers)
return self.session.get(url, params=params, headers=_headers, **kwargs)
@staticmethod
def _parse_response(resp: requests.Response):
if resp.status_code >= 400:
try:
d = json.loads(resp.content.decode("utf-8"))
except Exception:
resp.raise_for_status()
else:
raise APIError(
f"Error {resp.status_code} from API: "
f"{d.get('info', 'missing details')}"
)
if resp.content == b"":
return None
else:
return json.loads(resp.content.decode("utf-8"))
def get(self, relative_url, params=None, **kwargs):
"""Make a GET request, return response content from JSON"""
resp = self.get_request(relative_url, params, **kwargs)
return self._parse_response(resp)
_pagination_headers = (
"X-Total-Pages",
"X-Count-Per-Page",
"X-Current-Page",
"X-Total-Count",
)
def get_paged(self, relative_url, params=None, **kwargs):
"""Make a GET request, return response content & pagination info"""
resp = self.get_request(relative_url, params, **kwargs)
content = self._parse_response(resp)
pagination_info = {
k[2:].lower().replace("-", "_"): int(resp.headers[k])
for k in self._pagination_headers
if k in resp.headers
}
return content, pagination_info
global_client = None
@@ -32,19 +122,35 @@ def get_client():
return global_client
def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
def setup_client(
base_url,
client_id,
client_secret,
user_email,
scope="",
session_token=None,
oauth_retries=3,
oauth_timeout=12,
ssl_verify=True,
):
global global_client
global_client = CalibrationClient(
use_oauth2=(client_id is not None),
client_id=client_id,
client_secret=client_secret,
if client_id is not None:
oauth_client = Oauth2ClientBackend(
client_id=client_id,
client_secret=client_secret,
scope=scope,
token_url=f"{base_url}/oauth/token",
session_token=session_token,
max_retries=oauth_retries,
timeout=oauth_timeout,
ssl_verify=ssl_verify,
)
else:
oauth_client = None
global_client = CalCatAPIClient(
f"{base_url}/api/",
oauth_client=oauth_client,
user_email=user_email,
base_api_url=f"{base_url}/api/",
token_url=f"{base_url}/oauth/token",
refresh_url=f"{base_url}/oauth/token",
auth_url=f"{base_url}/oauth/authorize",
scope="",
**kwargs,
)
@@ -117,22 +223,56 @@ class ModulesConstantVersions:
"""A group of similar CCVs for several modules of one detector"""
constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00'
module_details: List[Dict]
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
) -> "ModulesConstantVersions":
n_specified = sum(
[module_nums is not None, aggregators is not None, qm_names is not None]
)
if n_specified < 1:
raise TypeError("select_modules() requires an argument")
elif n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregators & qm_names"
)
if module_nums is not None:
by_mod_no = {m["module_number"]: m for m in self.module_details}
aggregators = [by_mod_no[n]["karabo_da"] for n in module_nums]
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in self.module_details}
aggregators = [by_qm[s]["karabo_da"] for s in qm_names]
elif aggregators is not None:
miss = set(aggregators) - {m["karabo_da"] for m in self.module_details}
if miss:
raise KeyError("Aggregators not found: " + ", ".join(sorted(miss)))
def select_modules(self, aggregators) -> "ModulesConstantVersions":
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
return ModulesConstantVersions(d)
return ModulesConstantVersions(d, self.module_details)
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
@property
def aggregators(self):
return sorted(self.constants)
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
return [
m["module_number"]
for m in self.module_details
if m["karabo_da"] in self.constants
]
@property
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
return [
m["virtual_device_name"]
for m in self.module_details
if m["karabo_da"] in self.constants
]
def ndarray(self, caldb_root=None):
eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root)
@@ -165,22 +305,69 @@ class ModulesConstantVersions:
return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name)
@lru_cache()
def detector(identifier, client=None):
client = client or get_client()
res = client.get("detectors", {"identifier": identifier})
if not res:
raise KeyError(f"No detector with identifier {identifier}")
elif len(res) > 1:
raise ValueError(f"Multiple detectors found with identifier {identifier}")
return res[0]
@lru_cache()
def calibration_id(name, client=None):
"""ID for a calibration in CalCat."""
client = client or get_client()
res = client.get("calibrations", {"name": name})
if not res:
raise KeyError(f"No calibration with name {name}")
elif len(res) > 1:
raise ValueError(f"Multiple calibrations found with name {name}")
return res[0]["id"]
@lru_cache()
def calibration_name(cal_id, client=None):
"""Name for a calibration in CalCat."""
client = client or get_client()
res = client.get(f"calibrations/{cal_id}")
return res["name"]
class CalibrationData(Mapping):
"""Collected constants for a given detector"""
def __init__(self, constant_groups, aggregators):
def __init__(self, constant_groups, module_details):
self.constant_groups = {
const_type: ModulesConstantVersions(d)
const_type: ModulesConstantVersions(d, module_details)
for const_type, d in constant_groups.items()
}
self.aggregators = aggregators
self.module_details = module_details
@staticmethod
def _format_cond(condition):
"""Encode operating condition to CalCat API format.
Args:
condition (dict): Mapping of parameter DB name to value
Returns:
(dict) Operating condition for use in CalCat API.
"""
return {
"parameters_conditions_attributes": [
{"parameter_name": k, "value": str(v)} for k, v in condition.items()
]
}
@classmethod
def from_condition(
cls,
condition: "ConditionsBase",
detector_name,
modules: Optional[Sequence[str]] = None,
calibrations=None,
client=None,
event_at=None,
@@ -196,17 +383,20 @@ class CalibrationData(Mapping):
if cal_type in calibrations:
cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type)
api = CalCatApi(client or get_client())
client = client or get_client()
detector_id = api.detector(detector_name)["id"]
all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
if modules is None:
modules = sorted(all_modules)
else:
modules = sorted(modules)
for m in modules:
if m not in all_modules:
raise ModuleNameError(m)
detector_id = detector(detector_name)["id"]
pdus = client.get(
"physical_detector_units/get_all_by_detector",
{
"detector_id": detector_id,
"pdu_snapshot_at": client.format_time(pdu_snapshot_at),
},
)
module_details = sorted(pdus, key=lambda d: d["karabo_da"])
for mod in module_details:
if mod.get("module_number", -1) < 0:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
d = {}
@@ -214,18 +404,20 @@ class CalibrationData(Mapping):
condition_dict = condition.make_dict(params)
cal_id_map = {
api.calibration_id(calibration): calibration
for calibration in cal_types
calibration_id(calibration): calibration for calibration in cal_types
}
calibration_ids = list(cal_id_map.keys())
query_res = api._closest_ccv_by_time_by_condition(
detector_name,
calibration_ids,
condition_dict,
modules[0] if len(modules) == 1 else "",
event_at,
pdu_snapshot_at or event_at,
query_res = client.get(
"calibration_constant_versions/get_by_detector_conditions",
{
"detector_identifier": detector_name,
"calibration_id": str(calibration_ids),
"karabo_da": "",
"event_at": client.format_time(event_at),
"pdu_snapshot_at": client.format_time(pdu_snapshot_at),
},
data=json.dumps(cls._format_cond(condition_dict)),
)
for ccv in query_res:
@@ -236,10 +428,7 @@ class CalibrationData(Mapping):
ccv
)
res = cls(d, modules)
if modules:
res = res.select_modules(modules)
return res
return cls(d, module_details)
@classmethod
def from_report(
@@ -248,32 +437,26 @@ class CalibrationData(Mapping):
client=None,
):
client = client or get_client()
api = CalCatApi(client)
# Use max page size, hopefully always enough for CCVs from 1 report
params = {"page_size": 500}
if isinstance(report_id_or_path, int):
resp = CalibrationConstantVersion.get_by_report_id(
client, report_id_or_path
)
params["report_id"] = report_id_or_path # Numeric ID
else:
resp = CalibrationConstantVersion.get_by_report_path(
client, report_id_or_path
)
params["report.file_path"] = str(report_id_or_path)
if not resp["success"]:
raise CalCatError(resp)
res = client.get("calibration_constant_versions", params)
d = {}
aggregators = set()
pdus = []
for ccv in resp["data"]:
for ccv in res:
pdus.append(ccv["physical_detector_unit"])
cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"])
aggr = ccv["physical_detector_unit"]["karabo_da"]
aggregators.add(aggr)
cal_type = api.calibration_name(
ccv["calibration_constant"]["calibration_id"]
)
d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)
return cls(d, sorted(aggregators))
return cls(d, sorted(pdus, key=lambda d: d["karabo_da"]))
def __getitem__(self, key) -> ModulesConstantVersions:
return self.constant_groups[key]
@@ -284,13 +467,30 @@ class CalibrationData(Mapping):
def __len__(self):
return len(self.constant_groups)
def __repr__(self):
return (
f"<CalibrationData: {', '.join(sorted(self.constant_groups))} "
f"constants for {len(self.module_details)} modules>"
)
# These properties may include modules for which we have no constants -
# when created with .from_condition(), they represent all modules present in
# the detector (at the specified time).
@property
def module_nums(self):
return [int(da[-2:]) for da in self.aggregators]
return [m["module_number"] for m in self.module_details]
@property
def aggregator_names(self):
return [m["karabo_da"] for m in self.module_details]
@property
def qm_names(self):
return [module_index_to_qm(n) for n in self.module_nums]
return [m["virtual_device_name"] for m in self.module_details]
@property
def pdu_names(self):
return [m["physical_name"] for m in self.module_details]
def require_calibrations(self, calibrations):
"""Drop any modules missing the specified constant types"""
@@ -313,10 +513,16 @@ class CalibrationData(Mapping):
# TODO: missing for some modules?
return type(self)(d, self.aggregators)
def select_modules(self, aggregators):
def select_modules(
self, module_nums=None, *, aggregators=None, qm_names=None
) -> "CalibrationData":
return type(self)(
{
cal_type: mcv.select_modules(aggregators).constants
cal_type: mcv.select_modules(
module_nums=module_nums,
aggregators=aggregators,
qm_names=qm_names,
).constants
for (cal_type, mcv) in self.constant_groups.items()
},
sorted(aggregators),
Loading