import pytest
import numpy as np

from cal_tools.agipdutils import cast_array_inplace


@pytest.mark.parametrize(
    'dtype_str', ['f8', 'f4', 'f2', 'i4', 'i2', 'i1', 'u4', 'u2', 'u1'])
def test_downcast_array_inplace(dtype_str):
    """Test downcasting an array in-place."""

    dtype = np.dtype(dtype_str)

    ref_data = (np.random.rand(2, 3, 4) * 100)
    orig_data = ref_data.copy()
    cast_data = cast_array_inplace(orig_data, dtype)

    np.testing.assert_allclose(cast_data, ref_data.astype(dtype))
    assert np.may_share_memory(orig_data, cast_data)
    assert cast_data.dtype == dtype
    assert cast_data.flags.c_contiguous
    assert cast_data.flags.aligned
    assert not cast_data.flags.owndata


def test_upcast_array_inplace():
    """Test whether upcasting an array in-place fails."""

    with pytest.raises(Exception):
        cast_array_inplace(
            np.random.rand(4, 5, 6).astype(np.float32), np.float64)


def test_noncontiguous_cast_inplace():
    """Test whether casting a non-contiguous array in-place fails."""

    with pytest.raises(Exception):
        cast_array_inplace(np.random.rand(4, 5, 6).T, np.int32)