import hashlib
import io
import multiprocessing
import pathlib
import tempfile
import time
from contextlib import redirect_stdout
from datetime import datetime
from functools import partial
from subprocess import PIPE, run
from typing import Any, Dict, List, Tuple

import h5py
import pytest

import xfel_calibrate.calibrate as calibrate
from cal_tools import h5_copy_except

from .callab_tests import automated_test_config


def file_md5(
    tested_file: str,
    block_size: int = 2 ** 20,
) -> bytes:
    """Generating MD5 checksum for a file.

    Args:
        tested_file: File to be tested.
        block_size (_type_, optional): Block size for reading the file.
            Defaults to 2**20.
    """
    f = open(tested_file, "rb")
    md5 = hashlib.md5()
    while True:
        data = f.read(block_size)
        if not data:
            break
        md5.update(data)
    f.close()
    return md5.digest()


def validate_h5files(
    f: pathlib.PosixPath, ref_folder: pathlib.PosixPath
) -> Tuple[bool, pathlib.PosixPath]:
    """Validate the h5files based on the generated MD5 checksum.

    Args:
        f: file to validate
        ref_folder: Reference folder which has the same filename
            as of the file to validate.

    Returns:
        Validation result.
        The validate file.
    """
    return file_md5(f) == file_md5(ref_folder / f.name), f


def validate_constant_file(
    ref_folder: pathlib.PosixPath,
    out_folder: pathlib.PosixPath,
    test_file: pathlib.PosixPath,
) -> Tuple[bool, bool, pathlib.PosixPath]:
    """Validate the constants files by validating the metadata
    stored in the h5files, e.g. conditions and db_modules. This is
    done after excluding the report and data.

    Args:
        ref_folder: The reference folder for validating the files
        out_folder: The output folder for the test constant files.
        out_f: The output file to be validated.

    Returns:
        result: validation result for metadata.
        test_file: The validated file.

    """
    out_tf = tempfile.NamedTemporaryFile(
        dir=out_folder,
        suffix=".tmp",
        prefix="cal_",
        delete=True,
    )
    ref_tf = tempfile.NamedTemporaryFile(
        dir=out_folder,
        suffix=".tmp",
        prefix="cal_",
        delete=True,
    )
    hp1 = h5py.File(out_tf.name, "w", driver="core", backing_store=True)
    hp2 = h5py.File(ref_tf.name, "w", driver="core", backing_store=True)
    # Copy RAW non-calibrated sources.
    with h5py.File(test_file, "r") as sfile:
        h5_copy_except.h5_copy_except_paths(
            sfile,
            hp1,
            ["report"],
        )
    with h5py.File(ref_folder / test_file.name, "r") as sfile:
        h5_copy_except.h5_copy_except_paths(
            sfile,
            hp2,
            ["report"],
        )
    hp1.close()
    hp2.close()
    result = file_md5(out_tf.name) == file_md5(ref_tf.name)
    out_tf.close()
    ref_tf.close()

    return result, test_file


def parse_config(cmd: List[str], config: Dict[str, Any],
                 out_folder: str) -> List[str]:
    """Convert a dictionary to a list of arguments.

    Values that are not strings will be cast.
    Lists will be converted to several strings following their `--key`
    flag.
    Booleans will be converted to a `--key` flag, where `key` is the
    dictionary key.
    """

    for key, value in config.items():
        if " " in key or (isinstance(value, str) and " " in value):
            raise ValueError("Spaces are not allowed", key, value)

        if isinstance(value, list):
            cmd.append(f"--{key}")
            cmd += [str(v) for v in value]
        elif isinstance(value, bool):
            if value:
                cmd += ["--{}".format(key)]
        else:
            if value in ['""', "''"]:
                value = ""
            if key == "out-folder":
                value = out_folder
            cmd += ["--{}".format(key), str(value)]
    return cmd


