Skip to content
Snippets Groups Projects
red_writer.py 15.9 KiB
Newer Older

from collections import defaultdict
from pathlib import Path
import fnmatch
import logging

from packaging.version import Version
import numpy as np

from extra_data import by_id
from extra_data.read_machinery import select_train_ids

from exdf.write import SourceDataWriter
from ..write.datafile import write_compressed_frames
def apply_by_source(op_name):
    def op_decorator(op_func):
        def op_handler(self):
            assert isinstance(self, ReduceWriter)
            for source_glob, *args in self._filter_ops(op_name):
                for source in fnmatch.filter(self._sources, source_glob):
                    op_func(self, source, *args)
                    self._touched_sources.add(source)

        return op_handler
    return op_decorator


def apply_by_key(op_name):
    def op_decorator(op_func):
        def op_handler(self):
            assert isinstance(self, ReduceWriter)
            for source_glob, key_glob, *args in self._filter_ops(op_name):
                for source in fnmatch.filter(self._sources, source_glob):
                    keys = self._custom_keys.get(source,
                                                set(self._data[source].keys()))

                    for key in fnmatch.filter(keys, key_glob):
                        op_func(self, source, key, *args)

                    self._touched_sources.add(source)

        return op_handler
    return op_decorator


class ReduceWriter(SourceDataWriter):
    log = logging.getLogger('exdf.data_reduction.ReduceWriter')

    def __init__(self, data, methods, scope, sequence_len=-1, version=None):
        self._data = data
        self._methods = methods
        self._scope = scope
        self._sequence_len = sequence_len

        metadata = self._data.run_metadata()

        input_version = Version(metadata.get('dataFormatVersion', '1.0'))

        if input_version < Version('1.0'):
            raise ValueError('Currently input files are required to be '
                             'EXDF-v1.0+')

        if version == 'same':
            version = input_version
        else:
            self._version = Version(version)

        try:
            self.run_number = int(metadata['runNumber'])
        except KeyError:
            raise ValueError('runNumber dataset required in input METADATA')

        self._ops = sum(methods.values(), [])

        if not self._ops:
            self.log.warning('Sum of reduction methods yielded no operations '
                             'to apply')

        self._sources = sorted(data.all_sources)
        self._touched_sources = set()

        # Only populated for sources/keys that are modified.
        self._custom_keys = {}  # source -> set(<keys>)
        self._custom_trains = {}  # source -> list(<trains>)
        self._custom_xtdf_masks = {}  # source -> dict(train_id -> mask)
        self._custom_xtdf_counts = {}  # source -> ndarray
        self._custom_entry_masks = {}  # source -> dict(train_id -> mask)
        self._rechunked_keys = {}  # (source, key) -> chunks
        self._subsliced_keys = {}  # (source, key) -> list(<regions>)
        self._compressed_keys = {}  # (source, key) -> level

        # Collect reductions resulting from operations.
        # This is the most efficient order of operations to minimize
        # more expensive operations for source or trains that may not
        # end up being selected.
        self._handle_remove_sources()
        self._handle_remove_keys()
        self._handle_select_trains()
        self._handle_select_entries()
        self._handle_select_xtdf()
        self._handle_rechunk_keys()
        self._handle_subslice_keys()
        self._handle_compress_keys()

        custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
        if custom_entry_sources & self._custom_xtdf_masks.keys():
            raise ValueError(
                'Source may not be affected by both select-entries and '
                'select-xtdf operations')


        if self._rechunked_keys.keys() & self._compressed_keys.keys():
            raise ValueError('Key may not be affected by both '
                             'compress-keys and rechunk-keys')
        if self._scope == 'sources':
            self._sources = sorted(
                self._touched_sources.intersection(self._sources))

        elif self._scope == 'aggregators':
            touched_aggregators = {self._data[source].aggregator
                                   for source in self._touched_sources}

            self._sources = sorted(
                {source for source in self._sources
                 if (self._data[source].aggregator in touched_aggregators)})

        if not self._sources:
            raise ValueError('reduction sequence yields empty source '
                             'selection')

    def _filter_ops(self, op):
        return [args[1:] for args in self._ops if args[0] == op]

    def _is_xtdf_source(self, source):
        return self._data[source].keys() > {'header.pulseCount', 'image.data'}

    def _get_entry_masks(self, source, index_group, train_sel, entry_sel):
        train_ids = select_train_ids(
            self._custom_trains.get(source, list(self._data.train_ids)),
            train_sel)
        counts = self._data[source].select_trains(by_id[train_ids]) \
            .data_counts(index_group=index_group)
        masks = {}

        if isinstance(entry_sel, slice):
            for train_id, count in counts.items():
                if count > 0:
                    masks[train_id] = np.zeros(count, dtype=bool)
                    masks[train_id][entry_sel] = True
        elif np.issubdtype(type(entry_sel[0]), np.integer):
            max_entry = max(entry_sel)

            for train_id, count in counts.items():
                if count == 0:
                    continue
                elif max_entry >= count:
                        f'entry index exceeds data counts of train {train_id}')

                masks[train_id] = np.zeros(count, dtype=bool)
                masks[train_id][entry_sel] = True
        elif np.issubdtype(type(entry_sel[0]), bool):
            mask_len = len(entry_sel)

            for train_id, count in counts.items():
                if count == 0:
                    continue
                elif mask_len != counts.get(train_id, 0):
                    raise ValueError(
                        f'mask length mismatch for train {train_id}')

                masks[train_id] = entry_sel
            raise ValueError('unknown entry mask format')
    def write_collection(self, output_path):
        outp_data = self._data.select([(s, '*') for s in self._sources])

        # Collect all items (combination of data category and
        # aggregator) and the sources they contain.
        sources_by_item = defaultdict(list)
        for source in self._sources:
            sd = outp_data[source]
            sources_by_item[(sd.data_category, sd.aggregator)].append(source)

        for (data_category, aggregator), sources in sources_by_item.items():
            self.write_item(
                output_path, sources, f'{data_category}-{aggregator}',
                dict(data_category=data_category, aggregator=aggregator))

    def write_collapsed(self, output_path):
        self.write_item(output_path,  self._sources, 'COLLAPSED')

    def write_voview(self, output_path):
        raise NotImplementedError('voview output layout')

    def write_item(self, output_path, source_names, name, filename_fields={}):
        """Write sources to a single item."""

        # Select output data down to what's in this item both in terms
        # of sources and trains (via require_any).
        item_data = self._data.select({
            s: self._custom_keys[s] if s in self._custom_keys else set()
            for s in source_names
        }, require_any=True)

        # Switch to representation of SourceData objects for
        # per-source tracking of trains.
        item_sources = [item_data[source]
                        for source in item_data.all_sources]

        # Tetermine input sequence length if no explicit value was given
        # for output.
        if self._sequence_len < 1:
            sequence_len = max({
                len(sd._get_first_source_file().train_ids)
                for sd in item_sources
            })
        else:
            sequence_len = self._sequence_len

        # Apply custom train selections, if any.
        for i, sd in enumerate(item_sources):
            train_sel = self._custom_trains.get(sd.source, None)

            if train_sel is not None:
                item_sources[i] = sd.select_trains(by_id[train_sel])

        # Find the union of trains across all sources as total
        # trains for this item.
        item_train_ids = np.zeros(0, dtype=np.uint64)
        for sd in item_sources:
            item_train_ids = np.union1d(
                item_train_ids, sd.drop_empty_trains().train_ids)

        num_trains = len(item_train_ids)
        num_sequences = int(np.ceil(num_trains / sequence_len))

        self.log.info(
            f'{name} containing {len(item_sources)} sources with {num_trains} '
            f'trains over {num_sequences} sequences')

        for seq_no in range(num_sequences):
            seq_slice = np.s_[
                (seq_no * sequence_len):((seq_no + 1) * sequence_len)]

            # Slice out the train IDs and timestamps for this sequence.
            seq_train_ids = item_train_ids[seq_slice]

            # Select item data down to what's in this sequence.
            seq_sources = [sd.select_trains(by_id[seq_train_ids])
                           for sd in item_sources]

            # Build explicit output path for this sequence.
            seq_path = Path(str(output_path).format(
                run=self.run_number, sequence=seq_no, **filename_fields))

            self.log.debug(f'{seq_path.stem} containing {len(seq_sources)} '
                           f'sources with {len(seq_train_ids)} trains')
            self.write_sequence(seq_path, seq_sources, seq_no)

    # SourceDataWriter hooks.

    def write_base(self, f, sources, sequence):
        super().write_base(f, sources, sequence)

        # Add reduction-specific METADATA
        red_group = f.require_group('METADATA/reduction')

        for name, method in self._methods.items():
            ops = np.array([
                '\t'.join([str(x) for x in op[:]]).encode('ascii')
                for op in method
            ])
            red_group.create_dataset(name, shape=len(method), data=ops,)

    def get_data_format_version(self):
        return str(self._version)

    def with_origin(self):
        return self._version >= Version('1.2')

    def with_attrs(self):
        return self._version >= Version('1.3')

    def create_instrument_key(self, source, key, orig_dset, kwargs):
        # Keys are guaranteed to never use both custom chunking and
        # compression.

        if (source, key) in self._rechunked_keys:
            orig_chunks = kwargs['chunks']

            chunks = list(self._rechunked_keys[source, key])
            assert len(chunks) == len(orig_chunks)

            for i, dim_len in enumerate(chunks):
                if dim_len is None:
                    chunks[i] = orig_chunks[i]

            if -1 in chunks:
                chunks[chunks.index(-1)] = \
                    np.prod(orig_chunks) // -np.prod(chunks)

        elif (source, key) in self._compressed_keys or orig_dset.compression:
            # TODO: Maintain more of existing properties, for now it is
            # forced to use gzip and (1, *entry) chunking.
            kwargs['chunks'] = (1,) + kwargs['shape'][1:]
            kwargs['shuffle'] = True
            kwargs['compression'] = 'gzip'
            kwargs['compression_opts'] = self._compressed_keys.setdefault(
                (source, key), orig_dset.compression_opts)

    def mask_instrument_data(self, source, index_group, train_ids, counts):
        if source in self._custom_xtdf_masks and index_group == 'image':
            custom_masks = self._custom_xtdf_masks[source]
        elif (source, index_group) in self._custom_entry_masks:
            custom_masks = self._custom_entry_masks[source, index_group]
            return  # None efficiently selects all entries.

        masks = []

        for train_id, count_all in zip(train_ids, counts):
            if train_id in custom_masks:
                mask = custom_masks[train_id]
            else:
                mask = np.ones(count_all, dtype=bool)

            masks.append(mask)

        if source in self._custom_xtdf_masks:
            # Sources are guaranteed to never use both XTDF and general
            # entry slicing. In the XTDF case, the new data counts for
            # the image index group must be determined to be filled into
            # the respective header field.

            self._custom_xtdf_counts[source] = {
                train_id: mask.sum() for train_id, mask
                in zip(train_ids, masks) if mask.any()}

        return masks

    def copy_instrument_data(self, source, key, dest, train_ids, data):
        if source in self._custom_xtdf_counts and key == 'header.pulseCount':
            custom_counts = self._custom_xtdf_counts[source]

            for i, train_id in enumerate(train_ids):
                data[i] = custom_counts.get(train_id, data[i])

        if (source, key) in self._subsliced_keys:
            for region in self._subsliced_keys[source, key]:
                sel = (np.s_[:], *region)
                dest[sel] = data[sel]

        elif (source, key) in self._compressed_keys:
            write_compressed_frames(
                data, dest, self._compressed_keys[source, key], 8)

        else:
            dest[:] = data

    # Reduction operation handlers.

    @apply_by_source('remove-sources')
    def _handle_remove_sources(self, source):
        self._touched_sources.add(source)

    @apply_by_key('remove-keys')
    def _handle_remove_keys(self, source, key):
        self._custom_keys[source].remove(key)

    @apply_by_source('select-trains')
    def _handle_select_trains(self, source, train_sel):
        self._custom_trains[source] = select_train_ids(
            self._custom_trains.setdefault(source, list(self._data.train_ids)),
            train_sel)

    @apply_by_source('select-entries')
    def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
        if idx_group not in self._data[source].index_groups:
            raise ValueError(f'{idx_group} not index group of {source}')

        self._custom_entry_masks.setdefault((source, idx_group), {}).update(
            self._get_entry_masks(source, idx_group, train_sel, entry_sel))

    @apply_by_source('select-xtdf')
    def _handle_select_xtdf(self, source, train_sel, entry_sel):
        if not source.endswith(':xtdf'):
            # Simply ignore matches without trailing :xtdf.
            return

        if not self._is_xtdf_source(source):
            # Raise exception if essentials are missing.
            raise ValueError(f'{source} is not a valid XTDF source')

        self._custom_xtdf_masks.setdefault(source, {}).update(
            self._get_entry_masks(source, 'image', train_sel, entry_sel))

    @apply_by_key('rechunk-keys')
    def _handle_rechunk_keys(self, source, key, chunking):
        if not self._data[source].is_instrument:
            # Ignore CONTROL sources.
            return

        old_chunking = self._rechunked_keys.setdefault((source, key), chunking)

        if old_chunking != chunking:
            raise ValueError(
                f'Reduction sequence yields conflicting chunks for '
                f'{source}.{key}: {old_chunking}, {chunking}')

        self._rechunked_keys[(source, key)] = chunking

    @apply_by_key('subslice-keys')
    def _handle_subslice_keys(self, source, key, region):
        self._subsliced_keys.setdefault((source, key), []).append(region)

    @apply_by_key('compress-keys')
    def _handle_compress_keys(self, source, key, level):
        self._compressed_keys[source, key] = level