Skip to content
Snippets Groups Projects
Commit 4d59902c authored by Philipp Schmidt's avatar Philipp Schmidt
Browse files

Simplify reduction operation handlers

parent 958316d3
No related branches found
No related tags found
1 merge request!6Simplify reduction operation implementations and error handling
......@@ -14,6 +14,37 @@ from exdf.write import SourceDataWriter
from ..write.datafile import write_compressed_frames
def apply_by_source(op_name):
def op_decorator(op_func):
def op_handler(self):
assert isinstance(self, ReduceWriter)
for source_glob, *args in self._filter_ops(op_name):
for source in fnmatch.filter(self._sources, source_glob):
op_func(self, source, *args)
self._touched_sources.add(source)
return op_handler
return op_decorator
def apply_by_key(op_name):
def op_decorator(op_func):
def op_handler(self):
assert isinstance(self, ReduceWriter)
for source_glob, key_glob, *args in self._filter_ops(op_name):
for source in fnmatch.filter(self._sources, source_glob):
keys = self._custom_keys.get(source,
set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
op_func(self, source, key, *args)
self._touched_sources.add(source)
return op_handler
return op_decorator
class ReduceWriter(SourceDataWriter):
log = logging.getLogger('exdf.data_reduction.ReduceWriter')
......@@ -50,7 +81,7 @@ class ReduceWriter(SourceDataWriter):
self._sources = sorted(data.all_sources)
self._touched_sources = set()
# Only populated if trains/keys are selected/removed for sources.
# Only populated for sources/keys that are modified.
self._custom_keys = {} # source -> set(<keys>)
self._custom_trains = {} # source -> list(<trains>)
self._custom_xtdf_masks = {} # source -> dict(train_id -> mask)
......@@ -60,124 +91,29 @@ class ReduceWriter(SourceDataWriter):
self._subsliced_keys = {} # (source, key) -> list(<regions>)
self._compressed_keys = {} # (source, key) -> level
# TODO: Raise error if rechunking is overwritten!
# TODO: make partial copies a list of slices!
# Collect reductions resulting from operations.
for source_glob, in self._filter_ops('remove-sources'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
self._sources.remove(source)
for source_glob, key_glob in self._filter_ops('remove-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.setdefault(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
keys.remove(key)
for source_glob, train_sel in self._filter_ops('select-trains'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
train_ids = self._custom_trains.setdefault(
source, list(self._data.train_ids))
self._custom_trains[source] = select_train_ids(
train_ids, train_sel)
for source_glob, index_group, train_sel, entry_sel in self._filter_ops(
'select-entries'
):
for source in fnmatch.filter(self._sources, source_glob):
if index_group not in self._data[source].index_groups:
raise ValueError(f'{index_group} not index group of '
f'{source}')
new_mask = self._get_entry_masks(
source, index_group, train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_entry_masks.setdefault(
(source, index_group), {}).update(new_mask)
for source_glob, train_sel, entry_sel in self._filter_ops(
'select-xtdf'
):
for source in fnmatch.filter(self._sources, source_glob):
if not source.endswith(':xtdf'):
# Simply ignore matches without trailing :xtdf.
continue
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
new_mask = self._get_entry_masks(
source, 'image', train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_xtdf_masks.setdefault(source, {}).update(new_mask)
if (
{x[0] for x in self._custom_entry_masks.keys()} &
self._custom_xtdf_masks.keys()
):
raise ValueError('source may not be affected by both '
'select-entries and select-xtdf operations')
for source_glob, key_glob, chunking in self._filter_ops(
'rechunk-keys'
):
for source in fnmatch.filter(self._sources, source_glob):
if not self._data[source].is_instrument:
raise ValueError(
f'rechunking keys only supported for instrument '
f'sources, but {source_glob} matches '
f'{self._data[source].section}/{source}')
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
old_chunking = self._rechunked_keys.setdefault(
(source, key), chunking)
if old_chunking != chunking:
raise ValueError(
f'reduction sequence yields conflicting chunks '
f'for {source}.{key}: {old_chunking}, {chunking}')
self._rechunked_keys[(source, key)] = chunking
for source_glob, key_glob, region in self._filter_ops('subslice-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._subsliced_keys.setdefault((source, key), []).append(
region)
for source_glob, key_glob, level in self._filter_ops('compress-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._compressed_keys[source, key] = level
if (self._rechunked_keys.keys() & self._compressed_keys.keys()):
raise ValueError('keys may not be affected by both compress-keys '
'and rechunk-keys operations')
# This is the most efficient order of operations to minimize
# more expensive operations for source or trains that may not
# end up being selected.
self._handle_remove_sources()
self._handle_remove_keys()
self._handle_select_trains()
self._handle_select_entries()
self._handle_select_xtdf()
self._handle_rechunk_keys()
self._handle_subslice_keys()
self._handle_compress_keys()
custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
if custom_entry_sources & self._custom_xtdf_masks.keys():
raise ValueError(
'Source may not be affected by both select-entries and '
'select-xtdf operations')
if self._rechunked_keys.keys() & self._compressed_keys.keys():
raise ValueError('Key may not be affected by both '
'compress-keys and rechunk-keys')
if self._scope == 'sources':
self._sources = sorted(
......@@ -245,6 +181,8 @@ class ReduceWriter(SourceDataWriter):
return masks
# Public API
def write_collection(self, output_path):
outp_data = self._data.select([(s, '*') for s in self._sources])
......@@ -434,3 +372,63 @@ class ReduceWriter(SourceDataWriter):
else:
dest[:] = data
# Reduction operation handlers.
@apply_by_source('remove-sources')
def _handle_remove_sources(self, source):
self._touched_sources.add(source)
@apply_by_key('remove-keys')
def _handle_remove_keys(self, source, key):
self._custom_keys[source].remove(key)
@apply_by_source('select-trains')
def _handle_select_trains(self, source, train_sel):
self._custom_trains[source] = select_train_ids(
self._custom_trains.setdefault(source, list(self._data.train_ids)),
train_sel)
@apply_by_source('select-entries')
def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
if idx_group not in self._data[source].index_groups:
raise ValueError(f'{idx_group} not index group of {source}')
self._custom_entry_masks.setdefault((source, idx_group), {}).update(
self._get_entry_masks(source, idx_group, train_sel, entry_sel))
@apply_by_source('select-xtdf')
def _handle_select_xtdf(self, source, train_sel, entry_sel):
if not source.endswith(':xtdf'):
# Simply ignore matches without trailing :xtdf.
return
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
self._custom_xtdf_masks.setdefault(source, {}).update(
self._get_entry_masks(source, 'image', train_sel, entry_sel))
@apply_by_key('rechunk-keys')
def _handle_rechunk_keys(self, source, key, chunking):
if not self._data[source].is_instrument:
# Ignore CONTROL sources.
return
old_chunking = self._rechunked_keys.setdefault((source, key), chunking)
if old_chunking != chunking:
raise ValueError(
f'Reduction sequence yields conflicting chunks for '
f'{source}.{key}: {old_chunking}, {chunking}')
self._rechunked_keys[(source, key)] = chunking
@apply_by_key('subslice-keys')
def _handle_subslice_keys(self, source, key, region):
self._subsliced_keys.setdefault((source, key), []).append(region)
@apply_by_key('compress-keys')
def _handle_compress_keys(self, source, key, level):
self._compressed_keys[source, key] = level
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment