diff --git a/src/exdf/data_reduction/method.py b/src/exdf/data_reduction/method.py index d106ed4c56b16f21728e69439269e26ec2a4d5e7..f5ce70010e1aff4895a254895a224cc967f4e8c0 100644 --- a/src/exdf/data_reduction/method.py +++ b/src/exdf/data_reduction/method.py @@ -3,11 +3,13 @@ from typing import TypeVar from logging import getLogger import warnings +import numpy as np from extra_data.read_machinery import select_train_ids log = getLogger('exdf.data_reduction.ReductionMethod') train_sel = TypeVar('train_sel') +entry_sel = TypeVar('entry_sel') index_exp = TypeVar('index_exp') @@ -22,6 +24,22 @@ def is_train_selection(x): return True +def is_entry_selection(x): + if isinstance(x, slice): + return True + + if isinstance(x, list) and all([isinstance(y, (int, bool)) for y in x]): + return True + + if ( + isinstance(x, np.ndarray) and x.ndim == 1 and + (np.issubdtype(x.dtype, np.integer) or np.issubdtype(x.dtype, bool)) + ): + return True + + return False + + def is_index_expression(x): if isinstance(x, (slice, list)): return True @@ -54,19 +72,42 @@ class ReductionMethod(list): assert is_train_selection(trains) self._emit('select-trains', source_glob, trains) - def select_pulses( + def select_entries( self, source_glob: str, - index_group: str, # May be xtdf + index_group: str, trains: train_sel, - rows: index_exp + entries: entry_sel ): - raise NotImplementedError('select-pulses') assert isinstance(source_glob, str) assert isinstance(index_group, str) assert is_train_selection(trains) - assert is_index_expression(rows) - self._emit('select-pulses', source_glob, index_group, trains, rows) + assert is_entry_selection(entries) + self._emit('select-entries', source_glob, index_group, trains, entries) + + def select_xtdf( + self, + source_glob: str, + trains: train_sel, + entries: entry_sel + ): + """Slice XTDF data by entry. + + Roughly equivalent to select_entries(source_glob, 'image', + trains, entries), but only acts on XTDF sources and modifies + header data structures according to slicing. + + Requires sources to end with :xtdf and have all XTDF keys. + + Args: + source_glob (str): Source glob pattern. + trains (train_sel): Train selection. + entries (entry_sel): Entry selection. + """ + assert isinstance(source_glob, str) + assert is_train_selection(trains) + assert is_entry_selection(entries) + self._emit('select-xtdf', source_glob, trains, entries) def remove_sources( self, diff --git a/src/exdf/data_reduction/writer.py b/src/exdf/data_reduction/writer.py index a02509297d77981c8f107fbee704876bb000e998..27a50fff17572dd0a3d551dbd5a649e00bad2950 100644 --- a/src/exdf/data_reduction/writer.py +++ b/src/exdf/data_reduction/writer.py @@ -52,6 +52,9 @@ class ReduceWriter(SourceDataWriter): # Only populated if trains/keys are selected/removed for sources. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) + self._custom_xtdf_masks = {} # source -> dict(train_id -> mask) + self._custom_xtdf_counts = {} # source -> ndarray + self._custom_entry_masks = {} # source -> dict(train_id -> mask) self._rechunked_keys = {} # (source, key) -> chunks self._partial_copies = {} # (source, key) -> list(<regions>) @@ -83,6 +86,46 @@ class ReduceWriter(SourceDataWriter): self._custom_trains[source] = select_train_ids( train_ids, train_sel) + for source_glob, index_group, train_sel, entry_sel in self._filter_ops( + 'select-entries' + ): + for source in fnmatch.filter(self._sources, source_glob): + if index_group not in self._data[source].index_groups: + raise ValueError(f'{index_group} not index group of ' + f'{source}') + + new_mask = self._get_entry_masks( + source, index_group, train_sel, entry_sel) + + self._touched_sources.add(source) + self._custom_entry_masks.setdefault( + (source, index_group), {}).update(new_mask) + + for source_glob, train_sel, entry_sel in self._filter_ops( + 'select-xtdf' + ): + for source in fnmatch.filter(self._sources, source_glob): + if not source.endswith(':xtdf'): + # Simply ignore matches without trailing :xtdf. + continue + + if not self._is_xtdf_source(source): + # Raise exception if essentials are missing. + raise ValueError(f'{source} is not a valid XTDF source') + + new_mask = self._get_entry_masks( + source, 'image', train_sel, entry_sel) + + self._touched_sources.add(source) + self._custom_xtdf_masks.setdefault(source, {}).update(new_mask) + + if ( + {x[0] for x in self._custom_entry_masks.keys()} & + self._custom_xtdf_masks.keys() + ): + raise ValueError('source may not be affected by both ' + 'select-entries and select-xtdf operations') + for source_glob, key_glob, chunking in self._filter_ops( 'rechunk-keys' ): @@ -139,6 +182,53 @@ class ReduceWriter(SourceDataWriter): def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] + def _is_xtdf_source(self, source): + return self._data[source].keys() > {'header.pulseCount', 'image.data'} + + def _get_entry_masks(self, source, index_group, train_sel, entry_sel): + train_ids = select_train_ids( + self._custom_trains.get(source, list(self._data.train_ids)), + train_sel) + counts = self._data[source].select_trains(by_id[train_ids]) \ + .data_counts(index_group=index_group) + masks = {} + + if isinstance(entry_sel, slice): + for train_id, count in counts.items(): + if count > 0: + masks[train_id] = np.zeros(count, dtype=bool) + masks[train_id][entry_sel] = True + + elif np.issubdtype(type(entry_sel[0]), np.integer): + max_entry = max(entry_sel) + + for train_id, count in counts.items(): + if count == 0: + continue + elif max_entry >= count: + raise ValueError( + f'entry index exceeds data counts of train {train_id}') + + masks[train_id] = np.zeros(count, dtype=bool) + masks[train_id][entry_sel] = True + + elif np.issubdtype(type(entry_sel[0]), bool): + mask_len = len(entry_sel) + + for train_id, count in counts.items(): + if count == 0: + continue + elif mask_len != counts.get(train_id, 0): + raise ValueError( + f'mask length mismatch for train {train_id}') + + masks[train_id] = entry_sel + + else: + raise ValueError('unknown entry mask format') + + return masks + def write_collection(self, output_path): outp_data = self._data.select([(s, '*') for s in self._sources]) @@ -266,7 +356,43 @@ class ReduceWriter(SourceDataWriter): except KeyError: return orig_chunks - def copy_instrument_data(self, source, key, dest, data): + def mask_instrument_data(self, source, index_group, train_ids, counts): + if source in self._custom_xtdf_masks and index_group == 'image': + custom_masks = self._custom_xtdf_masks[source] + elif (source, index_group) in self._custom_entry_masks: + custom_masks = self._custom_entry_masks[source, index_group] + else: + return # None efficiently selects all entries. + + masks = [] + + for train_id, count_all in zip(train_ids, counts): + if train_id in custom_masks: + mask = custom_masks[train_id] + else: + mask = np.ones(count_all, dtype=bool) + + masks.append(mask) + + if source in self._custom_xtdf_masks: + # Sources are guaranteed to never use both XTDF and general + # entry slicing. In the XTDF case, the new data counts for + # the image index group must be determined to be filled into + # the respective header field. + + self._custom_xtdf_counts[source] = { + train_id: mask.sum() for train_id, mask + in zip(train_ids, masks) if mask.any()} + + return masks + + def copy_instrument_data(self, source, key, dest, train_ids, data): + if source in self._custom_xtdf_counts and key == 'header.pulseCount': + custom_counts = self._custom_xtdf_counts[source] + + for i, train_id in enumerate(train_ids): + data[i] = custom_counts.get(train_id, data[i]) + try: regions = self._partial_copies[source, key] except KeyError: diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py index 46489ae66c4cb740c8dcdcdc691519ca5966dca1..0d04245890443c727a7018188b1d1fd6854447b8 100644 --- a/src/exdf/write/sd_writer.py +++ b/src/exdf/write/sd_writer.py @@ -17,7 +17,7 @@ from operator import or_ import numpy as np from extra_data import FileAccess -from . import DataFile +from .datafile import DataFile, get_pulse_offsets log = getLogger('exdf.write.SourceDataWriter') @@ -50,13 +50,37 @@ class SourceDataWriter: return orig_chunks - def copy_instrument_data(self, source, key, dest, data): + def mask_instrument_data(self, source, index_group, train_ids, counts): + """Mask INSTRUMENT data. + + Each mask array must have the same length as the original data + counts. + + Args: + source (str): Source name. + index_group (str): Index group. + train_ids (ndarray): Train IDs. + counts (ndarray): Data counts per train. + + Returns: + (Iterable of ndarray): Boolean masks for each passed + train ID and equal in length to respective data counts, + or None to perform no masking. + """ + + return + + def copy_instrument_data(self, source, key, dest, train_ids, data): """Copy INSTRUMENT data from input to output. + The destination dataset is guaranteed to align with the shape of + train_ids and data. + Args: source (str): Source name. key (str): Key name. dset (h5py.Dataset): Destination dataset. + train_ids (ndarray): Train ID coordinates. data (ndarray): Source data. Returns: @@ -110,9 +134,9 @@ class SourceDataWriter: data_format_version=self.get_data_format_version(), control_sources=control_indices.keys(), instrument_channels=[ - f'{source}/{channel}' - for source, channels in instrument_indices.items() - for channel in channels.keys()]) + f'{source}/{index_group}' + for source, index_group_counts in instrument_indices.items() + for index_group in index_group_counts.keys()]) f.create_dataset('METADATA/dataWriter', data=b'exdf-tools', shape=(1,)) if not self.with_origin(): @@ -124,9 +148,10 @@ class SourceDataWriter: control_src = f.create_control_source(source) control_src.create_index(len(train_ids), per_train=True) - for source, channel_counts in instrument_indices.items(): + for source, index_group_counts in instrument_indices.items(): + # May be overwritten later as a result of masking. instrument_src = f.create_instrument_source(source) - instrument_src.create_index(**channel_counts) + instrument_src.create_index(**index_group_counts) def write_control(self, f, sources): """Write CONTROL and RUN data. @@ -143,7 +168,7 @@ class SourceDataWriter: """ for sd in sources: - source = f.source[sd.source] + h5source = f.source[sd.source] attrs = get_key_attributes(sd) if self.with_attrs() else {} run_data_leafs = {} @@ -164,23 +189,24 @@ class SourceDataWriter: ctrl_values = sd[f'{key}.value'].ndarray() ctrl_timestamps = sd[f'{key}.timestamp'].ndarray() - source.create_key( + h5source.create_key( key, values=ctrl_values, timestamps=ctrl_timestamps, run_entry=run_entry, attrs=attrs.pop(key, None)) # Write remaining RUN-only keys. for key, leafs in run_data_leafs.items(): - source.create_run_key(key, **leafs, attrs=attrs.pop(key, None)) + h5source.create_run_key( + key, **leafs, attrs=attrs.pop(key, None)) # Fill in the missing attributes for nodes. for path, attrs in attrs.items(): - source.run_key[path].attrs.update(attrs) - source.key[path].attrs.update(attrs) + h5source.run_key[path].attrs.update(attrs) + h5source.key[path].attrs.update(attrs) def write_instrument(self, f, sources): """Write INSTRUMENT data. - This method assumes the source datasets already exist. + This method assumes the INDEX and source datasets already exist. Args: f (exdf.DataFile): Output file. @@ -191,40 +217,67 @@ class SourceDataWriter: None """ - for sd in sources: - source = f.source[sd.source] + # Must be re-read at this point, as additional trains could have + # been introduced in this sequence. + train_ids = np.array(f['INDEX/trainId']) + # Stores mask for each entry per index group. + masks = {} + + for sd in sources: attrs = get_key_attributes(sd) if self.with_attrs() else {} + h5source = f.source[sd.source] + keys = sd.keys() - for key in sd.keys(): - kd = sd[key] + for index_group in sd.index_groups: + # Must be re-read same as train IDs. + h5index = f[f'INDEX/{sd.source}/{index_group}'] + counts = np.array(h5index['count']) - shape = (kd.data_counts(labelled=False).sum(), *kd.entry_shape) - chunks = self.chunk_instrument_data( - sd.source, key, - kd.files[0].file[kd.hdf5_data_path].chunks) + # Obtain mask for this index group. + masks_by_train = self.mask_instrument_data( + sd.source, index_group, train_ids, counts) - source.create_key( - key, shape=shape, maxshape=(None,) + shape[1:], - chunks=chunks, dtype=kd.dtype, attrs=attrs.pop(key, None)) + if masks_by_train is not None: + masks[index_group] = mask_index( + h5index, counts, masks_by_train) + num_entries = masks[index_group].sum() + else: + num_entries = counts.sum() - for path, attrs in attrs.items(): - source.key[path].attrs.update(attrs) + for key in iter_index_group_keys(keys, index_group): + kd = sd[key] - # Update tableSize for each index group to the correct - # number of trains. - for index_group in sd.index_groups: - source[index_group].attrs['tableSize'] = sd.data_counts( - labelled=False, index_group=index_group).sum() + shape = (num_entries, *kd.entry_shape) + chunks = self.chunk_instrument_data( + sd.source, key, + kd.files[0].file[kd.hdf5_data_path].chunks) + + h5source.create_key( + key, shape=shape, maxshape=(None,) + shape[1:], + chunks=chunks, dtype=kd.dtype, + attrs=attrs.pop(key, None)) + + # Update tableSize to the correct number of records. + h5source[index_group].attrs['tableSize'] = num_entries + + for path, attrs in attrs.items(): + h5source.key[path].attrs.update(attrs) # Copy INSTRUMENT data. for sd in sources: - source = f.source[sd.source] + h5source = f.source[sd.source] + + for index_group in sd.index_groups: + mask = masks.get(index_group, np.s_[:]) + + for key in iter_index_group_keys(keys, index_group): + # TODO: Copy by chunk / file if too large - for key in sd.keys(): - # TODO: Copy by chunk / file if too large - self.copy_instrument_data(sd.source, key, source.key[key], - sd[key].ndarray()) + self.copy_instrument_data( + sd.source, key, h5source.key[key], + sd[key].train_id_coordinates()[mask], + sd[key].ndarray()[mask]) def get_index_root_data(sources): @@ -375,3 +428,34 @@ def get_key_attributes(sd): f'{sd.source}.{path}') return source_attrs + +def iter_index_group_keys(keys, index_group): + for key in keys: + if key[:key.index('.')] == index_group: + yield key + + +def mask_index(g, counts, masks_by_train): + full_mask = np.concatenate(masks_by_train) + num_entries = counts.sum() + + assert len(full_mask) == num_entries, \ + 'incompatible INSTRUMENT mask shape' + + # Modify INDEX entry if necessary. + if full_mask.sum() != num_entries: + g.create_dataset( + f'original/first', data=get_pulse_offsets(counts)) + g.create_dataset( + f'original/count', data=counts) + g.create_dataset( + f'original/position', + data=np.concatenate([np.flatnonzero(mask) + for mask in masks_by_train])) + + # Compute new data counts. + counts = [mask.sum() for mask in masks_by_train] + g['first'][:] = get_pulse_offsets(counts) + g['count'][:] = counts + + return full_mask