From eb35f3552a710db2f9bf9ae2d95004df7d09e2c8 Mon Sep 17 00:00:00 2001
From: Cammille Carinan <cammille.carinan@xfel.eu>
Date: Fri, 21 May 2021 18:04:04 +0200
Subject: [PATCH] Support unknown signal dims

---
 src/toolbox_scs/base/knife_edge.py            |  3 +--
 src/toolbox_scs/base/tests/test_knife_edge.py | 19 ++++++++++++++++++-
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/src/toolbox_scs/base/knife_edge.py b/src/toolbox_scs/base/knife_edge.py
index 82e3aee..0b64fee 100644
--- a/src/toolbox_scs/base/knife_edge.py
+++ b/src/toolbox_scs/base/knife_edge.py
@@ -43,9 +43,8 @@ def prepare_arrays(positions: np.ndarray, intensities: np.ndarray,
         intensities = intensities[slice_]
 
     # Convert both arrays to 1D of the same size
-    n_pulses = intensities.shape[1]
-    positions = np.repeat(positions, n_pulses)
     intensities = intensities.flatten()
+    positions = np.repeat(positions, len(intensities) // len(positions))
     assert positions.shape == intensities.shape
 
     # Clean both arrays by only getting finite values
diff --git a/src/toolbox_scs/base/tests/test_knife_edge.py b/src/toolbox_scs/base/tests/test_knife_edge.py
index 4699597..a1cd9ae 100644
--- a/src/toolbox_scs/base/tests/test_knife_edge.py
+++ b/src/toolbox_scs/base/tests/test_knife_edge.py
@@ -45,7 +45,7 @@ def test_range_mask():
     np.testing.assert_array_equal(slice_, [False, False, True, False, True])
 
 
-def test_prepare_arrays():
+def test_prepare_arrays_nans():
     # Setup test values
     trains, pulses = 5, 10
     size = trains * pulses
@@ -74,6 +74,23 @@ def test_prepare_arrays():
     assert np.isfinite(intensities).all()
 
 
+def test_prepare_arrays_size():
+    trains, pulses = 5, 10
+    size = trains * pulses
+    motor = np.arange(trains)
+    signal = np.random.randint(100, size=(trains, pulses))
+
+    # Test finite motor and 2D signals
+    positions, intensities = prepare_arrays(motor, signal)
+    assert positions.shape == (size,)
+    assert intensities.shape == (size,)
+
+    # Test finite motor and 1D signals
+    positions, intensities = prepare_arrays(motor, signal.reshape(1, -1))
+    assert positions.shape == (size,)
+    assert intensities.shape == (size,)
+
+
 def with_values(array, value, num=5):
     copy = array.astype(np.float)
     copy.ravel()[np.random.choice(copy.size, num, replace=False)] = value
-- 
GitLab