from collections import defaultdict
from pathlib import Path
import fnmatch
import logging

from packaging.version import Version
import numpy as np

from extra_data import by_id
from extra_data.read_machinery import select_train_ids

from exdf.write import SourceDataWriter


class ReduceWriter(SourceDataWriter):
    log = logging.getLogger('exdf.data_reduction.ReduceWriter')

    def __init__(self, data, methods, scope, sequence_len=-1, version=None):
        self._data = data
        self._methods = methods
        self._scope = scope
        self._sequence_len = sequence_len

        metadata = self._data.run_metadata()

        input_version = Version(metadata.get('dataFormatVersion', '1.0'))

        if input_version < Version('1.0'):
            raise ValueError('Currently input files are required to be '
                             'EXDF-v1.0+')

        if version == 'same':
            version = input_version
        else:
            self._version = Version(version)

        try:
            self.run_number = int(metadata['runNumber'])
        except KeyError:
            raise ValueError('runNumber dataset required in input METADATA')

        self._ops = sum(methods.values(), [])

        if not self._ops:
            self.log.warning('Sum of reduction methods yielded no operations '
                             'to apply')

        self._sources = sorted(data.all_sources)
        self._touched_sources = set()

        # Only populated if trains/keys are selected/removed for sources.
        self._custom_keys = {}  # source -> set(<keys>)
        self._custom_trains = {}  # source -> list(<trains>)
        self._custom_xtdf_masks = {}  # source -> dict(train_id -> mask)
        self._custom_xtdf_counts = {}  # source -> ndarray
        self._custom_rows = {}  # source -> dict(train_id -> mask)
        self._rechunked_keys = {}  # (source, key) -> chunks
        self._partial_copies = {}  # (source, key) -> list(<regions>)

        # TODO: Raise error if rechunking is overwritten!
        # TODO: make partial copies a list of slices!

        # Collect reductions resulting from operations.
        for source_glob, in self._filter_ops('remove-sources'):
            for source in fnmatch.filter(self._sources, source_glob):
                self._touched_sources.add(source)
                self._sources.remove(source)

        for source_glob, key_glob in self._filter_ops('remove-keys'):
            for source in fnmatch.filter(self._sources, source_glob):
                self._touched_sources.add(source)

                keys = self._custom_keys.setdefault(
                    source, set(self._data[source].keys()))

                for key in fnmatch.filter(keys, key_glob):
                    keys.remove(key)

        for source_glob, train_sel in self._filter_ops('select-trains'):
            for source in fnmatch.filter(self._sources, source_glob):
                self._touched_sources.add(source)
                train_ids = self._custom_trains.setdefault(
                    source, list(self._data.train_ids))

                self._custom_trains[source] = select_train_ids(
                    train_ids, train_sel)

        for source_glob, index_group, train_sel, row_sel in self._filter_ops(
            'select-rows'
        ):
            for source in fnmatch.filter(self._sources, source_glob):
                if index_group not in self._data[source].index_groups:
                    raise ValueError(f'{index_group} not index group of '
                                     f'{source}')

                self._touched_sources.add(source)
                self._custom_rows.setdefault((source, index_group), {}).update(
                    self._get_row_masks(source, index_group,
                                        train_sel, row_sel))

        for source_glob, train_sel, row_sel in self._filter_ops('select-xtdf'):
            for source in fnmatch.filter(self._sources, source_glob):
                if not source.endswith(':xtdf'):
                    # Simply ignore matches without trailing :xtdf.
                    continue

                if not self._is_xtdf_source(source):
                    # Raise exception if essentials are missing.
                    raise ValueError(f'{source} is not a valid XTDF source')

                self._touched_sources.add(source)
                self._custom_xtdf_masks.setdefault(source, {}).update(
                    self._get_row_masks(source, 'image', train_sel, row_sel))

        if (
            {x[0] for x in self._custom_rows.keys()} &
            self._custom_xtdf_masks.keys()
        ):
            raise ValueError('source may not be affected by both select-rows '
                             'and select-xtdf operations')

        for source_glob, key_glob, chunking in self._filter_ops(
            'rechunk-keys'
        ):
            for source in fnmatch.filter(self._sources, source_glob):
                if not self._data[source].is_instrument:
                    raise ValueError(
                        f'rechunking keys only supported for instrument '
                        f'sources, but {source_glob} matches '
                        f'{self._data[source].section}/{source}')

                self._touched_sources.add(source)

                keys = self._custom_keys.get(
                    source, set(self._data[source].keys()))

                for key in fnmatch.filter(keys, key_glob):
                    old_chunking = self._rechunked_keys.setdefault(
                        (source, key), chunking)

                    if old_chunking != chunking:
                        raise ValueError(
                            f'reduction sequence yields conflicting chunks '
                            f'for {source}.{key}: {old_chunking}, {chunking}')

                    self._rechunked_keys[(source, key)] = chunking

        for source_glob, key_glob, region in self._filter_ops('partial-copy'):
            for source in fnmatch.filter(self._sources, source_glob):
                self._touched_sources.add(source)

                keys = self._custom_keys.get(
                    source, set(self._data[source].keys()))

                for key in fnmatch.filter(keys, key_glob):
                    self._partial_copies.setdefault((source, key), []).append(
                        region)

        if self._scope == 'sources':
            self._sources = sorted(
                self._touched_sources.intersection(self._sources))

        elif self._scope == 'aggregators':
            touched_aggregators = {self._data[source].aggregator
                                   for source in self._touched_sources}

            self._sources = sorted(
                {source for source in self._sources
                 if (self._data[source].aggregator in touched_aggregators)})

        if not self._sources:
            raise ValueError('reduction sequence yields empty source '
                             'selection')

    def _filter_ops(self, op):
        return [args[1:] for args in self._ops if args[0] == op]

    def _is_xtdf_source(self, source):
        return self._data[source].keys() > {'header.pulseCount', 'image.data'}

    def _get_row_masks(self, source, index_group, train_sel, row_sel):
        train_ids = select_train_ids(
            self._custom_trains.get(source, list(self._data.train_ids)),
            train_sel)
        counts = self._data[source].select_trains(by_id[train_ids]) \
            .data_counts(index_group=index_group)
        masks = {}

        if isinstance(row_sel, slice):
            for train_id, count in counts.items():
                if count > 0:
                    masks[train_id] = np.zeros(count, dtype=bool)
                    masks[train_id][row_sel] = True

        elif np.issubdtype(type(row_sel[0]), np.integer):
            max_row = max(row_sel)

            for train_id, count in counts.items():
                if count == 0:
                    continue
                elif max_row >= count:
                    raise ValueError(
                        f'row index exceeds data counts of train {train_id}')

                masks[train_id] = np.zeros(count, dtype=bool)
                masks[train_id][row_sel] = True

        elif np.issubdtype(type(row_sel[0]), bool):
            mask_len = len(row_sel)

            for train_id, count in counts.items():
                if count == 0:
                    continue
                elif mask_len != counts.get(train_id, 0):
                    raise ValueError(
                        f'mask length mismatch for train {train_id}')

                masks[train_id] = row_sel

        else:
            raise ValueError('unknown row mask format')

        return masks

    def write_collection(self, output_path):
        outp_data = self._data.select([(s, '*') for s in self._sources])

        # Collect all items (combination of data category and
        # aggregator) and the sources they contain.
        sources_by_item = defaultdict(list)
        for source in self._sources:
            sd = outp_data[source]
            sources_by_item[(sd.data_category, sd.aggregator)].append(source)

        for (data_category, aggregator), sources in sources_by_item.items():
            self.write_item(
                output_path, sources, f'{data_category}-{aggregator}',
                dict(data_category=data_category, aggregator=aggregator))

    def write_collapsed(self, output_path):
        self.write_item(output_path,  self._sources, 'COLLAPSED')

    def write_voview(self, output_path):
        raise NotImplementedError('voview output layout')

    def write_item(self, output_path, source_names, name, filename_fields={}):
        """Write sources to a single item."""

        # Select output data down to what's in this item both in terms
        # of sources and trains (via require_any).
        item_data = self._data.select({
            s: self._custom_keys[s] if s in self._custom_keys else set()
            for s in source_names
        }, require_any=True)

        # Switch to representation of SourceData objects for
        # per-source tracking of trains.
        item_sources = [item_data[source]
                        for source in item_data.all_sources]

        # Tetermine input sequence length if no explicit value was given
        # for output.
        if self._sequence_len < 1:
            sequence_len = max({
                len(sd._get_first_source_file().train_ids)
                for sd in item_sources
            })
        else:
            sequence_len = self._sequence_len

        # Apply custom train selections, if any.
        for i, sd in enumerate(item_sources):
            train_sel = self._custom_trains.get(sd.source, None)

            if train_sel is not None:
                item_sources[i] = sd.select_trains(by_id[train_sel])

        # Find the union of trains across all sources as total
        # trains for this item.
        item_train_ids = np.zeros(0, dtype=np.uint64)
        for sd in item_sources:
            item_train_ids = np.union1d(
                item_train_ids, sd.drop_empty_trains().train_ids)

        num_trains = len(item_train_ids)
        num_sequences = int(np.ceil(num_trains / sequence_len))

        self.log.info(
            f'{name} containing {len(item_sources)} sources with {num_trains} '
            f'trains over {num_sequences} sequences')

        for seq_no in range(num_sequences):
            seq_slice = np.s_[
                (seq_no * sequence_len):((seq_no + 1) * sequence_len)]

            # Slice out the train IDs and timestamps for this sequence.
            seq_train_ids = item_train_ids[seq_slice]

            # Select item data down to what's in this sequence.
            seq_sources = [sd.select_trains(by_id[seq_train_ids])
                           for sd in item_sources]

            # Build explicit output path for this sequence.
            seq_path = Path(str(output_path).format(
                run=self.run_number, sequence=seq_no, **filename_fields))

            self.log.debug(f'{seq_path.stem} containing {len(seq_sources)} '
                           f'sources with {len(seq_train_ids)} trains')
            self.write_sequence(seq_path, seq_sources, seq_no)

    # SourceDataWriter hooks.

    def write_base(self, f, sources, sequence):
        super().write_base(f, sources, sequence)

        # Add reduction-specific METADATA
        red_group = f.require_group('METADATA/reduction')

        for name, method in self._methods.items():
            ops = np.array([
                '\t'.join([str(x) for x in op[:]]).encode('ascii')
                for op in method
            ])
            red_group.create_dataset(name, shape=len(method), data=ops,)

    def get_data_format_version(self):
        return str(self._version)

    def with_origin(self):
        return self._version >= Version('1.2')

    def with_attrs(self):
        return self._version >= Version('1.3')

    def chunk_instrument_data(self, source, key, orig_chunks):
        try:
            chunks = list(self._rechunked_keys[source, key])
            assert len(chunks) == len(orig_chunks)

            for i, dim_len in enumerate(chunks):
                if dim_len is None:
                    chunks[i] = orig_chunks[i]

            if -1 in chunks:
                chunks[chunks.index(-1)] = \
                    np.prod(orig_chunks) // -np.prod(chunks)

            return tuple(chunks)
        except KeyError:
            return orig_chunks

    def mask_instrument_data(self, source, index_group, train_ids, counts):
        if source in self._custom_xtdf_masks and index_group == 'image':
            custom_masks = self._custom_xtdf_masks[source]
        elif (source, index_group) in self._custom_rows:
            custom_masks = self._custom_rows[source, index_group]
        else:
            return  # None efficiently selects all rows.

        masks = []

        for train_id, count_all in zip(train_ids, counts):
            if train_id in custom_masks:
                mask = custom_masks[train_id]
            else:
                mask = np.ones(count_all, dtype=bool)

            masks.append(mask)

        if source in self._custom_xtdf_masks:
            # Sources are guaranteed to never use both XTDF and general
            # row slicing. In the XTDF case, the new data counts for the
            # image index group must be determined to be filled into
            # the respective header field.

            self._custom_xtdf_counts[source] = {
                train_id: mask.sum() for train_id, mask
                in zip(train_ids, masks) if mask.any()}

        return masks

    def copy_instrument_data(self, source, key, dest, train_ids, data):
        if source in self._custom_xtdf_counts and key == 'header.pulseCount':
            custom_counts = self._custom_xtdf_counts[source]

            for i, train_id in enumerate(train_ids):
                data[i] = custom_counts.get(train_id, data[i])

        try:
            regions = self._partial_copies[source, key]
        except KeyError:
            dest[:] = data
        else:
            for region in regions:
                sel = (np.s_[:], *region)
                dest[sel] = data[sel]