import argparse
import glob
import hashlib
import os
import pickle
import re
import shutil
import subprocess as sp
import sys
import unittest
from abc import abstractmethod
from enum import Enum
from functools import partial
from multiprocessing import Pool
from time import sleep

import h5py
import numpy as np
from git import Repo
from scipy import stats

np.warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser()
parser.add_argument('--generate', action="store_true", default=False,
                    help="Set this flag to generate new artefacts from " +
                    "the test. These will be placed in the artefact " +
                    "directory, under the latest commit your git " +
                    "repository is on. This will launch the " +
                    "notebook under test first.")
parser.add_argument('--generate-wo-execution', action="store_true",
                    default=False, help="Set this flag to generate new " +
                    "artefacts from the test. These will be placed in " +
                    "the artefact directory, under the latest commit " +
                    "your git repository is on. This will not launch " +
                    "the notebook being tested, but assumes its " +
                    "output is already present. Use e.g. to debug tests.")
parser.add_argument('--test-wo-execution', action="store_true", default=False,
                    help="Run tests, but do not execute the notebook being " +
                    "tested first. This is assumes its output is already " +
                    "present. Use e.g. to debug tests.")
parser.add_argument('--skip-checksum', action="store_true", default=False,
                    help="Skip checksum tests (and artefact generation)")
parser.add_argument('--skip-histogram', action="store_true", default=False,
                    help="Skip histogram tests (and artefact generation)")
parser.add_argument('--skip-karabo-data', action="store_true", default=False,
                    help="Skip karabo_data tests (and artefact generation)")
parser.add_argument('--skip-report-gen', action="store_true", default=False,
                    help="Skip report generation tests")
parser.add_argument('--artefact-dir', type=str, default="./artefacts/",
                    help="Set directory to place artefacts in.")
parser.add_argument('unittest_args', nargs='*',
                    help="Any arguments to be passed to unittest")
args = parser.parse_args()


class Failures(Enum):
    ARTEFACT_MISSING = "No artefact"
    EMPTY_ARTEFACT = "Empty artefact"


def _do_generate():
    """ Determine if artefacts should be generated
    """
    return args.generate or args.generate_wo_execution


def get_last_commit():
    """ Return the last commit from the git repo
    """
    r = Repo(os.path.dirname(os.path.realpath(__file__)) + "/..")
    last_commit = next(r.iter_commits())
    return last_commit


def get_artefact_dir(cls):
    """ Get the artefact directory for the last commit

    :param cls: Test class
    """
    last_commit = get_last_commit()
    print("Last git commit is: {}".format(last_commit))
    test_name = cls.__name__
    art_base = os.path.realpath(args.artefact_dir)
    suffix = "G" if _do_generate() else "T"
    path = "{}/{}/{}_{}".format(art_base, last_commit,
                                test_name, suffix)
    print("artefacts will be placed in: {}".format(path))
    return path


def get_artefact_comp_dir(cls):
    """Get the artefact directory for the last commit that generated any

    :param cls: Test class
    """
    r = Repo(os.path.dirname(os.path.realpath(__file__)) + "/..")
    test_name = cls.__name__
    art_base = os.path.realpath(args.artefact_dir)
    for commit in r.iter_commits():
        path = "{}/{}/{}_G".format(art_base, commit, test_name)
        if os.path.exists(path):
            return path
    msg = ("Could not find any previous artefacts for test {}"
           + " in directory {}! Check if directory is correct"
           + " or create artefacts using the --generate flag.")
    raise FileNotFoundError(msg.format(test_name, art_base))


def get_matching_h5_paths(f, template):
    """ Return paths in h5 file that match a path template

    :param f: and h5 file object
    :param template: a path template, may contain regex expressions

    Paths are evaluated on each hierarchy level vs. the template. Levels are
    seprated by "/", e.g.

    INSTRUMENT/(.+)LPD(.+)/DET/[0-9]+CH0:xtdf/image/mask

    Will match any LPD instance on the second level, and any channels on the
    fourth level.
    """
    matches = []

    def check_key(pre, pc):
        for k in f[pre].keys():
            m = re.match(pc[0], k)
            if m:
                if len(pc[1:]) > 0:
                    check_key(pre + m.group(0) + "/", pc[1:])
                else:
                    matches.append(pre + m.group(0))

    check_key("/", template.split("/"))
    return matches


