From bd4debeb4f82cefd5df4ce971ed2cceb28458d15 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Tue, 18 Jun 2024 17:56:44 +0200
Subject: [PATCH] fix: refactor write_ccv into smaller functios and enable
 adding lower and upper deviations

---
 src/cal_tools/constants.py | 202 ++++++++++++++++++++++++++-----------
 1 file changed, 144 insertions(+), 58 deletions(-)

diff --git a/src/cal_tools/constants.py b/src/cal_tools/constants.py
index 21b047c00..8b5513d94 100644
--- a/src/cal_tools/constants.py
+++ b/src/cal_tools/constants.py
@@ -1,26 +1,37 @@
-
-from datetime import datetime, timezone
-from struct import pack, unpack
-from pathlib import Path
-from shutil import copyfile
-from hashlib import md5
 import binascii
 import time
+from hashlib import md5
+from pathlib import Path
+from shutil import copyfile
+from struct import pack, unpack
+from typing import List, Optional, Union
 
-import numpy as np
 import h5py
-
+import numpy as np
 from calibration_client import CalibrationClient
-from cal_tools.calcat_interface2 import get_default_caldb_root, get_client
-from cal_tools.tools import run_prop_seq_from_path
+
+from cal_tools.calcat_interface2 import (
+    CalCatAPIError,
+    _get_default_caldb_root,
+)
 from cal_tools.restful_config import calibration_client
 
+CONDITION_NAME_MAX_LENGTH = 60
+
+
+class InjectAPIError(CalCatAPIError):
+    ...
+
+
+class CCVAlreadyInjectedError(InjectAPIError):
+    ...
+
 
 def write_ccv(
     const_path,
     pdu_name, pdu_uuid, detector_type,
     calibration, conditions, created_at, proposal, runs,
-    data, dims, key='0'
+    data, dims, key='0', deviations={},
 ):
     """Write CCV data file.
 
@@ -37,8 +48,13 @@ def write_ccv(
         runs (Iterable of int): Raw data runs the calibration data is
             generated from
         data (ndarray): Calibration constant data
-        dims (Iterable of str):
-        key (str, optional):
+        dims (Iterable of str): Dimension names for the constant data.
+        key (str, optional): Key added in constant file dataset when
+            constant data are stored. Defaults to '0'.
+        deviations (dict, optional): Deviation values for operating conditions.
+            Each value can be a tuple (lower, upper) or a single value for
+            symmetric deviations. Defaults to {}.
+            e.g. {"integration_time": 0.025}
 
     Returns:
         (str) CCV HDF group name.
@@ -70,11 +86,18 @@ def write_ccv(
         opcond_dict = conditions.make_dict(
             conditions.calibration_types[calibration])
         for db_name, value in opcond_dict.items():
-            key = db_name.lower().replace(' ', '_')
-            dset = opcond_group.create_dataset(key, data=value,
-                                               dtype=np.float64)
-            dset.attrs['lower_deviation'] = 0.0
-            dset.attrs['upper_deviation'] = 0.0
+            cond_name = db_name.lower().replace(' ', '_')
+            dset = opcond_group.create_dataset(
+                cond_name, data=value, dtype=np.float64)
+
+            deviation = deviations.get(cond_name, (0.0, 0.0))
+            if isinstance(deviation, (int, float)):
+                lower_dev = upper_dev = deviation
+            else:
+                lower_dev, upper_dev = deviation
+
+            dset.attrs['lower_deviation'] = lower_dev
+            dset.attrs['upper_deviation'] = upper_dev
             dset.attrs['database_name'] = db_name
 
         dset = ccv_group.create_dataset('data', data=data)
@@ -83,7 +106,79 @@ def write_ccv(
     return ccv_group_name
 
 
-def inject_ccv(const_src, ccv_root, report_to=None):
+def get_condition_dict(
+    name: str,
+    value: Union[float, str, int, bool],
+    lower_deviation: float = 0.0,
+    upper_deviation: float = 0.0,
+):
+
+    def to_float_or_string(value):
+        """CALCAT expects data to either be float or a string.
+        """
+        try:  # Any digit or boolean
+            return float(value)
+        except:
+            return str(value)
+
+    return {
+        'parameter_name': name,
+        'value': to_float_or_string(value),
+        'lower_deviation_value': lower_deviation,
+        'upper_deviation_value': upper_deviation,
+        'flg_available': True
+    }
+
+
+def generate_unique_condition_name(
+    detector_type: str,
+    pdu_name: str,
+    pdu_uuid: float,
+    cond_params: List[dict],
+):
+    """Generate a unique condition using UUID and timestamp.
+
+    Args:
+        detector_type (str): detector type.
+        pdu_name (str): Physical detector unit db name.
+        pdu_uuid (float): Physical detector unit db id.
+        cond_params (List[dict]): A list of dictionary with each condition
+            e.g. [{
+                "parameter_name": "Memory Cells",
+                "value": 352.0,
+                "lower-deviation": 0.0,
+                "upper-deviation": 0.0
+            }]
+
+    Returns:
+        str:  A unique name used for the table of conditions.
+    """
+    unique_name = detector_type[:detector_type.index('-Type')] + ' Def'
+    cond_hash = md5(pdu_name.encode())
+    cond_hash.update(int(pdu_uuid).to_bytes(
+        length=8, byteorder='little', signed=False))
+
+    for param_dict in cond_params:
+        cond_hash.update(str(param_dict['parameter_name']).encode())
+        cond_hash.update(str(param_dict['value']).encode())
+
+    unique_name += binascii.b2a_base64(cond_hash.digest()).decode()
+    return unique_name[:CONDITION_NAME_MAX_LENGTH]
+
+
+def get_raw_data_location(proposal: str, runs: list):
+    if proposal and len(runs) > 0:
+        return (
+            f'proposal:{proposal} runs: {" ".join([str(x) for x in runs])}')
+    else:
+        return ""  # Fallback for non-run based constants
+
+
+def inject_ccv(
+    const_src: Union[Path, str],
+    ccv_root: str,
+    report_to: Optional[str] = None,
+):
     """Inject new CCV into CalCat.
 
     Args:
