From 4d59902c3a6a02331eb9774e7b11c43aba3b9331 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Fri, 15 Dec 2023 16:00:06 +0100
Subject: [PATCH] Simplify reduction operation handlers

---
 src/exdf/data_reduction/red_writer.py | 234 +++++++++++++-------------
 1 file changed, 116 insertions(+), 118 deletions(-)

diff --git a/src/exdf/data_reduction/red_writer.py b/src/exdf/data_reduction/red_writer.py
index d25422d..2b55b27 100644
--- a/src/exdf/data_reduction/red_writer.py
+++ b/src/exdf/data_reduction/red_writer.py
@@ -14,6 +14,37 @@ from exdf.write import SourceDataWriter
 from ..write.datafile import write_compressed_frames
 
 
+def apply_by_source(op_name):
+    def op_decorator(op_func):
+        def op_handler(self):
+            assert isinstance(self, ReduceWriter)
+            for source_glob, *args in self._filter_ops(op_name):
+                for source in fnmatch.filter(self._sources, source_glob):
+                    op_func(self, source, *args)
+                    self._touched_sources.add(source)
+
+        return op_handler
+    return op_decorator
+
+
+def apply_by_key(op_name):
+    def op_decorator(op_func):
+        def op_handler(self):
+            assert isinstance(self, ReduceWriter)
+            for source_glob, key_glob, *args in self._filter_ops(op_name):
+                for source in fnmatch.filter(self._sources, source_glob):
+                    keys = self._custom_keys.get(source,
+                                                set(self._data[source].keys()))
+
+                    for key in fnmatch.filter(keys, key_glob):
+                        op_func(self, source, key, *args)
+
+                    self._touched_sources.add(source)
+
+        return op_handler
+    return op_decorator
+
+
 class ReduceWriter(SourceDataWriter):
     log = logging.getLogger('exdf.data_reduction.ReduceWriter')
 
@@ -50,7 +81,7 @@ class ReduceWriter(SourceDataWriter):
         self._sources = sorted(data.all_sources)
         self._touched_sources = set()
 
-        # Only populated if trains/keys are selected/removed for sources.
+        # Only populated for sources/keys that are modified.
         self._custom_keys = {}  # source -> set(<keys>)
         self._custom_trains = {}  # source -> list(<trains>)
         self._custom_xtdf_masks = {}  # source -> dict(train_id -> mask)
@@ -60,124 +91,29 @@ class ReduceWriter(SourceDataWriter):
         self._subsliced_keys = {}  # (source, key) -> list(<regions>)
         self._compressed_keys = {}  # (source, key) -> level
 
-        # TODO: Raise error if rechunking is overwritten!
-        # TODO: make partial copies a list of slices!
-
         # Collect reductions resulting from operations.
-        for source_glob, in self._filter_ops('remove-sources'):
-            for source in fnmatch.filter(self._sources, source_glob):
-                self._touched_sources.add(source)
-                self._sources.remove(source)
-
-        for source_glob, key_glob in self._filter_ops('remove-keys'):
-            for source in fnmatch.filter(self._sources, source_glob):
-                self._touched_sources.add(source)
-
-                keys = self._custom_keys.setdefault(
-                    source, set(self._data[source].keys()))
-
-                for key in fnmatch.filter(keys, key_glob):
-                    keys.remove(key)
-
-        for source_glob, train_sel in self._filter_ops('select-trains'):
-            for source in fnmatch.filter(self._sources, source_glob):
-                self._touched_sources.add(source)
-                train_ids = self._custom_trains.setdefault(
-                    source, list(self._data.train_ids))
-
-                self._custom_trains[source] = select_train_ids(
-                    train_ids, train_sel)
-
-        for source_glob, index_group, train_sel, entry_sel in self._filter_ops(
-            'select-entries'
-        ):
-            for source in fnmatch.filter(self._sources, source_glob):
-                if index_group not in self._data[source].index_groups:
-                    raise ValueError(f'{index_group} not index group of '
-                                     f'{source}')
-
-                new_mask = self._get_entry_masks(
-                    source, index_group, train_sel, entry_sel)
-
-                self._touched_sources.add(source)
-                self._custom_entry_masks.setdefault(
-                    (source, index_group), {}).update(new_mask)
-
-        for source_glob, train_sel, entry_sel in self._filter_ops(
-            'select-xtdf'
-        ):
-            for source in fnmatch.filter(self._sources, source_glob):
-                if not source.endswith(':xtdf'):
-                    # Simply ignore matches without trailing :xtdf.
-                    continue
-
-                if not self._is_xtdf_source(source):
-                    # Raise exception if essentials are missing.
-                    raise ValueError(f'{source} is not a valid XTDF source')
-
-                new_mask = self._get_entry_masks(
-                    source, 'image', train_sel, entry_sel)
-
-                self._touched_sources.add(source)
-                self._custom_xtdf_masks.setdefault(source, {}).update(new_mask)
-
-        if (
-            {x[0] for x in self._custom_entry_masks.keys()} &
-            self._custom_xtdf_masks.keys()
-        ):
-            raise ValueError('source may not be affected by both '
-                             'select-entries and select-xtdf operations')
-
-        for source_glob, key_glob, chunking in self._filter_ops(
-            'rechunk-keys'
-        ):
-            for source in fnmatch.filter(self._sources, source_glob):
-                if not self._data[source].is_instrument:
-                    raise ValueError(
-                        f'rechunking keys only supported for instrument '
-                        f'sources, but {source_glob} matches '
-                        f'{self._data[source].section}/{source}')
-
-                self._touched_sources.add(source)
-
-                keys = self._custom_keys.get(
-                    source, set(self._data[source].keys()))
-
-                for key in fnmatch.filter(keys, key_glob):
-                    old_chunking = self._rechunked_keys.setdefault(
-                        (source, key), chunking)
-
-                    if old_chunking != chunking:
-                        raise ValueError(
-                            f'reduction sequence yields conflicting chunks '
-                            f'for {source}.{key}: {old_chunking}, {chunking}')
-
-                    self._rechunked_keys[(source, key)] = chunking
-
-        for source_glob, key_glob, region in self._filter_ops('subslice-keys'):
-            for source in fnmatch.filter(self._sources, source_glob):
-                self._touched_sources.add(source)
-
-                keys = self._custom_keys.get(
-                    source, set(self._data[source].keys()))
-
-                for key in fnmatch.filter(keys, key_glob):
-                    self._subsliced_keys.setdefault((source, key), []).append(
-                        region)
-
-        for source_glob, key_glob, level in self._filter_ops('compress-keys'):
-            for source in fnmatch.filter(self._sources, source_glob):
-                self._touched_sources.add(source)
-
-                keys = self._custom_keys.get(
-                    source, set(self._data[source].keys()))
-
-                for key in fnmatch.filter(keys, key_glob):
-                    self._compressed_keys[source, key] = level
-
-        if (self._rechunked_keys.keys() & self._compressed_keys.keys()):
-            raise ValueError('keys may not be affected by both compress-keys '
-                             'and rechunk-keys operations')
+        # This is the most efficient order of operations to minimize
+        # more expensive operations for source or trains that may not
+        # end up being selected.
+        self._handle_remove_sources()
+        self._handle_remove_keys()
+        self._handle_select_trains()
+        self._handle_select_entries()
+        self._handle_select_xtdf()
+        self._handle_rechunk_keys()
+        self._handle_subslice_keys()
+        self._handle_compress_keys()
+
+        custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
+        if custom_entry_sources & self._custom_xtdf_masks.keys():
+            raise ValueError(
+                'Source may not be affected by both select-entries and '
+                'select-xtdf operations')
+
+
+        if self._rechunked_keys.keys() & self._compressed_keys.keys():
+            raise ValueError('Key may not be affected by both '
+                             'compress-keys and rechunk-keys')
 
         if self._scope == 'sources':
             self._sources = sorted(
@@ -245,6 +181,8 @@ class ReduceWriter(SourceDataWriter):
 
         return masks
 
+    # Public API
+
     def write_collection(self, output_path):
         outp_data = self._data.select([(s, '*') for s in self._sources])
 
@@ -434,3 +372,63 @@ class ReduceWriter(SourceDataWriter):
 
         else:
             dest[:] = data
+
+    # Reduction operation handlers.
+
+    @apply_by_source('remove-sources')
+    def _handle_remove_sources(self, source):
+        self._touched_sources.add(source)
+
+    @apply_by_key('remove-keys')
+    def _handle_remove_keys(self, source, key):
+        self._custom_keys[source].remove(key)
+
+    @apply_by_source('select-trains')
+    def _handle_select_trains(self, source, train_sel):
+        self._custom_trains[source] = select_train_ids(
+            self._custom_trains.setdefault(source, list(self._data.train_ids)),
+            train_sel)
+
+    @apply_by_source('select-entries')
+    def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
+        if idx_group not in self._data[source].index_groups:
+            raise ValueError(f'{idx_group} not index group of {source}')
+
+        self._custom_entry_masks.setdefault((source, idx_group), {}).update(
+            self._get_entry_masks(source, idx_group, train_sel, entry_sel))
+
+    @apply_by_source('select-xtdf')
+    def _handle_select_xtdf(self, source, train_sel, entry_sel):
+        if not source.endswith(':xtdf'):
+            # Simply ignore matches without trailing :xtdf.
+            return
+
+        if not self._is_xtdf_source(source):
+            # Raise exception if essentials are missing.
+            raise ValueError(f'{source} is not a valid XTDF source')
+
+        self._custom_xtdf_masks.setdefault(source, {}).update(
+            self._get_entry_masks(source, 'image', train_sel, entry_sel))
+
+    @apply_by_key('rechunk-keys')
+    def _handle_rechunk_keys(self, source, key, chunking):
+        if not self._data[source].is_instrument:
+            # Ignore CONTROL sources.
+            return
+
+        old_chunking = self._rechunked_keys.setdefault((source, key), chunking)
+
+        if old_chunking != chunking:
+            raise ValueError(
+                f'Reduction sequence yields conflicting chunks for '
+                f'{source}.{key}: {old_chunking}, {chunking}')
+
+        self._rechunked_keys[(source, key)] = chunking
+
+    @apply_by_key('subslice-keys')
+    def _handle_subslice_keys(self, source, key, region):
+        self._subsliced_keys.setdefault((source, key), []).append(region)
+
+    @apply_by_key('compress-keys')
+    def _handle_compress_keys(self, source, key, level):
+        self._compressed_keys[source, key] = level
-- 
GitLab