from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Optional, Sequence, Union import h5py import pasha as psh from calibration_client import CalibrationClient from calibration_client.modules import CalibrationConstantVersion from .calcat_interface import CalCatApi, CalCatError from .tools import module_index_to_qm class ModuleNameError(KeyError): def __init__(self, name): self.name = name def __str__(self): return f"No module named {self.name!r}" global_client = None def get_client(): global global_client if global_client is None: setup_client("http://exflcalproxy:8080/", None, None, None) return global_client def setup_client(base_url, client_id, client_secret, user_email, **kwargs): global global_client global_client = CalibrationClient( use_oauth2=(client_id is not None), client_id=client_id, client_secret=client_secret, user_email=user_email, base_api_url=f"{base_url}/api/", token_url=f"{base_url}/oauth/token", refresh_url=f"{base_url}/oauth/token", auth_url=f"{base_url}/oauth/authorize", scope="", **kwargs, ) _default_caldb_root = ... def _get_default_caldb_root(): global _default_caldb_root if _default_caldb_root is ...: onc_path = Path("/common/cal/caldb_store") maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store") if onc_path.is_dir(): _default_caldb_root = onc_path elif maxwell_path.is_dir(): _default_caldb_root = maxwell_path else: _default_caldb_root = None return _default_caldb_root @dataclass class SingleConstantVersion: """A Calibration Constant Version for 1 detector module""" id: int version_name: str constant_id: int constant_name: str condition_id: int path: Path dataset: str begin_validity_at: datetime end_validity_at: datetime raw_data_location: str physical_name: str # PDU name @classmethod def from_response(cls, ccv: dict) -> "SingleConstantVersion": const = ccv["calibration_constant"] return cls( id=ccv["id"], version_name=ccv["name"], constant_id=const["id"], constant_name=const["name"], condition_id=const["condition_id"], path=Path(ccv["path_to_file"]) / ccv["file_name"], dataset=ccv["data_set_name"], begin_validity_at=ccv["begin_validity_at"], end_validity_at=ccv["end_validity_at"], raw_data_location=ccv["raw_data_location"], physical_name=ccv["physical_detector_unit"]["physical_name"], ) def dataset_obj(self, caldb_root=None): if caldb_root is not None: caldb_root = Path(caldb_root) else: caldb_root = _get_default_caldb_root() f = h5py.File(caldb_root / self.path, "r") return f[self.dataset]["data"] def ndarray(self, caldb_root=None): return self.dataset_obj(caldb_root)[:] @dataclass class ModulesConstantVersions: """A group of similar CCVs for several modules of one detector""" constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00' def select_modules(self, aggregators) -> "ModulesConstantVersions": d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} return ModulesConstantVersions(d) @property def aggregators(self): return sorted(self.constants) @property def module_nums(self): return [int(da[-2:]) for da in self.aggregators] @property def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] class CalibrationData(Mapping): """Collected constants for a given detector""" def __init__(self, constant_groups, aggregators): self.constant_groups = { const_type: ModulesConstantVersions(d) for const_type, d in constant_groups.items() } self.aggregators = aggregators @classmethod def from_condition( cls, condition: "ConditionsBase", detector_name, modules: Optional[Sequence[str]] = None, calibrations=None, client=None, event_at=None, pdu_snapshot_at=None, ): if calibrations is None: calibrations = set(condition.calibration_types) if pdu_snapshot_at is None: pdu_snapshot_at = event_at cal_types_by_params_used = {} for cal_type, params in condition.calibration_types.items(): if cal_type in calibrations: cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type) api = CalCatApi(client or get_client()) detector_id = api.detector(detector_name)["id"] all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at) if modules is None: modules = sorted(all_modules) else: modules = sorted(modules) for m in modules: if m not in all_modules: raise ModuleNameError(m) d = {} for params, cal_types in cal_types_by_params_used.items(): condition_dict = condition.make_dict(params) cal_id_map = { api.calibration_id(calibration): calibration for calibration in cal_types } calibration_ids = list(cal_id_map.keys()) query_res = api._closest_ccv_by_time_by_condition( detector_name, calibration_ids, condition_dict, modules[0] if len(modules) == 1 else "", event_at, pdu_snapshot_at or event_at, ) for ccv in query_res: aggr = ccv["physical_detector_unit"]["karabo_da"] cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]] d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response( ccv ) res = cls(d, modules) if modules: res = res.select_modules(modules) return res @classmethod def from_report( cls, report_id_or_path: Union[int, str], client=None, ): client = client or get_client() api = CalCatApi(client) if isinstance(report_id_or_path, int): resp = CalibrationConstantVersion.get_by_report_id( client, report_id_or_path ) else: resp = CalibrationConstantVersion.get_by_report_path( client, report_id_or_path ) if not resp["success"]: raise CalCatError(resp) d = {} aggregators = set() for ccv in resp["data"]: aggr = ccv["physical_detector_unit"]["karabo_da"] aggregators.add(aggr) cal_type = api.calibration_name( ccv["calibration_constant"]["calibration_id"] ) d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv) return cls(d, sorted(aggregators)) def __getitem__(self, key) -> ModulesConstantVersions: return self.constant_groups[key] def __iter__(self): return iter(self.constant_groups) def __len__(self): return len(self.constant_groups) @property def module_nums(self): return [int(da[-2:]) for da in self.aggregators] @property def qm_names(self): return [module_index_to_qm(n) for n in self.module_nums] def require_calibrations(self, calibrations): """Drop any modules missing the specified constant types""" mods = set(self.aggregators) for cal_type in calibrations: mods.intersection_update(self[cal_type].constants) return self.select_modules(mods) def select_calibrations(self, calibrations, require_all=True): if require_all: missing = set(calibrations) - set(self.constant_groups) if missing: raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}") d = { cal_type: mcv.constants for (cal_type, mcv) in self.constant_groups.items() if cal_type in calibrations } # TODO: missing for some modules? return type(self)(d, self.aggregators) def select_modules(self, aggregators): return type(self)( { cal_type: mcv.select_modules(aggregators).constants for (cal_type, mcv) in self.constant_groups.items() }, sorted(aggregators), ) def merge(self, *others: "CalibrationData") -> "CalibrationData": d = {} for cal_type, mcv in self.constant_groups.items(): d[cal_type] = mcv.constants.copy() for other in others: if cal_type in other: d[cal_type].update(other[cal_type].constants) aggregators = set(self.aggregators) for other in others: aggregators.update(other.aggregators) return type(self)(d, sorted(aggregators)) def load_all(self, caldb_root=None): res = {} const_load_mp = psh.ProcessContext(num_workers=24) keys = [] for cal_type, mcv in self.constant_groups.items(): res[cal_type] = {} for module in mcv.aggregators: dset = mcv.constants[module].dataset_obj(caldb_root) res[cal_type][module] = const_load_mp.alloc( shape=dset.shape, dtype=dset.dtype ) keys.append((cal_type, module)) def _load_constant_dataset(wid, index, key): cal_type, mod = key dset = self[cal_type].constants[mod].dataset_obj(caldb_root) dset.read_direct(res[cal_type][mod]) const_load_mp.map(_load_constant_dataset, keys) class ConditionsBase: calibration_types = {} # For subclasses: {calibration: [parameter names]} def make_dict(self, parameters) -> dict: d = dict() for db_name in parameters: value = getattr(self, db_name.lower().replace(" ", "_")) if value is not None: d[db_name] = value return d @dataclass class AGIPDConditions(ConditionsBase): sensor_bias_voltage: float memory_cells: int acquisition_rate: float gain_setting: Optional[int] gain_mode: Optional[int] source_energy: float integration_time: int = 12 pixels_x: int = 512 pixels_y: int = 128 _gain_parameters = [ "Sensor Bias Voltage", "Pixels X", "Pixels Y", "Memory cells", "Acquisition rate", "Gain setting", "Integration time", ] _other_dark_parameters = _gain_parameters + ["Gain mode"] _illuminated_parameters = _gain_parameters + ["Source energy"] calibration_types = { "Offset": _other_dark_parameters, "Noise": _other_dark_parameters, "ThresholdsDark": _other_dark_parameters, "BadPixelsDark": _other_dark_parameters, "BadPixelsPC": _gain_parameters, "SlopesPC": _gain_parameters, "BadPixelsFF": _illuminated_parameters, "SlopesFF": _illuminated_parameters, } def make_dict(self, parameters): cond = super().make_dict(parameters) # Fix-up some database quirks. if int(cond.get("Gain mode", -1)) == 0: del cond["Gain mode"] if int(cond.get("Integration time", -1)) == 12: del cond["Integration time"] return cond @dataclass class LPDConditions(ConditionsBase): sensor_bias_voltage: float memory_cells: int memory_cell_order: Optional[str] = None feedback_capacitor: float = 5.0 source_energy: float = 9.2 category: int = 1 pixels_x: int = 256 pixels_y: int = 256 _base_params = [ "Sensor Bias Voltage", "Memory cells", "Pixels X", "Pixels Y", "Feedback capacitor", ] _dark_parameters = _base_params + [ "Memory cell order", ] _illuminated_parameters = _base_params + ["Source Energy", "category"] calibration_types = { "Offset": _dark_parameters, "Noise": _dark_parameters, "BadPixelsDark": _dark_parameters, "RelativeGain": _illuminated_parameters, "GainAmpMap": _illuminated_parameters, "FFMap": _illuminated_parameters, "BadPixelsFF": _illuminated_parameters, }