From e1e80999a21f5e7fae24491852eb3daff0774e45 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Sat, 13 May 2023 18:12:57 +0200
Subject: [PATCH] Replace FileWriter by DataFile API

---
 src/exdf_tools/exdf.py   | 629 +++++++++++++++++++++++++++++++++++++++
 src/exdf_tools/reduce.py | 289 ++++++++++++------
 2 files changed, 823 insertions(+), 95 deletions(-)
 create mode 100644 src/exdf_tools/exdf.py

diff --git a/src/exdf_tools/exdf.py b/src/exdf_tools/exdf.py
new file mode 100644
index 0000000..89a8684
--- /dev/null
+++ b/src/exdf_tools/exdf.py
@@ -0,0 +1,629 @@
+
+
+from datetime import datetime
+from itertools import chain
+from numbers import Integral
+from pathlib import Path
+import re
+
+import numpy as np
+import h5py
+
+
+def get_pulse_offsets(pulses_per_train):
+    """Compute pulse offsets from pulse counts.
+
+    Given an array of number of pulses per train (INDEX/<source>/count),
+    computes the offsets (INDEX/<source>/first) for the first pulse of a
+    train inthe data array.
+
+    Args:
+        pulses_per_train (array_like): Pulse count per train.
+
+    Returns:
+        (array_like) Offet of first pulse for each train.
+    """
+
+    pulse_offsets = np.zeros_like(pulses_per_train)
+    np.cumsum(pulses_per_train[:-1], out=pulse_offsets[1:])
+
+    return pulse_offsets
+
+
+def sequence_trains(train_ids, trains_per_sequence=256):
+    """Iterate over sequences for a list of trains.
+
+    For pulse-resolved data, sequence_pulses may be used instead.
+
+    Args:
+        train_ids (array_like): Train IDs to sequence.
+        trains_per_sequence (int, optional): Number of trains
+            per sequence, 256 by default.
+
+    Yields:
+        (int, slice) Current sequence ID, train mask.
+    """
+
+    num_trains = len(train_ids)
+
+    for seq_id, start in enumerate(range(0, num_trains, trains_per_sequence)):
+        train_mask = slice(
+            *np.s_[start:start+trains_per_sequence].indices(num_trains))
+        yield seq_id, train_mask
+
+
+def sequence_pulses(train_ids, pulses_per_train=1, pulse_offsets=None,
+                    trains_per_sequence=256):
+    """Split trains into sequences.
+
+    Args:
+        train_ids (array_like): Train IDs to sequence.
+        pulses_per_train (int or array_like, optional): Pulse count per
+            train. If scalar, it is assumed to be constant for all
+            trains. If omitted, it is 1 by default.
+        pulse_offsets (array_like, optional): Offsets for the first
+            pulse in each train, computed from pulses_per_train if
+            omitted.
+        trains_per_sequence (int, optional): Number of trains
+            per sequence, 256 by default.
+
+    Yields:
+        (int, array_like, array_like)
+            Current sequence ID, train mask, pulse mask.
+    """
+
+    if isinstance(pulses_per_train, Integral):
+        pulses_per_train = np.full_like(train_ids, pulses_per_train,
+                                        dtype=np.uint64)
+
+    if pulse_offsets is None:
+        pulse_offsets = get_pulse_offsets(pulses_per_train)
+
+    for seq_id, train_mask in sequence_trains(train_ids, trains_per_sequence):
+        start = train_mask.start
+        stop = train_mask.stop - 1
+        pulse_mask = np.s_[
+            pulse_offsets[start]:pulse_offsets[stop]+pulses_per_train[stop]]
+
+        yield seq_id, train_mask, pulse_mask
+
+
+def escape_key(key):
+    """Escapes a key name from Karabo to HDF notation."""
+    return key.replace('.', '/')
+
+
+class CustomIndexer:
+    __slots__ = ('parent',)
+
+
+    def __init__(self, parent):
+        self.parent = parent
+
+
+class SourceIndexer(CustomIndexer):
+    def __getitem__(self, source):
+        if ':' in source:
+            root = 'INSTRUMENT'
+            cls = InstrumentSource
+        else:
+            root = 'CONTROL'
+            cls = ControlSource
+
+        return cls(self.parent[f'{root}/{source}'].id, source)
+
+
+class KeyIndexer(CustomIndexer):
+    def __getitem__(self, key):
+        return self.parent[escape_key(key)]
+
+
+class File(h5py.File):
+    """European XFEL HDF5 data file.
+
+    This class extends the h5py.File with methods specific to writing
+    data in the European XFEL file format. The internal state does not
+    depend on using any of these methods, and the underlying file may be
+    manipulated by any of the regular h5py methods, too.
+
+    Please refer to
+    https://extra-data.readthedocs.io/en/latest/data_format.html for
+    details of the file format.
+    """
+
+    filename_format = '{prefix}-R{run:04d}-{aggregator}-S{sequence:05d}.h5'
+    aggregator_pattern = re.compile(r'^\w{2,}\d{2}$')
+    instrument_source_pattern = re.compile(r'^[\w\/-]+:[\w.]+$')
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.__control_sources = set()
+        self.__instrument_sources = set()
+        self.__run = 0
+        self.__sequence = 0
+
+    @classmethod
+    def from_details(cls, folder, aggregator, run, sequence, prefix='CORR',
+                     mode='w', *args, **kwargs):
+        """Open or create a file based on European XFEL details.
+
+        This methis is a wrapper to construct the filename based its
+        components.
+
+        Args:
+            folder (Path, str): Parent location for this file.
+            aggregator (str): Name of the data aggregator, must satisfy
+                DataFile.aggregator_pattern.
+            run (int): Run number.
+            sequence (int): Sequence number.
+            prefix (str, optional): First filename component, 'CORR' by
+                default.
+            args, kwargs: Any additional arguments are passed on to
+                h5py.File
+
+
+        Returns:
+            (DataFile) Opened file object.
+        """
+
+        if not isinstance(folder, Path):
+            folder = Path(folder)
+
+        if not cls.aggregator_pattern.match(aggregator):
+            raise ValueError(f'invalid aggregator format, must satisfy '
+                             f'{cls.aggregator_pattern.pattern}')
+
+        filename = cls.filename_format.format(
+            prefix=prefix, aggregator=aggregator, run=run, sequence=sequence)
+
+        self = cls((folder / filename).resolve(), mode, *args, **kwargs)
+        self.__run = run
+        self.__sequence = sequence
+
+        return self
+
+    @property
+    def source(self):
+        return SourceIndexer(self)
+
+    def create_index(self, train_ids, timestamps=None, flags=None,
+                     origins=None, from_file=None):
+        """Create global INDEX datasets.
+
+        These datasets are agnostic of any source and describe the
+        trains contained in this file.
+
+        Args:
+            train_ids (array_like): Train IDs contained in this file.
+            timestamps (array_like, optional): Timestamp of each train,
+                0 if omitted.
+            flags (array_like, optional): Whether the time server is the
+                initial origin of each train, 1 if omitted.
+            origins (array_like, optional): Which source is the initial
+                origin of each train, -1 (time server) if omitted.
+            from_file (str, Path or extra_data.FileAccess, optional):
+                Existing data file to take timestamps, flags and origins
+                information from if present.
+
+        Returns:
+            None
+        """
+
+        if from_file is not None:
+            from extra_data import FileAccess
+
+            if not isinstance(from_file, FileAccess):
+                from_file = FileAccess(from_file)
+
+            sel_trains = np.isin(from_file.train_ids, train_ids)
+
+            if 'INDEX/timestamp' in from_file.file:
+                timestamps = from_file.file['INDEX/timestamp'][sel_trains]
+
+            flags = from_file.validity_flag[sel_trains]
+
+            if 'INDEX/origin' in from_file.file:
+                origins = from_file.file['INDEX/origin'][sel_trains]
+
+        self.create_dataset('INDEX/trainId', data=train_ids, dtype=np.uint64)
+
+        if timestamps is None:
+            timestamps = np.zeros_like(train_ids, dtype=np.uint64)
+        elif len(timestamps) != len(train_ids):
+            raise ValueError('timestamps and train_ids must be same length')
+
+        self.create_dataset('INDEX/timestamp', data=timestamps,
+                            dtype=np.uint64)
+
+        if flags is None:
+            flags = np.ones_like(train_ids, dtype=np.int32)
+        elif len(flags) != len(train_ids):
+            raise ValueError('flags and train_ids must be same length')
+
+        self.create_dataset('INDEX/flag', data=flags, dtype=np.int32)
+
+        if origins is None:
+            origins = np.full_like(train_ids, -1, dtype=np.int32)
+        elif len(origins) != len(train_ids):
+            raise ValueError('origins and train_ids must be same length')
+
+        self.create_dataset('INDEX/origin', data=origins, dtype=np.int32)
+
+    def create_control_source(self, source):
+        """Create group for a control source ("slow data").
+
+        Control sources created via this method are not required to be
+        passed to create_metadata() again.
+
+        Args:
+            source (str): Karabo device ID.
+
+        Returns:
+            (ControlSource) Created group in CONTROL.
+        """
+
+        self.__control_sources.add(source)
+        return ControlSource(self.create_group(f'CONTROL/{source}').id, source)
+
+    def create_instrument_source(self, source):
+        """Create group for an instrument source ("fast data").
+
+        Instrument sources created via this method are not required to be
+        passed to create_metadata() again.
+
+        Args:
+            source (str): Karabp pipeline path, must satisfy
+                DataFile.instrument_source_pattern.
+
+        Returns:
+            (InstrumentSource) Created group in INSTRUMENT.
+        """
+
+        if not self.instrument_source_pattern.match(source):
+            raise ValueError(f'invalid source format, must satisfy '
+                             f'{self.instrument_source_pattern.pattern}')
+
+        self.__instrument_sources.add(source)
+        return InstrumentSource(self.create_group(f'INSTRUMENT/{source}').id,
+                                source)
+
+    def create_metadata(self, like=None, *,
+                        creation_date=None, update_date=None, proposal=0,
+                        run=0, sequence=None, daq_library='1.x',
+                        karabo_framework='2.x', control_sources=(),
+                        instrument_channels=()):
+        """Create METADATA datasets.
+
+        Args:
+            like (DataCollection or SourceData, optional): Take
+                proposal, run, daq_library, karabo_framework from an
+                EXtra-data data collection, overwriting any of these
+                arguments passed.
+            creation_date (datetime, optional): Creation date and time,
+                now if omitted.
+            update_date (datetime, optional): Update date and time,
+                now if omitted.
+            proposal (int, optional): Proposal number, 0 if omitted and
+                no DataCollection passed.
+            run (int, optional): Run number, 0 if omitted, no
+                DataCollection is passed or object not created via
+                from_details.
+            sequence (int, optional): Sequence number, 0 if omitted and
+                object not created via from_details.
+            daq_library (str, optional): daqLibrary field, '1.x' if
+                omitted and no DataCollection passed.
+            karabo_framework (str, optional): karaboFramework field,
+                '2.x' if omitted and no DataCollection is passed.
+            control_sources (Iterable, optional): Control sources in
+                this file, sources created via create_control_source are
+                automatically included.
+            instrument_channels (Iterable, optional): Instrument
+                channels (source and first component of data hash) in
+                this file, channels created via create_instrument_source
+                are automatically included.
+
+        Returns:
+            None
+        """
+
+        if like is not None:
+            metadata = like.run_metadata()
+            proposal = metadata.get('proposalNumber', proposal)
+            run = metadata.get('runNumber', run)
+            daq_library = metadata.get('daqLibrary', daq_library)
+            karabo_framework = metadata.get('karaboFramework',
+                                            karabo_framework)
+
+        else:
+            if run is None:
+                run = self.__run
+
+            if sequence is None:
+                sequence = self.__sequence
+
+        if creation_date is None:
+            creation_date = datetime.now()
+
+        if update_date is None:
+            update_date = creation_date
+
+        md_group = self.require_group('METADATA')
+        md_group.create_dataset(
+            'creationDate', shape=(1,),
+            data=creation_date.strftime('%Y%m%dT%H%M%SZ').encode('ascii'))
+        md_group.create_dataset('daqLibrary', shape=(1,),
+                                data=daq_library.encode('ascii'))
+        md_group.create_dataset('dataFormatVersion', shape=(1,), data=b'1.2')
+
+        # Start with the known and specified control sources
+        sources = {name: 'CONTROL'
+                   for name in chain(self.__control_sources, control_sources)}
+
+        # Add in the specified instrument data channels.
+        sources.update({full_channel: 'INSTRUMENT'
+                        for full_channel in instrument_channels})
+
+        # Add in those already in the file, if not already passed.
+        sources.update({f'{name}/{channel}': 'INSTRUMENT'
+                        for name in self.__instrument_sources
+                        for channel in self[f'INSTRUMENT/{name}']})
+
+        source_names = sorted(sources.keys())
+        data_sources_shape = (len(sources),)
+        md_group.create_dataset('dataSources/dataSourceId',
+                                shape=data_sources_shape,
+                                data=[f'{sources[name]}/{name}'.encode('ascii')
+                                      for name in source_names])
+        md_group.create_dataset('dataSources/deviceId',
+                                shape=data_sources_shape,
+                                data=[name.encode('ascii')
+                                      for name in source_names])
+        md_group.create_dataset('dataSources/root', shape=data_sources_shape,
+                                data=[sources[name].encode('ascii')
+                                      for name in source_names])
+
+        md_group.create_dataset(
+            'karaboFramework', shape=(1,),
+            data=karabo_framework.encode('ascii'))
+        md_group.create_dataset(
+            'proposalNumber', shape=(1,), dtype=np.uint32, data=proposal)
+        md_group.create_dataset(
+            'runNumber', shape=(1,), dtype=np.uint32, data=run)
+        md_group.create_dataset(
+            'sequenceNumber', shape=(1,), dtype=np.uint32, data=sequence)
+        md_group.create_dataset(
+            'updateDate', shape=(1,),
+            data=update_date.strftime('%Y%m%dT%H%M%SZ').encode('ascii'))
+
+
+class Source(h5py.Group):
+    @property
+    def key(self):
+        return KeyIndexer(self)
+
+
+class ControlSource(Source):
+    """Group for a control source ("slow data").
+
+    This class extends h5py.Group with methods specific to writing data
+    of a control source in the European XFEL file format. The internal
+    state does not depend on using any of these methods, and the
+    underlying file may be manipulated by any of the regular h5py
+    methods, too.
+    """
+
+    ascii_dt = h5py.string_dtype('ascii')
+
+    def __init__(self, group_id, source):
+        super().__init__(group_id)
+
+        self.__source = source
+        self.__run_group = self.file.require_group(f'RUN/{source}')
+
+    def get_run_group(self):
+        return self.__run_group
+
+    def get_index_group(self):
+        return self.file.require_group(f'INDEX/{self.__source}')
+
+    def create_key(self, key, values, timestamps=None, run_entry=None):
+        """Create datasets for a key varying each train.
+
+        Args:
+            key (str): Source key, dots are automatically replaced by
+                slashes.
+            values (array_like): Source values for each train.
+            timestamps (array_like, optional): Timestamps for each
+                source value, 0 if omitted.
+            run_entry (tuple of array_like, optional): Value and
+                timestamp for the corresponding value in the RUN
+                section. The first entry for the train values is used if
+                omitted. No run key is created if exactly False.
+
+        Returns:
+            None
+        """
+
+        key = escape_key(key)
+
+        if timestamps is None:
+            timestamps = np.zeros_like(values, dtype=np.uint64)
+        elif len(values) != len(timestamps):
+            raise ValueError('values and timestamp must be the same length')
+
+        self.create_dataset(f'{key}/value', data=values)
+        self.create_dataset(f'{key}/timestamp', data=timestamps)
+
+        if run_entry is False:
+            return
+        elif run_entry is None:
+            run_entry = (values[0], timestamps[0])
+
+        self.create_run_key(key, *run_entry)
+
+    def create_run_key(self, key, value, timestamp=None):
+        """Create datasets for a key constant over a run.
+
+        Args:
+            key (str): Source key, dots are automatically replaced by
+                slashes.
+            value (Any): Key value.
+            timestamp (int, optional): Timestamp of the value,
+                0 if omitted.
+
+        Returns:
+            None
+        """
+
+        # TODO: Some types/shapes are still not fully correct here.
+
+        key = escape_key(key)
+
+        if timestamp is None:
+            timestamp = 0
+
+        if isinstance(value, list):
+            shape = (1, len(value))
+
+            try:
+                dtype = type(value[0])
+            except IndexError:
+                # Assume empty lists are string-typed.
+                dtype = self.ascii_dt
+        elif isinstance(value, np.ndarray):
+            shape = value.shape
+            dtype = value.dtype
+        else:
+            shape = (1,)
+            dtype = type(value)
+
+        if dtype is str:
+            dtype = self.ascii_dt
+
+        self.__run_group.create_dataset(
+            f'{key}/value', data=value, shape=shape, dtype=dtype)
+        self.__run_group.create_dataset(
+            f'{key}/timestamp', data=timestamp, shape=(1,), dtype=np.uint64)
+
+    def create_index(self, num_trains):
+        """Create source-specific INDEX datasets.
+
+        Depending on whether this source has train-varying data or not,
+        different count/first datasets are written.
+
+        Args:
+            num_trains (int): Total number of trains in this file.
+
+        Returns:
+            None
+        """
+
+        if len(self) > 0:
+            count_func = np.ones
+            first_func = np.arange
+        else:
+            count_func = np.zeros
+            first_func = np.zeros
+
+        index_group = self.get_index_group()
+        index_group.create_dataset(
+            'count', data=count_func(num_trains, dtype=np.uint64))
+        index_group.create_dataset(
+            'first', data=first_func(num_trains, dtype=np.uint64))
+
+
+class InstrumentSource(Source):
+    """Group for an instrument source ("fast data").
+
+    This class extends h5py.Group with methods specific to writing data
+    of a control source in the European XFEL file format. The internal
+    state does not depend on using any of these methods, and the
+    underlying file may be manipulated by any of the regular h5py
+    methods, too.
+    """
+
+    key_pattern = re.compile(r'^\w+\/[\w\/]+$')
+
+    def __init__(self, group_id, source):
+        super().__init__(group_id)
+
+        self.__source = source
+
+    def get_index_group(self, channel):
+        return self.file.require_group(f'INDEX/{self.__source}/{channel}')
+
+    def create_key(self, key, data=None, **kwargs):
+        """Create dataset for a key.
+
+        Args:
+            key (str): Source key, dots are automatically replaced by
+                slashes.
+            data (array_like, optional): Key data to initialize the
+                dataset to.
+            kwargs: Any additional keyword arguments are passed to
+                create_dataset.
+
+        Returns:
+            (h5py.Dataset) Created dataset
+        """
+
+        key = escape_key(key)
+
+        if not self.key_pattern.match(key):
+            raise ValueError(f'invalid key format, must satisfy '
+                             f'{self.key_pattern.pattern}')
+
+        return self.create_dataset(key, data=data, **kwargs)
+
+    def create_compressed_key(self, key, data, comp_threads=8):
+        """Create a compressed dataset for a key.
+
+        This method makes use of lower-level access in h5py to compress
+        the data separately in multiple threads and write it directly to
+        file rather than go through HDF's compression filters.
+
+        Args:
+            key (str): Source key, dots are automatically replaced by
+                slashes.
+            data (np.ndarray): Key data.ss
+            comp_threads (int, optional): Number of threads to use for
+                compression, 8 by default.
+
+        Returns:
+            (h5py.Dataset) Created dataset
+        """
+
+        key = escape_key(key)
+
+        if not self.key_pattern.match(key):
+            raise ValueError(f'invalid key format, must satisfy '
+                             f'{self.key_pattern.pattern}')
+
+        from cal_tools.tools import write_compressed_frames
+        return write_compressed_frames(data, self, key,
+                                       comp_threads=comp_threads)
+
+    def create_index(self, *args, **channels):
+        """Create source-specific INDEX datasets.
+
+        Instrument data is indexed by channel, which is the first
+        component in its key. If channels have already been created, the
+        index may be applied to all channels by passing them as a
+        positional argument.
+        """
+
+        if not channels:
+            try:
+                count = int(args[0])
+            except IndexError:
+                raise ValueError('positional arguments required if no '
+                                 'explicit channels are passed') from None
+                # Allow ValueError to propagate directly.
+
+            channels = {channel: count for channel in self}
+
+        for channel, count in channels.items():
+            index_group = self.get_index_group(channel)
+            index_group.create_dataset('count', data=count, dtype=np.uint64)
+            index_group.create_dataset(
+                'first', data=get_pulse_offsets(count), dtype=np.uint64)
diff --git a/src/exdf_tools/reduce.py b/src/exdf_tools/reduce.py
index 365212c..85dfca7 100644
--- a/src/exdf_tools/reduce.py
+++ b/src/exdf_tools/reduce.py
@@ -9,11 +9,15 @@ import logging
 import re
 import sys
 
+import numpy as np
+
 from pkg_resources import iter_entry_points
 from extra_data import RunDirectory, open_run, by_id
 from extra_data.writer import FileWriter, VirtualFileWriter
 from extra_data.read_machinery import select_train_ids
 
+from . import exdf
+
 
 exdf_filename_pattern = re.compile(
     r'([A-Z]+)-R(\d{4})-(\w+)-S(\d{5}).h5')
@@ -26,7 +30,7 @@ def get_source_origin(sourcedata):
     except AttributeError:
         files = sourcedata.files
 
-    storage_classes = set()
+    data_category = set()
     run_numbers = set()
     aggregators = set()
 
@@ -34,17 +38,47 @@ def get_source_origin(sourcedata):
         m = exdf_filename_pattern.match(Path(f.filename).name)
 
         if m:
-            storage_classes.add(m[1])
+            data_category.add(m[1])
             run_numbers.add(int(m[2]))
             aggregators.add(m[3])
 
-    result_sets = [storage_classes, run_numbers, aggregators]
+    result_sets = [data_category, run_numbers, aggregators]
 
     if any([len(x) != 1 for x in result_sets]):
-        raise ValueError('different or no classes or aggregators recognized '
-                         'from source filenames')
+        raise ValueError('none or multiple sets of traits recognized '
+                         'in source filenames, inconsistent input dataset?')
+
+    return data_category.pop(), run_numbers.pop(), aggregators.pop()
+
+
+def sourcedata_drop_empty_trains(sd):
+    if sd._is_control:
+        train_sel = sd[next(iter(sd.keys()))].drop_empty_trains().train_ids
+    else:
+        sample_keys = dict(zip(
+            [key[:key.find('.')] for key in sd.keys()], sd.keys()))
+
+        train_sel = np.zeros(0, dtype=np.uint64)
+
+        for key in sample_keys.values():
+            train_sel = np.union1d(
+                train_sel, sd[key].drop_empty_trains().train_ids)
+
+    return sd.select_trains(by_id[train_sel])
+
+
+def sourcedata_data_counts(sd):
+    if sd._is_control:
+        return sd[next(iter(sd.keys()))].data_counts(labelled=False)
+    else:
+        sample_keys = dict(zip(
+            [key[:key.find('.')] for key in sd.keys()], sd.keys()))
 
