Skip to content
Snippets Groups Projects
inject.py 13.95 KiB
'''
1- Set Condition
    a. /parameters [GET]  [DONE]
    b. /conditions/set_expected_condition [POST]
2. Set CC
    a. /calibrations [GET]
    b. /detector_types [GET]
    c. /calibration_constants/{id} [GET]
    d. /calibration_constants [POST]
3. Set Report
    a. /reports  [GET]
    b. /reports/set [POST]
4. Set CCV
    a. /physical_detector_units [GET]
    b. /calibration_constant_versions [POST]
'''
import binascii
import json
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from hashlib import md5
from pathlib import Path
from shutil import copyfile
from struct import pack, unpack
from typing import List, Optional, Union
from urllib.parse import urljoin
from functools import lru_cache

import h5py

from cal_tools.calcat_interface2 import (
    CalCatAPIClient,
    CalCatAPIError,
    _get_default_caldb_root,
)
from cal_tools.restful_config import extra_calibration_client
import requests
import pprint

CALIBRATION_CONSTANT_VERSIONS = "calibration_constant_versions"


class InjectAPIError(CalCatAPIError):
    ...


class CCVAlreadyInjectedError(InjectAPIError):
    ...


@dataclass
class ParameterConditionAttribute:
    parameter_id: int
    value: str
    lower_deviation_value: float = 0
    upper_deviation_value: float = 0
    flg_available: str = 'true'
    description: str = ''


@dataclass
class Condition:
    name: str
    parameters_conditions_attributes: List[ParameterConditionAttribute] = field(default_factory=list)
    flg_available: str = 'true'
    event_at: str = str(datetime.today())  # TODO: Why is this needed? it is not written in swagger.
    description: str = ''


@dataclass
class ConditionRequest:
    condition: Condition


@dataclass
class CalibrationConstant:
    name: str
    detector_type_id: int
    calibration_id: int
    condition_id: int
    flg_auto_approve: str = 'true'
    flg_available: str = 'true'
    description: str = ""


@dataclass
class Report:
    name: str
    file_path: str
    flg_available: str = 'true'
    description: str = ""


@dataclass
class CalibrationConstantVersion:
    name: str
    file_name: str
    path_to_file: str
    data_set_name: str
    calibration_constant_id: int
    physical_detector_unit_id: int
    begin_validity_at: str
    end_validity_at: str
    begin_at: str
    start_idx: str = '0'
    end_idx: str = '0'
    flg_deployed: str = 'true'
    flg_good_quality: str = 'true'
    raw_data_location: str = ''
    report_id: Optional[int] = None
    description: str = ''


def generate_unique_cond_name(detector_type, pdu_name, pdu_uuid, cond_params):
    # Generate condition name.
    unique_name = detector_type[:detector_type.index('-Type')] + ' Def'
    cond_hash = md5(pdu_name.encode())
    cond_hash.update(int(pdu_uuid).to_bytes(
        length=8, byteorder='little', signed=False))

    for pname, pattrs in cond_params.items():
        cond_hash.update(pname.encode())
        cond_hash.update(str(pattrs.value).encode())

    unique_name += binascii.b2a_base64(cond_hash.digest()).decode()
    return unique_name[:60]


def create_unique_cc_name(det_type, calibration, condition_name):
    """
    Generating CC name from condition name,
    detector type, and calibration name.
    """
    cc_name_1 = f'{det_type}_{calibration}'
    return f'{cc_name_1[:40]}_{condition_name}'  # I guess there is a limit to the name characters?


def create_unique_ccv_name(start_idx):
    # Generate unique name if it doesn't exist
    datetime_str = datetime_str = datetime.now(
        timezone.utc).strftime('%Y%m%d_%H%M%S')
    return f'{datetime_str}_sIdx={start_idx}'


def get_raw_data_location(proposal, runs):
    if proposal and len(runs) > 0:
        return (
            f'proposal:{proposal} runs: {" ".join([str(x) for x in runs])}')
    else:
        return ""  # Fallback for non-run based constants


def _failed_response(resp):
    # TODO: Add more errors if needed
    if CALIBRATION_CONSTANT_VERSIONS in resp.url.lstrip("/"):
        if resp.status_code == 422:
            raise CCVAlreadyInjectedError
    elif resp.status_code >= 400:
        try:
            d = json.loads(resp.content.decode("utf-8"))
        except Exception:
            resp.raise_for_status()
        else:
            raise InjectAPIError(
                f"Error {resp.status_code} from API: "
                f"{d.get('info', 'missing details')}"
            )


