Skip to content
Snippets Groups Projects

Simplify reduction operation implementations and error handling

Merged Philipp Schmidt requested to merge feat/error-handling into feat/compress-keys
1 file
+ 116
118
Compare changes
  • Side-by-side
  • Inline
@@ -14,6 +14,37 @@ from exdf.write import SourceDataWriter
@@ -14,6 +14,37 @@ from exdf.write import SourceDataWriter
from ..write.datafile import write_compressed_frames
from ..write.datafile import write_compressed_frames
 
def apply_by_source(op_name):
 
def op_decorator(op_func):
 
def op_handler(self):
 
assert isinstance(self, ReduceWriter)
 
for source_glob, *args in self._filter_ops(op_name):
 
for source in fnmatch.filter(self._sources, source_glob):
 
op_func(self, source, *args)
 
self._touched_sources.add(source)
 
 
return op_handler
 
return op_decorator
 
 
 
def apply_by_key(op_name):
 
def op_decorator(op_func):
 
def op_handler(self):
 
assert isinstance(self, ReduceWriter)
 
for source_glob, key_glob, *args in self._filter_ops(op_name):
 
for source in fnmatch.filter(self._sources, source_glob):
 
keys = self._custom_keys.get(source,
 
set(self._data[source].keys()))
 
 
for key in fnmatch.filter(keys, key_glob):
 
op_func(source, key, *args)
 
 
self._touched_sources.add(source)
 
 
return op_handler
 