-    return storage_classes.pop(), run_numbers.pop(), aggregators.pop()
+        data_counts = {
+            prefix: sd[key].data_counts(labelled=False)
+            for prefix, key in sample_keys.items()
+        }
+        return data_counts
 
 
 class ReductionSequence:
@@ -62,7 +96,7 @@ class ReductionSequence:
         self._sources = sorted(data.all_sources)
         self._touched_sources = set()
 
-        # Only populated if trains/keys are removed for select sources.
+        # Only populated if trains/keys are selected/removed for sources.
         self._custom_keys = {}  # source -> set(<keys>)
         self._custom_trains = {}  # source -> list(<trains>)
         self._partial_copies = {}  # (source, key) -> (region, chunks)
@@ -76,7 +110,7 @@ class ReductionSequence:
         for source_glob, in self._filter_ops('removeSources'):
             for source in fnmatch.filter(self._sources, source_glob):
                 self._touched_sources.add(source)
-                self._sources.remove(source)                
+                self._sources.remove(source)
 
         for source_glob, key_glob in self._filter_ops('removeKeys'):
             for source in fnmatch.filter(self._sources, source_glob):
@@ -121,19 +155,18 @@ class ReductionSequence:
 
     def write_collection(self, output_path, sequence_len=512):
         outp_data = self._data.select([(s, '*') for s in self._sources])
