Skip to content
Snippets Groups Projects
Commit 48c2779b authored by David Hammer's avatar David Hammer
Browse files

Fixing bug in set_on_axis, adding a few tests

parent 99de26dd
No related branches found
No related tags found
1 merge request!12Snapshot: field test deployed version as of end of run 202201
......@@ -112,7 +112,8 @@ def transpose_order(axes_in, axes_out):
def stacking_buffer_shape(array_shape, stack_num, axis=0):
"""Figures out the shape you would need for np.stack"""
"""Figures out the shape you would need for np.stack. Think of the axis in terms of
array after adding additional axis, i.e. the number of axes is len(aray_shape)+1."""
if axis > len(array_shape) or axis < -len(array_shape) - 1:
# complain when np.stack would
raise np.AxisError(
......@@ -126,7 +127,7 @@ def stacking_buffer_shape(array_shape, stack_num, axis=0):
def set_on_axis(array, vals, index, axis):
"""set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x"""
if axis >= len(array):
if axis >= array.ndim:
raise IndexError(
f"too many indices for array: array is {len(array.shape)}-dimensional, "
f"but {axis+1} were indexed"
......
......@@ -4,9 +4,56 @@ import time
import timeit
import numpy as np
import pytest
from calng import utils
def test_stacking_buffer_shape():
original_shape = (1, 2, 3)
assert utils.stacking_buffer_shape(original_shape, 4, 0) == (4, 1, 2, 3)
assert utils.stacking_buffer_shape(original_shape, 4, 2) == (1, 2, 4, 3)
assert utils.stacking_buffer_shape(original_shape, 4, 3) == (1, 2, 3, 4)
assert utils.stacking_buffer_shape(original_shape, 4, -1) == (1, 2, 3, 4)
assert utils.stacking_buffer_shape(original_shape, 4, -4) == (4, 1, 2, 3)
with pytest.raises(np.AxisError):
utils.stacking_buffer_shape(original_shape, 4, 4)
with pytest.raises(np.AxisError):
utils.stacking_buffer_shape(original_shape, 4, -5)
def test_set_on_axis():
A = np.array([[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]])
manual = A.copy()
manual[0] = 0
fun = A.copy()
utils.set_on_axis(fun, 0, 0, 0)
assert np.array_equal(manual, fun)
manual = A.copy()
manual[1] = np.arange(5)
fun = A.copy()
utils.set_on_axis(fun, np.arange(5), 1, 0)
assert np.array_equal(manual, fun)
manual = A.copy()
manual[:, 1] = 0
fun = A.copy()
utils.set_on_axis(fun, 0, 1, 1)
assert np.array_equal(manual, fun)
with pytest.raises(IndexError):
utils.set_on_axis(fun, ..., ..., 3)
# case triggering obvious bug I had made
A = np.array([[[1, 2, 3], [4, 5, 6]]])
manual = A.copy()
manual[:, 1] = 0
fun = A.copy()
utils.set_on_axis(fun, 0, 1, 1)
assert np.array_equal(manual, fun)
def test_get_c_type():
assert utils.np_dtype_to_c_type(np.float16) == "half"
assert utils.np_dtype_to_c_type(np.float32) == "float"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment