diff --git a/src/exdf/data_reduction/writer.py b/src/exdf/data_reduction/writer.py index 9282f304bb082d31aea574e9260c6c2fdbbba4ce..005f7dc9ebd07cf4ff408e32218f4f6799d6f058 100644 --- a/src/exdf/data_reduction/writer.py +++ b/src/exdf/data_reduction/writer.py @@ -52,6 +52,8 @@ class ReduceWriter(SourceDataWriter): # Only populated if trains/keys are selected/removed for sources. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) + self._custom_xtdf_masks = {} # source -> dict(train_id -> mask) + self._custom_xtdf_counts = {} # source -> ndarray self._custom_rows = {} # source -> dict(train_id -> mask) self._rechunked_keys = {} # (source, key) -> chunks self._partial_copies = {} # (source, key) -> list(<regions>) @@ -97,6 +99,27 @@ class ReduceWriter(SourceDataWriter): self._get_row_masks(source, index_group, train_sel, row_sel)) + for source_glob, train_sel, row_sel in self._filter_ops('select-xtdf'): + for source in fnmatch.filter(self._sources, source_glob): + if not source.endswith(':xtdf'): + # Simply ignore matches without trailing :xtdf. + continue + + if not self._is_xtdf_source(source): + # Raise exception if essentials are missing. + raise ValueError(f'{source} is not a valid XTDF source') + + self._touched_sources.add(source) + self._custom_xtdf_masks.setdefault(source, {}).update( + self._get_row_masks(source, 'image', train_sel, row_sel)) + + if ( + {x[0] for x in self._custom_rows.keys()} & + self._custom_xtdf_masks.keys() + ): + raise ValueError('source may not be affected by both select-rows ' + 'and select-xtdf operations') + for source_glob, key_glob, chunking in self._filter_ops( 'rechunk-keys' ): @@ -153,6 +176,9 @@ class ReduceWriter(SourceDataWriter): def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] + def _is_xtdf_source(self, source): + return self._data[source].keys() > {'header.pulseCount', 'image.data'} + def _get_row_masks(self, source, index_group, train_sel, row_sel): train_ids = select_train_ids( self._custom_trains.get(source, list(self._data.train_ids)), @@ -325,7 +351,9 @@ class ReduceWriter(SourceDataWriter): return orig_chunks def mask_instrument_data(self, source, index_group, train_ids, counts): - if (source, index_group) in self._custom_rows: + if source in self._custom_xtdf_masks and index_group == 'image': + custom_masks = self._custom_xtdf_masks[source] + elif (source, index_group) in self._custom_rows: custom_masks = self._custom_rows[source, index_group] else: return # None efficiently selects all rows. @@ -340,9 +368,25 @@ class ReduceWriter(SourceDataWriter): masks.append(mask) + if source in self._custom_xtdf_masks: + # Sources are guaranteed to never use both XTDF and general + # row slicing. In the XTDF case, the new data counts for the + # image index group must be determined to be filled into + # the respective header field. + + self._custom_xtdf_counts[source] = { + train_id: mask.sum() for train_id, mask + in zip(train_ids, masks) if mask.any()} + return masks def copy_instrument_data(self, source, key, dest, train_ids, data): + if source in self._custom_xtdf_counts and key == 'header.pulseCount': + custom_counts = self._custom_xtdf_counts[source] + + for i, train_id in enumerate(train_ids): + data[i] = custom_counts.get(train_id, data[i]) + try: regions = self._partial_copies[source, key] except KeyError: