import hashlib
import io
import logging
import multiprocessing
import pathlib
import tempfile
import time
from contextlib import redirect_stdout
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from subprocess import PIPE, run
from typing import Any, Dict, List, Tuple

import h5py
import numpy as np
import pytest

import xfel_calibrate.calibrate as calibrate

from .callab_tests import automated_test_config

LOGGER = logging.getLogger(__name__)


def file_md5(
    tested_file: str,
    block_size: int = 2 ** 20,
) -> bytes:
    """Generating MD5 checksum for a file.

    Args:
        tested_file: File to be tested.
        block_size (_type_, optional): Block size for reading the file.
            Defaults to 2**20.
    """
    f = open(tested_file, "rb")
    md5 = hashlib.md5()
    while True:
        data = f.read(block_size)
        if not data:
            break
        md5.update(data)
    f.close()
    return md5.digest()


def collect_attrs(groups, datasets, objects, exclude_attrs, name, node):
    """Collect h5 attrs in groups, datasets, and objects lists."""
    if node.name not in exclude_attrs:
        if isinstance(node, h5py.Group):
            groups.append(name)
        elif isinstance(node, h5py.Dataset):
            if node.dtype == 'object':
                objects.append(name)
            else:
                datasets.append(name)


def compare_datasets(
    file1,
    file2,
    datasets: list
):
    """Compare the values of datasets in two h5 files."""
    h5_diff = []
    for d in datasets:
        try:
            if not np.allclose(file1[d][()], file2[d][()], equal_nan=True):
                h5_diff.append(d)
        except ValueError as e:
            LOGGER.error(f"ValueError: {e}, {d}")
            h5_diff.append(d)
        except AttributeError as e:
            LOGGER.error(f"AttributeError: {e}, {d}")
            h5_diff.append(d)
    return h5_diff


def compare_objects(
    file1,
    file2,
    objects: list
):
    """Compare the objects in two h5 files."""
    h5_diff = []
    for d in objects:
        try:
            if isinstance(file1[d][()], bytes):
                if (
                    file1[d][()].decode('utf-8') != file2[d][()].decode('utf-8')  # noqa
                ):
                    h5_diff.append(d)
            elif (
                file1[d][()].dtype != file1[d][()].dtype and
                not file1[d][()] != file2[d][()]
            ):  # pnccd files has only list of bytes
                h5_diff.append(d)
        except ValueError as e:
            LOGGER.error(f"ValueError: {e}, {d}")
            h5_diff.append(d)
        except AttributeError as e:
            LOGGER.error(f"AttributeError: {e}, {d}, "
                         f"{file1[d][()].decode('utf-8')}")
            h5_diff.append(d)
    return h5_diff


def find_differences(
    test_file,
    reference_file,
    exclude_attrs,
):
    """
    Find difference in groups, datasets, and objects between two h5files.
    Args:
        file1: first h5 file.
        file2: second h5 file.
    """

    groups_f1 = []
    datasets_f1 = []
    objects_f1 = []

    groups_f2 = []
    datasets_f2 = []
    objects_f2 = []

    with h5py.File(test_file, 'r') as file1, h5py.File(reference_file, 'r') as file2:  # noqa

        # Fill groups, datasets, and objects list
        # to compare both h5files' attrs.
        file1.visititems(
            partial(
                collect_attrs,
                groups_f1,
                datasets_f1,
                objects_f1,
                exclude_attrs,
                ))
        file2.visititems(
            partial(
                collect_attrs,
                groups_f2,
                datasets_f2,
                objects_f2,
                exclude_attrs,
                ))

        start_time = time.perf_counter()
        # Compare groups, datasets, and objects to have the same content.
        assert set(groups_f1) == set(groups_f2), f"{test_file} and {reference_file} consists of different groups."  # noqa
        assert set(datasets_f1) == set(datasets_f2), f"{test_file} and {reference_file} consists of different datasets."  # noqa
        assert set(objects_f1) == set(objects_f2), f"{test_file} and {reference_file} consists of different datasets."  # noqa
        duration = time.perf_counter() - start_time
        LOGGER.debug("Elapsed time comparing groups, "
                    f"datasets, and objects: {duration} seconds")
        LOGGER.debug("Groups, datasets, and objects have the same content.")

        # Compare datasets and objects.
        start_time = time.perf_counter()
        h5_diff_datasets = compare_datasets(file1, file2, datasets_f1)
        duration = time.perf_counter() - start_time
        LOGGER.debug(f"Elapsed time comparing datasets: {duration} seconds")
        start_time = time.perf_counter()
        h5_diff_objects = compare_objects(file1, file2, objects_f1)
        LOGGER.debug(f"Elapsed time comparing objects: {duration} seconds")

        assert not h5_diff_datasets, f"{[d for d in h5_diff_datasets]} datasets contain different values for {test_file} and {reference_file}"  # noqa
        LOGGER.debug("Datasets are validated.")
        assert not h5_diff_objects, f"{[d for d in h5_diff_objects]} objects contain different values for {test_file} and {reference_file}"  # noqa
        LOGGER.debug("Objects are validated.")


def validate_files(
    ref_folder: pathlib.PosixPath,
    out_folder: pathlib.PosixPath,
    exclude_attrs: list,
    test_file: pathlib.PosixPath,
) -> Tuple[bool, pathlib.PosixPath]:
    """Validate file similarities. Create temporary files to exclude
    h5 attributes known to be different. e.g `report` for constants.
    If both files are not identical, the function is able to loop over
    both files and find and fail on the difference.

    Args:
        ref_folder: The reference folder for validating the files
        out_folder: The output folder for the test constant files.
        test_file: The output file to be validated.
        exclude_attrs: A list of datasets, groups to exclude
          from validated files.
    Returns:
        result: validation result for metadata.
        test_file: The validated file.
    """
    import h5py
    start_validating = time.perf_counter()

    def exclude_sources(source_file, dest, excluded_sources):
        # Open the source file in read-only mode
        with h5py.File(source_file, 'r') as source:

            # Recursively visit all objects in the source file
            def visit_func(name, obj):
                # Check if the object should be excluded
                if name in excluded_sources:
                    return

                # Check if the object is a dataset
                if isinstance(obj, h5py.Dataset):
                    # Create a new dataset in the destination
                    # file and copy the data
                    dest.create_dataset(name, data=obj[()])

            # Visit all objects in the source file and
            # copy them to the destination file
            source.visititems(visit_func)

    with tempfile.NamedTemporaryFile(
        dir=out_folder,
        suffix=".tmp",
        prefix="cal_",
        delete=True,
        ) as out_tf, tempfile.NamedTemporaryFile(
            dir=out_folder,
            suffix=".tmp",
            prefix="cal_",
            delete=True,
            ) as ref_tf:

        # Create in-memory HDF5 files for validation
        with h5py.File(out_tf.name, 'a') as hp1, h5py.File(ref_tf.name, 'a') as hp2:  # noqa

            start_time = time.perf_counter()
            # Copy h5 files for validation and exclude selected attrs.
            exclude_sources(test_file, hp1, exclude_attrs)

            duration = time.perf_counter() - start_time
            LOGGER.debug(f"Elapsed time copying {test_file}: "
                        f"{duration} seconds")

            start_time = time.perf_counter()
            exclude_sources(ref_folder / test_file.name, hp2, exclude_attrs)

            duration = time.perf_counter() - start_time
            LOGGER.debug(f"Elapsed time copying {ref_folder / test_file.name}: "
                        f"{duration} seconds")

            start_time = time.perf_counter()
            result = file_md5(out_tf.name) == file_md5(ref_tf.name)
            LOGGER.debug(f"MD5 validation for {test_file}: {duration} seconds")
    duration = time.perf_counter() - start_validating
    return result, test_file


@dataclass
class ComparisonResult:
    filename: str
    new_dsets: list
    missing_dsets: list
    changed_dsets: list

    def found_differences(self):
        return bool(self.new_dsets or self.missing_dsets or self.changed_dsets)

    def show(self):
        if not self.found_differences():
            print(f"{self.filename} - ✓ no changes")

        print(self.filename)
        for ds in self.new_dsets:
            print(f"  + NEW {ds}")
        for ds in self.missing_dsets:
            print(f"  - DEL {ds}")
        for ds in self.changed_dsets:
            print(f"  ~ DIF {ds}")

def gather_dsets(f: h5py.File):
    res = set()
    def visitor(name, obj):
        if isinstance(obj, h5py.Dataset):
            res.add(name)
    f.visititems(visitor)
    return res

def validate_file(
    ref_folder: pathlib.PosixPath,
    out_folder: pathlib.PosixPath,
    exclude_dsets: set,
    test_file: str
) -> ComparisonResult:
    res = ComparisonResult(test_file, [], [], [])
    ref_file = ref_folder / test_file
    out_file = out_folder / test_file
    with h5py.File(ref_file) as fref, h5py.File(out_file) as fout:
        ref_dsets = gather_dsets(fref)
        out_dsets = gather_dsets(fout)
        changed = []
        for dsname in sorted((ref_dsets & out_dsets) - exclude_dsets):
            ref_arr = fref[dsname][()]
            out_arr = fout[dsname][()]
            if isinstance(ref_arr, np.ndarray) ^ isinstance(out_arr, np.ndarray):
                eq = False  # One is an array, the other not
            elif isinstance(ref_arr, np.ndarray):
                # Both arrays
                eq = np.array_equal(ref_arr, out_arr, equal_nan=True)
            else:
                # Both single values
                eq = (ref_arr == out_arr)
            if not eq:
                changed.append(dsname)

    return ComparisonResult(
        test_file,
        new_dsets=sorted(out_dsets - ref_dsets),
        missing_dsets=sorted(ref_dsets - out_dsets),
        changed_dsets=changed
    )

def parse_config(cmd: List[str], config: Dict[str, Any], out_folder: str) -> List[str]:
    """Convert a dictionary to a list of arguments.

    Values that are not strings will be cast.
    Lists will be converted to several strings following their `--key`
    flag.
    Booleans will be converted to a `--key` flag, where `key` is the
    dictionary key.
    """

    for key, value in config.items():
        if " " in key or (isinstance(value, str) and " " in value):
            raise ValueError("Spaces are not allowed", key, value)

        if isinstance(value, list):
            cmd.append(f"--{key}")
            cmd += [str(v) for v in value]
        elif isinstance(value, bool):
            if value:
                cmd += ["--{}".format(key)]
        else:
            if value in ['""', "''"]:
                value = ""
            if key == "out-folder":
                value = out_folder
            cmd += ["--{}".format(key), str(value)]
    return cmd


def validate_hdf5_files(
    test_key: str,
    out_folder: pathlib.Path,
    reference_folder: pathlib.Path,
    cal_type: str,
    find_difference: bool,
):
    """Apply HDF5 data validation.

    Args:
        test_key (str): The test name.
        out_folder (pathlib.Path): The OUT folder for the tested data.
        reference_folder (pathlib.Path): The Reference folder for
          the reference data to validate against
        cal_type (str): The type of calibration processing.
          e.g. dark or correct.
        find_difference (bool): A flag indicating a need to find the
          difference between two files if tested data was
          not identical to the reference data.
    """
    print("\n--- Compare HDF5 files  ----")
    print("REF:", reference_folder)
    print("NEW:", out_folder)
    ok = True

    result_h5files = {p.name for p in out_folder.glob("*.h5")}
    ref_h5files = {p.name for p in reference_folder.glob("*.h5")}
    missing_files = ref_h5files - result_h5files
    if missing_files:
        print("Files missing from result (*.h5):", ", ".join(missing_files))
        ok = False
    new_files = result_h5files - ref_h5files
    if new_files:
        print("New files in result (*.h5):", ", ".join(new_files))
        ok = False

    files_to_check = sorted(result_h5files & ref_h5files)

    # Hard coded datasets to exclude from numerical validation.
    # These datasets are know to be updated everytime.
    if cal_type.lower() == "correct":
        exclude_attrs = {"METADATA/creationDate", "METADATA/updateDate"}
    else:
        exclude_attrs = {"report"}

    _validate_file = partial(
        validate_file,
        reference_folder,
        out_folder,
        exclude_attrs,
    )
    with multiprocessing.Pool(processes=8) as pool:
        for comparison in pool.imap(_validate_file, files_to_check):
            comparison.show()
            if comparison.found_differences():
                ok = False

    assert ok, "HDF5 files changed - see details above"