-        origin_sources = defaultdict(list)
 
-        print(len(self._data.train_ids), len(outp_data.train_ids))
-        return
+        # An origin is a pair of (storage_class, aggregator)
+        origin_sources = defaultdict(list)
 
         run_numbers = set()
 
-        # Collect all origins (combination of storage class and aggregator)
-        # and the sources they contain.
+        # Collect all origins (combination of data category and
+        # aggregator) and the sources they contain.
         for source in self._sources:
-            storage_class, run_no, aggregator = get_source_origin(
-                outp_data[source])
-            origin_sources[(storage_class, aggregator)].append(source)
+            data_category, run_no, aggregator = get_source_origin(
+                outp_data[source])  # TODO: Use support in EXtra-data
+            origin_sources[(data_category, aggregator)].append(source)
             run_numbers.add(run_no)
 
         if len(run_numbers) != 1:
@@ -142,47 +175,58 @@ class ReductionSequence:
 
         run_no = run_numbers.pop()
 
-        for (storage_class, aggregator), sources in origin_sources.items():
-            origin_tids = sorted(set(outp_data.train_ids).intersection(*[
-                tids for s, tids in self._custom_trains.items()
-                if s in sources]))
+        for (data_category, aggregator), sources in origin_sources.items():
+            # Build source & key selection for this origin.
             origin_sel = {
                 s: self._custom_keys[s] if s in self._custom_keys else set()
                 for s in sources}
 
