From 77ab74ff7136acd93d3c063eb2799644081f6061 Mon Sep 17 00:00:00 2001 From: Philipp Schmidt <philipp.schmidt@xfel.eu> Date: Thu, 29 Feb 2024 15:24:37 +0100 Subject: [PATCH] Parallelize large dataset reads via pasha --- setup.py | 1 + src/exdf/write/sd_writer.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 09e8fc0..faaca82 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ setup( python_requires='>=3.8', install_requires=[ 'extra_data>=1.13', + 'pasha', # These are pulled in by EXtra-data but listed here for # completeness until they may require pinning. diff --git a/src/exdf/write/sd_writer.py b/src/exdf/write/sd_writer.py index 9e05ac7..ba8067f 100644 --- a/src/exdf/write/sd_writer.py +++ b/src/exdf/write/sd_writer.py @@ -18,11 +18,13 @@ from time import perf_counter import numpy as np +import pasha as psh from extra_data import FileAccess from .datafile import DataFile, get_pulse_offsets log = getLogger('exdf.write.SourceDataWriter') +psh.set_default_context('processes', num_workers=24) class SourceDataWriter: @@ -297,10 +299,11 @@ class SourceDataWriter: for key in iter_index_group_keys(keys, index_group): # TODO: Copy by chunk / file if too large + kd = sd[key] start_key = perf_counter() - full_data = sd[key].ndarray() + full_data = read_keydata(kd) after_read = perf_counter() masked_data = full_data[mask] @@ -308,7 +311,7 @@ class SourceDataWriter: self.copy_instrument_data( sd.source, key, h5source.key[key], - sd[key].train_id_coordinates()[mask], + kd.train_id_coordinates()[mask], masked_data) after_copy = perf_counter() @@ -508,3 +511,24 @@ def mask_index(g, counts, masks_by_train): g['count'][:] = counts return full_mask + + +def read_keydata(kd): + if kd.nbytes > 1073741824: + data = psh.alloc(shape=kd.shape, dtype=kd.dtype) + + counts = kd.data_counts(labelled=False) + entry_starts = np.zeros_like(counts) + entry_starts[1:] = np.cumsum(counts[:-1]) + entry_ends = entry_starts + counts + + # Use parallelization for GiB-sized datasets. + def read_data(worker_id, index, train_id, entries): + data[entry_starts[index]:entry_ends[index]] = entries + + psh.map(read_data, kd) + return data + + else: + # Simple read for small datasets. + return kd.ndarray() -- GitLab