from datetime import datetime
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pytest
import yaml
import zmq
from extra_data import open_run
from iCalibrationDB import Conditions, ConstantMetaData, Constants

from cal_tools.plotting import show_processed_modules
from cal_tools.tools import (
    creation_date_file_metadata,
    creation_date_train_timestamp,
    get_dir_creation_date,
    get_from_db,
    get_pdu_from_db,
    map_seq_files,
    module_index_to_qm,
    recursive_update,
    send_to_db,
    write_constants_fragment,
    reorder_axes,
)

# AGIPD operating conditions.
ACQ_RATE = 1.1
BIAS_VOLTAGE = 300
GAIN_SETTING = 0
MEM_CELLS = 352

AGIPD_KARABO_ID = "SPB_DET_AGIPD1M-1"
WRONG_AGIPD_MODULE = "AGIPD_**"

CAL_DB_INTERFACE = "tcp://max-exfl-cal002:8020"
WRONG_CAL_DB_INTERFACE = "tcp://max-exfl-cal002:0000"

PROPOSAL = 900113


@pytest.fixture
def _agipd_const_cond():
    # AGIPD dark offset metadata
    constant = Constants.AGIPD.Offset()

    condition = Conditions.Dark.AGIPD(
        memory_cells=MEM_CELLS,
        bias_voltage=BIAS_VOLTAGE,
        acquisition_rate=ACQ_RATE,
        gain_setting=GAIN_SETTING,
    )
    return constant, condition


def test_show_processed_modules():
    mnames = ['Q1M1']

    with pytest.raises(ValueError) as err:
        show_processed_modules('LDP', mnames=mnames,
                               constants={}, mode='processed')
        assert 'LDP' in err.value()


@pytest.mark.requires_gpfs
@pytest.mark.parametrize(
    "karabo_da,sequences,expected_len",
    [
        ("AGIPD00", [-1], 3),
        ("AGIPD00", [0, 1], 2),
        ("EPIX01", [-1], 0),
        ("AGIPD00", [117], 0),
    ],
)
def test_map_seq_files(karabo_da, sequences, expected_len):
    run_folder = Path('/gpfs/exfel/exp/CALLAB/202031/p900113/raw/r9983')
    expected_dict = {karabo_da: []}

    if expected_len:
        sequences = range(expected_len) if sequences == [-1] else sequences
        expected_dict = {
            karabo_da: [run_folder / f"RAW-R9983-AGIPD00-S0000{s}.h5" for s in sequences]  # noqa
        }

    assert map_seq_files(run_folder, [karabo_da], sequences) == (
        expected_dict, expected_len)


@pytest.mark.requires_gpfs
def test_dir_creation_date():
    """This test is based on not connecting to MDC and failing to use
    `creation_date_metadata_client()`
    """
    folder = '/gpfs/exfel/exp/CALLAB/202031/p900113/raw'

    date = get_dir_creation_date(folder, 9983)
    assert isinstance(date, datetime)
    assert str(date) == '2020-09-23 13:30:45.821262+00:00'

    # The following data predates the addition of creation_time in metadata
    date = get_dir_creation_date(folder, 9999)
    assert isinstance(date, datetime)
    assert str(date) == '2019-12-16 07:52:25.196603+00:00'


@pytest.mark.requires_gpfs
def test_raise_dir_creation_date():
    folder = '/gpfs/exfel/exp/CALLAB/202031/p900113/raw'

    with pytest.raises(FileNotFoundError) as e:
        get_dir_creation_date(folder, 4)
    assert e.value.args[1] == Path(folder) / 'r0004'


@pytest.mark.requires_gpfs
def test_creation_date_file_metadata():

    date = creation_date_file_metadata(
        Path('/gpfs/exfel/exp/CALLAB/202031/p900113/raw/r9983'))
    assert isinstance(date, datetime)
    assert str(date) == '2020-09-23 13:30:50+00:00'

    # Old run without METADATA/CreationDate
    date = creation_date_file_metadata(
        Path('/gpfs/exfel/exp/CALLAB/202031/p900113/raw/r9999'))

    assert date is None