-            # Select output data down to what's in this origin.
+            # Select output data down to what's in this origin both in
+            # terms of sources and trains (via require_any).
             origin_data = outp_data \
-                .select_trains(by_id[origin_tids]) \
-                .select(origin_sel)
-
-            num_trains = len(origin_data.train_ids)
-            num_sequences = int(num_trains / sequence_len) + 1
+                .select(origin_sel, require_any=True)
+
+            # Switch to representation of SourceData objects for
+            # per-source tracking of trains.
+            origin_sources = {source: origin_data[source]
+                            for source in origin_data.all_sources}
+
+            # Apply custom train selections.
+            for source, train_sel in self._custom_trains.items():
+                if source in origin_sources:
+                    origin_sources[source] = origin_sources[source] \
+                        .select_trains(by_id[train_sel])
+
+            # Find the union of trains across all sources as total
+            # trains for this origin.
+            origin_train_ids = np.zeros(0, dtype=np.uint64)
+            for sd in origin_sources.values():
+                origin_train_ids = np.union1d(
+                    origin_train_ids,
+                    sourcedata_drop_empty_trains(sd).train_ids)
+
+            num_trains = len(origin_train_ids)
+            num_sequences = int(np.ceil(num_trains / sequence_len))
 
             self.log.info(
-                f'{storage_class}-{aggregator}: '
-                f'{len(origin_data.all_sources)} sources with {num_trains} '
-                f'trains over {num_sequences} sequences')
+                f'{data_category}-{aggregator}: {len(origin_sources)} sources '
+                f'with {num_trains} trains over {num_sequences} sequences')
 
             for seq_no in range(num_sequences):
