diff --git a/src/exdf/data_reduction/red_writer.py b/src/exdf/data_reduction/red_writer.py index d25422de21179a8b0c27baa17e8dd78c5bca142c..2b55b27414fe63cea7d464f770aa8e63f161bac5 100644 --- a/src/exdf/data_reduction/red_writer.py +++ b/src/exdf/data_reduction/red_writer.py @@ -14,6 +14,37 @@ from exdf.write import SourceDataWriter from ..write.datafile import write_compressed_frames +def apply_by_source(op_name): + def op_decorator(op_func): + def op_handler(self): + assert isinstance(self, ReduceWriter) + for source_glob, *args in self._filter_ops(op_name): + for source in fnmatch.filter(self._sources, source_glob): + op_func(self, source, *args) + self._touched_sources.add(source) + + return op_handler + return op_decorator + + +def apply_by_key(op_name): + def op_decorator(op_func): + def op_handler(self): + assert isinstance(self, ReduceWriter) + for source_glob, key_glob, *args in self._filter_ops(op_name): + for source in fnmatch.filter(self._sources, source_glob): + keys = self._custom_keys.get(source, + set(self._data[source].keys())) + + for key in fnmatch.filter(keys, key_glob): + op_func(self, source, key, *args) + + self._touched_sources.add(source) + + return op_handler + return op_decorator + + class ReduceWriter(SourceDataWriter): log = logging.getLogger('exdf.data_reduction.ReduceWriter') @@ -50,7 +81,7 @@ class ReduceWriter(SourceDataWriter): self._sources = sorted(data.all_sources) self._touched_sources = set() - # Only populated if trains/keys are selected/removed for sources. + # Only populated for sources/keys that are modified. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) self._custom_xtdf_masks = {} # source -> dict(train_id -> mask) @@ -60,124 +91,29 @@ class ReduceWriter(SourceDataWriter): self._subsliced_keys = {} # (source, key) -> list(<regions>) self._compressed_keys = {} # (source, key) -> level - # TODO: Raise error if rechunking is overwritten! - # TODO: make partial copies a list of slices! - # Collect reductions resulting from operations. - for source_glob, in self._filter_ops('remove-sources'): - for source in fnmatch.filter(self._sources, source_glob): - self._touched_sources.add(source) - self._sources.remove(source) - - for source_glob, key_glob in self._filter_ops('remove-keys'): - for source in fnmatch.filter(self._sources, source_glob): - self._touched_sources.add(source) - - keys = self._custom_keys.setdefault( - source, set(self._data[source].keys())) - - for key in fnmatch.filter(keys, key_glob): - keys.remove(key) - - for source_glob, train_sel in self._filter_ops('select-trains'): - for source in fnmatch.filter(self._sources, source_glob): - self._touched_sources.add(source) - train_ids = self._custom_trains.setdefault( - source, list(self._data.train_ids)) - - self._custom_trains[source] = select_train_ids( - train_ids, train_sel) - - for source_glob, index_group, train_sel, entry_sel in self._filter_ops( - 'select-entries' - ): - for source in fnmatch.filter(self._sources, source_glob): - if index_group not in self._data[source].index_groups: - raise ValueError(f'{index_group} not index group of ' - f'{source}') - - new_mask = self._get_entry_masks( - source, index_group, train_sel, entry_sel) - - self._touched_sources.add(source) - self._custom_entry_masks.setdefault( - (source, index_group), {}).update(new_mask) - - for source_glob, train_sel, entry_sel in self._filter_ops( - 'select-xtdf' - ): - for source in fnmatch.filter(self._sources, source_glob): - if not source.endswith(':xtdf'): - # Simply ignore matches without trailing :xtdf. - continue - - if not self._is_xtdf_source(source): - # Raise exception if essentials are missing. - raise ValueError(f'{source} is not a valid XTDF source') - - new_mask = self._get_entry_masks( - source, 'image', train_sel, entry_sel) - - self._touched_sources.add(source) - self._custom_xtdf_masks.setdefault(source, {}).update(new_mask) - - if ( - {x[0] for x in self._custom_entry_masks.keys()} & - self._custom_xtdf_masks.keys() - ): - raise ValueError('source may not be affected by both ' - 'select-entries and select-xtdf operations') - - for source_glob, key_glob, chunking in self._filter_ops( - 'rechunk-keys' - ): - for source in fnmatch.filter(self._sources, source_glob): - if not self._data[source].is_instrument: - raise ValueError( - f'rechunking keys only supported for instrument ' - f'sources, but {source_glob} matches ' - f'{self._data[source].section}/{source}') - - self._touched_sources.add(source) - - keys = self._custom_keys.get( - source, set(self._data[source].keys())) - - for key in fnmatch.filter(keys, key_glob): - old_chunking = self._rechunked_keys.setdefault( - (source, key), chunking) - - if old_chunking != chunking: - raise ValueError( - f'reduction sequence yields conflicting chunks ' - f'for {source}.{key}: {old_chunking}, {chunking}') - - self._rechunked_keys[(source, key)] = chunking - - for source_glob, key_glob, region in self._filter_ops('subslice-keys'): - for source in fnmatch.filter(self._sources, source_glob): - self._touched_sources.add(source) - - keys = self._custom_keys.get( - source, set(self._data[source].keys())) - - for key in fnmatch.filter(keys, key_glob): - self._subsliced_keys.setdefault((source, key), []).append( - region) - - for source_glob, key_glob, level in self._filter_ops('compress-keys'): - for source in fnmatch.filter(self._sources, source_glob): - self._touched_sources.add(source) - - keys = self._custom_keys.get( - source, set(self._data[source].keys())) - - for key in fnmatch.filter(keys, key_glob): - self._compressed_keys[source, key] = level - - if (self._rechunked_keys.keys() & self._compressed_keys.keys()): - raise ValueError('keys may not be affected by both compress-keys ' - 'and rechunk-keys operations') + # This is the most efficient order of operations to minimize + # more expensive operations for source or trains that may not + # end up being selected. + self._handle_remove_sources() + self._handle_remove_keys() + self._handle_select_trains() + self._handle_select_entries() + self._handle_select_xtdf() + self._handle_rechunk_keys() + self._handle_subslice_keys() + self._handle_compress_keys() + + custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()} + if custom_entry_sources & self._custom_xtdf_masks.keys(): + raise ValueError( + 'Source may not be affected by both select-entries and ' + 'select-xtdf operations') + + + if self._rechunked_keys.keys() & self._compressed_keys.keys(): + raise ValueError('Key may not be affected by both ' + 'compress-keys and rechunk-keys') if self._scope == 'sources': self._sources = sorted( @@ -245,6 +181,8 @@ class ReduceWriter(SourceDataWriter): return masks + # Public API + def write_collection(self, output_path): outp_data = self._data.select([(s, '*') for s in self._sources]) @@ -434,3 +372,63 @@ class ReduceWriter(SourceDataWriter): else: dest[:] = data + + # Reduction operation handlers. + + @apply_by_source('remove-sources') + def _handle_remove_sources(self, source): + self._touched_sources.add(source) + + @apply_by_key('remove-keys') + def _handle_remove_keys(self, source, key): + self._custom_keys[source].remove(key) + + @apply_by_source('select-trains') + def _handle_select_trains(self, source, train_sel): + self._custom_trains[source] = select_train_ids( + self._custom_trains.setdefault(source, list(self._data.train_ids)), + train_sel) + + @apply_by_source('select-entries') + def _handle_select_entries(self, source, idx_group, train_sel, entry_sel): + if idx_group not in self._data[source].index_groups: + raise ValueError(f'{idx_group} not index group of {source}') + + self._custom_entry_masks.setdefault((source, idx_group), {}).update( + self._get_entry_masks(source, idx_group, train_sel, entry_sel)) + + @apply_by_source('select-xtdf') + def _handle_select_xtdf(self, source, train_sel, entry_sel): + if not source.endswith(':xtdf'): + # Simply ignore matches without trailing :xtdf. + return + + if not self._is_xtdf_source(source): + # Raise exception if essentials are missing. + raise ValueError(f'{source} is not a valid XTDF source') + + self._custom_xtdf_masks.setdefault(source, {}).update( + self._get_entry_masks(source, 'image', train_sel, entry_sel)) + + @apply_by_key('rechunk-keys') + def _handle_rechunk_keys(self, source, key, chunking): + if not self._data[source].is_instrument: + # Ignore CONTROL sources. + return + + old_chunking = self._rechunked_keys.setdefault((source, key), chunking) + + if old_chunking != chunking: + raise ValueError( + f'Reduction sequence yields conflicting chunks for ' + f'{source}.{key}: {old_chunking}, {chunking}') + + self._rechunked_keys[(source, key)] = chunking + + @apply_by_key('subslice-keys') + def _handle_subslice_keys(self, source, key, region): + self._subsliced_keys.setdefault((source, key), []).append(region) + + @apply_by_key('compress-keys') + def _handle_compress_keys(self, source, key, level): + self._compressed_keys[source, key] = level