@pytest.mark.requires_gpfs
def test_creation_date_train_timestamp():

    date = creation_date_train_timestamp(open_run(PROPOSAL, 9983))

    assert isinstance(date, datetime)
    assert str(date) == '2020-09-23 13:30:45.821262+00:00'

    # Old run without trainId timestamps
    date = creation_date_train_timestamp(open_run(PROPOSAL, 9999))

    assert date is None


def _call_get_from_db(
    constant,
    condition,
    karabo_id,
    karabo_da,
    load_data=True,
    cal_db_interface=CAL_DB_INTERFACE,
    creation_time=None,
    doraise=True,
    timeout=10000,
):

    data, metadata = get_from_db(
        karabo_id=karabo_id,
        karabo_da=karabo_da,
        constant=constant,
        condition=condition,
        empty_constant=None,
        cal_db_interface=cal_db_interface,
        creation_time=creation_time,
        meta_only=True,
        load_data=load_data,
        ntries=1,
        doraise=doraise,
        timeout=timeout,
    )
    return data, metadata


def _call_send_to_db(
    constant,
    condition,
    db_module,
    data=np.zeros((2, 2, 2)),
    cal_db_interface=CAL_DB_INTERFACE,
    report_path="",
    doraise=True,
    timeout=1000,
    ntries=1,
):

    # TODO: create a known_constant for testing.
    constant.data = data
    metadata = send_to_db(
        karabo_id=AGIPD_KARABO_ID,
        db_module=db_module,
        constant=constant,
        condition=condition,
        file_loc="proposal, runs",
        report_path=report_path,
        cal_db_interface=cal_db_interface,
        creation_time=None,
        ntries=ntries,
        doraise=doraise,
        timeout=timeout,
    )
    return metadata


@pytest.mark.requires_caldb
@pytest.mark.requires_gpfs
def test_get_from_db_load_data(_agipd_const_cond):
    """ Test retrieving calibration constants with get_from_db
    with different loading data scenarios.
    """
    constant, condition = _agipd_const_cond

    # Normal operation and loading data from h5file.
    data, md = _call_get_from_db(
        constant=constant, condition=condition,
        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
    )

    assert type(data) is np.ndarray
    assert isinstance(md, ConstantMetaData)

    # None karabo_id is given.
    data, md = _call_get_from_db(
        constant=constant, condition=condition,
        karabo_id=None, karabo_da="AGIPD00",
    )

    assert data is None
    assert md is None

    # Retrieve constant without loading the data.
    data, md = _call_get_from_db(
        constant=constant, condition=condition,
        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
        load_data=False,
    )

    assert data is None
    assert isinstance(md, ConstantMetaData)


@pytest.mark.requires_caldb
@pytest.mark.requires_gpfs
def test_raise_get_from_db(_agipd_const_cond):
    """ Test error raised scenarios for get_from_db:"""

    constant, condition = _agipd_const_cond

    # Wrong address for the calibration database.
    with pytest.raises(zmq.error.Again) as excinfo:
        _call_get_from_db(
            constant=constant, condition=condition,
            karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
            cal_db_interface=WRONG_CAL_DB_INTERFACE,
        )
    assert str(excinfo.value) == "Resource temporarily unavailable"

    # Wrong type for creation_time.
    with pytest.raises(TypeError):
        _call_get_from_db(
            constant=constant, condition=condition,
            karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
            creation_time="WRONG_CREATION_TIME",
        )

    # No constant file path metadata retrieved.
    with patch("iCalibrationDB.ConstantMetaData.retrieve", return_value=""):
        with pytest.raises(ValueError):
            _call_get_from_db(
                constant=constant, condition=condition,
                karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
            )


