Skip to content
Snippets Groups Projects
Commit 26527b5d authored by Philipp Schmidt's avatar Philipp Schmidt
Browse files

Add custom exception type to propagate fatal errors from ReduceWriter

parent 82bb19d5
No related branches found
No related tags found
1 merge request!4Simplify reduction operation implementations and error handling
......@@ -12,7 +12,7 @@ import sys
from pkg_resources import iter_entry_points
from extra_data import RunDirectory, open_run
from ..data_reduction.red_writer import ReduceWriter
from ..data_reduction.red_writer import ReduceWriter, ReduceInitError
def _parse_args(argv):
......@@ -212,9 +212,12 @@ def main(argv=None):
if args.to_recipe:
_to_recipe(args.to_recipe, methods, inp_data, argv)
writer = ReduceWriter(
inp_data, methods,
args.output_scope, args.output_sequence_len, args.output_version)
try:
writer = ReduceWriter(inp_data, methods, args.output_scope,
args.output_sequence_len, args.output_version)
except ReduceInitError:
log.critical('Failed to initialize reduction writer')
return
if args.output_scope == 'none':
log.info('Not writing out any data files')
......
......@@ -45,6 +45,12 @@ def apply_by_key(op_name):
return op_decorator
class ReduceInitError(RuntimeError):
def __init__(self, msg):
super().__init__(msg)
ReduceWriter.log.error(msg)
class ReduceWriter(SourceDataWriter):
log = logging.getLogger('exdf.data_reduction.ReduceWriter')
......@@ -59,8 +65,8 @@ class ReduceWriter(SourceDataWriter):
input_version = Version(metadata.get('dataFormatVersion', '1.0'))
if input_version < Version('1.0'):
raise ValueError('Currently input files are required to be '
'EXDF-v1.0+')
raise ReduceInitError('Currently input files are required to be '
'EXDF-v1.0+')
if version == 'same':
version = input_version
......@@ -70,7 +76,8 @@ class ReduceWriter(SourceDataWriter):
try:
self.run_number = int(metadata['runNumber'])
except KeyError:
raise ValueError('runNumber dataset required in input METADATA')
raise ReduceInitError('runNumber dataset required to be present '
'in input METADATA')
self._ops = sum(methods.values(), [])
......@@ -106,14 +113,13 @@ class ReduceWriter(SourceDataWriter):
custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()}
if custom_entry_sources & self._custom_xtdf_masks.keys():
raise ValueError(
'Source may not be affected by both select-entries and '
'select-xtdf operations')
raise ReduceInitError('Source may not be affected by both '
'select-entries and select-xtdf operations')
if self._rechunked_keys.keys() & self._compressed_keys.keys():
raise ValueError('Key may not be affected by both '
'compress-keys and rechunk-keys')
raise ReduceInitError('Key may not be affected by both '
'compress-keys and rechunk-keys')
if self._scope == 'sources':
self._sources = sorted(
......@@ -128,8 +134,8 @@ class ReduceWriter(SourceDataWriter):
if (self._data[source].aggregator in touched_aggregators)})
if not self._sources:
raise ValueError('reduction sequence yields empty source '
'selection')
raise ReduceInitError('Reduction operations and output scope '
'yield an empty dataset')
def _filter_ops(self, op):
return [args[1:] for args in self._ops if args[0] == op]
......@@ -158,8 +164,8 @@ class ReduceWriter(SourceDataWriter):
if count == 0:
continue
elif max_entry >= count:
raise ValueError(
f'entry index exceeds data counts of train {train_id}')
raise ReduceInitError(f'Entry index exceeds data counts '
f'of train {train_id}')
masks[train_id] = np.zeros(count, dtype=bool)
masks[train_id][entry_sel] = True
......@@ -171,13 +177,13 @@ class ReduceWriter(SourceDataWriter):
if count == 0:
continue
elif mask_len != counts.get(train_id, 0):
raise ValueError(
f'mask length mismatch for train {train_id}')
raise ReduceInitError(f'Mask length mismatch for '
f'train {train_id}')
masks[train_id] = entry_sel
else:
raise ValueError('unknown entry mask format')
raise ReduceInitError('Unknown entry mask format')
return masks
......@@ -392,7 +398,7 @@ class ReduceWriter(SourceDataWriter):
@apply_by_source('select-entries')
def _handle_select_entries(self, source, idx_group, train_sel, entry_sel):
if idx_group not in self._data[source].index_groups:
raise ValueError(f'{idx_group} not index group of {source}')
raise ReduceInitError(f'{idx_group} not index group of {source}')
self._custom_entry_masks.setdefault((source, idx_group), {}).update(
self._get_entry_masks(source, idx_group, train_sel, entry_sel))
......@@ -405,7 +411,7 @@ class ReduceWriter(SourceDataWriter):
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
raise ReduceInitError(f'{source} is not a valid XTDF source')
self._custom_xtdf_masks.setdefault(source, {}).update(
self._get_entry_masks(source, 'image', train_sel, entry_sel))
......@@ -419,7 +425,7 @@ class ReduceWriter(SourceDataWriter):
old_chunking = self._rechunked_keys.setdefault((source, key), chunking)
if old_chunking != chunking:
raise ValueError(
raise ReduceInitError(
f'Reduction sequence yields conflicting chunks for '
f'{source}.{key}: {old_chunking}, {chunking}')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment