From cdde441db35625423883c4442d229d5aa526da40 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Thu, 4 Jul 2024 15:09:02 +0200
Subject: [PATCH] feat: Introduce new module for injecting new calibration
 constant version

---
 src/cal_tools/inject.py         | 436 ++++++++++++++++++++++++++++++++
 src/cal_tools/restful_config.py |   2 +-
 2 files changed, 437 insertions(+), 1 deletion(-)
 create mode 100644 src/cal_tools/inject.py

diff --git a/src/cal_tools/inject.py b/src/cal_tools/inject.py
new file mode 100644
index 000000000..57ee8d9b7
--- /dev/null
+++ b/src/cal_tools/inject.py
@@ -0,0 +1,436 @@
+'''
+1- Set Condition
+    a. /parameters [GET]  [DONE]
+    b. /conditions/set_expected_condition [POST]
+2. Set CC
+    a. /calibrations [GET]
+    b. /detector_types [GET]
+    c. /calibration_constants/{id} [GET]
+    d. /calibration_constants [POST]
+3. Set Report
+    a. /reports  [GET]
+    b. /reports/set [POST]
+4. Set CCV
+    a. /physical_detector_units [GET]
+    b. /calibration_constant_versions [POST]
+'''
+import binascii
+import json
+import time
+from dataclasses import asdict, dataclass, field
+from datetime import datetime, timezone
+from hashlib import md5
+from pathlib import Path
+from shutil import copyfile
+from struct import pack, unpack
+from typing import List, Optional, Union
+from urllib.parse import urljoin
+from functools import lru_cache
+
+import h5py
+
+from cal_tools.calcat_interface2 import (
+    CalCatAPIClient,
+    CalCatAPIError,
+    _get_default_caldb_root,
+)
+from cal_tools.restful_config import extra_calibration_client
+import requests
+import pprint
+
+CALIBRATION_CONSTANT_VERSIONS = "calibration_constant_versions"
+
+
+class InjectAPIError(CalCatAPIError):
+    ...
+
+
+class CCVAlreadyInjectedError(InjectAPIError):
+    ...
+
+
+@dataclass
+class ParameterConditionAttribute:
+    parameter_id: int
+    value: str
+    lower_deviation_value: float = 0
+    upper_deviation_value: float = 0
+    flg_available: str = 'false'
+    description: str = ''
+
+
+@dataclass
+class Condition:
+    name: str
+    parameters_conditions_attributes: List[ParameterConditionAttribute] = field(default_factory=list)
+    flg_available: str = 'false'
+    event_at: str = str(datetime.today())  # TODO: Why is this needed? it is not written in swagger.
+    description: str = ''
+
+
+@dataclass
+class ConditionRequest:
+    condition: Condition
+
+
+@dataclass
+class CalibrationConstant:
+    name: str
+    detector_type_id: int
+    calibration_id: int
+    condition_id: int
+    flg_auto_approve: str = 'false'
+    flg_available: str = 'false'
+    description: str = ""
+
+
+@dataclass
+class Report:
+    name: str
+    file_path: str
+    flg_available: str = 'false'
+    description: str = ""
+
+
+@dataclass
+class CalibrationConstantVersion:
+    name: str
+    file_name: str
+    path_to_file: str
+    data_set_name: str
+    calibration_constant_id: int
+    physical_detector_unit_id: int
+    begin_validity_at: str
+    end_validity_at: str
+    begin_at: str
+    start_idx: str = '0'
+    end_idx: str = '0'
+    flg_deployed: str = 'true'
+    flg_good_quality: str = 'true'
+    raw_data_location: str = ''
+    report_id: Optional[int] = None
+    description: str = ''
+
+
+def generate_unique_name(detector_type, pdu_name, pdu_uuid, cond_params):
+    # Generate condition name.
+    unique_name = detector_type[:detector_type.index('-Type')] + ' Def'
+    cond_hash = md5(pdu_name.encode())
+    cond_hash.update(int(pdu_uuid).to_bytes(
+        length=8, byteorder='little', signed=False))
+
+    for pname, pattrs in cond_params.items():
+        cond_hash.update(pname.encode())
+        cond_hash.update(str(pattrs.value).encode())
+
+    unique_name += binascii.b2a_base64(cond_hash.digest()).decode()
+    return unique_name[:60]
+
+
+def create_unique_cc_name(det_type, calibration, condition_name):
+    """
+    Generating CC name from condition name,
+    detector type, and calibration name.
+    """
+    cc_name_1 = f'{det_type}_{calibration}'
+    return f'{cc_name_1[:40]}_{condition_name}'  # I guess there is a limit to the name characters?
+
+
+def create_unique_ccv_name(start_idx):
+    # Generate unique name if it doesn't exist
+    datetime_str = datetime_str = datetime.now(
+        timezone.utc).strftime('%Y%m%d_%H%M%S')
+    return f'{datetime_str}_sIdx={start_idx}'
+
+
+def get_raw_data_location(proposal, runs):
+    if proposal and len(runs) > 0:
+        return (
+            f'proposal:{proposal} runs: {" ".join([str(x) for x in runs])}')
+    else:
+        return ""  # Fallback for non-run based constants
+
+
+def _failed_response(resp):
+    # TODO: Add more variant errors.
+    if CALIBRATION_CONSTANT_VERSIONS in resp.url.lstrip("/"):
+        if resp.status_code == 422:
+            raise CCVAlreadyInjectedError
+            warn(
+                "Calibration Constant Version has already been injected: "
+                f"{dir(resp)}\n{resp.reason}-{resp.status_code}-{resp.text}-{resp.url}", CCVAlreadyInjectedWarning)
+    elif resp.status_code >= 400:
+        try:
+            d = json.loads(resp.content.decode("utf-8"))
+        except Exception:
+            resp.raise_for_status()
+        else:
+            raise InjectAPIError(
+                f"Error {resp.status_code} from API: "
+                f"{d.get('info', 'missing details')}"
+            )
+
+
+class InjectAPI(CalCatAPIClient):
+    def __init__(self, base_api_url, oauth_client=None, user_email=""):
+        super().__init__(
+            base_api_url=base_api_url,
+            oauth_client=oauth_client,
+            user_email=user_email
+        )
+
+    @staticmethod
+    def _parse_post_response(resp: requests.Response):
+
+        if resp.status_code >= 400:
+            _failed_response(resp)
+
+        if resp.content == b"":
+            return None
+        else:
+            return json.loads(resp.content.decode("utf-8"))
+
+    # ------------------
+    # Cached wrappers for simple ID lookups of fixed-ish info
+    #
+    # N.B. lru_cache behaves oddly with instance methods (it's a global cache,
+    # with the instance as part of the key), but in this case it should be OK.
+    @lru_cache
+    def get_by_name(self, endpoint, name, name_key="name"):
+        res = self.get(endpoint, {name_key: name})
+        if not res:
+            raise KeyError(f"No {endpoint[:-1]} with name {name}")
+        elif len(res) > 1:
+            raise ValueError(f"Multiple {endpoint} found with name {name}")
+        return res[0]
+
+    @lru_cache
+    def parameter_by_name(self, name):
+        return self.get_by_name("parameters", name)
+
+    @lru_cache
+    def calibration_by_name(self, name):
+        return self.get_by_name("calibrations", name)
+
+    @lru_cache
+    def detector_type_by_name(self, name):
+        return self.get_by_name("detector_types", name)
+
+    @lru_cache
+    def pdu_by_name(self, name):
+        return self.get_by_name(
+            "physical_detector_units", name, name_key="physical_name")
+
+    @lru_cache
+    def report(self, report: Report):
+        return self.get("reports", asdict(report))
+
+    @lru_cache
+    def get_calibration_constant(
+        self, calibration_constant: CalibrationConstant):
+        return self.get(
+            f"calibrations/{calibration_constant.calibration_id}/get_calibration_constant",
+            asdict(calibration_constant)
+        )
+
+    def post_request(self, relative_url, data=None, headers=None, **kwargs):
+        """Make a POST request, return the HTTP response object"""
+        # Base URL may include e.g. '/api/'. This is a prefix for all URLs;
+        # even if they look like an absolute path.
+        url = urljoin(self.base_api_url, relative_url.lstrip("/"))
+        _headers = self.default_headers()
+        if headers:
+            _headers.update(headers)
+        return self.session.post(url, data=json.dumps(data), headers=_headers, **kwargs)
+
+    def post(self, relative_url, params=None, **kwargs):
+        """Make a POST request, return response content from JSON"""
+        resp = self.post_request(relative_url, params, **kwargs)
+        return self._parse_post_response(resp)
+
+    def set_expected_condition(self, conditions: ConditionRequest):
+        return self.post("conditions/set_expected_condition", asdict(conditions))
+
+    def create_calibration_constant(
+        self, calibration_constant: CalibrationConstant):
+        return self.post("calibration_constants", asdict(calibration_constant))
+
+    def create_report(self, report: Report):
+        # Based on create or get API
+        return self.post("reports/set", asdict(report))
+
+    def create_calibration_constant_version(
+        self, ccv: CalibrationConstantVersion):
+        return self.post(CALIBRATION_CONSTANT_VERSIONS, asdict(ccv))
+
+    def create_new_ccv(
+        self,
+        cond_params,
+        begin_at,
+        proposal,
+        runs,
+        pdu_name,
+        pdu_uuid,
+        detector_type,
+        calibration,
+        const_rel_path,
+        const_filename,
+        ccv_root,
+        report_to=None,
+        
+    ):
+        """Inject new CCV into CalCat."""
+
+        cond_name = generate_unique_name(
+            detector_type, pdu_name, pdu_uuid, cond_params)
+
+        # Create condition table in database, if not available.
+        resp = self.set_condition(cond_name, list(cond_params.values()))
+
+        # Prepare some parameters to set Calibration Constant.
+        condition_id = resp['id']
+        cal_id = self.calibration_by_name(calibration)['id']
+        det_type_id = self.detector_type_by_name(detector_type)['id']
+        cc_name = create_unique_cc_name(detector_type, calibration, cond_name)
+
+        # Create Calibration Constant in database, if not available.
+        resp = self.get_or_create_calibration_constant(
+            cc_name, cal_id, det_type_id, condition_id)
+        cc_id = resp["id"]
+
+        # Create report in database, if not available
+        report_id = None
+        if report_to:
+            report_path = Path(report_to).absolute().with_suffix('.pdf')
+            resp = self.get_or_create_report(name=report_path.stem, file_path=str(report_path))
+            report_id = resp["id"]
+
+        # Get PDU ID before creating new CCV.
+        resp = self.pdu_by_name(pdu_name)
+        pdu_id = resp["id"]
+        start_idx = 0
+        ccv = CalibrationConstantVersion(
+            name=create_unique_ccv_name(start_idx),
+            file_name=const_filename,
+            path_to_file=const_rel_path,
+            data_set_name=ccv_root,
+            calibration_constant_id=cc_id,
+            physical_detector_unit_id=pdu_id,
+            raw_data_location=get_raw_data_location(proposal, runs),
+            report_id=report_id,
+            begin_validity_at=begin_at,
+            end_validity_at='',
+            begin_at=begin_at,
+            start_idx=start_idx
+        )
+
+        return self.create_calibration_constant_version(ccv)
+
+    def set_condition(self, name, parameter_conditions):
+        cond = Condition(
+            name=name, parameters_conditions_attributes=parameter_conditions)
+        return self.set_expected_condition(ConditionRequest(cond))
+
+    def get_or_create_calibration_constant(
+        self, name: str, cal_id: int, det_type_id: int, condition_id: int):
+        cond_id = condition_id
+
+        calibration_constant = CalibrationConstant(
+            name=name,
+            calibration_id=cal_id,
+            condition_id=cond_id,
+            detector_type_id=det_type_id
+        )
+        
+        resp = self.get_calibration_constant(calibration_constant)
+        return resp if resp else self.create_calibration_constant(
+            calibration_constant)
+
+    def get_or_create_report(self, name: str, file_path: str):
+        report = Report(name=name, file_path=file_path)
+        resp = self.report(report)
+        # TODO: confirm if this create_report isn't already enough (it does get_or_create also)?
+        # In case report hasn't been created still.
+        return resp if resp else self.create_report(report)
+
+    def set_calibration_constant_version(self, ccv: CalibrationConstantVersion):
+        return self.create_calibration_constant_version(ccv)
+
+    def get_ccv_info_from_file(self, cfile: Union[str, Path], pdu: str, ccv_root: str):
+        def to_string(value):
+            """Send only accepted value types to CALCAT."""
+            if isinstance(value, bool):
+                value = float(value)
+            return str(value)
+
+        with h5py.File(cfile, 'r') as const_file:
+            pdu_group = const_file[pdu]
+            pdu_uuid = pdu_group.attrs['uuid']
+            detector_type = pdu_group.attrs['detector_type']
+            ccv_group = const_file[ccv_root]
+
+            proposal, runs = ccv_group.attrs['proposal'], ccv_group.attrs['runs']
+            begin_at = ccv_group.attrs['begin_at']
+
+            cond_params = {}
+            condition_group = ccv_group['operating_condition']
+            # It's really not ideal we're mixing conditionS and condition now.
+            # Get parameter data and create a list of `ParameterConditionAttribute`s
+            for parameter in condition_group:
+                param_dset = condition_group[parameter]
+                param_name = param_dset.attrs['database_name']
+                cond_params[param_name] = ParameterConditionAttribute(
+                        parameter_id=self.parameter_by_name(param_name)['id'],
+                        value=to_string(param_dset[()]),
+                        lower_deviation_value=param_dset.attrs['lower_deviation'],
+                        upper_deviation_value=param_dset.attrs['upper_deviation'],
+                    )
+            det_uuid = 'Detector UUID'
+            # Add PDU "UUID" to parameters.
+            cond_params[det_uuid] = ParameterConditionAttribute(
+                value=to_string(unpack('d', pack('q', pdu_uuid))[0]),
+                parameter_id=self.parameter_by_name(det_uuid)['id'],
+            )
+
+        return cond_params, begin_at, proposal, runs, pdu_uuid, detector_type
+
+
+def inject_ccv(
+    const_src: Union[str, Path],
+    ccv_root: str,  # pdu/calibration/key
+    report_to: Optional[Union[str, Path]] = None,
+    client: Optional[InjectAPI] = None,
+):
+    if client is None:
+        client = extra_calibration_client(inject=True)
+        
+    pdu_name, calibration, key = ccv_root.lstrip('/').split('/')
+
+    (
+        cond_params, begin_at, proposal, runs, pdu_uuid, detector_type
+    ) = client.get_ccv_info_from_file(const_src, pdu_name, ccv_root)
+
+    const_rel_path = f'xfel/cal/{detector_type.lower()}/{pdu_name.lower()}'
+    const_filename = f'cal.{time.time()}.h5'
+    const_dest = _get_default_caldb_root() / const_rel_path / const_filename
+    const_dest.parent.mkdir(parents=True, exist_ok=True)
+    copyfile(const_src, const_dest)
+
+    resp = client.create_new_ccv(
+        cond_params,
+        begin_at,
+        proposal,
+        runs,
+        pdu_name,
+        pdu_uuid,
+        detector_type,
+        calibration,
+        const_rel_path,
+        const_filename,
+        ccv_root,
+        report_to,
+    )
+
+    if not resp['success']:
+        const_dest.unlink()
+        raise RuntimeError(resp)
diff --git a/src/cal_tools/restful_config.py b/src/cal_tools/restful_config.py
index 671dec619..ee00554fa 100644
--- a/src/cal_tools/restful_config.py
+++ b/src/cal_tools/restful_config.py
@@ -44,7 +44,7 @@ def calibration_client():
         scope='')
 
 
-def extra_calibration_client():
+def extra_calibration_client(inject=False):
     """Obtain an initialized CalCatAPIClient object."""
 
     from cal_tools import calcat_interface2
-- 
GitLab