diff --git a/src/calng/utils.py b/src/calng/utils.py index 54ed40fedbb706379ac46611d62dc4a49dc43878..344752374491c1a4cbaaed6f1f0f03c492f4069f 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -112,7 +112,8 @@ def transpose_order(axes_in, axes_out): def stacking_buffer_shape(array_shape, stack_num, axis=0): - """Figures out the shape you would need for np.stack""" + """Figures out the shape you would need for np.stack. Think of the axis in terms of + array after adding additional axis, i.e. the number of axes is len(aray_shape)+1.""" if axis > len(array_shape) or axis < -len(array_shape) - 1: # complain when np.stack would raise np.AxisError( @@ -126,7 +127,7 @@ def stacking_buffer_shape(array_shape, stack_num, axis=0): def set_on_axis(array, vals, index, axis): """set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x""" - if axis >= len(array): + if axis >= array.ndim: raise IndexError( f"too many indices for array: array is {len(array.shape)}-dimensional, " f"but {axis+1} were indexed" diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py index ebe7d64061f47baf1f6f04b4db425e3d96962a49..d53a31eadc4818994516267f5b83f9d55be7f522 100644 --- a/src/tests/test_utils.py +++ b/src/tests/test_utils.py @@ -4,9 +4,56 @@ import time import timeit import numpy as np +import pytest from calng import utils +def test_stacking_buffer_shape(): + original_shape = (1, 2, 3) + assert utils.stacking_buffer_shape(original_shape, 4, 0) == (4, 1, 2, 3) + assert utils.stacking_buffer_shape(original_shape, 4, 2) == (1, 2, 4, 3) + assert utils.stacking_buffer_shape(original_shape, 4, 3) == (1, 2, 3, 4) + assert utils.stacking_buffer_shape(original_shape, 4, -1) == (1, 2, 3, 4) + assert utils.stacking_buffer_shape(original_shape, 4, -4) == (4, 1, 2, 3) + with pytest.raises(np.AxisError): + utils.stacking_buffer_shape(original_shape, 4, 4) + with pytest.raises(np.AxisError): + utils.stacking_buffer_shape(original_shape, 4, -5) + + +def test_set_on_axis(): + A = np.array([[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]]) + + manual = A.copy() + manual[0] = 0 + fun = A.copy() + utils.set_on_axis(fun, 0, 0, 0) + assert np.array_equal(manual, fun) + + manual = A.copy() + manual[1] = np.arange(5) + fun = A.copy() + utils.set_on_axis(fun, np.arange(5), 1, 0) + assert np.array_equal(manual, fun) + + manual = A.copy() + manual[:, 1] = 0 + fun = A.copy() + utils.set_on_axis(fun, 0, 1, 1) + assert np.array_equal(manual, fun) + + with pytest.raises(IndexError): + utils.set_on_axis(fun, ..., ..., 3) + + # case triggering obvious bug I had made + A = np.array([[[1, 2, 3], [4, 5, 6]]]) + manual = A.copy() + manual[:, 1] = 0 + fun = A.copy() + utils.set_on_axis(fun, 0, 1, 1) + assert np.array_equal(manual, fun) + + def test_get_c_type(): assert utils.np_dtype_to_c_type(np.float16) == "half" assert utils.np_dtype_to_c_type(np.float32) == "float"