import numpy as np
import pytest

from cal_tools.gotthard2.gotthard2algs import convert_to_10bit, correct_train


def test_convert_to_10bit():
    """Test converting 12bit Gotthard2 image to 10bit."""
    n_stripes = 10
    n_pulses = 500

    # Define LUT, raw data 12 bit, and raw data 10bit array.
    lut = np.array(
        [[list(range(4096//2))*2, list(range(4096//2, 4096))*2]] * n_stripes,
        dtype=np.uint16
    )
    raw_data = np.array([list(range(n_stripes))]*n_pulses, dtype=np.uint16)
    raw_data10bit = np.zeros(raw_data.shape, dtype=np.float32)

    result = np.concatenate(
        [
            np.array(x)[:, None] for x in [
                list(range(n_stripes)),
                list(range(2048, 2048+n_stripes))
            ] * (n_pulses//2)], axis=1, dtype=np.float32,
    ).T

    convert_to_10bit(raw_data, lut.astype(np.uint16), raw_data10bit)
    assert np.allclose(result, raw_data10bit)


@pytest.mark.parametrize("gain_corr", [True, False])
def test_correct_train(gain_corr):
    """Test gotthard2 correction function."""

    raw_d = np.random.randn(2700, 1280).astype(np.float32)
    gain = np.random.choice([0, 1, 2], size=(2700, 1280)).astype(np.uint8)

    offset = np.random.randn(1280, 2, 3).astype(np.float32)
    gain_map = np.random.randn(1280, 2, 3).astype(np.float32)
    badpixles = np.zeros_like(offset).astype(np.uint32).astype(np.uint32)

    test_data = raw_d.copy()
    mask = np.zeros_like(test_data).astype(np.uint32)

    correct_train(
        data=test_data, mask=mask, gain=gain, offset_map=offset,
        gain_map=gain_map, bpix_map=badpixles, apply_gain=gain_corr,
    )

    ref_data = raw_d.copy()

    ref_data[::2, :] -= np.choose(
        gain[::2, :], np.moveaxis(offset[:, 0, :], 1, 0))
    ref_data[1::2, :] -= np.choose(
        gain[1::2, :], np.moveaxis(offset[:, 1, :], 1, 0))

    if gain_corr:
        ref_data[::2, :] /= np.choose(
            gain[::2, :], np.moveaxis(gain_map[:, 0, :], 1, 0))
        ref_data[1::2, :] /= np.choose(
            gain[1::2, :], np.moveaxis(gain_map[:, 1, :], 1, 0))

    assert np.allclose(test_data, test_data)