def test_no_doraise_get_from_db(_agipd_const_cond):
    """get_from_db using wrong cal_db_interface
    fails without raising errors, as doraise = False
    """
    constant, condition = _agipd_const_cond

    data, _ = _call_get_from_db(
        constant=constant, condition=condition,
        karabo_id=AGIPD_KARABO_ID, karabo_da="AGIPD00",
        cal_db_interface=WRONG_CAL_DB_INTERFACE,
        doraise=False,
    )
    assert data is None


@patch(
    'iCalibrationDB.ConstantMetaData.send',
    return_value='',
)
def test_send_to_db_success(send, _agipd_const_cond):
    """test sending constants to the database (send_to_db):
    Injecting constant as expected.
    # TODO: Add a test calibration constant to the test physical module
    # to inject without mocking `send` method.
    """
    # Use wrong AGIPD module as a backup.
    # To avoid constants injection, in-case of mock failure.

    constant, condition = _agipd_const_cond

    db_module = WRONG_AGIPD_MODULE
    metadata = _call_send_to_db(
        constant=constant,
        condition=condition,
        db_module=db_module,
    )

    assert isinstance(metadata, ConstantMetaData)


@patch(
    'iCalibrationDB.ConstantMetaData.send',
    return_value='',
)
def test_raise_send_to_db_mocked(send, _agipd_const_cond):
    """Test raised errors while sending constants to the
    database (send_to_db):
    """
    # Use wrong AGIPD module as a backup.
    # To avoid constants injection, in-case of mock failure.
    constant, condition = _agipd_const_cond

    # report_path has the wrong type.
    with pytest.raises(TypeError):
        _call_send_to_db(
            constant=constant,
            condition=condition,
            db_module=WRONG_AGIPD_MODULE,
            report_path=2,
        )

    # No constant data to inject.
    with pytest.raises(ValueError):
        _call_send_to_db(
            constant=constant,
            condition=condition,
            db_module=WRONG_AGIPD_MODULE,
            data=None,
        )


def test_raise_send_to_db(_agipd_const_cond):

    constant, condition = _agipd_const_cond

    # wrong calibration database address.
    with pytest.raises(zmq.error.Again) as excinfo:
        _call_send_to_db(
            constant=constant,
            condition=condition,
            db_module=WRONG_AGIPD_MODULE,
            cal_db_interface=WRONG_CAL_DB_INTERFACE,
        )
    assert str(excinfo.value) == "Resource temporarily unavailable"


@pytest.mark.requires_caldb
def test_get_pdu_from_db(_agipd_const_cond):

    constant, condition = _agipd_const_cond
    snapshot_at = "2021-05-06 00:20:10.00"

    # A karabo_da str returns a list of one element.
    pdu_dict = get_pdu_from_db(karabo_id="TEST_DET_CAL_CI-1",
                               karabo_da="TEST_DET_CAL_DA0",
                               constant=constant,
                               condition=condition,
                               cal_db_interface=CAL_DB_INTERFACE,
                               snapshot_at=snapshot_at,
                               timeout=30000)
    assert len(pdu_dict) == 1
    assert pdu_dict[0] == "CAL_PHYSICAL_DETECTOR_UNIT-1_TEST"

    # A list of karabo_das to return thier PDUs, if available.
    pdu_dict = get_pdu_from_db(karabo_id="TEST_DET_CAL_CI-1",
                               karabo_da=[
                                   "TEST_DET_CAL_DA0",
                                   "TEST_DET_CAL_DA1",
                                   "UNAVAILABLE_DA"],
                               constant=constant,
                               condition=condition,
                               cal_db_interface=CAL_DB_INTERFACE,
                               snapshot_at=snapshot_at,
                               timeout=30000)

    assert pdu_dict == ["CAL_PHYSICAL_DETECTOR_UNIT-1_TEST",
                        "CAL_PHYSICAL_DETECTOR_UNIT-2_TEST",
                        None]

    # "all" is used to return all corresponding units for a karabo_id.
    pdu_dict = get_pdu_from_db(karabo_id="TEST_DET_CAL_CI-1",
                               karabo_da="all",
                               constant=constant,
                               condition=condition,
                               cal_db_interface=CAL_DB_INTERFACE,
                               snapshot_at=snapshot_at,
                               timeout=30000)

    assert len(pdu_dict) == 2
    assert pdu_dict == ["CAL_PHYSICAL_DETECTOR_UNIT-1_TEST",
                        "CAL_PHYSICAL_DETECTOR_UNIT-2_TEST"]


