from argparse import ArgumentParser
from pathlib import Path

import numpy as np

from extra_data import RunDirectory, by_id, by_index


show_true_cond = True
show_false_cond = True
num_true_cond = 0
num_false_cond = 0


def cmp(label, cond):
    if cond:
        global num_true_cond
        num_true_cond += 1

        if show_true_cond:
            print(f'✅ {label}')
    else:
        global num_false_cond
        num_false_cond += 1

        if show_false_cond:
            print(f'❌ {label}')

    return cond


def main(argv=None):
    ap = ArgumentParser(
        description='Compare data collections structured in the '
                    'European XFEL Data Format (EXDF).')

    ap.add_argument(
        'input1', metavar='INPUT1', type=Path,
        help='folder of input data to compare with INPUT2')

    ap.add_argument(
        'input2', metavar='INPUT2', type=Path,
        help='folder of input data to compare with INPUT1')

    output_group = ap.add_mutually_exclusive_group()

    output_group.add_argument(
        '--verbose', '-v', action='store_true',
        help='whether to show all compared items and not only the different '
             'ones')

    output_group.add_argument(
        '--quiet', '-q', action='store_true',
        help='whether to only print the count of unequal items')

    select_group = ap.add_argument_group(
        'Selection arguments',
        'Allows to select only part of the data collections before comparing.')

    src_select_group = select_group.add_mutually_exclusive_group()

    src_select_group.add_argument(
        '--select',
        metavar='SRC,KEY', action='store', type=str, nargs='*',
        help='only compare the data collection after selecting specified '
             'sources and/or keys')

    src_select_group.add_argument(
        '--deselect',
        metavar='SRC,KEY', action='store', type=str, nargs='*',
        help='only compare the data collection after deselecting specified '
             'sources and/or keys')

    train_select_group = select_group.add_mutually_exclusive_group()

    train_select_group.add_argument(
        '--trains-by-id',
        metavar='SLICE_EXPR', action='store', type=str,
        help='only compare the data collection after selecting specified '
             'trains by ID')

    train_select_group.add_argument(
        '--trains-by-index',
        metavar='SLICE_EXPR', action='store', type=str,
        help='only compare the data collection after selecting specified '
             'trains by index')

    scope_group = ap.add_argument_group(
        'Scope of comparison arguments',
        'Allows to restrict the scope to which the data collections are '
        'compared with by default includes everything including the data '
        'itself.'
    ).add_mutually_exclusive_group()

    scope_group.add_argument(
        '--only-metadata', '-m', action='store_true',
        help='check only metadata independent of individual sources')

    scope_group.add_argument(
        '--only-index', '-i', action='store_true',
        help='check only metadata and sources\' index entries')

    scope_group.add_argument(
        '--only-control', '-c', action='store_true',
        help='check metadata and index entries of all sources but actual '
             'data only for control sources')

    args = ap.parse_args(argv)

    global show_true_cond, show_false_cond
    if args.verbose:
        show_true_cond = True
        show_false_cond = True
    elif args.quiet:
        show_true_cond = False
        show_false_cond = False
    else:
        show_true_cond = False
        show_false_cond = True

    data1 = RunDirectory(args.input1)
    data2 = RunDirectory(args.input2)

    if args.select:
        select_strs = args.select
        select_method = 'select'
    elif args.deselect:
        select_strs = args.deselect
        select_method = 'deselect'
    else:
        select_strs = []

    sel = [select_str.split(',') if ',' in select_str else (select_str, '*')
           for select_str in select_strs]

    if sel:
        data1 = getattr(data1, select_method)(sel)
        data2 = getattr(data2, select_method)(sel)

    if args.trains_by_id:
        sel = eval(f'by_id[{args.trains_by_id}]')
    elif args.trains_by_index:
        sel = eval(f'by_index[{args.trains_by_index}]')
    else:
        sel = None

    if sel is not None:
        data1 = data1.select_trains(sel)
        data2 = data2.select_trains(sel)

    meta1 = data1.run_metadata()
    meta2 = data2.run_metadata()

    for meta in [meta1, meta2]:
        meta.pop('creationDate', None)
        meta.pop('updateDate', None)
        meta.pop('dataFormatVersion', None)
        meta.pop('karaboFramework', None)
        meta.pop('daqLibrary', None)
        meta.pop('dataWriter', None)

    cmp('Metadata excluding dates and versions', meta1 == meta2)
    cmp('Train IDs', data1.train_ids == data2.train_ids)

    # This is sometimes not equal.
    cmp('Train timestamps',
        np.array_equal(data1.train_timestamps(), data2.train_timestamps()))

    cmp('Control source names',
        data1.control_sources == data2.control_sources)
    cmp('Instrument source names',
        data1.instrument_sources == data2.instrument_sources)

    if args.only_metadata:
        return

    for source in sorted(data1.all_sources & data2.all_sources):
        cmp(f'{source} keys', data1[source].keys() == data2[source].keys())

        sd1 = data1[source]
        sd2 = data2[source]

        counts1 = {grp: sd1.data_counts(labelled=False, index_group=grp)
                   for grp in sd1.index_groups}
        counts2 = {grp: sd2.data_counts(labelled=False, index_group=grp)
                   for grp in sd2.index_groups}

        for index_group in sorted(counts1.keys() & counts2.keys()):
            index_group_str = f'/{index_group}' if index_group else ''
            cmp(f'{source}{index_group_str} counts',
                np.array_equal(counts1[index_group], counts2[index_group]))

        if args.only_index or (sd1.is_instrument and args.only_control):
            continue

        if not sd1.is_instrument:
            run_values1 = sd1.run_values()
            run_values2 = sd2.run_values()

            cmp(f'{source} run keys', run_values1.keys() == run_values2.keys())

            for key in sorted(run_values1.keys() & run_values2.keys()):
                value1 = run_values1[key]
                value2 = run_values2[key]

                if isinstance(value1, np.ndarray):
                    is_equal = np.array_equal(
                        value1, value2,
                        equal_nan=np.issubdtype(value1.dtype, np.floating))
                elif isinstance(value1, np.floating):
                    is_equal = np.array_equal(value1, value2, equal_nan=True)
                else:
                    is_equal = run_values1[key] == run_values2[key]

                cmp(f'{source}, {key} run value', is_equal)

        for key in sorted(sd1.keys() & sd2.keys()):
            cmp(f'{source}, {key} data',
                np.array_equal(sd1[key].ndarray(), sd2[key].ndarray(),
                               equal_nan=True))

    if not args.quiet:
        num_total_cond = num_true_cond + num_false_cond
        print('Compared {} items, {} are equal and {} are not'.format(
            num_total_cond, num_true_cond, num_false_cond))
    else:
        print(num_false_cond)