From 207b998921dc896f39c242050011c0cafe2f4397 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Wed, 25 Oct 2023 16:55:59 +0200
Subject: [PATCH] Add implementation for select-rows operation

---
 src/exdf/data_reduction/method.py |  53 ++++++++--
 src/exdf/data_reduction/writer.py |  78 ++++++++++++++-
 src/exdf/write/sd_writer.py       | 156 +++++++++++++++++++++++-------
 3 files changed, 244 insertions(+), 43 deletions(-)

diff --git a/src/exdf/data_reduction/method.py b/src/exdf/data_reduction/method.py
index d106ed4..fd75a00 100644
--- a/src/exdf/data_reduction/method.py
+++ b/src/exdf/data_reduction/method.py
@@ -3,11 +3,13 @@ from typing import TypeVar
 from logging import getLogger
 import warnings
 
+import numpy as np
 from extra_data.read_machinery import select_train_ids
 
 
 log = getLogger('exdf.data_reduction.ReductionMethod')
 train_sel = TypeVar('train_sel')
+row_sel = TypeVar('row_sel')
 index_exp = TypeVar('index_exp')
 
 
@@ -22,6 +24,22 @@ def is_train_selection(x):
         return True
 
 
+def is_row_selection(x):
+    if isinstance(x, slice):
+        return True
+
+    if isinstance(x, list) and all([isinstance(y, (int, bool)) for y in x]):
+        return True
+
+    if (
+        isinstance(x, np.ndarray) and x.ndim == 1 and
+        (np.issubdtype(x.dtype, np.integer) or np.issubdtype(x.dtype, bool))
+    ):
+        return True
+
+    return False
+
+
 def is_index_expression(x):
     if isinstance(x, (slice, list)):
         return True
@@ -54,19 +72,42 @@ class ReductionMethod(list):
         assert is_train_selection(trains)
         self._emit('select-trains', source_glob, trains)
 
-    def select_pulses(
+    def select_rows(
         self,
         source_glob: str,
-        index_group: str,  # May be xtdf
+        index_group: str,
         trains: train_sel,
-        rows: index_exp
+        rows: row_sel
     ):
-        raise NotImplementedError('select-pulses')
         assert isinstance(source_glob, str)
         assert isinstance(index_group, str)
         assert is_train_selection(trains)
-        assert is_index_expression(rows)
-        self._emit('select-pulses', source_glob, index_group, trains, rows)
+        assert is_row_selection(rows)
+        self._emit('select-rows', source_glob, index_group, trains, rows)
+
+    def select_xtdf(
+        self,
+        source_glob: str,
+        trains: train_sel,
+        rows: row_sel
+    ):
+        """Slice XTDF data by row.
+
+        Roughly equivalent to select_rows(source_glob, 'image',
+        train_sel, row_sel), but only acts on XTDF sources and modifies
+        header data structures according to slicing.
+
+        Requires sources to end with :xtdf and have all XTDF keys.
+
+        Args:
+            source_glob (str): Source glob pattern.
+            train_sel (train_sel): Train selection.
+            row_sel (row_sel): Row selection.
+        """
+        assert isinstance(source_glob, str)
+        assert is_train_selection(trains)
+        assert is_row_selection(rows)
+        self._emit('select-xtdf', source_glob, trains, rows)
 
     def remove_sources(
         self,
diff --git a/src/exdf/data_reduction/writer.py b/src/exdf/data_reduction/writer.py
index a025092..9282f30 100644
--- a/src/exdf/data_reduction/writer.py
+++ b/src/exdf/data_reduction/writer.py
@@ -52,6 +52,7 @@ class ReduceWriter(SourceDataWriter):
         # Only populated if trains/keys are selected/removed for sources.
         self._custom_keys = {}  # source -> set(<keys>)
         self._custom_trains = {}  # source -> list(<trains>)
+        self._custom_rows = {}  # source -> dict(train_id -> mask)
         self._rechunked_keys = {}  # (source, key) -> chunks
         self._partial_copies = {}  # (source, key) -> list(<regions>)
 
@@ -83,6 +84,19 @@ class ReduceWriter(SourceDataWriter):
                 self._custom_trains[source] = select_train_ids(
                     train_ids, train_sel)
 
+        for source_glob, index_group, train_sel, row_sel in self._filter_ops(
+            'select-rows'
+        ):
+            for source in fnmatch.filter(self._sources, source_glob):
+                if index_group not in self._data[source].index_groups:
+                    raise ValueError(f'{index_group} not index group of '
+                                     f'{source}')
+
+                self._touched_sources.add(source)
+                self._custom_rows.setdefault((source, index_group), {}).update(
+                    self._get_row_masks(source, index_group,
+                                        train_sel, row_sel))
+
         for source_glob, key_glob, chunking in self._filter_ops(
             'rechunk-keys'
         ):
@@ -139,6 +153,50 @@ class ReduceWriter(SourceDataWriter):
     def _filter_ops(self, op):
         return [args[1:] for args in self._ops if args[0] == op]
 
+    def _get_row_masks(self, source, index_group, train_sel, row_sel):
+        train_ids = select_train_ids(
+            self._custom_trains.get(source, list(self._data.train_ids)),
+            train_sel)
+        counts = self._data[source].select_trains(by_id[train_ids]) \
+            .data_counts(index_group=index_group)
+        masks = {}
+
+        if isinstance(row_sel, slice):
+            for train_id, count in counts.items():
+                if count > 0:
+                    masks[train_id] = np.zeros(count, dtype=bool)
+                    masks[train_id][row_sel] = True
+
+        elif np.issubdtype(type(row_sel[0]), np.integer):
+            max_row = max(row_sel)
+
+            for train_id, count in counts.items():
+                if count == 0:
+                    continue
+                elif max_row >= count:
+                    raise ValueError(
+                        f'row index exceeds data counts of train {train_id}')
+
+                masks[train_id] = np.zeros(count, dtype=bool)
+                masks[train_id][row_sel] = True
+
+        elif np.issubdtype(type(row_sel[0]), bool):
+            mask_len = len(row_sel)
+
+            for train_id, count in counts.items():
+                if count == 0:
+                    continue
+                elif mask_len != counts.get(train_id, 0):
+                    raise ValueError(
+                        f'mask length mismatch for train {train_id}')
+
+                masks[train_id] = row_sel
+
+        else:
+            raise ValueError('unknown row mask format')
+
+        return masks
+
     def write_collection(self, output_path):
         outp_data = self._data.select([(s, '*') for s in self._sources])
 
@@ -266,7 +324,25 @@ class ReduceWriter(SourceDataWriter):
         except KeyError:
             return orig_chunks
 
-    def copy_instrument_data(self, source, key, dest, data):
+    def mask_instrument_data(self, source, index_group, train_ids, counts):
+        if (source, index_group) in self._custom_rows:
+            custom_masks = self._custom_rows[source, index_group]
+        else:
+            return  # None efficiently selects all rows.
+
+        masks = []
+
+        for train_id, count_all in zip(train_ids, counts):
+            if train_id in custom_masks:
+                mask = custom_masks[train_id]
+            else:
+                mask = np.ones(count_all, dtype=bool)
+
+            masks.append(mask)
+
+        return masks
+
+    def copy_instrument_data(self, source, key, dest, train_ids, data):
         try:
             regions = self._partial_copies[source, key]
         except KeyError:
diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py
index 46489ae..d9354af 100644
--- a/src/exdf/write/sd_writer.py
+++ b/src/exdf/write/sd_writer.py
@@ -17,7 +17,7 @@ from operator import or_
 import numpy as np
 
 from extra_data import FileAccess
-from . import DataFile
+from .datafile import DataFile, get_pulse_offsets
 
 
 log = getLogger('exdf.write.SourceDataWriter')
@@ -50,13 +50,37 @@ class SourceDataWriter:
 
         return orig_chunks
 
-    def copy_instrument_data(self, source, key, dest, data):
+    def mask_instrument_data(self, source, index_group, train_ids, counts):
+        """Mask INSTRUMENT data.
+
+        Each mask array must have the same length as the original data
+        counts.
+
+        Args:
+            source (str): Source name.
+            index_group (str): Index group.
+            train_ids (ndarray): Train IDs.
+            counts (ndarray): Data counts per train.
+
+        Returns:
+            (Iterable of ndarray): Boolean masks for each passed
+                train ID and equal in length to respective data counts,
+                or None to perform no masking.
+        """
+
+        return
+
+    def copy_instrument_data(self, source, key, dest, train_ids, data):
         """Copy INSTRUMENT data from input to output.
 
+        The destination dataset is guaranteed to align with the shape of
+        train_ids and data.
+
         Args:
             source (str): Source name.
             key (str): Key name.
             dset (h5py.Dataset): Destination dataset.
+            train_ids (ndarray): Train ID coordinates.
             data (ndarray): Source data.
 
         Returns:
@@ -110,9 +134,9 @@ class SourceDataWriter:
             data_format_version=self.get_data_format_version(),
             control_sources=control_indices.keys(),
             instrument_channels=[
-                f'{source}/{channel}'
-                for source, channels in instrument_indices.items()
-                for channel in channels.keys()])
+                f'{source}/{index_group}'
+                for source, index_group_counts in instrument_indices.items()
+                for index_group in index_group_counts.keys()])
         f.create_dataset('METADATA/dataWriter', data=b'exdf-tools', shape=(1,))
 
         if not self.with_origin():
@@ -124,9 +148,10 @@ class SourceDataWriter:
             control_src = f.create_control_source(source)
             control_src.create_index(len(train_ids), per_train=True)
 
-        for source, channel_counts in instrument_indices.items():
+        for source, index_group_counts in instrument_indices.items():
+            # May be overwritten later as a result of masking.
             instrument_src = f.create_instrument_source(source)
-            instrument_src.create_index(**channel_counts)
+            instrument_src.create_index(**index_group_counts)
 
     def write_control(self, f, sources):
         """Write CONTROL and RUN data.
@@ -143,7 +168,7 @@ class SourceDataWriter:
         """
 
         for sd in sources:
-            source = f.source[sd.source]
+            h5source = f.source[sd.source]
 
             attrs = get_key_attributes(sd) if self.with_attrs() else {}
             run_data_leafs = {}
@@ -164,23 +189,24 @@ class SourceDataWriter:
                 ctrl_values = sd[f'{key}.value'].ndarray()
                 ctrl_timestamps = sd[f'{key}.timestamp'].ndarray()
 
-                source.create_key(
+                h5source.create_key(
                     key, values=ctrl_values, timestamps=ctrl_timestamps,
                     run_entry=run_entry, attrs=attrs.pop(key, None))
 
             # Write remaining RUN-only keys.
             for key, leafs in run_data_leafs.items():
-                source.create_run_key(key, **leafs, attrs=attrs.pop(key, None))
+                h5source.create_run_key(
+                    key, **leafs, attrs=attrs.pop(key, None))
 
             # Fill in the missing attributes for nodes.
             for path, attrs in attrs.items():
-                source.run_key[path].attrs.update(attrs)
-                source.key[path].attrs.update(attrs)
+                h5source.run_key[path].attrs.update(attrs)
+                h5source.key[path].attrs.update(attrs)
 
     def write_instrument(self, f, sources):
         """Write INSTRUMENT data.
 
-        This method assumes the source datasets already exist.
+        This method assumes the INDEX and source datasets already exist.
 
         Args:
             f (exdf.DataFile): Output file.
@@ -191,40 +217,67 @@ class SourceDataWriter:
             None
         """
 
-        for sd in sources:
-            source = f.source[sd.source]
+        # Must be re-read at this point, as additional trains could have
+        # been introduced in this sequence.
+        train_ids = np.array(f['INDEX/trainId'])
 
+        # Stores mask for each row per index group.
+        masks = {}
+
+        for sd in sources:
             attrs = get_key_attributes(sd) if self.with_attrs() else {}
+            h5source = f.source[sd.source]
+            keys = sd.keys()
 
-            for key in sd.keys():
-                kd = sd[key]
+            for index_group in sd.index_groups:
+                # Must be re-read same as train IDs.
+                h5index = f[f'INDEX/{sd.source}/{index_group}']
+                counts = np.array(h5index['count'])
 
-                shape = (kd.data_counts(labelled=False).sum(), *kd.entry_shape)
-                chunks = self.chunk_instrument_data(
-                    sd.source, key,
-                    kd.files[0].file[kd.hdf5_data_path].chunks)
+                # Obtain mask for this index group.
+                masks_by_train = self.mask_instrument_data(
+                    sd.source, index_group, train_ids, counts)
 
-                source.create_key(
-                    key, shape=shape, maxshape=(None,) + shape[1:],
-                    chunks=chunks, dtype=kd.dtype, attrs=attrs.pop(key, None))
+                if masks_by_train is not None:
+                    masks[index_group] = mask_index(
+                        h5index, counts, masks_by_train)
+                    num_entries = masks[index_group].sum()
+                else:
+                    num_entries = counts.sum()
 
-            for path, attrs in attrs.items():
-                source.key[path].attrs.update(attrs)
+                for key in iter_index_group_keys(keys, index_group):
+                    kd = sd[key]
 
-            # Update tableSize for each index group to the correct
-            # number of trains.
-            for index_group in sd.index_groups:
-                source[index_group].attrs['tableSize'] = sd.data_counts(
-                    labelled=False, index_group=index_group).sum()
+                    shape = (num_entries, *kd.entry_shape)
+                    chunks = self.chunk_instrument_data(
+                        sd.source, key,
+                        kd.files[0].file[kd.hdf5_data_path].chunks)
+
+                    h5source.create_key(
+                        key, shape=shape, maxshape=(None,) + shape[1:],
+                        chunks=chunks, dtype=kd.dtype,
+                        attrs=attrs.pop(key, None))
+
+                # Update tableSize to the correct number of records.
+                h5source[index_group].attrs['tableSize'] = num_entries
+
+            for path, attrs in attrs.items():
+                h5source.key[path].attrs.update(attrs)
 
         # Copy INSTRUMENT data.
         for sd in sources:
-            source = f.source[sd.source]
+            h5source = f.source[sd.source]
+
+            for index_group in sd.index_groups:
+                mask = masks.get(index_group, np.s_[:])
+
+                for key in iter_index_group_keys(keys, index_group):
+                    # TODO: Copy by chunk / file if too large
 
-            for key in sd.keys():
-                # TODO: Copy by chunk / file if too large
-                self.copy_instrument_data(sd.source, key, source.key[key],
-                                          sd[key].ndarray())
+                    self.copy_instrument_data(
+                        sd.source, key, h5source.key[key],
+                        sd[key].train_id_coordinates()[mask],
+                        sd[key].ndarray()[mask])
 
 
 def get_index_root_data(sources):
@@ -375,3 +428,34 @@ def get_key_attributes(sd):
                                          f'{sd.source}.{path}')
 
     return source_attrs
+
+def iter_index_group_keys(keys, index_group):
+    for key in keys:
+        if key[:key.index('.')] == index_group:
+            yield key
+
+
+def mask_index(g, counts, masks_by_train):
+    full_mask = np.concatenate(masks_by_train)
+    num_entries = counts.sum()
+
+    assert len(full_mask) == num_entries, \
+        'incompatible INSTRUMENT mask shape'
+
+    # Modify INDEX entry if necessary.
+    if full_mask.sum() != num_entries:
+        g.create_dataset(
+            f'original/first', data=get_pulse_offsets(counts))
+        g.create_dataset(
+            f'original/count', data=counts)
+        g.create_dataset(
+            f'original/position',
+            data=np.concatenate([np.flatnonzero(mask)
+                                 for mask in masks_by_train]))
+
+        # Compute new data counts.
+        counts = [mask.sum() for mask in masks_by_train]
+        g['first'][:] = get_pulse_offsets(counts)
+        g['count'][:] = counts
+
+        return full_mask
-- 
GitLab