From 623c75e262f65eb4dd58faccd7244c2ef089b621 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Thu, 2 Jun 2022 10:40:13 +0200
Subject: [PATCH] add tests for gotthard2alg

---
 .gitlab-ci.yml                                |  2 +-
 .../{test_agipdalgs.py => test_cythonalgs.py} |  1 +
 tests/test_gotthard2algs.py                   | 61 +++++++++++++++++++
 3 files changed, 63 insertions(+), 1 deletion(-)
 rename tests/{test_agipdalgs.py => test_cythonalgs.py} (56%)
 create mode 100644 tests/test_gotthard2algs.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index bb8dd05a0..c4d3d5ba6 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -57,4 +57,4 @@ cython-editable-install-test:
   <<: *before_script
   script:
     - python3 -m pip install -e ".[test]"
-    - python3 -m pytest --color yes --verbose ./tests/test_agipdalgs.py
+    - python3 -m pytest --color yes --verbose ./tests/test_cythonalgs.py 
diff --git a/tests/test_agipdalgs.py b/tests/test_cythonalgs.py
similarity index 56%
rename from tests/test_agipdalgs.py
rename to tests/test_cythonalgs.py
index 6296eb049..a20d23703 100644
--- a/tests/test_agipdalgs.py
+++ b/tests/test_cythonalgs.py
@@ -1,2 +1,3 @@
 def test_import():
     from cal_tools import agipdalgs  # noqa
+    from cal_tools import gotthard2algs  # noqa
diff --git a/tests/test_gotthard2algs.py b/tests/test_gotthard2algs.py
new file mode 100644
index 000000000..bb58accff
--- /dev/null
+++ b/tests/test_gotthard2algs.py
@@ -0,0 +1,61 @@
+import numpy as np
+import pytest
+
+from cal_tools.gotthard2algs import convert_to_10bit, correct_train
+
+
+def test_convert_to_10bit():
+
+    n_stripes = 10
+    n_pulses = 500
+
+    # Define LUT, raw data 12 bit, and raw data 10bit array.
+    lut = np.array(
+        [[list(range(4096//2))*2, list(range(4096//2, 4096))*2]] * n_stripes,
+        dtype=np.uint16
+    )
+    raw_data = np.array([list(range(n_stripes))]*n_pulses, dtype=np.uint16)
+    raw_data10bit = np.zeros(raw_data.shape, dtype=np.float32)
+
+    result = np.concatenate(
+        [
+            np.array(x)[:, None] for x in [
+                list(range(n_stripes)),
+                list(range(2048, 2048+n_stripes))
+            ] * (n_pulses//2)], axis=1, dtype=np.float32,
+    ).T
+
+    convert_to_10bit(raw_data, lut.astype(np.uint16), raw_data10bit)
+    assert np.allclose(result, raw_data10bit)
+
+
+@pytest.mark.parametrize("gain_corr", [True, False])
+def test_correct_train(gain_corr):
+
+    raw_d = np.random.randn(2700, 1280).astype(np.float32)
+    gain = np.random.choice([0, 1, 2], size=(2700, 1280)).astype(np.uint8)
+
+    offset = np.random.randn(1280, 2, 3).astype(np.float32)
+    relgain = np.random.randn(1280, 2, 3).astype(np.float32)
+    badpixles = np.zeros_like(offset).astype(np.uint32).astype(np.uint32)
+
+    test_data = raw_d.copy()
+    mask = np.zeros_like(test_data).astype(np.uint32)
+
+    correct_train(
+        test_data, mask, gain, offset, relgain, badpixles, gain_corr)
+
+    ref_data = raw_d.copy()
+
+    ref_data[::2, :] -= np.choose(
+        gain[::2, :], np.moveaxis(offset[:, 0, :], 1, 0))
+    ref_data[1::2, :] -= np.choose(
+        gain[1::2, :], np.moveaxis(offset[:, 1, :], 1, 0))
+
+    if gain_corr:
+        ref_data[::2, :] /= np.choose(
+            gain[::2, :], np.moveaxis(relgain[:, 0, :], 1, 0))
+        ref_data[1::2, :] /= np.choose(
+            gain[1::2, :], np.moveaxis(relgain[:, 1, :], 1, 0))
+
+    assert np.allclose(test_data, test_data)
-- 
GitLab