def test_module_index_to_qm():

    assert module_index_to_qm(0) == 'Q1M1'
    assert module_index_to_qm(1) == 'Q1M2'
    assert module_index_to_qm(4) == 'Q2M1'
    assert module_index_to_qm(6) == 'Q2M3'

    assert module_index_to_qm(4, 5) == 'Q5M1'

    with pytest.raises(AssertionError):
        module_index_to_qm(18)

    with pytest.raises(AssertionError):
        module_index_to_qm(7, 5)


def test_recursive_update():
    tgt = {"a": {"b": 1}, "c": 2}
    src = {"a": {"d": 3}, "e": 4}
    assert recursive_update(tgt, src) is False
    assert tgt == {"a": {"b": 1, "d": 3}, "c": 2, "e": 4}

    tgt = {"a": {"b": 1}, "c": 2}
    src = {"a": {"b": 3}, "e": 4}
    assert recursive_update(tgt, src) is True
    assert tgt == {"a": {"b": 1}, "c": 2, "e": 4}


def test_write_constants_fragment(tmp_path: Path):
    """Test `write_constants_fragment` with jungfrau.
    This metadata is from constants used to correct FXE_XAD_JF1M
    detector from proposal 900226, run 106.

    tmp_path:
        tmp_path (pathlib.Path): Temporary directory for file tests.
        https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html
    """

    jf_metadata = {
        "JNGFR01": {
            "Offset10Hz": {
                "cc_id": 7064,
                "cc_name": "jungfrau-Type_Offset10Hz_Jungfrau DefiFE6iJX",
                "condition_id": 2060,
                "ccv_id": 41876,
                "ccv_name": "20200304_152733_sIdx=0",
                "path": Path("xfel/cal/jungfrau-type/jungfrau_m233/cal.1583335651.8084984.h5"),
                "dataset": "/Jungfrau_M233/Offset10Hz/0",
                "begin_validity_at": "2020-03-04T15:16:34.000+01:00",
                "end_validity_at": None,
                "raw_data_location": "proposal:p900121 runs:136 137 138",
                "start_idx": 0,
                "end_idx": 0,
                "physical_name": "Jungfrau_M233"},
            "BadPixelsDark10Hz": {
                "cc_id": 7066,
                "cc_name": "jungfrau-Type_BadPixelsDark10Hz_Jungfrau DefiFE6iJX",
                "condition_id": 2060,
                "ccv_id": 41878,
                "ccv_name": "20200304_152740_sIdx=0",
                "path": Path("xfel/cal/jungfrau-type/jungfrau_m233/cal.1583335658.6813955.h5"),
                "dataset": "/Jungfrau_M233/BadPixelsDark10Hz/0",
                "begin_validity_at": "2020-03-04T15:16:34.000+01:00",
                "end_validity_at": None,
                "raw_data_location": "proposal:p900121 runs:136 137 138",
                "start_idx": 0,
                "end_idx": 0,
                "physical_name": "Jungfrau_M233"
                }
            },
        "JNGFR02": {
            "Offset10Hz": {
                "cc_id": 7067,
                "cc_name": "jungfrau-Type_Offset10Hz_Jungfrau DefzgIVHz1",
                "condition_id": 2061,
                "ccv_id": 41889,
                "ccv_name": "20200304_154434_sIdx=0",
                "path": Path("xfel/cal/jungfrau-type/jungfrau_m125/cal.1583336672.760199.h5"),
                "dataset": "/Jungfrau_M125/Offset10Hz/0",
                "begin_validity_at": "2020-03-04T15:16:34.000+01:00",
                "end_validity_at": None,
                "raw_data_location": "proposal:p900121 runs:136 137 138",
                "start_idx": 0,
                "end_idx": 0,
                "physical_name": "Jungfrau_M125",
                },
            "BadPixelsDark10Hz": {
                "cc_id": 7069,
                "cc_name": "jungfrau-Type_BadPixelsDark10Hz_Jungfrau DefzgIVHz1",
                "condition_id": 2061,
                "ccv_id": 41893,
                "ccv_name": "20200304_154441_sIdx=0",
                "path": Path("xfel/cal/jungfrau-type/jungfrau_m125/cal.1583336679.5835564.h5"),
                "dataset": "/Jungfrau_M125/BadPixelsDark10Hz/0",
                "begin_validity_at": "2020-03-04T15:16:34.000+01:00",
                "end_validity_at": None,
                "raw_data_location": "proposal:p900121 runs:136 137 138",
                "start_idx": 0,
                "end_idx": 0,
                "physical_name": "Jungfrau_M125",
                }
            }
        }

    write_constants_fragment(
        tmp_path,
        jf_metadata,
        Path("/gpfs/exfel/d/cal/caldb_store")
    )
    fragments = list(tmp_path.glob("metadata_frag*yml"))
    assert len(fragments) == 1

    # Open YAML file
    with open(fragments[0], "r") as file:
        # Load YAML content into dictionary
        yaml_dict = yaml.safe_load(file)
        assert yaml_dict == {
            "retrieved-constants":{
                "JNGFR01": {
                    "constants": {
                        "BadPixelsDark10Hz": {
                            "ccv_id": 41878,
                            "creation-time": "2020-03-04T15:16:34.000+01:00",
                            "dataset": "/Jungfrau_M233/BadPixelsDark10Hz/0",
                            "path": "/gpfs/exfel/d/cal/caldb_store/xfel/cal/jungfrau-type/jungfrau_m233/cal.1583335658.6813955.h5",  # noqa
                        },
                        "Offset10Hz": {
                            "ccv_id": 41876,
                            "creation-time": "2020-03-04T15:16:34.000+01:00",
                            "dataset": "/Jungfrau_M233/Offset10Hz/0",
                            "path": "/gpfs/exfel/d/cal/caldb_store/xfel/cal/jungfrau-type/jungfrau_m233/cal.1583335651.8084984.h5",  # noqa
                        },
                    },
                    "physical-name": "Jungfrau_M233",
                },
                "JNGFR02": {
                    "constants": {
                        "BadPixelsDark10Hz": {
                            "ccv_id": 41893,
                            "creation-time": "2020-03-04T15:16:34.000+01:00",
                            "dataset": "/Jungfrau_M125/BadPixelsDark10Hz/0",
                            "path": "/gpfs/exfel/d/cal/caldb_store/xfel/cal/jungfrau-type/jungfrau_m125/cal.1583336679.5835564.h5",  # noqa
                        },
                        "Offset10Hz": {
                            "ccv_id": 41889,
                            "creation-time": "2020-03-04T15:16:34.000+01:00",
                            "dataset": "/Jungfrau_M125/Offset10Hz/0",
                            "path": "/gpfs/exfel/d/cal/caldb_store/xfel/cal/jungfrau-type/jungfrau_m125/cal.1583336672.760199.h5",  # noqa
                        },
                    },
                    "physical-name": "Jungfrau_M125",
                },
            }
        }


def test_reorder_axes():
    a = np.zeros((10, 32, 256, 3))
    from_order = ('cells', 'slow_scan', 'fast_scan', 'gain')
    to_order = ('slow_scan', 'fast_scan', 'cells', 'gain')
    assert reorder_axes(a, from_order, to_order).shape == (32, 256, 10, 3)

    to_order = ('gain', 'fast_scan', 'slow_scan', 'cells')
    assert reorder_axes(a, from_order, to_order).shape == (3, 256, 32, 10)