-                train_slice = slice(
-                    seq_no * sequence_len,
-                    (seq_no + 1) * sequence_len)
+                # Slice out the train IDs for this sequence.
+                seq_train_ids = origin_train_ids[
+                    (seq_no * sequence_len):((seq_no + 1) * sequence_len)]
 
                 # Select origin data down to what's in this file.
-                file_data = origin_data.select_trains(train_slice)
-
-                # file_data now can contain trains for which no source has
-                # any data. .select(require_all) Unfortunately does the
-                # opposite of what we want
-
-                file_path = output_path / '{}-R{:04d}-{}-S{:06d}.h5'.format(
-                    storage_class, run_no, aggregator, seq_no)
+                seq_sources = {source: sd.select_trains(by_id[seq_train_ids])
+                                for source, sd in origin_sources.items()}
 
+                seq_path = self._write_sequence(
+                    output_path, aggregator, run_no, seq_no, data_category,
+                    seq_sources, seq_train_ids)
                 self.log.debug(
-                    f'{file_path.stem}: {len(file_data.all_sources)} sources '
-                    f'with {len(file_data.train_ids)} trains')
-                fw = ReductionFileWriter(file_path, file_data, self)
-                fw.write()
+                    f'{seq_path.stem}: {len(seq_sources)} sources with '
+                    f'{len(seq_train_ids)} trains')
 
     def write_voview(self, output_path):
         raise NotImplementedError('voview output layout')
