import numpy as np

from cal_tools.agipdutils_ff import get_mask, set_par_limits


def test_get_mask():
    fit_summary = {
        'chi2_ndof': 1.674524751845516,
        'g0n': 6031.641198873036,
        'error_g0n': 94.63055028459667,
        'limit_g0n': np.array([0.0, None]),
        'fix_g0n': False,
        'g0mean': -13.711814669099589,
        'error_g0mean': 0.2532017427306297,
        'limit_g0mean': np.array([-30,  30]),
        'fix_g0mean': False,
        'g0sigma': 13.478502058651568,
        'error_g0sigma': 0.2588135637661919,
        'limit_g0sigma': np.array([0, 30]),
        'fix_g0sigma': False,
        'g1n': 4337.126861254491,
        'error_g1n': 69.764180118274,
        'limit_g1n': np.array([0, None]),
        'fix_g1n': False,
        'g1mean': 53.90265411499657,
        'error_g1mean': 0.27585684670864746,
        'limit_g1mean': None,
        'fix_g1mean': False,
        'g1sigma': 15.687448834904817,
        'error_g1sigma': 0.2951166525483524,
        'limit_g1sigma': np.array([0, 35]),
        'fix_g1sigma': False,
        'g2n': 1542.531700635003,
        'error_g2n': 43.20145030604567,
        'limit_g2n': np.array([0, None]),
        'fix_g2n': False,
        'g2mean': 120.98535387591575,
        'error_g2mean': 0.509566354942716,
        'limit_g2mean': None,
        'fix_g2mean': False,
        'g2sigma': 15.550452880533143,
        'error_g2sigma': 0.5003254358001863,
        'limit_g2sigma': np.array([0, 40]),
        'fix_g2sigma': False,
        'g3n': 1261189.2282171287,
        'error_g3n': 1261190.2282163086,
        'limit_g3n': np.array([0, None]),
        'fix_g3n': False,
        'g3mean': 348.68766379647343,
        'error_g3mean': 17.23872285713375,
        'limit_g3mean': None,
        'fix_g3mean': False,
        'g3sigma': 44.83987230934497,
        'error_g3sigma': 30.956164693249242,
        'limit_g3sigma': np.array([0, 45]),
        'fix_g3sigma': False,
        'fval': 336.5794751209487,
        'edm': 0.00011660826330754263,
        'tolerance': 0.1,
        'nfcn': 4620,
        'ncalls': 4620,
        'up': 1.0,
        'is_valid': True,
        'has_valid_parameters': True,
        'has_accurate_covar': True,
        'has_posdef_covar': True,
        'has_made_posdef_covar': False,
        'hesse_failed': False,
        'has_covariance': True,
        'is_above_max_edm': False,
        'has_reached_call_limit': False}
    peak_lim = [-30, 30]
    d0_lim = [10, 80]
    chi2_lim = [0, 3.0]
    peak_width_lim = np.array([[0.9, 1.55], [0.95, 1.65]])
    mask = get_mask(fit_summary, peak_lim, d0_lim, chi2_lim, peak_width_lim)
    assert mask == 0


def test_set_par_limits():
    peak_range = np.array([[-30, 30],
                           [35, 70],
                           [95, 135],
                           [145,  220]])

    peak_norm_range = np.array([[0.0, None],
                                [0, None],
                                [0, None],
                                [0, None]])
    peak_width_range = np.array([[0, 30],
                                 [0, 35],
                                 [0, 40],
                                 [0, 45]])

    parameters = {
        'g0sigma': 9.620186459204016,
        'g0n': 5659.0,
        'g0mean': -3.224686340342817,
        'g1sigma': 8.149415371586683,
        'g1n': 3612.0,
        'g1mean': 54.6281838316722,
        'g2sigma': 9.830124777667839,
        'g2n': 1442.0,
        'g2mean': 114.92510402219139,
        'g3sigma': 15.336595220228498,
        'g3n': 474.0,
        'g3mean': 167.0295358649789}

    expected = {
        'g0sigma': 9.620186459204016,
        'g0n': 5659.0,
        'g0mean': -3.224686340342817,
        'g1sigma': 8.149415371586683,
        'g1n': 3612.0,
        'g1mean': 54.6281838316722,
        'g2sigma': 9.830124777667839,
        'g2n': 1442.0,
        'g2mean': 114.92510402219139,
        'g3sigma': 15.336595220228498,
        'g3n': 474.0,
        'g3mean': 167.0295358649789,
        'limit_g0n': np.array([0.0, None]),
        'limit_g0mean': np.array([-30,  30]),
        'limit_g0sigma': np.array([0, 30]),
        'limit_g1n': np.array([0, None]),
        'limit_g1mean': np.array([35, 70]),
        'limit_g1sigma': np.array([0, 35]),
        'limit_g2n': np.array([0, None]),
        'limit_g2mean': np.array([95, 135]),
        'limit_g2sigma': np.array([0, 40]),
        'limit_g3n': np.array([0, None]),
        'limit_g3mean': np.array([145, 220]),
        'limit_g3sigma': np.array([0, 45])}

    set_par_limits(parameters, peak_range, peak_norm_range, peak_width_range)
    assert parameters.keys() == expected.keys()
    for key in parameters:
        if isinstance(parameters[key], np.ndarray):
            assert np.all(parameters[key] == expected[key])
        else:
            assert parameters[key] == expected[key]