def parallel_hist_gen(hist_paths, artefact_dir, fname):
    """ Function to generate histogram artefacts in parallel

    :param hist_paths: paths to create histograms for. Should be a dict with
        the following structure:

        { path_template :
            { htype_name1 : { "bins": 100, "range": (0, 1000),
                              "scl_fun": np.log2},
              htype_name2 : { "bins": 100, "range": (0, 100000) },
            }
        }

        here `path_template` is a template accepted by `get_matching_h5_paths`,
        `htype_name1` is a name for the generated histogram, by which it
        will be identified during testing, `bins` and `range` define the
        histogram, and optionally, `scl_fun` is a scaling function to be
        executed on the data before creating the histogram.

    :param artefact_dir: the directory under which to place histograms.
        For each input file a histogram file (a npz archive) containing
        the named histograms defined by `hist_paths` is created.
    :param fname: path to the h5 file containing the paths specified in
        `hist_paths`.

    For pickling to work in subprocess calls, this needs to be defined
    on the module level.

    """
    all_hists = {}
    with h5py.File(fname, "r") as f:
        for path, hists in hist_paths.items():
            mpaths = get_matching_h5_paths(f, path)
            for mpath in mpaths:
                print("Creating histograms for path: {}".format(mpath))
                d = f[mpath]
                for htype, conf in hists.items():
                    bins = conf["bins"]
                    rnge = conf["range"]
                    fn = conf.get("scl_fun", None)
                    if fn is None:
                        h, _ = np.histogram(d, bins=bins, range=rnge)
                    else:
                        h, _ = np.histogram(fn(d), bins=bins, range=rnge)
                    all_hists["{}_{}".format(mpath, htype)] = h
                del d
    hist_fname = "{}.hist".format(fname)
    hpath = "{}/{}".format(artefact_dir, os.path.basename(hist_fname))
    if len(all_hists):
        np.savez(hpath, **all_hists)


def parallel_hist_eval(hist_paths, cls, fname):
    """ Function to compare histogram artefacts in parallel

    :param hist_paths: paths to create histograms for. Should be a dict with
        the following structure:

        { path_template :
            { htype_name1 : { "bins": 100, "range": (0, 1000),
                              "scl_fun": np.log2},
              htype_name2 : { "bins": 100, "range": (0, 100000) },
            }
        }

        here `path_template` is a template accepted by `get_matching_h5_paths`,
        `htype_name1` is a name for the generated histogram, by which it
        will be identified during testing, `bins` and `range` define the
        histogram, and optionally, `scl_fun` is a scaling function to be
        executed on the data before creating the histogram.

    :param cls: Test class
    :param fname: path to the h5 file containing the paths specified in
        `hist_paths`.

    For pickling to work in subprocess calls, this needs to be defined
    on the module level.
    """
    ret = []
    with h5py.File(fname, "r") as f:
        hist_fname = "{}.hist.npz".format(fname)
        test_art_dir = get_artefact_comp_dir(cls)
        hpath = "{}/{}".format(test_art_dir, os.path.basename(hist_fname))
        has_artefact = os.path.exists(hpath)
        if not has_artefact:
            return Failures.ARTEFACT_MISSING

        try:
            test_hists = np.load(hpath)
        except OSError:
            return Failures.EMPTY_ARTEFACT  # likely empty data

        for path, hists in hist_paths.items():
            mpaths = get_matching_h5_paths(f, path)
            for mpath in mpaths:
                print("Creating histograms for path: {}".format(mpath))
                d = f[mpath]
                for htype, conf in hists.items():
                    bins = conf["bins"]
                    rnge = conf["range"]
                    fn = conf.get("scl_fun", None)
                    if fn is None:
                        h, _ = np.histogram(d, bins=bins, range=rnge)
                    else:
                        h, _ = np.histogram(fn(d), bins=bins, range=rnge)
                    if mpath.startswith("/"):
                        # would have been trucated on saving
                        mpath = mpath[1:]
                    th = test_hists["{}_{}".format(mpath, htype)]
                    ret.append((mpath, htype, th, h))
                del d
    return ret


