From 207b998921dc896f39c242050011c0cafe2f4397 Mon Sep 17 00:00:00 2001 From: Philipp Schmidt <philipp.schmidt@xfel.eu> Date: Wed, 25 Oct 2023 16:55:59 +0200 Subject: [PATCH] Add implementation for select-rows operation --- src/exdf/data_reduction/method.py | 53 ++++++++-- src/exdf/data_reduction/writer.py | 78 ++++++++++++++- src/exdf/write/sd_writer.py | 156 +++++++++++++++++++++++------- 3 files changed, 244 insertions(+), 43 deletions(-) diff --git a/src/exdf/data_reduction/method.py b/src/exdf/data_reduction/method.py index d106ed4..fd75a00 100644 --- a/src/exdf/data_reduction/method.py +++ b/src/exdf/data_reduction/method.py @@ -3,11 +3,13 @@ from typing import TypeVar from logging import getLogger import warnings +import numpy as np from extra_data.read_machinery import select_train_ids log = getLogger('exdf.data_reduction.ReductionMethod') train_sel = TypeVar('train_sel') +row_sel = TypeVar('row_sel') index_exp = TypeVar('index_exp') @@ -22,6 +24,22 @@ def is_train_selection(x): return True +def is_row_selection(x): + if isinstance(x, slice): + return True + + if isinstance(x, list) and all([isinstance(y, (int, bool)) for y in x]): + return True + + if ( + isinstance(x, np.ndarray) and x.ndim == 1 and + (np.issubdtype(x.dtype, np.integer) or np.issubdtype(x.dtype, bool)) + ): + return True + + return False + + def is_index_expression(x): if isinstance(x, (slice, list)): return True @@ -54,19 +72,42 @@ class ReductionMethod(list): assert is_train_selection(trains) self._emit('select-trains', source_glob, trains) - def select_pulses( + def select_rows( self, source_glob: str, - index_group: str, # May be xtdf + index_group: str, trains: train_sel, - rows: index_exp + rows: row_sel ): - raise NotImplementedError('select-pulses') assert isinstance(source_glob, str) assert isinstance(index_group, str) assert is_train_selection(trains) - assert is_index_expression(rows) - self._emit('select-pulses', source_glob, index_group, trains, rows) + assert is_row_selection(rows) + self._emit('select-rows', source_glob, index_group, trains, rows) + + def select_xtdf( + self, + source_glob: str, + trains: train_sel, + rows: row_sel + ): + """Slice XTDF data by row. + + Roughly equivalent to select_rows(source_glob, 'image', + train_sel, row_sel), but only acts on XTDF sources and modifies + header data structures according to slicing. + + Requires sources to end with :xtdf and have all XTDF keys. + + Args: + source_glob (str): Source glob pattern. + train_sel (train_sel): Train selection. + row_sel (row_sel): Row selection. + """ + assert isinstance(source_glob, str) + assert is_train_selection(trains) + assert is_row_selection(rows) + self._emit('select-xtdf', source_glob, trains, rows) def remove_sources( self, diff --git a/src/exdf/data_reduction/writer.py b/src/exdf/data_reduction/writer.py index a025092..9282f30 100644 --- a/src/exdf/data_reduction/writer.py +++ b/src/exdf/data_reduction/writer.py @@ -52,6 +52,7 @@ class ReduceWriter(SourceDataWriter): # Only populated if trains/keys are selected/removed for sources. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) + self._custom_rows = {} # source -> dict(train_id -> mask) self._rechunked_keys = {} # (source, key) -> chunks self._partial_copies = {} # (source, key) -> list(<regions>) @@ -83,6 +84,19 @@ class ReduceWriter(SourceDataWriter): self._custom_trains[source] = select_train_ids( train_ids, train_sel) + for source_glob, index_group, train_sel, row_sel in self._filter_ops( + 'select-rows' + ): + for source in fnmatch.filter(self._sources, source_glob): + if index_group not in self._data[source].index_groups: + raise ValueError(f'{index_group} not index group of ' + f'{source}') + + self._touched_sources.add(source) + self._custom_rows.setdefault((source, index_group), {}).update( + self._get_row_masks(source, index_group, + train_sel, row_sel)) + for source_glob, key_glob, chunking in self._filter_ops( 'rechunk-keys' ): @@ -139,6 +153,50 @@ class ReduceWriter(SourceDataWriter): def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] + def _get_row_masks(self, source, index_group, train_sel, row_sel): + train_ids = select_train_ids( + self._custom_trains.get(source, list(self._data.train_ids)), + train_sel) + counts = self._data[source].select_trains(by_id[train_ids]) \ + .data_counts(index_group=index_group) + masks = {} + + if isinstance(row_sel, slice): + for train_id, count in counts.items(): + if count > 0: + masks[train_id] = np.zeros(count, dtype=bool) + masks[train_id][row_sel] = True + + elif np.issubdtype(type(row_sel[0]), np.integer): + max_row = max(row_sel) + + for train_id, count in counts.items(): + if count == 0: + continue + elif max_row >= count: + raise ValueError( + f'row index exceeds data counts of train {train_id}') + + masks[train_id] = np.zeros(count, dtype=bool) + masks[train_id][row_sel] = True + + elif np.issubdtype(type(row_sel[0]), bool): + mask_len = len(row_sel) + + for train_id, count in counts.items(): + if count == 0: + continue + elif mask_len != counts.get(train_id, 0): + raise ValueError( + f'mask length mismatch for train {train_id}') + + masks[train_id] = row_sel + + else: + raise ValueError('unknown row mask format') + + return masks + def write_collection(self, output_path): outp_data = self._data.select([(s, '*') for s in self._sources]) @@ -266,7 +324,25 @@ class ReduceWriter(SourceDataWriter): except KeyError: return orig_chunks - def copy_instrument_data(self, source, key, dest, data): + def mask_instrument_data(self, source, index_group, train_ids, counts): + if (source, index_group) in self._custom_rows: + custom_masks = self._custom_rows[source, index_group] + else: + return # None efficiently selects all rows. + + masks = [] + + for train_id, count_all in zip(train_ids, counts): + if train_id in custom_masks: + mask = custom_masks[train_id] + else: + mask = np.ones(count_all, dtype=bool) + + masks.append(mask) + + return masks + + def copy_instrument_data(self, source, key, dest, train_ids, data): try: regions = self._partial_copies[source, key] except KeyError: diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py index 46489ae..d9354af 100644 --- a/src/exdf/write/sd_writer.py +++ b/src/exdf/write/sd_writer.py @@ -17,7 +17,7 @@ from operator import or_ import numpy as np from extra_data import FileAccess -from . import DataFile +from .datafile import DataFile, get_pulse_offsets log = getLogger('exdf.write.SourceDataWriter') @@ -50,13 +50,37 @@ class SourceDataWriter: return orig_chunks - def copy_instrument_data(self, source, key, dest, data): + def mask_instrument_data(self, source, index_group, train_ids, counts): + """Mask INSTRUMENT data. + + Each mask array must have the same length as the original data + counts. + + Args: + source (str): Source name. + index_group (str): Index group. + train_ids (ndarray): Train IDs. + counts (ndarray): Data counts per train. + + Returns: + (Iterable of ndarray): Boolean masks for each passed + train ID and equal in length to respective data counts, + or None to perform no masking. + """ + + return + + def copy_instrument_data(self, source, key, dest, train_ids, data): """Copy INSTRUMENT data from input to output. + The destination dataset is guaranteed to align with the shape of + train_ids and data. + Args: source (str): Source name. key (str): Key name. dset (h5py.Dataset): Destination dataset. + train_ids (ndarray): Train ID coordinates. data (ndarray): Source data. Returns: @@ -110,9 +134,9 @@ class SourceDataWriter: data_format_version=self.get_data_format_version(), control_sources=control_indices.keys(), instrument_channels=[ - f'{source}/{channel}' - for source, channels in instrument_indices.items() - for channel in channels.keys()]) + f'{source}/{index_group}' + for source, index_group_counts in instrument_indices.items() + for index_group in index_group_counts.keys()]) f.create_dataset('METADATA/dataWriter', data=b'exdf-tools', shape=(1,)) if not self.with_origin(): @@ -124,9 +148,10 @@ class SourceDataWriter: control_src = f.create_control_source(source) control_src.create_index(len(train_ids), per_train=True) - for source, channel_counts in instrument_indices.items(): + for source, index_group_counts in instrument_indices.items(): + # May be overwritten later as a result of masking. instrument_src = f.create_instrument_source(source) - instrument_src.create_index(**channel_counts) + instrument_src.create_index(**index_group_counts) def write_control(self, f, sources): """Write CONTROL and RUN data. @@ -143,7 +168,7 @@ class SourceDataWriter: """ for sd in sources: - source = f.source[sd.source] + h5source = f.source[sd.source] attrs = get_key_attributes(sd) if self.with_attrs() else {} run_data_leafs = {} @@ -164,23 +189,24 @@ class SourceDataWriter: ctrl_values = sd[f'{key}.value'].ndarray() ctrl_timestamps = sd[f'{key}.timestamp'].ndarray() - source.create_key( + h5source.create_key( key, values=ctrl_values, timestamps=ctrl_timestamps, run_entry=run_entry, attrs=attrs.pop(key, None)) # Write remaining RUN-only keys. for key, leafs in run_data_leafs.items(): - source.create_run_key(key, **leafs, attrs=attrs.pop(key, None)) + h5source.create_run_key( + key, **leafs, attrs=attrs.pop(key, None)) # Fill in the missing attributes for nodes. for path, attrs in attrs.items(): - source.run_key[path].attrs.update(attrs) - source.key[path].attrs.update(attrs) + h5source.run_key[path].attrs.update(attrs) + h5source.key[path].attrs.update(attrs) def write_instrument(self, f, sources): """Write INSTRUMENT data. - This method assumes the source datasets already exist. + This method assumes the INDEX and source datasets already exist. Args: f (exdf.DataFile): Output file. @@ -191,40 +217,67 @@ class SourceDataWriter: None """ - for sd in sources: - source = f.source[sd.source] + # Must be re-read at this point, as additional trains could have + # been introduced in this sequence. + train_ids = np.array(f['INDEX/trainId']) + # Stores mask for each row per index group. + masks = {} + + for sd in sources: attrs = get_key_attributes(sd) if self.with_attrs() else {} + h5source = f.source[sd.source] + keys = sd.keys() - for key in sd.keys(): - kd = sd[key] + for index_group in sd.index_groups: + # Must be re-read same as train IDs. + h5index = f[f'INDEX/{sd.source}/{index_group}'] + counts = np.array(h5index['count']) - shape = (kd.data_counts(labelled=False).sum(), *kd.entry_shape) - chunks = self.chunk_instrument_data( - sd.source, key, - kd.files[0].file[kd.hdf5_data_path].chunks) + # Obtain mask for this index group. + masks_by_train = self.mask_instrument_data( + sd.source, index_group, train_ids, counts) - source.create_key( - key, shape=shape, maxshape=(None,) + shape[1:], - chunks=chunks, dtype=kd.dtype, attrs=attrs.pop(key, None)) + if masks_by_train is not None: + masks[index_group] = mask_index( + h5index, counts, masks_by_train) + num_entries = masks[index_group].sum() + else: + num_entries = counts.sum() - for path, attrs in attrs.items(): - source.key[path].attrs.update(attrs) + for key in iter_index_group_keys(keys, index_group): + kd = sd[key] - # Update tableSize for each index group to the correct - # number of trains. - for index_group in sd.index_groups: - source[index_group].attrs['tableSize'] = sd.data_counts( - labelled=False, index_group=index_group).sum() + shape = (num_entries, *kd.entry_shape) + chunks = self.chunk_instrument_data( + sd.source, key, + kd.files[0].file[kd.hdf5_data_path].chunks) + + h5source.create_key( + key, shape=shape, maxshape=(None,) + shape[1:], + chunks=chunks, dtype=kd.dtype, + attrs=attrs.pop(key, None)) + + # Update tableSize to the correct number of records. + h5source[index_group].attrs['tableSize'] = num_entries + + for path, attrs in attrs.items(): + h5source.key[path].attrs.update(attrs) # Copy INSTRUMENT data. for sd in sources: - source = f.source[sd.source] + h5source = f.source[sd.source] + + for index_group in sd.index_groups: + mask = masks.get(index_group, np.s_[:]) + + for key in iter_index_group_keys(keys, index_group): + # TODO: Copy by chunk / file if too large - for key in sd.keys(): - # TODO: Copy by chunk / file if too large - self.copy_instrument_data(sd.source, key, source.key[key], - sd[key].ndarray()) + self.copy_instrument_data( + sd.source, key, h5source.key[key], + sd[key].train_id_coordinates()[mask], + sd[key].ndarray()[mask]) def get_index_root_data(sources): @@ -375,3 +428,34 @@ def get_key_attributes(sd): f'{sd.source}.{path}') return source_attrs + +def iter_index_group_keys(keys, index_group): + for key in keys: + if key[:key.index('.')] == index_group: + yield key + + +def mask_index(g, counts, masks_by_train): + full_mask = np.concatenate(masks_by_train) + num_entries = counts.sum() + + assert len(full_mask) == num_entries, \ + 'incompatible INSTRUMENT mask shape' + + # Modify INDEX entry if necessary. + if full_mask.sum() != num_entries: + g.create_dataset( + f'original/first', data=get_pulse_offsets(counts)) + g.create_dataset( + f'original/count', data=counts) + g.create_dataset( + f'original/position', + data=np.concatenate([np.flatnonzero(mask) + for mask in masks_by_train])) + + # Compute new data counts. + counts = [mask.sum() for mask in masks_by_train] + g['first'][:] = get_pulse_offsets(counts) + g['count'][:] = counts + + return full_mask -- GitLab