diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index ff4b893110bd278d901d4da54f2481ea1a27fd94..c8f786e8c9719e8500464da656486251c2360c13 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -1,17 +1,17 @@ +import re from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Dict, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import h5py import numpy as np import pasha as psh from calibration_client import CalibrationClient -from calibration_client.modules import CalibrationConstantVersion +from calibration_client.modules import CalibrationConstantVersion, PhysicalDetectorUnit from .calcat_interface import CalCatApi, CalCatError -from .tools import module_index_to_qm class ModuleNameError(KeyError): @@ -117,22 +117,52 @@ class ModulesConstantVersions: """A group of similar CCVs for several modules of one detector""" constants: Dict[str, SingleConstantVersion] # Keys e.g. 'LPD00' + module_details: List[Dict] + + def select_modules( + self, module_nums=None, *, aggregators=None, qm_names=None + ) -> "ModulesConstantVersions": + n_specified = sum([ + module_nums is not None, + aggregators is not None, + qm_names is not None + ]) + if n_specified < 1: + raise TypeError("select_modules() requires an argument") + elif n_specified > 1: + raise TypeError( + "select_modules() accepts only one of module_nums, aggregators & qm_names" + ) + + if module_nums is not None: + by_mod_no = {m['module_number']: m for m in self.module_details} + aggregators = [by_mod_no[n]['karabo_da'] for n in module_nums] + elif qm_names is not None: + by_qm = {m['virtual_device_name']: m for m in self.module_details} + aggregators = [by_qm[s]['karabo_da'] for s in qm_names] + elif aggregators is not None: + miss = set(aggregators) - {m['karabo_da'] for m in self.module_details} + if miss: + raise KeyError("Aggregators not found: " + ', '.join(sorted(miss))) - def select_modules(self, aggregators) -> "ModulesConstantVersions": d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators} - return ModulesConstantVersions(d) + return ModulesConstantVersions(d, self.module_details) + # These properties label only the modules we have constants for, which may + # be a subset of what's in module_details @property def aggregators(self): return sorted(self.constants) @property def module_nums(self): - return [int(da[-2:]) for da in self.aggregators] + return [m['module_number'] for m in self.module_details + if m['karabo_da'] in self.constants] @property def qm_names(self): - return [module_index_to_qm(n) for n in self.module_nums] + return [m['virtual_device_name'] for m in self.module_details + if m['karabo_da'] in self.constants] def ndarray(self, caldb_root=None): eg_dset = self.constants[self.aggregators[0]].dataset_obj(caldb_root) @@ -168,19 +198,18 @@ class ModulesConstantVersions: class CalibrationData(Mapping): """Collected constants for a given detector""" - def __init__(self, constant_groups, aggregators): + def __init__(self, constant_groups, module_details): self.constant_groups = { const_type: ModulesConstantVersions(d) for const_type, d in constant_groups.items() } - self.aggregators = aggregators + self.module_details = module_details @classmethod def from_condition( cls, condition: "ConditionsBase", detector_name, - modules: Optional[Sequence[str]] = None, calibrations=None, client=None, event_at=None, @@ -199,14 +228,15 @@ class CalibrationData(Mapping): api = CalCatApi(client or get_client()) detector_id = api.detector(detector_name)["id"] - all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at) - if modules is None: - modules = sorted(all_modules) - else: - modules = sorted(modules) - for m in modules: - if m not in all_modules: - raise ModuleNameError(m) + resp_pdus = PhysicalDetectorUnit.get_all_by_detector( + api.client, detector_id, api.format_time(pdu_snapshot_at) + ) + if not resp_pdus["success"]: + raise CalCatError(resp_pdus) + module_details = sorted(resp_pdus["data"], key=lambda d: d['karabo_da']) + for mod in module_details: + if mod.get('module_number', -1) < 0: + mod['module_number'] = int(re.findall(r'\d+', mod['karabo_da'])[-1]) d = {} @@ -223,7 +253,7 @@ class CalibrationData(Mapping): detector_name, calibration_ids, condition_dict, - modules[0] if len(modules) == 1 else "", + "", event_at, pdu_snapshot_at or event_at, ) @@ -236,10 +266,7 @@ class CalibrationData(Mapping): ccv ) - res = cls(d, modules) - if modules: - res = res.select_modules(modules) - return res + return cls(d, module_details) @classmethod def from_report( @@ -286,15 +313,26 @@ class CalibrationData(Mapping): def __repr__(self): return (f"<CalibrationData: {', '.join(sorted(self.constant_groups))} " - f"constants for {len(self.aggregators)} modules>") + f"constants for {len(self.module_details)} modules>") + # These properties may include modules for which we have no constants - + # when created with .from_condition(), they represent all modules present in + # the detector (at the specified time). @property def module_nums(self): - return [int(da[-2:]) for da in self.aggregators] + return [m['module_number'] for m in self.module_details] + + @property + def aggregator_names(self): + return [m['karabo_da'] for m in self.module_details] @property def qm_names(self): - return [module_index_to_qm(n) for n in self.module_nums] + return [m['virtual_device_name'] for m in self.module_details] + + @property + def pdu_names(self): + return [m['physical_name'] for m in self.module_details] def require_calibrations(self, calibrations): """Drop any modules missing the specified constant types""" @@ -317,10 +355,16 @@ class CalibrationData(Mapping): # TODO: missing for some modules? return type(self)(d, self.aggregators) - def select_modules(self, aggregators): + def select_modules( + self, module_nums=None, *, aggregators=None, qm_names=None + ) -> "CalibrationData": return type(self)( { - cal_type: mcv.select_modules(aggregators).constants + cal_type: mcv.select_modules( + module_nums=module_nums, + aggregators=aggregators, + qm_names=qm_names, + ).constants for (cal_type, mcv) in self.constant_groups.items() }, sorted(aggregators),