Skip to content
Snippets Groups Projects

Simplify reduction operation implementations and error handling

Merged Philipp Schmidt requested to merge feat/error-handling into feat/compress-keys
1 file
+ 116
118
Compare changes
  • Side-by-side
  • Inline
@@ -14,6 +14,43 @@ from exdf.write import SourceDataWriter
from ..write.datafile import write_compressed_frames
def apply_by_source(op_name):
def op_decorator(op_func):
def op_handler(self):
assert isinstance(self, ReduceWriter)
for source_glob, *args in self._filter_ops(op_name):
for source in fnmatch.filter(self._sources, source_glob):
op_func(self, source, *args)
self._touched_sources.add(source)
return op_handler
return op_decorator
def apply_by_key(op_name):
def op_decorator(op_func):
def op_handler(self):
assert isinstance(self, ReduceWriter)
for source_glob, key_glob, *args in self._filter_ops(op_name):
for source in fnmatch.filter(self._sources, source_glob):
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
op_func(source, key, *args)
self._touched_sources.add(source)
return op_handler
return op_decorator
class ReduceInitError(RuntimeError):
def __init__(self, msg):
super().__init__(msg)
ReduceWriter.log.error(msg)
class ReduceWriter(SourceDataWriter):
log = logging.getLogger('exdf.data_reduction.ReduceWriter')
@@ -28,8 +65,10 @@ class ReduceWriter(SourceDataWriter):
input_version = Version(metadata.get('dataFormatVersion', '1.0'))
if input_version < Version('1.0'):
raise ValueError('Currently input files are required to be '
'EXDF-v1.0+')
raise ReduceInitError('Currently input files are required to be '
'EXDF-v1.0+')
self.log.debug(f'Input data EXDF version {input_version}')
if version == 'same':
version = input_version
@@ -39,7 +78,8 @@ class ReduceWriter(SourceDataWriter):
try:
self.run_number = int(metadata['runNumber'])
except KeyError:
raise ValueError('runNumber dataset required in input METADATA')
raise ReduceInitError('runNumber dataset required to be present '
'in input METADATA')
self._ops = sum(methods.values(), [])
@@ -50,7 +90,7 @@ class ReduceWriter(SourceDataWriter):
self._sources = sorted(data.all_sources)
self._touched_sources = set()
# Only populated if trains/keys are selected/removed for sources.
# Only populated for sources/keys that are modified.
self._custom_keys = {} # source -> set(<keys>)
self._custom_trains = {} # source -> list(<trains>)
self._custom_xtdf_masks = {} # source -> dict(train_id -> mask)
@@ -60,124 +100,27 @@ class ReduceWriter(SourceDataWriter):
self._subsliced_keys = {} # (source, key) -> list(<regions>)
self._compressed_keys = {} # (source, key) -> level
# TODO: Raise error if rechunking is overwritten!
# TODO: make partial copies a list of slices!
# Collect reductions resulting from operations.
for source_glob, in self._filter_ops('remove-sources'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
self._sources.remove(source)
for source_glob, key_glob in self._filter_ops('remove-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.setdefault(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
keys.remove(key)
for source_glob, train_sel in self._filter_ops('select-trains'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
train_ids = self._custom_trains.setdefault(
source, list(self._data.train_ids))
self._custom_trains[source] = select_train_ids(
train_ids, train_sel)
for source_glob, index_group, train_sel, entry_sel in self._filter_ops(
'select-entries'
):
for source in fnmatch.filter(self._sources, source_glob):
if index_group not in self._data[source].index_groups:
raise ValueError(f'{index_group} not index group of '
f'{source}')
new_mask = self._get_entry_masks(
source, index_group, train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_entry_masks.setdefault(
(source, index_group), {}).update(new_mask)
for source_glob, train_sel, entry_sel in self._filter_ops(
'select-xtdf'
):
for source in fnmatch.filter(self._sources, source_glob):
if not source.endswith(':xtdf'):
# Simply ignore matches without trailing :xtdf.
continue
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
new_mask = self._get_entry_masks(
source, 'image', train_sel, entry_sel)
self._touched_sources.add(source)
self._custom_xtdf_masks.setdefault(source, {}).update(new_mask)
if (
{x[0] for x in self._custom_entry_masks.keys()} &
self._custom_xtdf_masks.keys()
):
raise ValueError('source may not be affected by both '
'select-entries and select-xtdf operations')
for source_glob, key_glob, chunking in self._filter_ops(
'rechunk-keys'
):
for source in fnmatch.filter(self._sources, source_glob):
if not self._data[source].is_instrument:
raise ValueError(
f'rechunking keys only supported for instrument '
f'sources, but {source_glob} matches '
f'{self._data[source].section}/{source}')
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
old_chunking = self._rechunked_keys.setdefault(
(source, key), chunking)
if old_chunking != chunking:
raise ValueError(
f'reduction sequence yields conflicting chunks '
f'for {source}.{key}: {old_chunking}, {chunking}')
self._rechunked_keys[(source, key)] = chunking
for source_glob, key_glob, region in self._filter_ops('subslice-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._subsliced_keys.setdefault((source, key), []).append(
region)
for source_glob, key_glob, level in self._filter_ops('compress-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._compressed_keys[source, key] = level
if (self._rechunked_keys.keys() & self._compressed_keys.keys()):
raise ValueError('keys may not be affected by both compress-keys '
'and rechunk-keys operations')
# This is the most efficient order of operations to minimize
# more expensive operations for source or trains that may not
# end up being selected.
self._handle_remove_sources()
self._handle_remove_keys()
self._handle_select_trains()
self._handle_select_entries()
self._handle_select_xtdf()
self._handle_rechunk_keys()
self._handle_subslice_keys()
self._handle_compress_keys()
custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
if custom_entry_sources & self._custom_xtdf_masks.keys():
raise ReduceInitError('Source may not be affected by both '
'select-entries and select-xtdf operations')
if self._rechunked_keys.keys() & self._compressed_keys.keys():
raise ReduceInitError('Key may not be affected by both '
'compress-keys and rechunk-keys')
if self._scope == 'sources':
self._sources = sorted(
@@ -192,8 +135,12 @@ class ReduceWriter(SourceDataWriter):
if (self._data[source].aggregator in touched_aggregators)})
if not self._sources:
raise ValueError('reduction sequence yields empty source '
'selection')
raise ReduceInitError('Reduction operations and output scope '
'yield an empty dataset')
else:
self.log.debug(
f'Sources being modified: {sorted(self._touched_sources)}')
self.log.debug(f'Sources included in output: {self._sources}')
def _filter_ops(self, op):
return [args[1:] for args in self._ops if args[0] == op]
@@ -222,8 +169,8 @@ class ReduceWriter(SourceDataWriter):
if count == 0:
continue
elif max_entry >= count:
raise ValueError(
f'entry index exceeds data counts of train {train_id}')
raise ReduceInitError(f'Entry index exceeds data counts '
f'of train {train_id}')
masks[train_id] = np.zeros(count, dtype=bool)
masks[train_id][entry_sel] = True
@@ -235,16 +182,18 @@ class ReduceWriter(SourceDataWriter):
if count == 0:
continue
elif mask_len != counts.get(train_id, 0):
raise ValueError(
f'mask length mismatch for train {train_id}')
raise ReduceInitError(f'Mask length mismatch for '
f'train {train_id}')
masks[train_id] = entry_sel
else:
raise ValueError('unknown entry mask format')
raise ReduceInitError('Unknown entry mask format')
return masks
# Public API
def write_collection(self, output_path):
outp_data = self._data.select([(s, '*') for s in self._sources])
@@ -359,10 +308,12 @@ class ReduceWriter(SourceDataWriter):
# Keys are guaranteed to never use both custom chunking and
# compression.
if (source, key) in self._rechunked_keys:
sourcekey = source, key
if sourcekey in self._rechunked_keys:
orig_chunks = kwargs['chunks']
chunks = list(self._rechunked_keys[source, key])
chunks = list(self._rechunked_keys[sourcekey])
assert len(chunks) == len(orig_chunks)
for i, dim_len in enumerate(chunks):
@@ -375,14 +326,14 @@ class ReduceWriter(SourceDataWriter):
kwargs['chunks'] = tuple(chunks)
elif (source, key) in self._compressed_keys or orig_dset.compression:
elif sourcekey in self._compressed_keys or orig_dset.compression:
# TODO: Maintain more of existing properties, for now it is
# forced to use gzip and (1, *entry) chunking.
kwargs['chunks'] = (1,) + kwargs['shape'][1:]
kwargs['shuffle'] = True
kwargs['compression'] = 'gzip'
kwargs['compression_opts'] = self._compressed_keys.setdefault(
(source, key), orig_dset.compression_opts)
sourcekey, orig_dset.compression_opts)
return kwargs
@@ -434,3 +385,75 @@ class ReduceWriter(SourceDataWriter):
else:
dest[:] = data
# Reduction operation handlers.
@apply_by_source('remove-sources')
def _handle_remove_sources(self, source):
self._touched_sources.add(source)
self.log.debug(f'Removing {source}')
@apply_by_key('remove-keys')
def _handle_remove_keys(self, source, key):
self._custom_keys[source].remove(key)
self.log.debug(f'Removing {source}, {key}')
@apply_by_source('select-trains')
def _handle_select_trains(self, source, train_sel):
self._custom_trains[source] = select_train_ids(
self._custom_trains.setdefault(source, list(self._data.train_ids)),
train_sel)
self.log.debug(f'Selecting {len(self._custom_trains[source])} trains '
f'for {source}')
@apply_by_source('select-entries')
def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
if idx_group not in self._data[source].index_groups:
raise ReduceInitError(
f'{idx_group} not an index group of {source}')
self._custom_entry_masks.setdefault((source, idx_group), {}).update(
self._get_entry_masks(source, idx_group, train_sel, entry_sel))
self.log.debug(f'Applying entry selection to {source}, {idx_group}')
@apply_by_source('select-xtdf')
def _handle_select_xtdf(self, source, train_sel, entry_sel):
if not source.endswith(':xtdf'):
self.log.warning(
f'Ignoring non-XTDF source {source} based on name')
return
if not self._is_xtdf_source(source):
self.log.warning(
f'Ignoring non-XTDF source {source} based on structure')
return
self._custom_xtdf_masks.setdefault(source, {}).update(
self._get_entry_masks(source, 'image', train_sel, entry_sel))
self.log.debug(f'Applying XTDF selection to {source}')
@apply_by_key('rechunk-keys')
def _handle_rechunk_keys(self, source, key, chunking):
if not self._data[source].is_instrument:
# Ignore CONTROL sources.
return
old_chunking = self._rechunked_keys.setdefault((source, key), chunking)
if old_chunking != chunking:
raise ReduceInitError(
f'Reduction sequence yields conflicting chunks for '
f'{source}.{key}: {old_chunking}, {chunking}')
self._rechunked_keys[(source, key)] = chunking
self.log.debug(f'Rechunking {source}, {key} to {chunking}')
@apply_by_key('subslice-keys')
def _handle_subslice_keys(self, source, key, region):
self._subsliced_keys.setdefault((source, key), []).append(region)
self.log.debug(f'Subslicing {region} of {source}, {key}')
@apply_by_key('compress-keys')
def _handle_compress_keys(self, source, key, level):
self._compressed_keys[source, key] = level
self.log.debug(f'Compressing {source}, {key}')
Loading