class CorrectionTestBase:
    """
    A base class for testing correction type notebooks.
    """
    detector = None
    task = None
    parms = {}
    rel_file_ext = ".h5"
    artefact_dir = None
    artefact_comp_dir = None
    hist_paths = []
    karabo_data_inspects = []
    expected_reports = []

    @classmethod
    def setUpClass(cls):
        """
        Sets up the test by executing the notebook under test

        If artefact generation is requested an artefact directory is
        created.

        Note that this method will block until any slurm jobs scheduled
        as part of the notebook execution have finished.
        """
        assert (cls.detector is not None)
        assert (cls.task is not None)
        cmd = ["xfel-calibrate", cls.detector, cls.task]
        for k, v in cls.parms.items():
            cmd += ["--{}".format(k), str(v)]

        print("Executing {}".format(" ".join(cmd)))

        
        print("Creating data paths for artefacts")
        cls.artefact_dir = get_artefact_dir(cls)
        if not os.path.exists(cls.artefact_dir):
            os.makedirs(cls.artefact_dir)

        if args.generate_wo_execution or args.test_wo_execution:
            return

        out = sp.check_output(cmd)
        joblist = None

        for ln in out.decode().split("\n"):
            if "Submitted the following SLURM jobs:" in ln:
                txt, jobs = ln.split(":")
                joblist = jobs.split(",")
                break
        cls._wait_on_jobs(joblist)

    @abstractmethod
    def _output_to_path(self):
        """ Return the path a notebook under test places its results in.

        Must be overwritten by concrete test classes.
        """
        pass

    @classmethod
    def _wait_on_jobs(cls, joblist):
        """ Wait on SLURM jobs defined by `joblist` to finish.
        """
        print("Waiting on jobs to finish: {}".format(joblist))
        while True:
            found_jobs = set()
            output = sp.check_output(['squeue']).decode('utf8')
            for line in output.split("\n"):
                for job in joblist:
                    if str(job) in line:
                        found_jobs.add(job)
            if len(found_jobs) == 0:
                break
            sleep(10)

    @unittest.skipUnless(_do_generate() and not args.skip_checksum,
                         "artefact generation is not requested")
    def test_generate_checksums(self):
        """ Generate Fletcher32 checksums of output files from notebook
        """
        out_folder = self._output_to_path()
        files_to_check = glob.glob(
            "{}/*{}".format(out_folder, self.rel_file_ext))  
        
        for fname in files_to_check:
            
            with h5py.File(fname, "r") as f:
                d = {}
                def visitor(k, item):
                    if isinstance(item, h5py.Dataset):
                        d[k] = item.fletcher32

                f.visititems(visitor)
                
                chkfname = "{}.checksum".format(fname)
                chkpath = "{}/{}".format(self.artefact_dir,
                                         os.path.basename(chkfname))
                with open(chkpath, 'wb') as fc:
                    pickle.dump(d, fc, pickle.HIGHEST_PROTOCOL) 

    @unittest.skipIf(args.skip_checksum,
                     "User requested to skip checksum test")
    def test_checksums(self):
        """ Compare Fletcher32 checksums of notebook's output with artefacts

        This test will verify if datasets with checksums are identical. 
        Even for small changes in the correction logic this test is likely 
        to fail.
        If this is the case, it is recommended to verify correctness using
        the other tests, which inspect data, and the create new checksums
        using the --generate option.

        If no previous checksum exists for a given file the test for that
        file will fail. It will also fail if a dataset previously did not
        have a checksum assigned.
        """
        out_folder = self._output_to_path()
        files_to_check = glob.glob(
            "{}/*{}".format(out_folder, self.rel_file_ext))
        for fname in files_to_check:
            chkfname = "{}.checksum".format(fname)
            test_art_dir = get_artefact_comp_dir(self.__class__)
            chkpath = "{}/{}".format(test_art_dir, os.path.basename(chkfname))
            with self.subTest(msg="Verifying against: {}".format(chkpath)):
                self.assertTrue(os.path.exists(chkpath),
                                "No comparison checksums found")
            with open(chkpath, 'rb') as fc:
                d = pickle.load(fc)
                
                with h5py.File(fname, "r") as f:
 
                    def visitor(k, item):
                        if isinstance(item, h5py.Dataset):
                            
                            msg = "Verify checksum of: {}".format(k)
                            with self.subTest(msg=msg):
                                self.assertIn(k, d)
                                self.assertEqual(d[k], item.fletcher32)

                    f.visititems(visitor)

    @unittest.skipUnless(_do_generate() and not args.skip_histogram,
                         "artefact generation is not requested")
    def test_generate_histograms(self):
        """ Generate histogram artefacts for the output of the notebook
        """
        out_folder = self._output_to_path()
        files_to_check = glob.glob(
            "{}/*{}".format(out_folder, self.rel_file_ext))
        with Pool(8) as p:
            pf = partial(parallel_hist_gen, self.hist_paths, self.artefact_dir)
            p.map(pf, files_to_check)
        self.assertTrue(True)

    @unittest.skipIf(args.skip_histogram,
                     "User requested to skip histogram test")
    def test_histograms(self):
        """ Compare histograms of notebook output with previous artefacts

        Comparison is performed in multiple tests:

        * using np.allclose, which tests that histograms are equal within
          numerical limits

        Using statistical tests to check for p-value compatibility within
        confidence levels of 0.9, 0.95 and 0.99

        * via a Kolmogornov-Smirnoff test to verify that distributions
          are of similar shape

        * via a χ2 test

        * via a Shapiro-Wilks test to verify normal distribution of the
          the deviation of both histogram (which would be expected, if
          the deviation is of statistical and not systematic nature).

        If no previous histograms exist for a given file the test for that
        file will fail.

        Empty files are skipped.
        """
        out_folder = self._output_to_path()
        files_to_check = glob.glob(
            "{}/*{}".format(out_folder, self.rel_file_ext))
        with Pool(8) as p:
            pf = partial(parallel_hist_eval, self.hist_paths, self.__class__)
            r = p.map(pf, files_to_check)
            for i, rvals in enumerate(r):
                msg = "Verifying: {}".format(files_to_check[i])
                with self.subTest(msg=msg):
                    self.assertNotEqual(Failures.ARTEFACT_MISSING,
                                        "No comparison histograms found")
                if rvals is Failures.ARTEFACT_MISSING:
                    return
                if rvals is Failures.EMPTY_ARTEFACT:
                    return
                else:
                    for rval in rvals:  # inner loop
                        mpath, htype, th, h = rval

                        # straight-forward all equal
                        msg = "Test all values equal for: {}/{}".format(mpath,
                                                                        htype)
                        with self.subTest(msg=msg):
                            self.assertTrue(np.allclose(h, th))

                        confidence_levels = [0.9, 0.95, 0.99]
                        for cl in confidence_levels:
                            # run additional tests and report on them
                            msg = "KS-Test for: {}/{}, confidence-level: {}"
                            msg = msg.format(mpath, htype, cl)
                            with self.subTest(msg=msg):
                                D, p = stats.ks_2samp(h, th)
                                self.assertGreaterEqual(p, cl)

                            idx = (h != 0) & (th != 0)
                            msg = "Chi2-Test for: {}/{}, confidence-level: {}"
                            msg = msg.format(mpath, htype, cl)
                            with self.subTest(msg=msg):
                                chisq, p = stats.chisquare(h[idx], th[idx])
                                self.assertGreaterEqual(p, cl)

                            msg = ("Shapiro-Wilks-Test for: {}/{}, " +
                                   "confidence-level: {}")
                            msg = msg.format(mpath, htype, cl)
                            with self.subTest(msg=msg):
                                enorm = (h[idx] - th[idx]) / np.sqrt(th[idx])
                                t, p = stats.shapiro(enorm)
                                self.assertGreaterEqual(p, cl)

    @unittest.skipUnless(_do_generate() and not args.skip_karabo_data,
                         "artefact generation is not requested")
    def test_generate_karabo_data(self):
        """ Generate artefacts for the Karabo Data test of notebook output

        Note that Karabo Data related imports are inline in this test so
        that other tests may be executed without Karabo data installed.
        """
        out_folder = self._output_to_path()
        kdata = "{}/karabo.data".format(self.artefact_dir)
        # we inline the import to be able to skip if not installed
        import karabo_data as kd
        rd = kd.RunDirectory(out_folder)
        d = {}
        d["datasources"] = rd.detector_sources
        d["detector_infos"] = {}
        for source in rd.detector_sources:
            d["detector_infos"][source] = rd.detector_info(source)
        d["trainids"] = rd.train_ids

        def write_train_info(train, identifier):
            d[identifier] = {}
            d[identifier]["sources"] = list(train.keys())
            for source in train.keys():
                d[identifier][source] = {}
                d[identifier][source]["keys"] = list(train[source].keys())
                for key in train[source].keys():
                    if key in self.karabo_data_inspects:
                        d[identifier][source][key] = train[source][key]

        _, first_train = next(rd.trains())  # also tests iteration
        write_train_info(first_train, "first_train")

        _, last_train = rd.train_from_id(rd.train_ids[-1])
        write_train_info(last_train, "last_train")

        with open(kdata, 'wb') as f:
            pickle.dump(d, f, pickle.HIGHEST_PROTOCOL)

    @unittest.skipIf(args.skip_karabo_data,
                     "User requested to skip karabo data test")
    def test_karabo_data(self):
        """ Test Karabo Data compatibility for notebook output

        The following tests are performed:

        * test that output files can be loaded as a `RunDirectory`
        * verify that detector data sources are the same
        * verify that detector info for all sources is the same
        * verify that train ids are unchanged
        * for the first and last train
          - verify all sources and all keys are the same
          - verify all data paths defined in `karabo_data_inspects` is
            equal. This should be metadata, not data processed as part of
            the notebook
          these tests also test train iteration (`next(rd.trains())`) and
          selection (`rd.train_from_id(rd.train_ids[-1])`) on the output.

        Note that Karabo Data related imports are inline in this test so
        that other tests may be executed without Karabo data installed.
        """
        out_folder = self._output_to_path()
        kdata = "{}/karabo.data".format(get_artefact_comp_dir(self.__class__))
        # we inline the import to be able to skip if not installed
        import karabo_data as kd
        rd = kd.RunDirectory(out_folder)

        # test against artefacts
        with open(kdata, 'rb') as f:
            d = pickle.load(f)

            self.assertEqual(d["datasources"], rd.detector_sources)
            for source, info in d["detector_infos"].items():
                self.assertEqual(info, rd.detector_info(source))
            self.assertEqual(d["trainids"], rd.train_ids)

            def test_train_info(train, identifier):

                self.assertEqual(sorted(d[identifier]["sources"]),
                                 sorted(list(train.keys())))
                for source in train.keys():
                    self.assertEqual(sorted(d[identifier][source]["keys"]),
                                     sorted(list(train[source].keys())))
                    for key in train[source].keys():
                        if key in self.karabo_data_inspects:
                            val = train[source][key]
                            cval = d[identifier][source][key]
                            if isinstance(val, np.ndarray):
                                self.assertTrue(np.equal(cval, val).all())
                            else:
                                self.assertEqual(cval, val)

            _, first_train = next(rd.trains())  # also tests iteration
            test_train_info(first_train, "first_train")

            _, last_train = rd.train_from_id(rd.train_ids[-1])
            test_train_info(last_train, "last_train")
            
    @unittest.skipIf(args.skip_karabo_data,
                     "User requested to skip karabo data test")
    def test_karabo_data_self_test(self):
        """ Runs validation included in karabo data
        """
        out_folder = self._output_to_path()
        sp.check_call([sys.executable, '-m',
                       'karabo_data.validation', out_folder])

    @unittest.skipIf(args.skip_report_gen,
                     "User requested to skip report generation test")
    def test_report_gen(self):
        """ Verify expected reports are generated
        
        Also verifies that no additional reports are present, and copies
        the report to the artefact dir.
        """
        out_folder = self._output_to_path()
        for report in self.expected_reports:
            msg = "Verifying report exists: {}".format(report)
            with self.subTest(msg=msg):
                rpath = "{}/{}".format(out_folder, report)
                self.assertTrue(os.path.exists(rpath))
                # copy report to artefacts
                dpath = "{}/{}".format(get_artefact_dir(self.__class__),
                                       os.path.basename(rpath))
                shutil.copyfile(rpath, dpath)
        # verify no additional reports exist
        pdfs = glob.glob("{}/*.pdf".format(out_folder))
        for pdf in pdfs:
            self.assertIn(os.path.basename(pdf), self.expected_reports)