@pytest.mark.manual_run
@pytest.mark.parametrize(
    "test_key, val_dict",
    list(automated_test_config.items()),
    ids=list(automated_test_config.keys()),
)
def test_xfel_calibrate(
    test_key: str, val_dict: dict, release_test_config: Tuple[bool, bool, bool, bool]
):
    """Test xfel calibrate detectors and calibrations written
    in the given callab_test.py file.

    xfel-calibrate CL is ran over each test_key. Expected a val_dict of
    1. out-folder, 2. in-folder, 3. run numbers, 4.karabo-id, 5. detector type
    6. calibration type, and 7. reference folder to compare the data with.
    The val_dict can/should acquire the needed custom configurations for each test.
    e.g. AGIPD notebooks needs to differentiate between the different `karabo-id-control`
    which differ across instruments.

    Each calibration test is ran sequentially. Each calibration test consists of multiple tests.
    1. Testing if all SLURM jobs were successfull.
    2. Testing if there is an available report.
    3. Validate number of H5 files against files in reference folder.
    4. Validate numberical values for the H5 files against reference files.
      a. CORRECTED files are compared using MD5 checksum.
      b. CONSTANTS local files saved using cal_tools are the one being tested.
      Unfortunately they can't be compared using MD5 checksum. As the metadata in files
      consists of the report path. which differ between each test. So the metadata is cleaned
      from the report path first before comparing the H5 constant files numerically.
    Args:
        test_key : Key for the xfel-calibrate test.
        val_dict: Dictionary of the configurations for the running test.
        release_test_config: Tuple of booleans to pick or skip tests
            based on the given boolean configs.
    """

    (
        detectors,
        calibration,
        picked_test,
        skip_numerical_validation,
        reference_dir_base,
        out_dir_base,
    ) = release_test_config

    cal_type = val_dict["cal_type"]
    det_type = val_dict["det_type"]

    if picked_test is None:
        # Skip non-selected detectors
        if detectors != ["all"] and det_type.lower() not in [
            d.lower() for d in detectors
        ]:
            pytest.skip()

        # Skip non-selected calibration
        if calibration != "all" and cal_type.lower() != calibration:
            pytest.skip()
    else:
        if test_key != picked_test:
            pytest.skip()

    cmd = ["xfel-calibrate", det_type, cal_type]
    cal_conf = val_dict["config"]
    report_name = f"{test_key}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
    cal_conf["report-to"] = report_name

    out_folder = pathlib.Path(
        cal_conf["out-folder"].format(out_dir_base,
                                      cal_conf["karabo-id"], test_key)
    )
    cmd = parse_config(cmd, cal_conf, out_folder)

    reference_folder = pathlib.Path(
        val_dict["reference-folder"].format(
            reference_dir_base, cal_conf["karabo-id"], test_key
        )
    )

    cmd += ["--slurm-name", test_key]
    f = io.StringIO()

    with redirect_stdout(f):
        calibrate.run(cmd)

        out_str = f.getvalue()

    slurm_watcher = True
    for r in out_str.split("\n"):
        if "Submitted the following SLURM jobs:" in r:
            _, jobids = r.split(":")

    # Adding a sleep for the slurm jobs initialization
    time.sleep(len(jobids.split(",")))
    jobids = jobids.strip()
    while slurm_watcher:
        cmd = ["sacct", "-j", jobids, "--format=state"]

        res = run(cmd, stdout=PIPE)
        states = res.stdout.decode().split("\n")[2:-1]

        if not any(
            s.strip() in ["COMPLETING", "RUNNING", "CONFIGURING", "PENDING"]
            for s in states
        ):  # noqa
            slurm_watcher = False
        else:
            time.sleep(0.5)

    # 1st check that all jobs were COMPLETED without errors.
    states = res.stdout.decode().split("\n")[2:-1]
    assert all(
        s.strip() == "COMPLETED" for s in states
    ), f"{test_key} failure, calibration jobs were not completed. {jobids}: {states}"  # noqa
    print(f"{test_key}'s jobs were COMPLETED")
    time.sleep(2.0)

    # 2nd check for report availability.
    report_file = out_folder / f"{report_name}.pdf"
    assert report_file.exists(), f"{test_key} failure, report doesn't exists."

    # 3rd Check number of produced h5 files.
    h5files = list(out_folder.glob("*.h5"))
    expected_h5files = list(reference_folder.glob("*.h5"))
    assert len(h5files) == len(
        expected_h5files
    ), f"{test_key} failure, number of files are not as expected."  # noqa
    print(f"{test_key}'s calibration h5files numbers are as expected.")

    # Stop tests at this point, if desired.
    if skip_numerical_validation:
        return
    non_valid_files = []

    # 4th check that h5files are exactly the same as the reference h5files.
    if cal_type.lower() == "correct":
        with multiprocessing.Pool() as pool:
            result = pool.starmap(
                validate_h5files, zip(
                    h5files, len(h5files) * [reference_folder])
            )
    else:  # "dark"
        validate_files = partial(
            validate_constant_file, reference_folder, out_folder
        )  # noqa
        with multiprocessing.Pool() as pool:
            result = pool.map(validate_files, h5files)

    for valid, file in result:
        if not valid:
            non_valid_files.append(file)
    assert (
        len(non_valid_files) == 0
    ), f"{test_key} failure, while validating metadata for {non_valid_files}"  # noqa
    print(f"{test_key}'s calibration h5files are validated successfully.")