class InjectAPI(CalCatAPIClient):
    def __init__(self, base_api_url, oauth_client=None, user_email=""):
        super().__init__(
            base_api_url=base_api_url,
            oauth_client=oauth_client,
            user_email=user_email
        )

    @staticmethod
    def _parse_post_response(resp: requests.Response):

        if resp.status_code >= 400:
            _failed_response(resp)

        if resp.content == b"":
            return None
        else:
            return json.loads(resp.content.decode("utf-8"))

    # ------------------
    # Cached wrappers for simple ID lookups of fixed-ish info
    #
    # N.B. lru_cache behaves oddly with instance methods (it's a global cache,
    # with the instance as part of the key), but in this case it should be OK.
    @lru_cache
    def get_by_name(self, endpoint, name, name_key="name"):
        res = self.get(endpoint, {name_key: name})
        if not res:
            raise KeyError(f"No {endpoint[:-1]} with name {name}")
        elif len(res) > 1:
            raise ValueError(f"Multiple {endpoint} found with name {name}")
        return res[0]

    @lru_cache
    def parameter_by_name(self, name):
        return self.get_by_name("parameters", name)

    @lru_cache
    def calibration_by_name(self, name):
        return self.get_by_name("calibrations", name)

    @lru_cache
    def detector_type_by_name(self, name):
        return self.get_by_name("detector_types", name)

    @lru_cache
    def pdu_by_name(self, name):
        return self.get_by_name(
            "physical_detector_units", name, name_key="physical_name")

    @lru_cache
    def report(self, report: Report):
        return self.get("reports", asdict(report))

    @lru_cache
    def get_calibration_constant(
        self, calibration_constant: CalibrationConstant):
        return self.get(
            f"calibrations/{calibration_constant.calibration_id}/get_calibration_constant",
            asdict(calibration_constant)
        )

    def post_request(self, relative_url, data=None, headers=None, **kwargs):
        """Make a POST request, return the HTTP response object"""
        # Base URL may include e.g. '/api/'. This is a prefix for all URLs;
        # even if they look like an absolute path.
        url = urljoin(self.base_api_url, relative_url.lstrip("/"))
        _headers = self.default_headers()
        if headers:
            _headers.update(headers)
        return self.session.post(url, data=json.dumps(data), headers=_headers, **kwargs)

    def post(self, relative_url, params=None, **kwargs):
        """Make a POST request, return response content from JSON"""
        resp = self.post_request(relative_url, params, **kwargs)
        return self._parse_post_response(resp)

    def set_expected_condition(self, conditions: ConditionRequest):
        return self.post("conditions/set_expected_condition", asdict(conditions))

    def create_calibration_constant(
        self, calibration_constant: CalibrationConstant):
        return self.post("calibration_constants", asdict(calibration_constant))

    def create_report(self, report: Report):
        # Based on create or get API
        return self.post("reports/set", asdict(report))

    def create_calibration_constant_version(
        self, ccv: CalibrationConstantVersion):
        return self.post(CALIBRATION_CONSTANT_VERSIONS, asdict(ccv))

    def create_new_ccv(
        self,
        cond_params,
        begin_at,
        pdu_name,
        pdu_uuid,
        detector_type,
        calibration,
        const_rel_path,
        const_filename,
        ccv_root,
        report_to=None,
        raw_data_location='',
    ):
        """Inject new CCV into CalCat."""

        cond_name = generate_unique_cond_name(
            detector_type, pdu_name, pdu_uuid, cond_params)

        # Create condition table in database, if not available.
        resp = self.set_condition(cond_name, list(cond_params.values()))
        condition_id = resp['id']
        
        # Prepare some parameters to set Calibration Constant.
        cal_id = self.calibration_by_name(calibration)['id']
        det_type_id = self.detector_type_by_name(detector_type)['id']
        cc_name = create_unique_cc_name(detector_type, calibration, cond_name)

        # Create Calibration Constant in database, if not available.
        resp = self.get_or_create_calibration_constant(
            cc_name, cal_id, det_type_id, condition_id)
        cc_id = resp["id"]

        # Create report in database, if not available
        report_id = None
        if report_to:
            report_path = Path(report_to).absolute().with_suffix('.pdf')
            resp = self.get_or_create_report(
                name=report_path.stem, file_path=str(report_path))
            report_id = resp["id"]

        # Get PDU ID before creating new CCV.
        resp = self.pdu_by_name(pdu_name)
        pdu_id = resp["id"]
        start_idx = 0
        ccv = CalibrationConstantVersion(
            name=create_unique_ccv_name(start_idx),
            file_name=const_filename,
            path_to_file=const_rel_path,
            data_set_name=ccv_root,
            calibration_constant_id=cc_id,
            physical_detector_unit_id=pdu_id,
            raw_data_location=raw_data_location,
            report_id=report_id,
            begin_validity_at=begin_at,
            end_validity_at='',
            begin_at=begin_at,
            start_idx=start_idx
        )

        return self.create_calibration_constant_version(ccv)

    def set_condition(self, name, parameter_conditions):
        cond = Condition(
            name=name, parameters_conditions_attributes=parameter_conditions)
        return self.set_expected_condition(ConditionRequest(cond))

    def get_or_create_calibration_constant(
        self, name: str, cal_id: int, det_type_id: int, condition_id: int):
        cond_id = condition_id

        calibration_constant = CalibrationConstant(
            name=name,
            calibration_id=cal_id,
            condition_id=cond_id,
            detector_type_id=det_type_id
        )
        
        resp = self.get_calibration_constant(calibration_constant)
        return resp if resp else self.create_calibration_constant(
            calibration_constant)

    def get_or_create_report(self, name: str, file_path: str):
        report = Report(name=name, file_path=file_path)
        resp = self.report(report)
        # TODO: confirm if this create_report isn't already enough (it does get_or_create also)?
        # In case report hasn't been created still.
        return resp if resp else self.create_report(report)

    def set_calibration_constant_version(self, ccv: CalibrationConstantVersion):
        return self.create_calibration_constant_version(ccv)