def slurm_watcher(
    test_key: str,
    std_out: str
):
    """
    Watch for submitted slurm jobs and wait for them to finish.
    After they finish apply first test and check
    if they were `COMPLETED`, successfully.

    Args:
        test_key (str): Test name.
        out_str (str): xfel-calibrate CLU std output.
    """
    slurm_watcher = True

    LOGGER.info(f"{test_key} - xfel-calibrate std out: {std_out}")

    for r in std_out.split("\n"):
        if "Submitted the following SLURM jobs:" in r:
            _, jobids = r.split(":")

    # Adding a sleep for the slurm jobs initialization
    time.sleep(len(jobids.split(",")))
    jobids = jobids.strip()
    while slurm_watcher:
        cmd = ["sacct", "-j", jobids, "--format=state"]

        res = run(cmd, stdout=PIPE)
        states = res.stdout.decode().split("\n")[2:-1]

        if not any(s.strip() in [
            "COMPLETING",
            "RUNNING",
            "CONFIGURING",
            "PENDING",
        ] for s in states):
            slurm_watcher = False
        else:
            time.sleep(2)

    # 1st check that all jobs were COMPLETED without errors.
    states = res.stdout.decode().split("\n")[2:-1]
    assert all(s.strip() == "COMPLETED" for s in states), f"{test_key} failure, calibration jobs were not completed. {jobids}: {states}"  # noqa
    LOGGER.info(f"{test_key}'s jobs were COMPLETED")


@pytest.mark.manual_run
@pytest.mark.parametrize(
    "test_key, val_dict",
    list(automated_test_config.items()),
    ids=list(automated_test_config.keys()),
)
def test_xfel_calibrate(
    test_key: str, val_dict: dict,
    release_test_config: Tuple[bool, bool, bool, bool]
):
    """ Test xfel calibrate detectors and calibrations written
    in the given callab_test YAML file.
    Args:
        test_key : Key for the xfel-calibrate test.
        val_dict: Dictionary of the configurations for the running test.
        release_test_config: Tuple of booleans to pick or skip tests
            based on the given boolean configs.
    """

    (
        detectors, calibration, picked_test,
        skip_numerical_validation, only_validate, find_difference,
        use_slurm, reference_dir_base, out_dir_base,
    ) = release_test_config

    cal_type = val_dict["cal_type"]
    det_type = val_dict["det_type"]

    if not picked_test:
        # Skip non-selected detectors
        if (
            detectors != ["all"] and
            det_type.lower() not in [d.lower() for d in detectors]
        ):
            pytest.skip()

        # Skip non-selected calibration
        if calibration != "all" and cal_type.lower() != calibration:
            pytest.skip()
    else:
        if test_key != picked_test:
            pytest.skip()

    cmd = ["xfel-calibrate", det_type, cal_type]

    cal_conf = val_dict["config"]

    out_folder = pathlib.Path(cal_conf["out-folder"].format(
        out_dir_base, cal_conf["karabo-id"], test_key))
    reference_folder = pathlib.Path(val_dict["reference-folder"].format(
        reference_dir_base, cal_conf["karabo-id"], test_key))

    report_name = (
        out_folder /
        f"{test_key}_{datetime.now().strftime('%y%m%d_%H%M%S')}")

    cal_conf["report-to"] = str(report_name)

    cmd = parse_config(cmd, cal_conf, out_folder)

    if only_validate:
        validate_hdf5_files(
            test_key,
            out_folder,
            reference_folder,
            cal_type,
            find_difference,
            )
        return

    if not use_slurm:  # e.g. for Gitlab CI.
        cmd += ["--no-cluster-job"]

    cmd += ["--slurm-name", test_key, "--cal-db-interface", "tcp://max-exfl-cal001:8015#8045"]
    f = io.StringIO()
    LOGGER.info(f"Submitting CL: {cmd}")
    with redirect_stdout(f):
        errors = calibrate.run(cmd)
        out_str = f.getvalue()

    if use_slurm:
        slurm_watcher(test_key, out_str)
    else:
        # confirm that all jobs succeeded.
        assert errors == 0

    time_to_wait = 5
    time_counter = 0
    # 2nd check for report availability.
    report_file = out_folder / f"{report_name}.pdf"
    while not report_file.exists():
        time.sleep(1)
        time_counter += 1
        if time_counter > time_to_wait:
            assert False, f"{test_key} failure, report doesn't exists."
    LOGGER.info("Report found.")

    # Stop tests at this point, if desired.
    if not skip_numerical_validation:
        validate_hdf5_files(
            test_key,
            out_folder,
            reference_folder,
            cal_type,
            find_difference,
        )