From 44085af960c9420a1a195d1bf230dbdc7b8b6941 Mon Sep 17 00:00:00 2001
From: Karim Ahmed <karim.ahmed@xfel.eu>
Date: Mon, 21 Jun 2021 12:14:53 +0200
Subject: [PATCH] Test get_from_db and send_to_db from cal_tools

---
 src/cal_tools/tools.py  |  32 +++-
 tests/conftest.py       |   2 +-
 tests/test_cal_tools.py | 320 +++++++++++++++++++++++++++++++++++-----
 3 files changed, 309 insertions(+), 45 deletions(-)

diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py
index 4a31df713..d13d97691 100644
--- a/src/cal_tools/tools.py
+++ b/src/cal_tools/tools.py
@@ -532,6 +532,8 @@ def get_from_db(karabo_id: str, karabo_da: str,
                 ntries -= 1
                 timeout *= 2
                 sleep(np.random.randint(30))
+                # TODO: reevaluate the need for doraise
+                # and remove if not needed.
                 if ntries == 0 and doraise:
                     raise
             except Exception as e:
@@ -547,14 +549,24 @@ def get_from_db(karabo_id: str, karabo_da: str,
         if ntries > 0:
             mdata_const = metadata.calibration_constant_version
             if load_data and meta_only:
-                fpath = Path(mdata_const.hdf5path, mdata_const.filename)
-                with h5py.File(fpath, "r") as f:
-                    arr = f[f"{mdata_const.h5path}/data"][()]
+                hdf5path = getattr(mdata_const, 'hdf5path', None)
+                filename = getattr(mdata_const, 'filename', None)
+                h5path = getattr(mdata_const, 'h5path', None)
+                if not (hdf5path and filename and h5path):
+                    raise ValueError(
+                        "Wrong metadata received to access the constant data."
+                        f" Retrieved constant filepath is {hdf5path}/{filename}"  # noqa
+                        f" and data_set_name is {h5path}."
+                    )
+                with h5py.File(Path(hdf5path, filename), "r") as f:
+                    arr = f[f"{h5path}/data"][()]
                 metadata.calibration_constant.data = arr
 
             if verbosity > 0:
                 if constant.name not in already_printed or verbosity > 1:
                     already_printed[constant.name] = True
+                    # TODO: Reset mdata_const.begin_at
+                    # if comm_db_success is False.
                     begin_at = mdata_const.begin_at
                     print(f"Retrieved {constant.name} "
                           f"with creation time: {begin_at}")
@@ -601,6 +613,12 @@ def send_to_db(db_module: str, karabo_id: str, constant, condition,
         if report_path:
             # calibration_client expects a dict of injected report path
             # of at least 2 characters for each key.
+            if not isinstance(report_path, str) or len(report_path) < 2:
+                raise TypeError(
+                    "\"report_path\" needs to be a string "
+                    "of at least 2 characters."
+                )
+
             report = {"name": path.basename(report_path),
                       "file_path": report_path}
             metadata.calibration_constant_version.report_path = report
@@ -609,13 +627,17 @@ def send_to_db(db_module: str, karabo_id: str, constant, condition,
         metadata.calibration_constant_version.device_name = db_module
         metadata.calibration_constant_version.karabo_da = None
         metadata.calibration_constant_version.raw_data_location = file_loc
-
+        if constant.data is None:
+            raise ValueError(
+                "There is no data available to "
+                "inject to the database."
+            )
         while ntries > 0:
 
             this_interface = get_random_db_interface(cal_db_interface)
             try:
                 metadata.send(this_interface, timeout=timeout)
-                success = True
+                success = True  # TODO: use comm_db_success
                 break
             except zmq.error.Again:
                 ntries -= 1
diff --git a/tests/conftest.py b/tests/conftest.py
index d0e461a4c..f8186a593 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,7 @@ def pytest_addoption(parser):
     parser.addoption(
         "--no-gpfs",
         action="store_true",
-        default="false",
+        default=False,
         help="Skips tests marked as requiring GPFS access",
     )
 
diff --git a/tests/test_cal_tools.py b/tests/test_cal_tools.py
index 62567e438..f17a29019 100644
--- a/tests/test_cal_tools.py
+++ b/tests/test_cal_tools.py
@@ -1,15 +1,48 @@
 from datetime import datetime
 from pathlib import Path
+from unittest.mock import patch
 
+import numpy as np
 import pytest
+import zmq
+from iCalibrationDB import Conditions, ConstantMetaData, Constants
+
 from cal_tools.agipdlib import AgipdCorrections
 from cal_tools.plotting import show_processed_modules
 from cal_tools.tools import (
     get_dir_creation_date,
+    get_from_db,
     get_pdu_from_db,
     module_index_to_qm,
+    send_to_db,
 )
-from iCalibrationDB import Conditions, Constants
+
+# AGIPD operating conditions.
+ACQ_RATE = 1.1
+BIAS_VOLTAGE = 300
+GAIN_SETTING = 0
+MEM_CELLS = 352
+PHOTON_ENERGY = 9.2
+
+AGIPD_KARABO_ID = "SPB_DET_AGIPD1M-1"
+WRONG_AGIPD_MODULE = "AGIPD_**"
+
+CAL_DB_INTERFACE = "tcp://max-exfl017:8020"
+WRONG_CAL_DB_INTERFACE = "tcp://max-exfl017:0000"
+
+
+@pytest.fixture
+def _agipd_const_cond():
+    # AGIPD dark offset metadata
+    constant = Constants.AGIPD.Offset()
+
+    condition = Conditions.Dark.AGIPD(
+        memory_cells=MEM_CELLS,
+        bias_voltage=BIAS_VOLTAGE,
+        acquisition_rate=ACQ_RATE,
+        gain_setting=GAIN_SETTING,
+    )
+    return constant, condition
 
 
 def test_show_processed_modules():
@@ -39,22 +72,224 @@ def test_dir_creation_date():
     assert str(date) == '2019-12-16 08:52:25.196603'
 
 
-# AGIPD dark offset metadata
-constant = Constants.AGIPD.Offset()
-mem_cells = 352
-bias_voltage = 300
-acq_rate = 1.1
-gain_setting = 0
-photon_energy = 9.2
-condition = Conditions.Dark.AGIPD(memory_cells=mem_cells,
-                                  bias_voltage=bias_voltage,
-                                  acquisition_rate=acq_rate,
-                                  gain_setting=gain_setting)
-cal_db_interface = "tcp://max-exfl017:8020"
+def _call_get_from_db(
+    constant,
+    condition,
+    karabo_id,
+    karabo_da,
+    load_data=True,
+    cal_db_interface=CAL_DB_INTERFACE,
+    creation_time=None,
+    doraise=True,
+    timeout=10000,
+):
+
+    data, metadata = get_from_db(
+        karabo_id=karabo_id,
+        karabo_da=karabo_da,
+        constant=constant,
+        condition=condition,
+        empty_constant=None,
+        cal_db_interface=cal_db_interface,
+        creation_time=creation_time,
+        meta_only=True,
+        load_data=load_data,
+        ntries=1,
+        doraise=doraise,
+        timeout=timeout,
+    )
+    return data, metadata
+
+
+def _call_send_to_db(
+    constant,
+    condition,
+    db_module,
+    data=np.zeros((2, 2, 2)),
+    cal_db_interface=CAL_DB_INTERFACE,
+    report_path="",
+    doraise=True,
+    timeout=1000,
+    ntries=1,
+):
+
+    # TODO: create a known_constant for testing.
+    constant.data = data
+    metadata = send_to_db(
+        karabo_id=AGIPD_KARABO_ID,
+        db_module=db_module,
+        constant=constant,
+        condition=condition,
+        file_loc="proposal, runs",
+        report_path=report_path,
+        cal_db_interface=cal_db_interface,
+        creation_time=None,
+        ntries=ntries,
+        doraise=doraise,
+        timeout=timeout,
+    )
+    return metadata
+
+
+# TODO add a marker for accessing zmq end_point
+@pytest.mark.requires_gpfs
+def test_get_from_db_load_data(_agipd_const_cond):
+    """ Test retrieving calibration constants with get_from_db
+    with different loading data scenarios.
+    """
+    constant, condition = _agipd_const_cond
+
+    # Normal operation and loading data from h5file.
+    data, md = _call_get_from_db(
+        constant=constant, condition=condition,
+        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+    )
+
+    assert type(data) is np.ndarray
+    assert isinstance(md, ConstantMetaData)
+
+    # None karabo_id is given.
+    data, md = _call_get_from_db(
+        constant=constant, condition=condition,
+        karabo_id=None, karabo_da="AGIPD00",
+    )
+
+    assert data is None
+    assert md is None
+
+    # Retrieve constant without loading the data.
+    data, md = _call_get_from_db(
+        constant=constant, condition=condition,
+        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+        load_data=False,
+    )
+
+    assert data is None
+    assert isinstance(md, ConstantMetaData)
 
 
-def test_get_pdu_from_db():
+# TODO add a marker for accessing zmq end_point
+@pytest.mark.requires_gpfs
+def test_raise_get_from_db(_agipd_const_cond):
+    """ Test error raised scenarios for get_from_db:"""
+
+    constant, condition = _agipd_const_cond
+
+    # Wrong address for the calibration database.
+    with pytest.raises(zmq.error.Again) as excinfo:
+        _call_get_from_db(
+            constant=constant, condition=condition,
+            karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+            cal_db_interface=WRONG_CAL_DB_INTERFACE,
+        )
+    assert str(excinfo.value) == "Resource temporarily unavailable"
+
+    # Wrong type for creation_time.
+    with pytest.raises(ValueError):
+        _call_get_from_db(
+            constant=constant, condition=condition,
+            karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+            creation_time="WRONG_CREATION_TIME",
+        )
+
+    # No constant file path metadata retrieved.
+    with patch("iCalibrationDB.ConstantMetaData.retrieve", return_value=""):
+        with pytest.raises(ValueError):
+            _call_get_from_db(
+                constant=constant, condition=condition,
+                karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+            )
+
+
+def test_no_doraise_get_from_db(_agipd_const_cond):
+    """get_from_db using wrong cal_db_interface
+    fails without raising errors, as doraise = False
+    """
+    constant, condition = _agipd_const_cond
+
+    data, _ = _call_get_from_db(
+        constant=constant, condition=condition,
+        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
+        cal_db_interface=WRONG_CAL_DB_INTERFACE,
+        doraise=False,
+    )
+    assert data is None
+
 
+@patch(
+    'iCalibrationDB.ConstantMetaData.send',
+    return_value='',
+)
+def test_send_to_db_success(send, _agipd_const_cond):
+    """test sending constants to the database (send_to_db):
+    Injecting constant as expected.
+    # TODO: Add a test calibration constant to the test physical module
+    # to inject without mocking `send` method.
+    """
+    # Use wrong AGIPD module as a backup.
+    # To avoid constants injection, in-case of mock failure.
+
+    constant, condition = _agipd_const_cond
+
+    db_module = WRONG_AGIPD_MODULE
+    metadata = _call_send_to_db(
+        constant=constant,
+        condition=condition,
+        db_module=db_module,
+    )
+
+    assert isinstance(metadata, ConstantMetaData)
+
+
+@patch(
+    'iCalibrationDB.ConstantMetaData.send',
+    return_value='',
+)
+def test_raise_send_to_db_mocked(send, _agipd_const_cond):
+    """Test raised errors while sending constants to the
+    database (send_to_db):
+    """
+    # Use wrong AGIPD module as a backup.
+    # To avoid constants injection, in-case of mock failure.
+    constant, condition = _agipd_const_cond
+
+    # report_path has the wrong type.
+    with pytest.raises(TypeError):
+        _call_send_to_db(
+            constant=constant,
+            condition=condition,
+            db_module=WRONG_AGIPD_MODULE,
+            report_path=2,
+        )
+
+    # No constant data to inject.
+    with pytest.raises(ValueError):
+        _call_send_to_db(
+            constant=constant,
+            condition=condition,
+            db_module=WRONG_AGIPD_MODULE,
+            data=None,
+        )
+
+
+def test_raise_send_to_db(_agipd_const_cond):
+
+    constant, condition = _agipd_const_cond
+
+    # wrong calibration database address.
+    with pytest.raises(zmq.error.Again) as excinfo:
+        _call_send_to_db(
+            constant=constant,
+            condition=condition,
+            db_module=WRONG_AGIPD_MODULE,
+            cal_db_interface=WRONG_CAL_DB_INTERFACE,
+        )
+    assert str(excinfo.value) == "Resource temporarily unavailable"
+
+
+def test_get_pdu_from_db(_agipd_const_cond):
+
+    constant, condition = _agipd_const_cond
     snapshot_at = "2021-05-06 00:20:10.00"
 
     # A karabo_da str returns a list of one element.
@@ -62,7 +297,7 @@ def test_get_pdu_from_db():
                                karabo_da="TEST_DET_CAL_DA0",
                                constant=constant,
                                condition=condition,
-                               cal_db_interface=cal_db_interface,
+                               cal_db_interface=CAL_DB_INTERFACE,
                                snapshot_at=snapshot_at,
                                timeout=30000)
     assert len(pdu_dict) == 1
@@ -70,11 +305,13 @@ def test_get_pdu_from_db():
 
     # A list of karabo_das to return thier PDUs, if available.
     pdu_dict = get_pdu_from_db(karabo_id="TEST_DET_CAL_CI-1",
-                               karabo_da=["TEST_DET_CAL_DA0", "TEST_DET_CAL_DA1",
-                                          "UNAVAILABLE_DA"],
+                               karabo_da=[
+                                   "TEST_DET_CAL_DA0",
+                                   "TEST_DET_CAL_DA1",
+                                   "UNAVAILABLE_DA"],
                                constant=constant,
                                condition=condition,
-                               cal_db_interface=cal_db_interface,
+                               cal_db_interface=CAL_DB_INTERFACE,
                                snapshot_at=snapshot_at,
                                timeout=30000)
 
@@ -87,7 +324,7 @@ def test_get_pdu_from_db():
                                karabo_da="all",
                                constant=constant,
                                condition=condition,
-                               cal_db_interface=cal_db_interface,
+                               cal_db_interface=CAL_DB_INTERFACE,
                                snapshot_at=snapshot_at,
                                timeout=30000)
 
@@ -96,27 +333,30 @@ def test_get_pdu_from_db():
                         "CAL_PHYSICAL_DETECTOR_UNIT-2_TEST"]
 
 
+# TODO add a marker for accessing zmq end_point
 @pytest.mark.requires_gpfs
 def test_initialize_from_db():
-    creation_time = datetime.strptime("2020-01-07 13:26:48.00",
-                                      "%Y-%m-%d %H:%M:%S.%f")
+    creation_time = datetime.strptime(
+        "2020-01-07 13:26:48.00", "%Y-%m-%d %H:%M:%S.%f")
 
-    agipd_corr = AgipdCorrections(max_cells=mem_cells,
-                                  max_pulses=[0, 500, 1])
+    agipd_corr = AgipdCorrections(
+        max_cells=MEM_CELLS,
+        max_pulses=[0, 500, 1])
 
-    agipd_corr.allocate_constants(modules=[0],
-                                  constant_shape=(3, mem_cells, 512, 128))
+    agipd_corr.allocate_constants(
+        modules=[0],
+        constant_shape=(3, MEM_CELLS, 512, 128))
 
     dark_const_time_dict = agipd_corr.initialize_from_db(
-        karabo_id="TEST_DET_CI-2",
-        karabo_da="TEST_DAQ_DA_01",
-        cal_db_interface=cal_db_interface,
+        karabo_id="TEST_DET_CAL_CI-1",
+        karabo_da="TEST_DET_CAL_DA1",
+        cal_db_interface=CAL_DB_INTERFACE,
         creation_time=creation_time,
-        memory_cells=mem_cells,
-        bias_voltage=bias_voltage,
-        photon_energy=photon_energy,
-        gain_setting=gain_setting,
-        acquisition_rate=acq_rate,
+        memory_cells=MEM_CELLS,
+        bias_voltage=BIAS_VOLTAGE,
+        photon_energy=PHOTON_ENERGY,
+        gain_setting=GAIN_SETTING,
+        acquisition_rate=ACQ_RATE,
         module_idx=0,
         only_dark=False,
     )
@@ -129,14 +369,16 @@ def test_initialize_from_db():
     }
 
     dark_const_time_dict = agipd_corr.initialize_from_db(
-        karabo_id="SPB_DET_AGIPD1M-1",
+        karabo_id=AGIPD_KARABO_ID,
         karabo_da="AGIPD00",
-        cal_db_interface=cal_db_interface,
+        cal_db_interface=CAL_DB_INTERFACE,
         creation_time=creation_time,
-        memory_cells=mem_cells, bias_voltage=bias_voltage,
-        photon_energy=photon_energy, gain_setting=gain_setting,
-        acquisition_rate=acq_rate, module_idx=0,
-        only_dark=False)
+        memory_cells=MEM_CELLS, bias_voltage=BIAS_VOLTAGE,
+        photon_energy=PHOTON_ENERGY, gain_setting=GAIN_SETTING,
+        acquisition_rate=ACQ_RATE, module_idx=0,
+        only_dark=False,
+    )
+
     # A retrieved constant has a value of datetime creation_time
     assert isinstance(dark_const_time_dict["Offset"], datetime)
     assert list(dark_const_time_dict.keys()) == [
-- 
GitLab