Skip to content
Snippets Groups Projects
Commit 28142409 authored by Philipp Schmidt's avatar Philipp Schmidt
Browse files

Add implementation for select-xtdf operation

parent 207b9989
No related branches found
No related tags found
No related merge requests found
...@@ -52,6 +52,8 @@ class ReduceWriter(SourceDataWriter): ...@@ -52,6 +52,8 @@ class ReduceWriter(SourceDataWriter):
# Only populated if trains/keys are selected/removed for sources. # Only populated if trains/keys are selected/removed for sources.
self._custom_keys = {} # source -> set(<keys>) self._custom_keys = {} # source -> set(<keys>)
self._custom_trains = {} # source -> list(<trains>) self._custom_trains = {} # source -> list(<trains>)
self._custom_xtdf_masks = {} # source -> dict(train_id -> mask)
self._custom_xtdf_counts = {} # source -> ndarray
self._custom_rows = {} # source -> dict(train_id -> mask) self._custom_rows = {} # source -> dict(train_id -> mask)
self._rechunked_keys = {} # (source, key) -> chunks self._rechunked_keys = {} # (source, key) -> chunks
self._partial_copies = {} # (source, key) -> list(<regions>) self._partial_copies = {} # (source, key) -> list(<regions>)
...@@ -97,6 +99,27 @@ class ReduceWriter(SourceDataWriter): ...@@ -97,6 +99,27 @@ class ReduceWriter(SourceDataWriter):
self._get_row_masks(source, index_group, self._get_row_masks(source, index_group,
train_sel, row_sel)) train_sel, row_sel))
for source_glob, train_sel, row_sel in self._filter_ops('select-xtdf'):
for source in fnmatch.filter(self._sources, source_glob):
if not source.endswith(':xtdf'):
# Simply ignore matches without trailing :xtdf.
continue
if not self._is_xtdf_source(source):
# Raise exception if essentials are missing.
raise ValueError(f'{source} is not a valid XTDF source')
self._touched_sources.add(source)
self._custom_xtdf_masks.setdefault(source, {}).update(
self._get_row_masks(source, 'image', train_sel, row_sel))
if (
{x[0] for x in self._custom_rows.keys()} &
self._custom_xtdf_masks.keys()
):
raise ValueError('source may not be affected by both select-rows '
'and select-xtdf operations')
for source_glob, key_glob, chunking in self._filter_ops( for source_glob, key_glob, chunking in self._filter_ops(
'rechunk-keys' 'rechunk-keys'
): ):
...@@ -153,6 +176,9 @@ class ReduceWriter(SourceDataWriter): ...@@ -153,6 +176,9 @@ class ReduceWriter(SourceDataWriter):
def _filter_ops(self, op): def _filter_ops(self, op):
return [args[1:] for args in self._ops if args[0] == op] return [args[1:] for args in self._ops if args[0] == op]
def _is_xtdf_source(self, source):
return self._data[source].keys() > {'header.pulseCount', 'image.data'}
def _get_row_masks(self, source, index_group, train_sel, row_sel): def _get_row_masks(self, source, index_group, train_sel, row_sel):
train_ids = select_train_ids( train_ids = select_train_ids(
self._custom_trains.get(source, list(self._data.train_ids)), self._custom_trains.get(source, list(self._data.train_ids)),
...@@ -325,7 +351,9 @@ class ReduceWriter(SourceDataWriter): ...@@ -325,7 +351,9 @@ class ReduceWriter(SourceDataWriter):
return orig_chunks return orig_chunks
def mask_instrument_data(self, source, index_group, train_ids, counts): def mask_instrument_data(self, source, index_group, train_ids, counts):
if (source, index_group) in self._custom_rows: if source in self._custom_xtdf_masks and index_group == 'image':
custom_masks = self._custom_xtdf_masks[source]
elif (source, index_group) in self._custom_rows:
custom_masks = self._custom_rows[source, index_group] custom_masks = self._custom_rows[source, index_group]
else: else:
return # None efficiently selects all rows. return # None efficiently selects all rows.
...@@ -340,9 +368,25 @@ class ReduceWriter(SourceDataWriter): ...@@ -340,9 +368,25 @@ class ReduceWriter(SourceDataWriter):
masks.append(mask) masks.append(mask)
if source in self._custom_xtdf_masks:
# Sources are guaranteed to never use both XTDF and general
# row slicing. In the XTDF case, the new data counts for the
# image index group must be determined to be filled into
# the respective header field.
self._custom_xtdf_counts[source] = {
train_id: mask.sum() for train_id, mask
in zip(train_ids, masks) if mask.any()}
return masks return masks
def copy_instrument_data(self, source, key, dest, train_ids, data): def copy_instrument_data(self, source, key, dest, train_ids, data):
if source in self._custom_xtdf_counts and key == 'header.pulseCount':
custom_counts = self._custom_xtdf_counts[source]
for i, train_id in enumerate(train_ids):
data[i] = custom_counts.get(train_id, data[i])
try: try:
regions = self._partial_copies[source, key] regions = self._partial_copies[source, key]
except KeyError: except KeyError:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment