diff --git a/setup.py b/setup.py index 09e8fc08cff0396620dd7447a4bc9cf97d2905a3..faaca82107c624e6300097aa66a861f17efda9a1 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ setup( python_requires='>=3.8', install_requires=[ 'extra_data>=1.13', + 'pasha', # These are pulled in by EXtra-data but listed here for # completeness until they may require pinning. diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py index 9e05ac7f55349b19e1f502e2ab964cc95acf396f..ba8067f576762c5a5bfea4d52b0ffedb8aaa4080 100644 --- a/src/exdf/write/sd_writer.py +++ b/src/exdf/write/sd_writer.py @@ -18,11 +18,13 @@ from time import perf_counter import numpy as np +import pasha as psh from extra_data import FileAccess from .datafile import DataFile, get_pulse_offsets log = getLogger('exdf.write.SourceDataWriter') +psh.set_default_context('processes', num_workers=24) class SourceDataWriter: @@ -297,10 +299,11 @@ class SourceDataWriter: for key in iter_index_group_keys(keys, index_group): # TODO: Copy by chunk / file if too large + kd = sd[key] start_key = perf_counter() - full_data = sd[key].ndarray() + full_data = read_keydata(kd) after_read = perf_counter() masked_data = full_data[mask] @@ -308,7 +311,7 @@ class SourceDataWriter: self.copy_instrument_data( sd.source, key, h5source.key[key], - sd[key].train_id_coordinates()[mask], + kd.train_id_coordinates()[mask], masked_data) after_copy = perf_counter() @@ -508,3 +511,24 @@ def mask_index(g, counts, masks_by_train): g['count'][:] = counts return full_mask + + +def read_keydata(kd): + if kd.nbytes > 1073741824: + data = psh.alloc(shape=kd.shape, dtype=kd.dtype) + + counts = kd.data_counts(labelled=False) + entry_starts = np.zeros_like(counts) + entry_starts[1:] = np.cumsum(counts[:-1]) + entry_ends = entry_starts + counts + + # Use parallelization for GiB-sized datasets. + def read_data(worker_id, index, train_id, entries): + data[entry_starts[index]:entry_ends[index]] = entries + + psh.map(read_data, kd) + return data + + else: + # Simple read for small datasets. + return kd.ndarray()