From 70526f914f2daa39de7dc04db294a1934c0108d4 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Thu, 4 Jul 2024 18:46:30 +0200
Subject: [PATCH] fix: convert dataclass to hashable, keep only reports/set and
 small fixes for new functions

---
 src/cal_tools/inject.py | 31 +++++++++++++------------------
 1 file changed, 13 insertions(+), 18 deletions(-)

diff --git a/src/cal_tools/inject.py b/src/cal_tools/inject.py
index 8d201411c..b26da79a7 100644
--- a/src/cal_tools/inject.py
+++ b/src/cal_tools/inject.py
@@ -58,7 +58,7 @@ class ParameterConditionAttribute:
     description: str = ''
 
 
-@dataclass
+@dataclass(frozen=True)
 class Condition:
     name: str
     parameters_conditions_attributes: List[ParameterConditionAttribute] = field(default_factory=list)
@@ -67,12 +67,12 @@ class Condition:
     description: str = ''
 
 
-@dataclass
+@dataclass(frozen=True)
 class ConditionRequest:
     condition: Condition
 
 
-@dataclass
+@dataclass(frozen=True)
 class CalibrationConstant:
     name: str
     detector_type_id: int
@@ -83,7 +83,7 @@ class CalibrationConstant:
     description: str = ""
 
 
-@dataclass
+@dataclass(frozen=True)
 class Report:
     name: str
     file_path: str
@@ -91,7 +91,7 @@ class Report:
     description: str = ""
 
 
-@dataclass
+@dataclass(frozen=True)
 class CalibrationConstantVersion:
     name: str
     file_name: str
@@ -218,8 +218,10 @@ class InjectAPI(CalCatAPIClient):
             "physical_detector_units", name, name_key="physical_name")
 
     @lru_cache
-    def report(self, report: Report):
-        return self.get("reports", asdict(report))
+    def get_report(self, report: Report):
+        resp = self.get("reports", asdict(report))
+        # `Get all reports` response is a list 
+        return resp if not resp else resp[0]
 
     @lru_cache
     def get_calibration_constant(
@@ -251,7 +253,7 @@ class InjectAPI(CalCatAPIClient):
         self, calibration_constant: CalibrationConstant):
         return self.post("calibration_constants", asdict(calibration_constant))
 
-    def create_report(self, report: Report):
+    def get_or_create_report(self, report: Report):
         # Based on create or get API
         return self.post("reports/set", asdict(report))
 
@@ -281,7 +283,7 @@ class InjectAPI(CalCatAPIClient):
         # Create condition table in database, if not available.
         resp = self.set_condition(cond_name, list(cond_params.values()))
         condition_id = resp['id']
-        
+
         # Prepare some parameters to set Calibration Constant.
         cal_id = self.calibration_by_name(calibration)['id']
         det_type_id = self.detector_type_by_name(detector_type)['id']
@@ -297,7 +299,7 @@ class InjectAPI(CalCatAPIClient):
         if report_to:
             report_path = Path(report_to).absolute().with_suffix('.pdf')
             resp = self.get_or_create_report(
-                name=report_path.stem, file_path=str(report_path))
+                Report(name=report_path.stem, file_path=str(report_path)))
             report_id = resp["id"]
 
         # Get PDU ID before creating new CCV.
@@ -336,18 +338,10 @@ class InjectAPI(CalCatAPIClient):
             condition_id=cond_id,
             detector_type_id=det_type_id
         )
-        
         resp = self.get_calibration_constant(calibration_constant)
         return resp if resp else self.create_calibration_constant(
             calibration_constant)
 
-    def get_or_create_report(self, name: str, file_path: str):
-        report = Report(name=name, file_path=file_path)
-        resp = self.report(report)
-        # TODO: confirm if this create_report isn't already enough (it does get_or_create also)?
-        # In case report hasn't been created still.
-        return resp if resp else self.create_report(report)
-
     def set_calibration_constant_version(self, ccv: CalibrationConstantVersion):
         return self.create_calibration_constant_version(ccv)
 
@@ -378,6 +372,7 @@ def extract_parameter_conditions(client, ccv_group, pdu_uuid):
         value=_to_string(unpack('d', pack('q', pdu_uuid))[0]),
         parameter_id=client.parameter_by_name(det_uuid)['id'],
     )
+    return cond_params
 
 
 def get_ccv_info_from_file(
-- 
GitLab