@@ -190,6 +234,106 @@ class ReductionSequence:
     def write_collapsed(self, output_path, sequence_len):
         raise NotImplementedError('collapsed collection output layout')
 
+    def _write_sequence(self, folder, aggregator, run, sequence, data_category,
+                        sources, train_ids=None):
+        """Write collection of SourceData to file."""
+
+        if train_ids is None:
+            pass  # Figure out as union of source_train_ids
+
+        sd_example = next(iter(sources.values()))
+
+        # Build sources and indices.
+        control_indices = {}
+        instrument_indices = {}
+
+        for source, sd in sources.items():
+            if sd._is_control:
+                control_indices[sd.source] = sourcedata_data_counts(sd)
+            else:
+                instrument_indices[sd.source] = sourcedata_data_counts(sd)
+
+        with exdf.File.from_details(folder, aggregator, run, sequence,
+                                    prefix=data_category, mode='w') as f:
+            path = Path(f.filename)
+
+            # TODO: Handle old files without extended METADATA?
+            # TODO: Handle timestamps, origin
+            # TODO: Attributes
+
+            # Create METADATA section.
+            f.create_metadata(
+                like=sd_example,
+                control_sources=control_indices.keys(),
+                instrument_channels=[
+                    f'{source}/{channel}'
+                    for source, channels in instrument_indices.items()
+                    for channel in channels.keys()])
+
+            # Create INDEX section.
+            f.create_index(train_ids)
+
+            for source, counts in control_indices.items():
+                control_src = f.create_control_source(source)
+                control_src.create_index(len(train_ids))
+
+            for source, channel_counts in instrument_indices.items():
+                instrument_src = f.create_instrument_source(source)
+                instrument_src.create_index(**channel_counts)
+
+            # Create CONTROL and RUN sections.
+            for source in control_indices.keys():
+                exdf_source = f.source[source]
+                sd = sources[source]
+
+                for key, value in sd.run_values().items():
+                    exdf_source.create_run_key(key, value)
+
+                for key in sd.keys(False):
+                    exdf_source.create_key(key,
+                                           sd[f'{key}.value'].ndarray(),
+                                           sd[f'{key}.timestamp'].ndarray(),
+                                           run_entry=False)
+
+            # Create INSTRUMENT datasets.
+            for source in instrument_indices.keys():
+                exdf_source = f.source[source]
+                sd = sources[source]
+
+                for key in sd.keys():
+                    kd = sd[key]
+
+                    shape = (kd.data_counts(labelled=False).sum(),
+                             *kd.entry_shape)
+                    try:
+                        _, chunks = self._partial_copies[source, key]
+                    except KeyError:
+                        chunks = kd.files[0].file[kd.hdf5_data_path].chunks
+
+                    exdf_source.create_key(
+                        key, shape=shape, maxshape=(None,) + shape[1:],
+                        chunks=chunks, dtype=kd.dtype)
+
+            # Copy INSTRUMENT data.
+            for source in instrument_indices.keys():
+                exdf_source = f.source[source]
+                sd = sources[source]
+
+                for key in sd.keys():
+                    # TODO: Copy by chunk / file if too large
+
+                    data = sd[key].ndarray()
+
+                    try:
+                        region, _ = self._partial_copies[source, key]
+                    except KeyError:
+                        exdf_source.key[key][:] = data
+                    else:
+                        full_region = (np.s_[:], *region)
+                        exdf_source.key[key][full_region] = data[full_region]
+
+        return path
+
 
 class ReductionMethod:
     log = logging.getLogger('ReductionMethod')
