diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index c8f786e8c9719e8500464da656486251c2360c13..39e4f1752a69b254d2a4b9579415cdd56ac7e222 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -1,17 +1,18 @@ +import json import re from collections.abc import Mapping from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, date, time, timezone +from functools import lru_cache from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Union +from urllib.parse import urljoin import h5py import numpy as np import pasha as psh -from calibration_client import CalibrationClient -from calibration_client.modules import CalibrationConstantVersion, PhysicalDetectorUnit - -from .calcat_interface import CalCatApi, CalCatError +import requests +from oauth2_xfel_client import Oauth2ClientBackend class ModuleNameError(KeyError): @@ -22,6 +23,89 @@ class ModuleNameError(KeyError): return f"No module named {self.name!r}" +class APIError(requests.HTTPError): + """Used when the response includes error details as JSON""" + + +class CalCatAPIClient: + def __init__(self, base_api_url, oauth_client=None, user_email=""): + if oauth_client is not None: + self.oauth_client = oauth_client + self.session = self.oauth_client.session + else: + # Oauth disabled - used with base_api_url pointing to an + # xfel-oauth-proxy instance + self.oauth_client = None + self.session = requests.Session() + + self.user_email = user_email + # Ensure the base URL has a trailing slash + self.base_api_url = base_api_url.rstrip("/") + "/" + + def default_headers(self): + return { + "content-type": "application/json", + "Accept": "application/json; version=2", + "X-User-Email": self.user_email, + } + + @classmethod + def format_time(cls, dt): + """Parse different ways to specify time to CalCat.""" + + if isinstance(dt, datetime): + return dt.astimezone(timezone.utc).isoformat() + elif isinstance(dt, date): + return cls.format_time(datetime.combine(dt, time())) + + return dt + + def get_request(self, relative_url, params=None, headers=None, **kwargs): + # Base URL may include e.g. '/api/'. This is a prefix for all URLs; + # even if they look like an absolute path. + url = urljoin(self.base_api_url, relative_url.lstrip("/")) + _headers = self.default_headers() + if headers: + _headers.update(headers) + return self.session.get(url, params=params, headers=_headers, **kwargs) + + @staticmethod + def _parse_response(resp: requests.Response): + if resp.status_code >= 400: + try: + d = json.loads(resp.content.decode("utf-8")) + except Exception: + resp.raise_for_status() + else: + raise APIError( + f"Error {resp.status_code} from API: " + f"{d.get('info', 'missing details')}" + ) + + if resp.content == b"": + return None + else: + res = json.loads(resp.content.decode("utf-8")) + pagination_header_fields = [ + "X-Total-Pages", + "X-Count-Per-Page", + "X-Current-Page", + "X-Total-Count", + ] + pagination_info = { + k[2:].lower().replace("-", "_"): int(resp.headers[k]) + for k in pagination_header_fields + if k in resp.headers + } + if pagination_info: + res[".pages"] = pagination_info + return res + + def get(self, relative_url, params=None, **kwargs): + resp = self.get_request(relative_url, params, **kwargs) + return self._parse_response(resp) + + global_client = None @@ -32,19 +116,35 @@ def get_client(): return global_client -def setup_client(base_url, client_id, client_secret, user_email, **kwargs): +def setup_client( + base_url, + client_id, + client_secret, + user_email, + scope="", + session_token=None, + oauth_retries=3, + oauth_timeout=12, + ssl_verify=True, +): global global_client - global_client = CalibrationClient( - use_oauth2=(client_id is not None), - client_id=client_id, - client_secret=client_secret, + if client_id is not None: + oauth_client = Oauth2ClientBackend( + client_id=client_id, + client_secret=client_secret, + scope=scope, + token_url=f"{base_url}/oauth/token", + session_token=session_token, + max_retries=oauth_retries, + timeout=oauth_timeout, + ssl_verify=ssl_verify, + ) + else: + oauth_client = None + global_client = CalCatAPIClient( + f"{base_url}/api/", + oauth_client=oauth_client, user_email=user_email, - base_api_url=f"{base_url}/api/", - token_url=f"{base_url}/oauth/token", - refresh_url=f"{base_url}/oauth/token", - auth_url=f"{base_url}/oauth/authorize", - scope="", - **kwargs, ) @@ -120,13 +220,11 @@ class ModulesConstantVersions: module_details: List[Dict] def select_modules( - self, module_nums=None, *, aggregators=None, qm_names=None + self, module_nums=None, *, aggregators=None, qm_names=None ) -> "ModulesConstantVersions": - n_specified = sum([ - module_nums is not None, - aggregators is not None, - qm_names is not None - ]) + n_specified = sum( + [module_nums is not None, aggregators is not None, qm_names is not None] + ) if n_specified < 1: raise TypeError("select_modules() requires an argument") elif n_specified > 1: @@ -195,6 +293,37 @@ class ModulesConstantVersions: return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name) +@lru_cache() +def detector(identifier, client=None): + client = client or get_client() + res = client.get("detectors", {"identifier": identifier}) + if not res: + raise KeyError(f"No detector with identifier {identifier}") + elif len(res) > 1: + raise ValueError(f"Multiple detectors found with identifier {identifier}") + return res[0] + + +@lru_cache() +def calibration_id(name, client=None): + """ID for a calibration in CalCat.""" + client = client or get_client() + res = client.get("calibrations", {"name": name}) + if not res: + raise KeyError(f"No calibration with name {name}") + elif len(res) > 1: + raise ValueError(f"Multiple calibrations found with name {name}") + return res[0]["id"] + + +@lru_cache() +def calibration_name(cal_id, client=None): + """Name for a calibration in CalCat.""" + client = client or get_client() + res = client.get(f"calibrations/{cal_id}") + return res["name"] + + class CalibrationData(Mapping): """Collected constants for a given detector""" @@ -205,6 +334,23 @@ class CalibrationData(Mapping): } self.module_details = module_details + @staticmethod + def _format_cond(condition): + """Encode operating condition to CalCat API format. + + Args: + condition (dict): Mapping of parameter DB name to value + + Returns: + (dict) Operating condition for use in CalCat API. + """ + + return { + "parameters_conditions_attributes": [ + {"parameter_name": k, "value": str(v)} for k, v in condition.items() + ] + } + @classmethod def from_condition( cls, @@ -225,18 +371,20 @@ class CalibrationData(Mapping): if cal_type in calibrations: cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type) - api = CalCatApi(client or get_client()) + client = client or get_client() - detector_id = api.detector(detector_name)["id"] - resp_pdus = PhysicalDetectorUnit.get_all_by_detector( - api.client, detector_id, api.format_time(pdu_snapshot_at) + detector_id = detector(detector_name)["id"] + pdus = client.get( + "physical_detector_units/get_all_by_detector", + { + "detector_id": detector_id, + "pdu_snapshot_at": client.format_time(pdu_snapshot_at), + }, ) - if not resp_pdus["success"]: - raise CalCatError(resp_pdus) - module_details = sorted(resp_pdus["data"], key=lambda d: d['karabo_da']) + module_details = sorted(pdus, key=lambda d: d["karabo_da"]) for mod in module_details: - if mod.get('module_number', -1) < 0: - mod['module_number'] = int(re.findall(r'\d+', mod['karabo_da'])[-1]) + if mod.get("module_number", -1) < 0: + mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1]) d = {} @@ -244,18 +392,20 @@ class CalibrationData(Mapping): condition_dict = condition.make_dict(params) cal_id_map = { - api.calibration_id(calibration): calibration - for calibration in cal_types + calibration_id(calibration): calibration for calibration in cal_types } calibration_ids = list(cal_id_map.keys()) - query_res = api._closest_ccv_by_time_by_condition( - detector_name, - calibration_ids, - condition_dict, - "", - event_at, - pdu_snapshot_at or event_at, + query_res = client.get( + "calibration_constant_versions/get_by_detector_conditions", + { + "detector_identifier": detector_name, + "calibration_id": str(calibration_ids), + "karabo_da": "", + "event_at": client.format_time(event_at), + "pdu_snapshot_at": client.format_time(pdu_snapshot_at), + }, + data=json.dumps(cls._format_cond(condition_dict)), ) for ccv in query_res: @@ -275,29 +425,23 @@ class CalibrationData(Mapping): client=None, ): client = client or get_client() - api = CalCatApi(client) + # Use max page size, hopefully always enough for CCVs from 1 report + params = {"page_size": 500} if isinstance(report_id_or_path, int): - resp = CalibrationConstantVersion.get_by_report_id( - client, report_id_or_path - ) + params["report_id"] = report_id_or_path # Numeric ID else: - resp = CalibrationConstantVersion.get_by_report_path( - client, report_id_or_path - ) + params["report.file_path"] = str(report_id_or_path) - if not resp["success"]: - raise CalCatError(resp) + res = client.get("calibration_constant_versions", params) d = {} aggregators = set() - for ccv in resp["data"]: + for ccv in res: aggr = ccv["physical_detector_unit"]["karabo_da"] aggregators.add(aggr) - cal_type = api.calibration_name( - ccv["calibration_constant"]["calibration_id"] - ) + cal_type = calibration_name(ccv["calibration_constant"]["calibration_id"]) d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) return cls(d, sorted(aggregators))