diff --git a/setup.py b/setup.py index 09e8fc08cff0396620dd7447a4bc9cf97d2905a3..faaca82107c624e6300097aa66a861f17efda9a1 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ setup( python_requires='>=3.8', install_requires=[ 'extra_data>=1.13', + 'pasha', # These are pulled in by EXtra-data but listed here for # completeness until they may require pinning. diff --git a/src/exdf/data_reduction/red_writer.py b/src/exdf/data_reduction/red_writer.py index 9817be904dd45e71268a1fdc2b8d46f63333cda6..64ed7c4a2951bfff1c4ef75de1cc6c0306251b6f 100644 --- a/src/exdf/data_reduction/red_writer.py +++ b/src/exdf/data_reduction/red_writer.py @@ -14,6 +14,27 @@ from exdf.write import SourceDataWriter from ..write.datafile import write_compressed_frames +# Patch SourceData object. +import h5py +from extra_data.sourcedata import SourceData + +def _SourceData_get_index_group_sample(self, index_group): + if self.is_control and not index_group: + # Shortcut for CONTROL data. + return self.one_key() + + def get_key(key, value): + if isinstance(value, h5py.Dataset): + return index_group + '.' + key.replace('/', '.') + + group = f'/INSTRUMENT/{self.source}/{index_group}' + + for f in self.files: + return f.file[group].visititems(get_key) + +SourceData._get_index_group_sample = _SourceData_get_index_group_sample + + def apply_by_source(op_name): def op_decorator(op_func): def op_handler(self): @@ -145,9 +166,6 @@ class ReduceWriter(SourceDataWriter): def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] - def _is_xtdf_source(self, source): - return self._data[source].keys() > {'header.pulseCount', 'image.data'} - def _get_entry_masks(self, source, index_group, train_sel, entry_sel): train_ids = select_train_ids( self._custom_trains.get(source, list(self._data.train_ids)), @@ -423,11 +441,6 @@ class ReduceWriter(SourceDataWriter): f'Ignoring non-XTDF source {source} based on name') return - if not self._is_xtdf_source(source): - self.log.warning( - f'Ignoring non-XTDF source {source} based on structure') - return - self._custom_xtdf_masks.setdefault(source, {}).update( self._get_entry_masks(source, 'image', train_sel, entry_sel)) self.log.debug(f'Applying XTDF selection to {source}') diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py index 9e05ac7f55349b19e1f502e2ab964cc95acf396f..ba8067f576762c5a5bfea4d52b0ffedb8aaa4080 100644 --- a/src/exdf/write/sd_writer.py +++ b/src/exdf/write/sd_writer.py @@ -18,11 +18,13 @@ from time import perf_counter import numpy as np +import pasha as psh from extra_data import FileAccess from .datafile import DataFile, get_pulse_offsets log = getLogger('exdf.write.SourceDataWriter') +psh.set_default_context('processes', num_workers=24) class SourceDataWriter: @@ -297,10 +299,11 @@ class SourceDataWriter: for key in iter_index_group_keys(keys, index_group): # TODO: Copy by chunk / file if too large + kd = sd[key] start_key = perf_counter() - full_data = sd[key].ndarray() + full_data = read_keydata(kd) after_read = perf_counter() masked_data = full_data[mask] @@ -308,7 +311,7 @@ class SourceDataWriter: self.copy_instrument_data( sd.source, key, h5source.key[key], - sd[key].train_id_coordinates()[mask], + kd.train_id_coordinates()[mask], masked_data) after_copy = perf_counter() @@ -508,3 +511,24 @@ def mask_index(g, counts, masks_by_train): g['count'][:] = counts return full_mask + + +def read_keydata(kd): + if kd.nbytes > 1073741824: + data = psh.alloc(shape=kd.shape, dtype=kd.dtype) + + counts = kd.data_counts(labelled=False) + entry_starts = np.zeros_like(counts) + entry_starts[1:] = np.cumsum(counts[:-1]) + entry_ends = entry_starts + counts + + # Use parallelization for GiB-sized datasets. + def read_data(worker_id, index, train_id, entries): + data[entry_starts[index]:entry_ends[index]] = entries + + psh.map(read_data, kd) + return data + + else: + # Simple read for small datasets. + return kd.ndarray()