@@ -234,54 +378,6 @@ class ReductionMethod:
         self.log.debug(f'Emitted {self._instructions[-1]}')
 
 
-class ReductionFileWriter(FileWriter):
-    def __init__(self, path, data, sequence):
-        super().__init__(path, data)
-        self.sequence = sequence
-
-    def prepare_source(self, source):
-        for key in sorted(self.data.keys_for_source(source)):
-            path = f"{self._section(source)}/{source}/{key.replace('.', '/')}"
-            nentries = self._guess_number_of_storing_entries(source, key)
-            src_ds1 = self.data[source].files[0].file[path]
-
-            try:
-                chunking_kwargs = dict(
-                    chunks=self.sequence._partial_copies[source, key][1])
-            except KeyError:
-                chunking_kwargs = dict()
-
-            self.file.create_dataset_like(
-                path, src_ds1, shape=(nentries,) + src_ds1.shape[1:],
-                # Corrected detector data has maxshape==shape, but if any max
-                # dim is smaller than the chunk size, h5py complains. Making
-                # the first dimension unlimited avoids this.
-                maxshape=(None,) + src_ds1.shape[1:],
-                **chunking_kwargs
-            )
-            if source in self.data.instrument_sources:
-                self.data_sources.add(f"INSTRUMENT/{source}/{key.partition('.')[0]}")
-
-        if source not in self.data.instrument_sources:
-            self.data_sources.add(f"CONTROL/{source}")        
-
-    def copy_dataset(self, source, key):
-        path = f"{self._section(source)}/{source}/{key.replace('.', '/')}"
-
-        try:
-            region = self.sequence._partial_copies[source, key][0]
-        except KeyError:
-            a = self.data.get_array(source, key)
-            self.file[path][:] = a.values
-        else:
-            a = self.data.get_array(source, key)
-            full_region = (slice(None, None, None), *region)
-            
-            self.file[path][full_region] = a.values[full_region]
-        
-        self._make_index(source, key, a.coords['trainId'].values)
-
-
 proposal_run_input_pattern = re.compile(
     r'^p(\d{4}|9\d{5}):r(\d{1,3}):*([a-z]*)$')
 
@@ -362,6 +458,9 @@ def main(argv=None):
         help='export reduction operations to a script file'
     )
 
+    # TODO: Whether to use suspect trains or not
+    # TODO: Whether to drop entirely empty sources
+
     args = ap.parse_args(argv)
 
     if args.verbose:
@@ -383,14 +482,14 @@ def main(argv=None):
             inp_data = open_run(proposal=m[1], run=m[2], data=m[3] or 'raw')
             log.info(f'Found proposal run at '
                      f'{Path(inp_data.files[0].filename).parent}')
-    
+
     else:
         inp_data = RunDirectory(args.input)
 
     log.info(f'Opened data with {len(inp_data.all_sources)} sources with '
              f'{len(inp_data.train_ids)} trains across {len(inp_data.files)} '
              f'files')
-    
+
     methods = {}
 
     for ep in iter_entry_points('exdf_tools.reduction_method'):
@@ -445,7 +544,7 @@ def main(argv=None):
 
         log.debug(f'Writing collapsed collection to {args.output}')
         seq.write_collapsed(args.output, args.output_sequence_len)
-        
+
 
 if __name__ == '__main__':
     main()
-- 
GitLab