import numpy as np
import pytest

from ..knife_edge import erfc, knife_edge_base, prepare_arrays, range_mask


def test_range_mask():
    arr = np.array([1, 2, 3, 4, 5])

    # Exact
    slice_ = range_mask(arr, minimum=2, maximum=4)
    np.testing.assert_array_equal(slice_, [False, True, True, True, False])

    # Range exceeds the closest values
    slice_ = range_mask(arr, minimum=1.75, maximum=4.25)
    np.testing.assert_array_equal(slice_, [False, True, True, True, False])

    # Range misses the closest values
    slice_ = range_mask(arr, minimum=2.25, maximum=3.75)
    np.testing.assert_array_equal(slice_, [False, False, True, False, False])

    # Equidistant
    slice_ = range_mask(arr, minimum=2.5, maximum=4.5)
    np.testing.assert_array_equal(slice_, [False, False, True, True, False])

    # Out of bounds, valid minimum
    slice_ = range_mask(arr, minimum=0)
    np.testing.assert_array_equal(slice_, [True, True, True, True, True])

    # Out of bounds, invalid minimum
    with pytest.raises(ValueError):
        range_mask(arr, minimum=6)

    # Out of bounds, valid maximum
    slice_ = range_mask(arr, maximum=6)
    np.testing.assert_array_equal(slice_, [True, True, True, True, True])

    # Out of bounds, invalid minimum
    with pytest.raises(ValueError):
        range_mask(arr, maximum=0)

    # with NaNs
    arr = np.array([1, np.nan, 3, np.nan, 5])
    slice_ = range_mask(arr, minimum=3)
    np.testing.assert_array_equal(slice_, [False, False, True, False, True])


def test_prepare_arrays_nans():
    # Setup test values
    trains, pulses = 5, 10
    size = trains * pulses
    motor = np.arange(trains)
    signal = np.random.randint(100, size=(trains, pulses))

    # Test finite motor and signal values
    positions, intensities = prepare_arrays(motor, signal)
    assert positions.shape == (size,)
    assert intensities.shape == (size,)

    # Test finite motors and signals with NaNs
    signal_nan = _with_values(signal, value=np.nan, num=20)
    positions, intensities = prepare_arrays(motor, signal_nan)
    assert positions.shape == (size-20,)
    assert np.isfinite(positions).all()
    assert intensities.shape == (size-20,)
    assert np.isfinite(intensities).all()

    # Test finite signals and motors with NaNs
    motor_nan = _with_values(motor, value=np.nan, num=3)
    positions, intensities = prepare_arrays(motor_nan, signal)
    assert positions.shape == ((trains-3) * pulses,)
    assert np.isfinite(positions).all()
    assert intensities.shape == ((trains-3) * pulses,)
    assert np.isfinite(intensities).all()


def test_prepare_arrays_size():
    trains, pulses = 5, 10
    size = trains * pulses
    motor = np.arange(trains)
    signal = np.random.randint(100, size=(trains, pulses))

    # Test finite motor and 2D signals
    positions, intensities = prepare_arrays(motor, signal)
    assert positions.shape == (size,)
    assert intensities.shape == (size,)

    # Test finite motor and 1D signals
    positions, intensities = prepare_arrays(motor, signal[:, 0])
    assert positions.shape == (trains,)
    assert intensities.shape == (trains,)


def test_prepare_arrays_range():
    trains, pulses = 5, 10
    motor = np.arange(trains)
    signal = np.random.randint(100, size=(trains, pulses))

    # Test valid range, inside bounds
    positions, intensities = prepare_arrays(motor, signal, xRange=[2, 4])
    assert (positions.min(), positions.max()) == (2, 4)
    unique = np.unique(positions)
    np.testing.assert_array_equal(unique, [2, 3, 4])
    assert intensities.shape == (unique.size * pulses,)

    # Test invalid ranges
    with pytest.raises(ValueError):
        prepare_arrays(motor, signal, xRange=[5, 3])
    with pytest.raises(ValueError):
        prepare_arrays(motor, signal, xRange=[3, 3])


def test_knife_edge_base():
    p0 = [0, -1.5, 1, 0]
    x = np.linspace(-3, 3)
    y = erfc(x, *p0)
    noise = y * np.random.normal(0, .02, y.shape)  # 2% error
    eff_y = y + noise

    # Test noisy data
    popt, _ = knife_edge_base(x, eff_y)
    np.testing.assert_allclose(p0, popt, atol=1e-1)

    # Test flipped data
    popt, _ = knife_edge_base(x, eff_y[::-1])
    p0[1] = abs(p0[1])  # Absolute value when flipped
    np.testing.assert_allclose(p0, popt, atol=1e-1)


def _with_values(array, value, num=5):
    copy = array.astype(np.float)
    copy.ravel()[np.random.choice(copy.size, num, replace=False)] = value
    return copy