diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 400c9486bd7cdb7a30556d6495479ec114ce7548..8126b0a60162d8d07713bb54b1c60f266247b8b3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: # If `CI_MERGE_REQUEST_TARGET_BRANCH_SHA` env var is set then this will # run flake8 on the diff of the merge request, otherwise it will run # flake8 as it would usually execute via the pre-commit hook - entry: bash -c 'if [ -z ${CI_MERGE_REQUEST_TARGET_BRANCH_SHA} ]; then (flake8 "$@"); else (git diff $CI_MERGE_REQUEST_TARGET_BRANCH_SHA...$CI_MERGE_REQUEST_SOURCE_BRANCH_SHA | flake8 --diff); fi' -- + entry: bash -c 'if [ -z ${CI_MERGE_REQUEST_TARGET_BRANCH_SHA} ]; then (flake8 "$@" --max-line-length 88); else (git diff $CI_MERGE_REQUEST_TARGET_BRANCH_SHA...$CI_MERGE_REQUEST_SOURCE_BRANCH_SHA | flake8 --diff --max-line-length 88); fi' -- - repo: https://github.com/myint/rstcheck rev: 3f92957478422df87bd730abde66f089cc1ee19b # commit where pre-commit support was added hooks: diff --git a/setup.py b/setup.py index cfd9aaf32ebd94b6117163f02b30a367ca5a2e52..9177c82c32ff8cceab356c6a180d548940c27b5f 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ install_requires = [ "markupsafe==2.0.1", "astcheck==0.3.0", "cfelpyutils==2.0.6", - "calibration_client==11.2.0", + "calibration_client==11.3.0", "dill==0.3.0", "docutils==0.17.1", "dynaconf==3.1.4", diff --git a/src/cal_tools/calcat_interface.py b/src/cal_tools/calcat_interface.py index b6fae338496ad62a1ddf33c3c6d768dcc1c8796e..164ce0ee0f52c08a7e1f48f81a2a505f6ab454aa 100644 --- a/src/cal_tools/calcat_interface.py +++ b/src/cal_tools/calcat_interface.py @@ -2,6 +2,7 @@ from datetime import date, datetime, time, timezone from functools import lru_cache from pathlib import Path +from typing import Optional, Sequence from weakref import WeakKeyDictionary import h5py @@ -113,8 +114,7 @@ class CalCatApi(metaclass=ClientWrapper): """Encode operating condition to CalCat API format. Args: - caldata (CalibrationData): Calibration data instance used to - interface with database. + condition (dict): Mapping of parameter DB name to value Returns: (dict) Operating condition for use in CalCat API. @@ -192,6 +192,19 @@ class CalCatApi(metaclass=ClientWrapper): return resp_calibration["data"]["id"] + @lru_cache() + def calibration_name(self, calibration_id): + """Name for a calibration in CalCat.""" + + resp_calibration = Calibration.get_by_id( + self.client, calibration_id + ) + + if not resp_calibration["success"]: + raise CalCatError(resp_calibration) + + return resp_calibration["data"]["name"] + @lru_cache() def parameter_id(self, param_name): """ID for an operating condition parameter in CalCat.""" @@ -203,6 +216,33 @@ class CalCatApi(metaclass=ClientWrapper): return resp_parameter["data"]["id"] + def _closest_ccv_by_time_by_condition( + self, + detector_name: str, + calibration_ids: Sequence[int], + condition: dict, + karabo_da: Optional[str] = None, + event_at=None, + pdu_snapshot_at=None, + ): + resp = CalibrationConstantVersion.get_closest_by_time_by_detector_conditions( + self.client, + detector_name, + calibration_ids, + self.format_cond(condition), + karabo_da=karabo_da or "", + event_at=self.format_time(event_at), + pdu_snapshot_at=self.format_time(pdu_snapshot_at), + ) + + if not resp["success"]: + if resp["status_code"] == 200: + # calibration_client turns empty response into an error + return [] + raise CalCatError(resp) + + return resp["data"] + def closest_ccv_by_time_by_condition( self, detector_name, @@ -284,20 +324,16 @@ class CalCatApi(metaclass=ClientWrapper): # afterwards, if necessary. karabo_da = next(iter(da_to_modname)) if len(da_to_modname) == 1 else '', - resp_versions = CalibrationConstantVersion.get_closest_by_time_by_detector_conditions( # noqa - self.client, + resp_data = self._closest_ccv_by_time_by_condition( detector_name, calibration_ids, - self.format_cond(condition), + condition, karabo_da=karabo_da, event_at=event_at, pdu_snapshot_at=pdu_snapshot_at, ) - if not resp_versions["success"]: - raise CalCatError(resp_versions) - - for ccv in resp_versions["data"]: + for ccv in resp_data: try: mod = da_to_modname[ccv['physical_detector_unit']['karabo_da']] except KeyError: diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py new file mode 100644 index 0000000000000000000000000000000000000000..9048db129d60e69991e63edd14f659dc902643f7 --- /dev/null +++ b/src/cal_tools/calcat_interface2.py @@ -0,0 +1,839 @@ +import json +import re +from collections.abc import Mapping +from dataclasses import dataclass, field, replace +from datetime import date, datetime, time, timezone +from functools import lru_cache +from pathlib import Path +from typing import Dict, List, Optional, Union +from urllib.parse import urljoin +from warnings import warn + +import h5py +import pasha as psh +import requests +from oauth2_xfel_client import Oauth2ClientBackend + + +# Default address to connect to, only available internally +CALCAT_PROXY_URL = "http://exflcalproxy.desy.de:8080/" + + +class ModuleNameError(KeyError): + def __init__(self, name): + self.name = name + + def __str__(self): + return f"No module named {self.name!r}" + + +class CalCatAPIError(requests.HTTPError): + """Used when the response includes error details as JSON""" + + +class CalCatAPIClient: + def __init__(self, base_api_url, oauth_client=None, user_email=""): + if oauth_client is not None: + self.oauth_client = oauth_client + self.session = self.oauth_client.session + else: + # Oauth disabled - used with base_api_url pointing to an + # xfel-oauth-proxy instance + self.oauth_client = None + self.session = requests.Session() + + self.user_email = user_email + # Ensure the base URL has a trailing slash + self.base_api_url = base_api_url.rstrip("/") + "/" + + def default_headers(self): + return { + "content-type": "application/json", + "Accept": "application/json; version=2", + "X-User-Email": self.user_email, + } + + @classmethod + def format_time(cls, dt): + """Parse different ways to specify time to CalCat.""" + + if isinstance(dt, datetime): + return dt.astimezone(timezone.utc).isoformat() + elif isinstance(dt, date): + return cls.format_time(datetime.combine(dt, time())) + elif not isinstance(dt, str): + raise TypeError( + f"Timestamp parameter ({dt!r}) must be a string, datetime or " + f"date object" + ) + + return dt + + def get_request(self, relative_url, params=None, headers=None, **kwargs): + """Make a GET request, return the HTTP response object""" + # Base URL may include e.g. '/api/'. This is a prefix for all URLs; + # even if they look like an absolute path. + url = urljoin(self.base_api_url, relative_url.lstrip("/")) + _headers = self.default_headers() + if headers: + _headers.update(headers) + return self.session.get(url, params=params, headers=_headers, **kwargs) + + @staticmethod + def _parse_response(resp: requests.Response): + if resp.status_code >= 400: + try: + d = json.loads(resp.content.decode("utf-8")) + except Exception: + resp.raise_for_status() + else: + raise CalCatAPIError( + f"Error {resp.status_code} from API: " + f"{d.get('info', 'missing details')}" + ) + + if resp.content == b"": + return None + else: + return json.loads(resp.content.decode("utf-8")) + + def get(self, relative_url, params=None, **kwargs): + """Make a GET request, return response content from JSON""" + resp = self.get_request(relative_url, params, **kwargs) + return self._parse_response(resp) + + _pagination_headers = ( + "X-Total-Pages", + "X-Count-Per-Page", + "X-Current-Page", + "X-Total-Count", + ) + + def get_paged(self, relative_url, params=None, **kwargs): + """Make a GET request, return response content & pagination info""" + resp = self.get_request(relative_url, params, **kwargs) + content = self._parse_response(resp) + pagination_info = { + k[2:].lower().replace("-", "_"): int(resp.headers[k]) + for k in self._pagination_headers + if k in resp.headers + } + return content, pagination_info + + # ------------------ + # Cached wrappers for simple ID lookups of fixed-ish info + # + # N.B. lru_cache behaves oddly with instance methods (it's a global cache, + # with the instance as part of the key), but in this case it should be OK. + @lru_cache() + def calibration_by_id(self, cal_id): + return self.get(f"calibrations/{cal_id}") + + @lru_cache() + def detector_by_id(self, det_id): + return self.get(f"detectors/{det_id}") + + # -------------------- + # Shortcuts to find 1 of something by an ID-like field (e.g. name) other + # than CalCat's own integer IDs. Error on no match or >1 matches. + @lru_cache() + def detector_by_identifier(self, identifier): + # The "identifier", "name" & "karabo_name" fields seem to have the same names + res = self.get("detectors", {"identifier": identifier}) + if not res: + raise KeyError(f"No detector with identifier {identifier}") + elif len(res) > 1: + raise ValueError(f"Multiple detectors found with identifier {identifier}") + return res[0] + + @lru_cache() + def calibration_by_name(self, name): + res = self.get("calibrations", {"name": name}) + if not res: + raise KeyError(f"No calibration with name {name}") + elif len(res) > 1: + raise ValueError(f"Multiple calibrations found with name {name}") + return res[0] + + +global_client = None + + +def get_client(): + global global_client + if global_client is None: + setup_client(CALCAT_PROXY_URL, None, None, None) + return global_client + + +def setup_client( + base_url, + client_id, + client_secret, + user_email, + scope="", + session_token=None, + oauth_retries=3, + oauth_timeout=12, + ssl_verify=True, +): + global global_client + if client_id is not None: + oauth_client = Oauth2ClientBackend( + client_id=client_id, + client_secret=client_secret, + scope=scope, + token_url=f"{base_url}/oauth/token", + session_token=session_token, + max_retries=oauth_retries, + timeout=oauth_timeout, + ssl_verify=ssl_verify, + ) + else: + oauth_client = None + global_client = CalCatAPIClient( + f"{base_url}/api/", + oauth_client=oauth_client, + user_email=user_email, + ) + + # Check we can connect to exflcalproxy + if oauth_client is None and base_url == CALCAT_PROXY_URL: + try: + # timeout=(connect_timeout, read_timeout) + global_client.get_request("me", timeout=(1, 5)) + except requests.ConnectionError as e: + raise RuntimeError( + "Could not connect to calibration catalog proxy. This proxy allows " + "unauthenticated access inside the XFEL/DESY network. To look up " + "calibration constants from outside, you will need to create an Oauth " + "client ID & secret in the CalCat web interface. You will still not " + "be able to load constants without the constant store folder." + ) from e + + +_default_caldb_root = None + + +def _get_default_caldb_root(): + global _default_caldb_root + if _default_caldb_root is None: + onc_path = Path("/common/cal/caldb_store") + maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store") + if onc_path.is_dir(): + _default_caldb_root = onc_path + elif maxwell_path.is_dir(): + _default_caldb_root = maxwell_path + else: + raise RuntimeError( + f"Neither {onc_path} nor {maxwell_path} was found. If the caldb_store " + "directory is at another location, pass its path as caldb_root." + ) + + return _default_caldb_root + + +@dataclass +class SingleConstant: + """A calibration constant for one detector module + + CalCat calls this a calibration constant version (CCV). + """ + + path: Path + dataset: str + ccv_id: Optional[int] + pdu_name: Optional[str] + _metadata: dict = field(default_factory=dict) + _have_calcat_metadata: bool = False + + @classmethod + def from_response(cls, ccv: dict) -> "SingleConstant": + return cls( + path=Path(ccv["path_to_file"]) / ccv["file_name"], + dataset=ccv["data_set_name"], + ccv_id=ccv["id"], + pdu_name=ccv["physical_detector_unit"]["physical_name"], + _metadata=ccv, + _have_calcat_metadata=True, + ) + + def dataset_obj(self, caldb_root=None) -> h5py.Dataset: + if caldb_root is not None: + caldb_root = Path(caldb_root) + else: + caldb_root = _get_default_caldb_root() + + f = h5py.File(caldb_root / self.path, "r") + return f[self.dataset]["data"] + + def ndarray(self, caldb_root=None): + return self.dataset_obj(caldb_root)[:] + + def _load_calcat_metadata(self, client=None): + client = client or get_client() + calcat_meta = client.get(f"calibration_constant_versions/{self.ccv_id}") + # Any metadata we already have takes precedence over CalCat, so + # this can't change a value that was previously returned. + self._metadata = calcat_meta | self._metadata + self._have_calcat_metadata = True + + def metadata(self, key, client=None): + """Get a specific metadata field, e.g. 'begin_validity_at' + + This may make a request to CalCat if the value is not already known. + """ + if key not in self._metadata and not self._have_calcat_metadata: + if self.ccv_id is None: + raise KeyError(f"{key!r} (no CCV ID to request data from CalCat") + self._load_calcat_metadata(client) + + return self._metadata[key] + + def metadata_dict(self, client=None): + """Get a dict of available metadata + + If this constant didn't come from CalCat but we have a CalCat CCV ID, + this will fetch metadata from CalCat. + """ + if (not self._have_calcat_metadata) and (self.ccv_id is not None): + self._load_calcat_metadata(client) + return self._metadata.copy() + + +def prepare_selection( + module_details, module_nums=None, aggregator_names=None, qm_names=None +): + aggs = aggregator_names # Shorter name -> fewer multi-line statements + n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None]) + if n_specified > 1: + raise TypeError( + "select_modules() accepts only one of module_nums, aggregator_names " + "& qm_names" + ) + + if module_nums is not None: + by_mod_no = {m["module_number"]: m for m in module_details} + return [by_mod_no[n]["karabo_da"] for n in module_nums] + elif qm_names is not None: + by_qm = {m["virtual_device_name"]: m for m in module_details} + return [by_qm[s]["karabo_da"] for s in qm_names] + elif aggs is not None: + miss = set(aggs) - {m["karabo_da"] for m in module_details} + if miss: + raise KeyError("Aggregators not found: " + ", ".join(sorted(miss))) + return aggs + else: + raise TypeError("select_modules() requires an argument") + + +@dataclass +class MultiModuleConstant(Mapping): + """A group of similar constants for several modules of one detector""" + + constants: Dict[str, SingleConstant] # Keys e.g. 'LPD00' + module_details: List[Dict] + detector_name: str # e.g. 'HED_DET_AGIPD500K2G' + calibration_name: str + + def __repr__(self): + return ( + f"<MultiModuleConstant: {self.calibration_name} for " + f"{len(self.constants)} modules of {self.detector_name}>" + ) + + def __iter__(self): + return iter(self.constants) + + def __len__(self): + return len(self.constants) + + def __getitem__(self, key): + if key in (None, ""): + raise KeyError(key) + + candidate_kdas = set() + if key in self.constants: # Karabo DA name, e.g. 'LPD00' + candidate_kdas.add(key) + + for m in self.module_details: + names = (m["module_number"], m["virtual_device_name"], m["physical_name"]) + if key in names and m["karabo_da"] in self.constants: + candidate_kdas.add(m["karabo_da"]) + + if not candidate_kdas: + raise KeyError(key) + elif len(candidate_kdas) > 1: + raise KeyError(f"Ambiguous key: {key} matched {candidate_kdas}") + + return self.constants[candidate_kdas.pop()] + + def select_modules( + self, module_nums=None, *, aggregator_names=None, qm_names=None + ) -> "MultiModuleConstant": + aggs = prepare_selection( + self.module_details, module_nums, aggregator_names, qm_names + ) + d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs} + mods = [m for m in self.module_details if m["karabo_da"] in d] + return replace(self, constants=d, module_details=mods) + + # These properties label only the modules we have constants for, which may + # be a subset of what's in module_details + @property + def aggregator_names(self): + return sorted(self.constants) + + @property + def module_nums(self): + return [ + m["module_number"] + for m in self.module_details + if m["karabo_da"] in self.constants + ] + + @property + def qm_names(self): + return [ + m["virtual_device_name"] + for m in self.module_details + if m["karabo_da"] in self.constants + ] + + @property + def pdu_names(self): + return [ + m["physical_name"] + for m in self.module_details + if m["karabo_da"] in self.constants + ] + + def ndarray(self, caldb_root=None, *, parallel=0): + eg_dset = self.constants[self.aggregator_names[0]].dataset_obj(caldb_root) + shape = (len(self.constants),) + eg_dset.shape + + if parallel > 0: + load_ctx = psh.ProcessContext(num_workers=parallel) + else: + load_ctx = psh.SerialContext() + + arr = psh.alloc(shape, eg_dset.dtype, fill=0) + + def _load_constant_dataset(wid, index, mod): + dset = self.constants[mod].dataset_obj(caldb_root) + dset.read_direct(arr[index]) + + load_ctx.map(_load_constant_dataset, self.aggregator_names) + return arr + + def xarray(self, module_naming="modnum", caldb_root=None, *, parallel=0): + import xarray + + if module_naming == "aggregator": + modules = self.aggregator_names + elif module_naming == "modnum": + modules = self.module_nums + elif module_naming == "qm": + modules = self.qm_names + else: + raise ValueError( + f"{module_naming=} (must be 'aggregator', 'modnum' or 'qm'" + ) + + ndarr = self.ndarray(caldb_root, parallel=parallel) + + # Dimension labels + dims = ["module"] + ["dim_%d" % i for i in range(ndarr.ndim - 1)] + coords = {"module": modules} + name = self.calibration_name + + return xarray.DataArray(ndarr, dims=dims, coords=coords, name=name) + + +class CalibrationData(Mapping): + """Collected constants for a given detector""" + + def __init__(self, constant_groups, module_details, detector_name): + # {calibration: {karabo_da: SingleConstant}} + self.constant_groups = constant_groups + self.module_details = module_details + self.detector_name = detector_name + + @staticmethod + def _format_cond(condition): + """Encode operating condition to CalCat API format. + + Args: + condition (dict): Mapping of parameter DB name to value + + Returns: + (dict) Operating condition for use in CalCat API. + """ + + return { + "parameters_conditions_attributes": [ + {"parameter_name": k, "value": str(v)} for k, v in condition.items() + ] + } + + @classmethod + def from_condition( + cls, + condition: "ConditionsBase", + detector_name, + calibrations=None, + client=None, + event_at=None, + pdu_snapshot_at=None, + ): + if calibrations is None: + calibrations = set(condition.calibration_types) + if pdu_snapshot_at is None: + pdu_snapshot_at = event_at + + cal_types_by_params_used = {} + for cal_type, params in condition.calibration_types.items(): + if cal_type in calibrations: + cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type) + + client = client or get_client() + + detector_id = client.detector_by_identifier(detector_name)["id"] + pdus = client.get( + "physical_detector_units/get_all_by_detector", + { + "detector_id": detector_id, + "pdu_snapshot_at": client.format_time(pdu_snapshot_at), + }, + ) + module_details = sorted(pdus, key=lambda d: d["karabo_da"]) + for mod in module_details: + if mod.get("module_number") is None: + mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1]) + + constant_groups = {} + + for params, cal_types in cal_types_by_params_used.items(): + condition_dict = condition.make_dict(params) + + cal_id_map = { + client.calibration_by_name(name)["id"]: name for name in cal_types + } + calibration_ids = list(cal_id_map.keys()) + + query_res = client.get( + "calibration_constant_versions/get_by_detector_conditions", + { + "detector_identifier": detector_name, + "calibration_id": str(calibration_ids), + "karabo_da": "", + "event_at": client.format_time(event_at), + "pdu_snapshot_at": client.format_time(pdu_snapshot_at), + }, + data=json.dumps(cls._format_cond(condition_dict)), + ) + + for ccv in query_res: + aggr = ccv["physical_detector_unit"]["karabo_da"] + cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]] + + const_group = constant_groups.setdefault(cal_type, {}) + const_group[aggr] = SingleConstant.from_response(ccv) + + return cls(constant_groups, module_details, detector_name) + + @classmethod + def from_report( + cls, + report_id_or_path: Union[int, str], + client=None, + ): + client = client or get_client() + + # Use max page size, hopefully always enough for CCVs from 1 report + params = {"page_size": 500} + if isinstance(report_id_or_path, int): + params["report_id"] = report_id_or_path # Numeric ID + else: + params["report.file_path"] = str(report_id_or_path) + + res = client.get("calibration_constant_versions", params) + + constant_groups = {} + pdus = {} # keyed by karabo_da (e.g. 'AGIPD00') + det_ids = set() # Should only have one detector + + for ccv in res: + pdu = ccv["physical_detector_unit"] + # We're only interested in the PDU mapping from the CCV start time + kda = pdu["karabo_da"] = pdu.pop("karabo_da_at_ccv_begin_at") + det_id = pdu["detector_id"] = pdu.pop("detector_id_at_ccv_begin_at") + pdu["virtual_device_name"] = pdu.pop("virtual_device_name_at_ccv_begin_at") + if pdu.get("module_number_at_ccv_begin_at") is not None: + pdu["module_number"] = pdu.pop("module_number_at_ccv_begin_at") + else: + pdu["module_number"] = int(re.findall(r"\d+", kda)[-1]) + + det_ids.add(det_id) + if kda in pdus: + if pdu["physical_name"] != pdus[kda]["physical_name"]: + raise Exception( + f"Mismatched PDU mapping from calibration report: {kda} is both" + f" {pdu['physical_name']} and {pdus[kda]['physical_name']}" + ) + else: + pdus[kda] = pdu + + cal_type = client.calibration_by_id( + ccv["calibration_constant"]["calibration_id"] + )["name"] + const_group = constant_groups.setdefault(cal_type, {}) + const_group[kda] = SingleConstant.from_response(ccv) + + if len(det_ids) > 1: + raise Exception(f"Found multiple detector IDs in report: {det_ids}") + # The "identifier", "name" & "karabo_name" fields seem to have the same names + det_name = client.detector_by_id(det_ids.pop())["identifier"] + + module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"]) + return cls(constant_groups, module_details, det_name) + + def __getitem__(self, key) -> MultiModuleConstant: + if isinstance(key, str): + return MultiModuleConstant( + self.constant_groups[key], self.module_details, self.detector_name, key + ) + elif isinstance(key, tuple) and len(key) == 2: + cal_type, module = key + return self[cal_type][module] + else: + raise TypeError(f"Key should be string or 2-tuple (got {key!r})") + + def __iter__(self): + return iter(self.constant_groups) + + def __len__(self): + return len(self.constant_groups) + + def __contains__(self, item): + return item in self.constant_groups + + def __repr__(self): + return ( + f"<CalibrationData: {', '.join(sorted(self.constant_groups))} " + f"constants for {len(self.module_details)} modules of {self.detector_name}>" + ) + + # These properties may include modules for which we have no constants - + # when created with .from_condition(), they represent all modules present in + # the detector (at the specified time). + @property + def module_nums(self): + return [m["module_number"] for m in self.module_details] + + @property + def aggregator_names(self): + return [m["karabo_da"] for m in self.module_details] + + @property + def qm_names(self): + return [m["virtual_device_name"] for m in self.module_details] + + @property + def pdu_names(self): + return [m["physical_name"] for m in self.module_details] + + def require_calibrations(self, calibrations): + """Drop any modules missing the specified constant types""" + mods = set(self.aggregator_names) + for cal_type in calibrations: + mods.intersection_update(self[cal_type].constants) + return self.select_modules(aggregator_names=mods) + + def select_modules( + self, module_nums=None, *, aggregator_names=None, qm_names=None + ) -> "CalibrationData": + # Validate the specified modules against those we know about. + # Each specific constant type may have only a subset of these modules. + aggs = prepare_selection( + self.module_details, module_nums, aggregator_names, qm_names + ) + constant_groups = {} + matched_aggregators = set() + for cal_type, const_group in self.constant_groups.items(): + constant_groups[cal_type] = d = { + aggr: const for (aggr, const) in const_group.items() if aggr in aggs + } + matched_aggregators.update(d.keys()) + module_details = [ + m for m in self.module_details if m["karabo_da"] in matched_aggregators + ] + return type(self)(constant_groups, module_details, self.detector_name) + + def select_calibrations(self, calibrations) -> "CalibrationData": + const_groups = {c: self.constant_groups[c] for c in calibrations} + return type(self)(const_groups, self.module_details, self.detector_name) + + def merge(self, *others: "CalibrationData") -> "CalibrationData": + det_names = set(cd.detector_name for cd in (self,) + others) + if len(det_names) > 1: + raise Exception( + "Cannot merge calibration data for different " + "detectors: " + ", ".join(sorted(det_names)) + ) + det_name = det_names.pop() + + cal_types = set(self.constant_groups) + aggregators = set(self.aggregator_names) + pdus_d = {m["karabo_da"]: m for m in self.module_details} + for other in others: + cal_types.update(other.constant_groups) + aggregators.update(other.aggregator_names) + for md in other.module_details: + # Warn if constants don't refer to same modules + md_da = md["karabo_da"] + if md_da in pdus_d: + pdu_a = pdus_d[md_da]["physical_name"] + pdu_b = md["physical_name"] + if pdu_a != pdu_b: + warn( + f"Merging constants with different modules for " + f"{md_da}: {pdu_a!r} != {pdu_b!r}", + stacklevel=2, + ) + else: + pdus_d[md_da] = md + + module_details = sorted(pdus_d.values(), key=lambda d: d["karabo_da"]) + + constant_groups = {} + for cal_type in cal_types: + d = constant_groups[cal_type] = {} + for caldata in (self,) + others: + if cal_type in caldata: + d.update(caldata.constant_groups[cal_type]) + + return type(self)(constant_groups, module_details, det_name) + + +class ConditionsBase: + calibration_types = {} # For subclasses: {calibration: [parameter names]} + + def make_dict(self, parameters) -> dict: + d = dict() + + for db_name in parameters: + value = getattr(self, db_name.lower().replace(" ", "_")) + if value is not None: + d[db_name] = value + + return d + + +@dataclass +class AGIPDConditions(ConditionsBase): + sensor_bias_voltage: float + memory_cells: int + acquisition_rate: float + gain_setting: Optional[int] + gain_mode: Optional[int] + source_energy: float + integration_time: int = 12 + pixels_x: int = 512 + pixels_y: int = 128 + + _gain_parameters = [ + "Sensor Bias Voltage", + "Pixels X", + "Pixels Y", + "Memory cells", + "Acquisition rate", + "Gain setting", + "Integration time", + ] + _other_dark_parameters = _gain_parameters + ["Gain mode"] + _illuminated_parameters = _gain_parameters + ["Source energy"] + + calibration_types = { + "Offset": _other_dark_parameters, + "Noise": _other_dark_parameters, + "ThresholdsDark": _other_dark_parameters, + "BadPixelsDark": _other_dark_parameters, + "BadPixelsPC": _gain_parameters, + "SlopesPC": _gain_parameters, + "BadPixelsFF": _illuminated_parameters, + "SlopesFF": _illuminated_parameters, + } + + def make_dict(self, parameters): + cond = super().make_dict(parameters) + + # Fix-up some database quirks. + if int(cond.get("Gain mode", -1)) == 0: + del cond["Gain mode"] + + if int(cond.get("Integration time", -1)) == 12: + del cond["Integration time"] + + return cond + + +@dataclass +class LPDConditions(ConditionsBase): + sensor_bias_voltage: float + memory_cells: int + memory_cell_order: Optional[str] = None + feedback_capacitor: float = 5.0 + source_energy: float = 9.2 + category: int = 1 + pixels_x: int = 256 + pixels_y: int = 256 + + _base_params = [ + "Sensor Bias Voltage", + "Memory cells", + "Pixels X", + "Pixels Y", + "Feedback capacitor", + ] + _dark_parameters = _base_params + [ + "Memory cell order", + ] + _illuminated_parameters = _base_params + ["Source Energy", "category"] + + calibration_types = { + "Offset": _dark_parameters, + "Noise": _dark_parameters, + "BadPixelsDark": _dark_parameters, + "RelativeGain": _illuminated_parameters, + "GainAmpMap": _illuminated_parameters, + "FFMap": _illuminated_parameters, + "BadPixelsFF": _illuminated_parameters, + } + + +@dataclass +class DSSCConditions(ConditionsBase): + sensor_bias_voltage: float + memory_cells: int + pulse_id_checksum: Optional[float] = None + acquisition_rate: Optional[float] = None + target_gain: Optional[int] = None + encoded_gain: Optional[int] = None + pixels_x: int = 512 + pixels_y: int = 128 + + _params = [ + "Sensor Bias Voltage", + "Memory cells", + "Pixels X", + "Pixels Y", + "Pulse id checksum", + "Acquisition rate", + "Target gain", + "Encoded gain", + ] + calibration_types = { + "Offset": _params, + "Noise": _params, + } diff --git a/tests/test_calcat_interface2.py b/tests/test_calcat_interface2.py new file mode 100644 index 0000000000000000000000000000000000000000..ca84e3a5ccf162185f3bde549566adfd4c1906c8 --- /dev/null +++ b/tests/test_calcat_interface2.py @@ -0,0 +1,187 @@ +import numpy as np +import pytest +import xarray as xr + +from cal_tools.calcat_interface2 import ( + AGIPDConditions, + CalibrationData, + DSSCConditions, + LPDConditions, + SingleConstant, +) + + +@pytest.mark.requires_gpfs +def test_AGIPD_CalibrationData_metadata(): + """Test CalibrationData with AGIPD condition""" + cond = AGIPDConditions( + # From: https://in.xfel.eu/calibration/calibration_constants/5754#condition + sensor_bias_voltage=300, # V + memory_cells=352, + acquisition_rate=2.2, # MHz + gain_mode=0, + gain_setting=1, + integration_time=12, + source_energy=9.2, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "MID_DET_AGIPD1M-1", + event_at="2022-09-01 13:26:48.00", + calibrations=["Offset", "SlopesFF"], + ) + assert agipd_cd.detector_name == "MID_DET_AGIPD1M-1" + assert "Offset" in agipd_cd + assert set(agipd_cd["Offset"].constants) == {f"AGIPD{m:02}" for m in range(16)} + assert isinstance(agipd_cd["Offset", "AGIPD00"], SingleConstant) + assert agipd_cd["Offset", "Q1M2"] == agipd_cd["Offset", "AGIPD01"] + + +@pytest.mark.requires_gpfs +def test_AGIPD_merge(): + cond = AGIPDConditions( + # From: https://in.xfel.eu/calibration/calibration_constants/5754#condition + sensor_bias_voltage=300, # V + memory_cells=352, + acquisition_rate=2.2, # MHz + gain_mode=0, + gain_setting=1, + integration_time=12, + source_energy=9.2, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "MID_DET_AGIPD1M-1", + event_at="2022-09-01 13:26:48.00", + calibrations=["Offset", "SlopesFF"], + ) + + modnos_q1 = list(range(0, 4)) + modnos_q4 = list(range(12, 16)) + merged = agipd_cd.select_modules(modnos_q1).merge( + agipd_cd.select_modules(modnos_q4) + ) + assert merged.module_nums == modnos_q1 + modnos_q4 + + offset_only = agipd_cd.select_calibrations(["Offset"]) + slopes_only = agipd_cd.select_calibrations(["SlopesFF"]) + assert set(offset_only) == {"Offset"} + assert set(slopes_only) == {"SlopesFF"} + merged_cals = offset_only.merge(slopes_only) + assert set(merged_cals) == {"Offset", "SlopesFF"} + assert merged_cals.module_nums == list(range(16)) + + +@pytest.mark.requires_gpfs +def test_AGIPD_CalibrationData_metadata_SPB(): + """Test CalibrationData with AGIPD condition""" + cond = AGIPDConditions( + sensor_bias_voltage=300, + memory_cells=352, + acquisition_rate=1.1, + integration_time=12, + source_energy=9.2, + gain_mode=0, + gain_setting=0, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "SPB_DET_AGIPD1M-1", + event_at="2020-01-07 13:26:48.00", + ) + assert "Offset" in agipd_cd + assert set(agipd_cd["Offset"].constants) == {f"AGIPD{m:02}" for m in range(16)} + assert agipd_cd["Offset"].module_nums == list(range(16)) + assert agipd_cd["Offset"].qm_names == [ + f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16) + ] + assert isinstance(agipd_cd["Offset", 0], SingleConstant) + + +@pytest.mark.requires_gpfs +def test_AGIPD_load_data(): + cond = AGIPDConditions( + sensor_bias_voltage=300, + memory_cells=352, + acquisition_rate=1.1, + integration_time=12, + source_energy=9.2, + gain_mode=0, + gain_setting=0, + ) + agipd_cd = CalibrationData.from_condition( + cond, + "SPB_DET_AGIPD1M-1", + event_at="2020-01-07 13:26:48.00", + ) + arr = agipd_cd["Offset"].select_modules(list(range(4))).xarray() + assert arr.shape == (4, 128, 512, 352, 3) + assert arr.dims[0] == "module" + np.testing.assert_array_equal(arr.coords["module"], np.arange(0, 4)) + assert arr.dtype == np.float64 + + # Load parallel + arr_p = agipd_cd["Offset"].select_modules(list(range(4))).xarray(parallel=4) + xr.testing.assert_identical(arr_p, arr) + + +@pytest.mark.requires_gpfs +def test_DSSC_modules_missing(): + dssc_cd = CalibrationData.from_condition( + DSSCConditions(sensor_bias_voltage=100, memory_cells=600), + "SQS_DET_DSSC1M-1", + event_at="2023-11-29 00:00:00", + ) + # DSSC was used with only 3 quadrants at this point + modnos = list(range(4)) + list(range(8, 16)) + assert dssc_cd.aggregator_names == [f"DSSC{m:02}" for m in modnos] + assert dssc_cd.module_nums == modnos + assert dssc_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in modnos] + + offset = dssc_cd["Offset"] + assert offset.module_nums == modnos + + # test ModulesConstantVersions.select_modules() + modnos_q3 = list(range(8, 12)) + aggs_q3 = [f"DSSC{m:02}" for m in modnos_q3] + qm_q3 = [f"Q3M{i}" for i in range(1, 5)] + assert offset.select_modules(modnos_q3).module_nums == modnos_q3 + assert offset.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3 + assert offset.select_modules(qm_names=qm_q3).module_nums == modnos_q3 + + # test CalibrationData.select_modules() + assert dssc_cd.select_modules(modnos_q3).module_nums == modnos_q3 + assert dssc_cd.select_modules(aggregator_names=aggs_q3).module_nums == modnos_q3 + assert dssc_cd.select_modules(qm_names=qm_q3).module_nums == modnos_q3 + + +@pytest.mark.requires_gpfs +def test_LPD_constant_missing(): + lpd_cd = CalibrationData.from_condition( + LPDConditions(memory_cells=200, sensor_bias_voltage=250), + "FXE_DET_LPD1M-1", + event_at="2022-05-22T02:00:00", + ) + # Constants are missing for 1 module (LPD05), but it was still included in + # the PDUs for the detector, so it should still appear in the lists. + assert lpd_cd.aggregator_names == [f"LPD{m:02}" for m in range(16)] + assert lpd_cd.module_nums == list(range(16)) + assert lpd_cd.qm_names == [f"Q{(m // 4) + 1}M{(m % 4) + 1}" for m in range(16)] + + # When we look at a specific constant, module LPD05 is missing + modnos_w_constant = list(range(0, 5)) + list(range(6, 16)) + assert lpd_cd["Offset"].module_nums == modnos_w_constant + + # Test CalibrationData.require_constant() + assert lpd_cd.require_calibrations(["Offset"]).module_nums == modnos_w_constant + + +@pytest.mark.requires_gpfs +def test_AGIPD_CalibrationData_report(): + """Test CalibrationData with data from report""" + # Report ID: https://in.xfel.eu/calibration/reports/3757 + agipd_cd = CalibrationData.from_report(3757) + assert agipd_cd.detector_name == "SPB_DET_AGIPD1M-1" + assert set(agipd_cd) == {"Offset", "Noise", "ThresholdsDark", "BadPixelsDark"} + assert agipd_cd.aggregator_names == [f"AGIPD{n:02}" for n in range(16)] + assert isinstance(agipd_cd["Offset", "AGIPD00"], SingleConstant)