return op_decorator
 
 
class ReduceWriter(SourceDataWriter):
class ReduceWriter(SourceDataWriter):
log = logging.getLogger('exdf.data_reduction.ReduceWriter')
log = logging.getLogger('exdf.data_reduction.ReduceWriter')
@@ -50,7 +81,7 @@ class ReduceWriter(SourceDataWriter):
@@ -50,7 +81,7 @@ class ReduceWriter(SourceDataWriter):
self._sources = sorted(data.all_sources)
self._sources = sorted(data.all_sources)
self._touched_sources = set()
self._touched_sources = set()
# Only populated if trains/keys are selected/removed for sources.
# Only populated for sources/keys that are modified.
self._custom_keys = {} # source -> set(<keys>)
self._custom_keys = {} # source -> set(<keys>)
self._custom_trains = {} # source -> list(<trains>)
self._custom_trains = {} # source -> list(<trains>)
self._custom_xtdf_masks = {} # source -> dict(train_id -> mask)
self._custom_xtdf_masks = {} # source -> dict(train_id -> mask)
@@ -60,124 +91,29 @@ class ReduceWriter(SourceDataWriter):
@@ -60,124 +91,29 @@ class ReduceWriter(SourceDataWriter):
self._subsliced_keys = {} # (source, key) -> list(<regions>)
self._subsliced_keys = {} # (source, key) -> list(<regions>)
self._compressed_keys = {} # (source, key) -> level
self._compressed_keys = {} # (source, key) -> level
# TODO: Raise error if rechunking is overwritten!
# TODO: make partial copies a list of slices!
# Collect reductions resulting from operations.
# Collect reductions resulting from operations.
for source_glob, in self._filter_ops('remove-sources'):
# This is the most efficient order of operations to minimize
for source in fnmatch.filter(self._sources, source_glob):
# more expensive operations for source or trains that may not
self._touched_sources.add(source)
# end up being selected.
self._sources.remove(source)
self._handle_remove_sources()
self._handle_remove_keys()
for source_glob, key_glob in self._filter_ops('remove-keys'):
self._handle_select_trains()
for source in fnmatch.filter(self._sources, source_glob):
self._handle_select_entries()
self._touched_sources.add(source)
self._handle_select_xtdf()
self._handle_rechunk_keys()
keys = self._custom_keys.setdefault(
self._handle_subslice_keys()
source, set(self._data[source].keys()))
self._handle_compress_keys()
for key in fnmatch.filter(keys, key_glob):
custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
keys.remove(key)
if custom_entry_sources & self._custom_xtdf_masks.keys():
raise ValueError(
for source_glob, train_sel in self._filter_ops('select-trains'):
'Source may not be affected by both select-entries and '
for source in fnmatch.filter(self._sources, source_glob):
'select-xtdf operations')
self._touched_sources.add(source)
train_ids = self._custom_trains.setdefault(
source, list(self._data.train_ids))
if self._rechunked_keys.keys() & self._compressed_keys.keys():
raise ValueError('Key may not be affected by both '
self._custom_trains[source] = select_train_ids(
'compress-keys and rechunk-keys')
train_ids, train_sel)
for source_glob, index_group, train_sel, entry_sel in self._filter_ops(
'select-entries'
):
for source in fnmatch.filter(self._sources, source_glob):
if index_group not in self._data[source].index_groups:
raise ValueError(f'{index_group} not index group of '
f'{source}')
new_mask = self._get_entry_masks(
source, index_group, train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_entry_masks.setdefault(
(source, index_group), {}).update(new_mask)
for source_glob, train_sel, entry_sel in self._filter_ops(
'select-xtdf'
):
for source in fnmatch.filter(self._sources, source_glob):
if not source.endswith(':xtdf'):
# Simply ignore matches without trailing :xtdf.
continue
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
new_mask = self._get_entry_masks(
source, 'image', train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_xtdf_masks.setdefault(source, {}).update(new_mask)
if (
{x[0] for x in self._custom_entry_masks.keys()} &
self._custom_xtdf_masks.keys()
):
raise ValueError('source may not be affected by both '
'select-entries and select-xtdf operations')
for source_glob, key_glob, chunking in self._filter_ops(
'rechunk-keys'
):
for source in fnmatch.filter(self._sources, source_glob):
if not self._data[source].is_instrument:
raise ValueError(
f'rechunking keys only supported for instrument '
f'sources, but {source_glob} matches '
f'{self._data[source].section}/{source}')
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
old_chunking = self._rechunked_keys.setdefault(
(source, key), chunking)
if old_chunking != chunking:
raise ValueError(
f'reduction sequence yields conflicting chunks '
f'for {source}.{key}: {old_chunking}, {chunking}')
self._rechunked_keys[(source, key)] = chunking
for source_glob, key_glob, region in self._filter_ops('subslice-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._subsliced_keys.setdefault((source, key), []).append(
region)
for source_glob, key_glob, level in self._filter_ops('compress-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._compressed_keys[source, key] = level
if (self._rechunked_keys.keys() & self._compressed_keys.keys()):
raise ValueError('keys may not be affected by both compress-keys '
'and rechunk-keys operations')
if self._scope == 'sources':
if self._scope == 'sources':
self._sources = sorted(
self._sources = sorted(
@@ -245,6 +181,8 @@ class ReduceWriter(SourceDataWriter):
@@ -245,6 +181,8 @@ class ReduceWriter(SourceDataWriter):
return masks
return masks
 
# Public API
 
def write_collection(self, output_path):
def write_collection(self, output_path):
outp_data = self._data.select([(s, '*') for s in self._sources])
outp_data = self._data.select([(s, '*') for s in self._sources])
@@ -434,3 +372,63 @@ class ReduceWriter(SourceDataWriter):
@@ -434,3 +372,63 @@ class ReduceWriter(SourceDataWriter):
else:
else:
dest[:] = data
dest[:] = data
 
 
# Reduction operation handlers.
 
 
@apply_by_source('remove-sources')
 
def _handle_remove_sources(self, source):
 
self._touched_sources.add(source)
 
 
@apply_by_key('remove-keys')
 
def _handle_remove_keys(self, source, key):
 
self._custom_keys[source].remove(key)
 
 
@apply_by_source('select-trains')
 
def _handle_select_trains(self, source, train_sel):
 
self._custom_trains[source] = select_train_ids(
 
self._custom_trains.setdefault(source, list(self._data.train_ids)),
 
train_sel)
 
 
@apply_by_source('select-entries')
 
def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
 
if idx_group not in self._data[source].index_groups:
 
raise ValueError(f'{idx_group} not index group of {source}')
 
 
self._custom_entry_masks.setdefault((source, idx_group), {}).update(
 
self._get_entry_masks(source, idx_group, train_sel, entry_sel))
 
 
@apply_by_source('select-xtdf')
 
def _handle_select_xtdf(self, source, train_sel, entry_sel):
 
if not source.endswith(':xtdf'):
 
# Simply ignore matches without trailing :xtdf.
 
return
 
 
if not self._is_xtdf_source(source):
 
# Raise exception if essentials are missing.
 
raise ValueError(f'{source} is not a valid XTDF source')
 
 
self._custom_xtdf_masks.setdefault(source, {}).update(
 
self._get_entry_masks(source, 'image', train_sel, entry_sel))
 
 
@apply_by_key('rechunk-keys')
 
def _handle_rechunk_keys(self, source, key, chunking):
 
if not self._data[source].is_instrument:
 
# Ignore CONTROL sources.
 
return
 
 
old_chunking = self._rechunked_keys.setdefault((source, key), chunking)
 
 
if old_chunking != chunking:
 
raise ValueError(
 
f'Reduction sequence yields conflicting chunks for '
 
f'{source}.{key}: {old_chunking}, {chunking}')
 
 
self._rechunked_keys[(source, key)] = chunking
 
 
@apply_by_key('subslice-keys')
 
def _handle_subslice_keys(self, source, key, region):
 
self._subsliced_keys.setdefault((source, key), []).append(region)
 
 
@apply_by_key('compress-keys')
 
def _handle_compress_keys(self, source, key, level):
 
self._compressed_keys[source, key] = level
Loading