from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Optional, Sequence, Union from calibration_client import CalibrationClient from calibration_client.modules import CalibrationConstantVersion from .calcat_interface import CalCatApi, CalCatError from .tools import module_index_to_qm global_client = None def get_client(): global global_client if global_client is None: setup_client("http://exflcalproxy:8080/", None, None, None) return global_client def setup_client(base_url, client_id, client_secret, user_email, **kwargs): global global_client global_client = CalibrationClient( use_oauth2=(client_id is not None), client_id=client_id, client_secret=client_secret, user_email=user_email, base_api_url=f"{base_url}/api/", token_url=f"{base_url}/oauth/token", refresh_url=f"{base_url}/oauth/token", auth_url=f"{base_url}/oauth/authorize", scope="", **kwargs, ) @dataclass class SingleConstantVersion: """A Calibration Constant Version for 1 detector module""" id: int version_name: str constant_id: int constant_name: str condition_id: int path: Path dataset: str begin_validity_at: datetime end_validity_at: datetime raw_data_location: str physical_name: str # PDU name @classmethod def from_response(cls, ccv: dict) -> "SingleConstantVersion": const = ccv["calibration_constant"] return cls( id=ccv["id"], version_name=ccv["name"], constant_id=const["id"], constant_name=const["name"], condition_id=const["condition_id"], path=Path(ccv["path_to_file"]) / ccv["file_name"], dataset=ccv["data_set_name"], begin_validity_at=ccv["begin_validity_at"], end_validity_at=ccv["end_validity_at"], raw_data_location=ccv["raw_data_location"], physical_name=ccv["physical_detector_unit"]["physical_name"], ) @dataclass class ModulesConstantVersions: """A group of similar CCVs for several modules of one detector""" constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00' def select_modules(self, *aggregators) -> "ModulesConstantVersions": d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} return ModulesConstantVersions(d) @property def aggregators(self): return sorted(self.constants) @property def module_nums(self): return [int(da[-2:]) for da in self.aggregators] @property def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] class CalibrationData(Mapping): """Collected constants for a given detector""" def __init__(self, constant_groups: Dict[str, Dict[str, SingleConstantVersion]]): self.constant_groups = { const_type: ModulesConstantVersions(d) for const_type, d in constant_groups.items() } @classmethod def from_condition( cls, condition: "ConditionsBase", detector_name, modules: Optional[Sequence[str]] = None, calibrations=None, client=None, event_at=None, pdu_snapshot_at=None, ): if calibrations is None: calibrations = set(condition.calibration_types) cal_types_by_params_used = {} for cal_type, params in condition.calibration_types.items(): if cal_type in calibrations: cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type) api = CalCatApi(client or get_client()) d = {} for params, cal_types in cal_types_by_params_used.items(): condition_dict = condition.make_dict(params) cal_id_map = { api.calibration_id(calibration): calibration for calibration in calibrations } calibration_ids = list(cal_id_map.keys()) query_res = api._closest_ccv_by_time_by_condition( detector_name, calibration_ids, condition_dict, modules[0] if len(modules) == 1 else "", event_at, pdu_snapshot_at or event_at, ) for ccv in query_res: aggr = ccv["physical_detector_unit"]["karabo_da"] cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]] d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response( ccv ) res = cls(d) if modules: res = res.select_modules(modules) return res @classmethod def from_report( cls, report_id_or_path: Union[int, str], client=None, ): client = client or get_client() api = CalCatApi(client) if isinstance(report_id_or_path, int): resp = CalibrationConstantVersion.get_by_report_id( client, report_id_or_path ) else: resp = CalibrationConstantVersion.get_by_report_path( client, report_id_or_path ) if not resp["success"]: raise CalCatError(resp) d = {} for ccv in resp["data"]: aggr = ccv["physical_detector_unit"]["karabo_da"] cal_type = api.calibration_name( ccv["calibration_constant"]["calibration_id"] ) d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) return cls(d) def __getitem__(self, key) -> ModulesConstantVersions: return self.constant_groups[key] def __iter__(self): return iter(self.constant_groups) def __len__(self): return len(self.constant_groups) @property def aggregators(self): names = set() for mcv in self.constant_groups.values(): names.update(mcv.aggregators) return sorted(names) @property def module_nums(self): return [int(da[-2:]) for da in self.aggregators] @property def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] def require_calibrations(self, *calibrations): """Drop any modules missing the specified constant types""" mods = set(self.aggregators) for cal_type in calibrations: mods.intersection_update(self[cal_type].constants) return self.select_modules(mods) def select_calibrations(self, *calibrations, require_all=True): if require_all: missing = set(calibrations) - set(self.constant_groups) if missing: raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}") d = { cal_type: mcv.constants for (cal_type, mcv) in self.constant_groups.items() if cal_type in calibrations } # TODO: missing for some modules? return type(self)(d) def select_modules(self, *aggregators): return type(self)( { cal_type: mcv.select_modules(*aggregators).constants for (cal_type, mcv) in self.constant_groups.items() } ) def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData": d = {} for cal_type, mcv in self.constant_groups.items(): d[cal_type] = mcv.constants.copy() for other in others: if cal_type in other: d[cal_type].update(other[cal_type].constants) return type(self)(d) class ConditionsBase: calibration_types = {} # For subclasses: {calibration: [parameter names]} def make_dict(self, parameters) -> dict: d = dict() for db_name in parameters: value = getattr(self, db_name.lower().replace(" ", "_")) if value is not None: d[db_name] = value return d @dataclass class AGIPDConditions(ConditionsBase): sensor_bias_voltage: float memory_cells: int acquisition_rate: float gain_setting: Optional[int] gain_mode: Optional[int] source_energy: float integration_time: int = 12 pixels_x: int = 512 pixels_y: int = 128 _gain_parameters = [ "Sensor Bias Voltage", "Pixels X", "Pixels Y", "Memory cells", "Acquisition rate", "Gain setting", "Integration time", ] _other_dark_parameters = _gain_parameters + ["Gain mode"] _illuminated_parameters = _gain_parameters + ["Source energy"] calibration_types = { "Offset": _other_dark_parameters, "Noise": _other_dark_parameters, "ThresholdsDark": _other_dark_parameters, "BadPixelsDark": _other_dark_parameters, "BadPixelsPC": _gain_parameters, "SlopesPC": _gain_parameters, "BadPixelsFF": _illuminated_parameters, "SlopesFF": _illuminated_parameters, } def make_dict(self, parameters): cond = super().make_dict(parameters) # Fix-up some database quirks. if int(cond.get("Gain mode", -1)) == 0: del cond["Gain mode"] if int(cond.get("Integration time", -1)) == 12: del cond["Integration time"] return cond @dataclass class LPDConditions(ConditionsBase): sensor_bias_voltage: float memory_cells: int memory_cell_order: Optional[str] = None feedback_capacitor: float = 5.0 source_energy: float = 9.2 category: int = 1 pixels_x: int = 256 pixels_y: int = 256 _base_params = [ "Sensor Bias Voltage", "Memory cells", "Pixels X", "Pixels Y", "Feedback capacitor", ] _dark_parameters = _base_params + [ "Memory cell order", ] _illuminated_parameters = _base_params + ["Source Energy", "category"] calibration_types = { "Offset": _dark_parameters, "Noise": _dark_parameters, "BadPixelsDark": _dark_parameters, "RelativeGain": _illuminated_parameters, "GainAmpMap": _illuminated_parameters, "FFMap": _illuminated_parameters, "BadPixelsFF": _illuminated_parameters, }