def extract_parameter_conditions(client, ccv_group, pdu_uuid):
    def _to_string(value):
        """Send only accepted value types to CALCAT."""
        if isinstance(value, bool):
            value = float(value)
        return str(value)

    cond_params = {}
    condition_group = ccv_group['operating_condition']
    # It's really not ideal we're mixing conditionS and condition now.
    # Get parameter data and create a list of `ParameterConditionAttribute`s
    for parameter in condition_group:
        param_dset = condition_group[parameter]
        param_name = param_dset.attrs['database_name']
        cond_params[param_name] = ParameterConditionAttribute(
                parameter_id=client.parameter_by_name(param_name)['id'],
                value=_to_string(param_dset[()]),
                lower_deviation_value=param_dset.attrs['lower_deviation'],
                upper_deviation_value=param_dset.attrs['upper_deviation'],
            )
    det_uuid = 'Detector UUID'
    # Add PDU "UUID" to parameters.
    cond_params[det_uuid] = ParameterConditionAttribute(
        value=_to_string(unpack('d', pack('q', pdu_uuid))[0]),
        parameter_id=client.parameter_by_name(det_uuid)['id'],
    )


def get_ccv_info_from_file(
    client: InjectAPI, cfile: Union[str, Path], pdu: str, ccv_root: str):

    with h5py.File(cfile, 'r') as const_file:
        pdu_group = const_file[pdu]
        pdu_uuid = pdu_group.attrs['uuid']
        detector_type = pdu_group.attrs['detector_type']
        ccv_group = const_file[ccv_root]
        raw_data_location = get_raw_data_location(
            ccv_group.attrs['proposal'],
            ccv_group.attrs['runs']
        )
        begin_at = ccv_group.attrs['begin_at']
        cond_params = extract_parameter_conditions(client, ccv_group, pdu_uuid)

    return cond_params, begin_at, pdu_uuid, detector_type, raw_data_location


def inject_ccv(
    const_src: Union[str, Path],
    ccv_root: str,  # pdu/calibration/key
    report_to: Optional[Union[str, Path]] = None,
    client: Optional[InjectAPI] = None,
):
    if client is None:
        client = extra_calibration_client(inject=True)
        
    pdu_name, calibration, _ = ccv_root.lstrip('/').split('/')

    (
        cond_params, begin_at, pdu_uuid, detector_type, raw_data_location
    ) = client.get_ccv_info_from_file(client, const_src, pdu_name, ccv_root)

    const_rel_path = f'xfel/cal/{detector_type.lower()}/{pdu_name.lower()}'
    const_filename = f'cal.{time.time()}.h5'
    const_dest = _get_default_caldb_root() / const_rel_path / const_filename
    const_dest.parent.mkdir(parents=True, exist_ok=True)
    copyfile(const_src, const_dest)

    try:
        client.create_new_ccv(
            cond_params,
            begin_at,
            pdu_name,
            pdu_uuid,
            detector_type,
            calibration,
            const_rel_path,
            const_filename,
            ccv_root,
            report_to,
            raw_data_location,
        )
    except Exception as e:
        const_dest.unlink()
        raise e