@@ -91,69 +186,51 @@ def inject_ccv(const_src, ccv_root, report_to=None):
         ccv_root (str): CCV HDF group name.
         report_to (str): Metadata location.
 
-    Returns:
-        None
-
     Raises:
         RuntimeError: If CalCat POST request fails.
     """
-
-    pdu_name, calibration, key = ccv_root.lstrip('/').split('/')
+    pdu_name, calibration, _ = ccv_root.lstrip('/').split('/')
 
     with h5py.File(const_src, 'r') as const_file:
+        if ccv_root not in const_file:
+            raise ValueError(
+                f"Invalid HDF5 structure: {ccv_root} not found in file.")
+
         pdu_group = const_file[pdu_name]
         pdu_uuid = pdu_group.attrs['uuid']
         detector_type = pdu_group.attrs['detector_type']
 
         ccv_group = const_file[ccv_root]
+
         proposal, runs = ccv_group.attrs['proposal'], ccv_group.attrs['runs']
         begin_at_str = ccv_group.attrs['begin_at']
 
         condition_group = ccv_group['operating_condition']
 
         cond_params = []
-
         # It's really not ideal we're mixing conditionS and condition now.
         for parameter in condition_group:
             param_dset = condition_group[parameter]
-            cond_params.append({
-                'parameter_name': param_dset.attrs['database_name'],
-                'value': float(param_dset[()]),
-                'lower_deviation_value': param_dset.attrs['lower_deviation'],
-                'upper_deviation_value': param_dset.attrs['upper_deviation'],
-                'flg_available': True
-            })
+            cond_params.append(get_condition_dict(
+                param_dset.attrs['database_name'],
+                param_dset[()],
+                param_dset.attrs['lower_deviation'],
+                param_dset.attrs['upper_deviation'],
+            ))
 
     const_rel_path = f'xfel/cal/{detector_type.lower()}/{pdu_name.lower()}'
     const_filename = f'cal.{time.time()}.h5'
 
-    if proposal and len(runs) > 0:
-        raw_data_location = 'proposal:{} runs: {}'.format(
-            proposal, ' '.join([str(x) for x in runs]))
-    else:
-        pass  # Fallback for non-run based constants
-
-    # Generate condition name.
-    unique_name = detector_type[:detector_type.index('-Type')] + ' Def'
-    cond_hash = md5(pdu_name.encode())
-    cond_hash.update(int(pdu_uuid).to_bytes(
-        length=8, byteorder='little', signed=False))
-
-    for param_dict in cond_params:
-        cond_hash.update(str(param_dict['parameter_name']).encode())
-        cond_hash.update(str(param_dict['value']).encode())
+    unique_name = generate_unique_condition_name(
+        detector_type, pdu_name, pdu_uuid, cond_params)
 
-    unique_name += binascii.b2a_base64(cond_hash.digest()).decode()
-    unique_name = unique_name[:60]
+    raw_data_location = get_raw_data_location(proposal, runs)
 
     # Add PDU "UUID" to parameters.
-    cond_params.append({
-        'parameter_name': 'Detector UUID',
-        'value': unpack('d', pack('q', pdu_uuid))[0],
-        'lower_deviation_value': 0.0,
-        'upper_deviation_value': 0.0,
-        'flg_available': True
-    })
+    cond_params.append(get_condition_dict(
+        'Detector UUID',
+        unpack('d', pack('q', pdu_uuid))[0]
+    ))
 
     inject_h = {
         'detector_condition': {
@@ -187,13 +264,22 @@ def inject_ccv(const_src, ccv_root, report_to=None):
             'file_path': str(report_path)
         }
 
-    const_dest = get_default_caldb_root() / const_rel_path / const_filename
+    const_dest = _get_default_caldb_root() / const_rel_path / const_filename
     const_dest.parent.mkdir(parents=True, exist_ok=True)
     copyfile(const_src, const_dest)
 
+    # TODO: Consider catching `RequestException`s
+    # when bypassing calibration_client
     resp = CalibrationClient.inject_new_calibration_constant_version(
         calibration_client(), inject_h)
 
     if not resp['success']:
         const_dest.unlink()  # Delete already copied CCV file.
-        raise RuntimeError(resp)
+        # TODO: Remove this when the new injection code is added.
+        if (
+            resp['status_code'] == 422 and
+            "taken" in resp['app_info'].get("begin_at", [""])[0]
+        ):
+            raise CCVAlreadyInjectedError
+        else:
+            raise RuntimeError(resp)
-- 
GitLab