Skip to content
Snippets Groups Projects
Commit f1ad47d1 authored by Philipp Schmidt's avatar Philipp Schmidt
Browse files

Merge branch 'feat/compress-keys' into 'master'

Support additional and maintain existing data compression

See merge request !3
parents c2713e1b b823aa91
No related branches found
No related tags found
2 merge requests!4Simplify reduction operation implementations and error handling,!3Support additional and maintain existing data compression
This commit is part of merge request !4. Comments created here will be created in the context of that merge request.
...@@ -232,7 +232,7 @@ class AgipdGain(ReductionMethod): ...@@ -232,7 +232,7 @@ class AgipdGain(ReductionMethod):
return return
self.rechunk_keys(agipd_sources, 'image.data', (-1, 1, None, None)) self.rechunk_keys(agipd_sources, 'image.data', (-1, 1, None, None))
self.partial_copy(agipd_sources, 'image.data', np.s_[0, :, :]) self.subslice_keys(agipd_sources, 'image.data', np.s_[0, :, :])
class LpdMini(ReductionMethod): class LpdMini(ReductionMethod):
...@@ -263,5 +263,5 @@ class LpdMini(ReductionMethod): ...@@ -263,5 +263,5 @@ class LpdMini(ReductionMethod):
self.rechunk_keys(lpdmini_sources, 'image.data', (-1, 32, None)) self.rechunk_keys(lpdmini_sources, 'image.data', (-1, 32, None))
for mini_pos in args.lpd_mini_select_slots: for mini_pos in args.lpd_mini_select_slots:
self.partial_copy(lpdmini_sources, 'image.data', self.subslice_keys(lpdmini_sources, 'image.data',
np.s_[(mini_pos-1)*32:mini_pos*32, :, :]) np.s_[(mini_pos-1)*32:mini_pos*32, :, :])
...@@ -8,6 +8,7 @@ from extra_data.read_machinery import select_train_ids ...@@ -8,6 +8,7 @@ from extra_data.read_machinery import select_train_ids
log = getLogger('exdf.data_reduction.ReductionMethod') log = getLogger('exdf.data_reduction.ReductionMethod')
train_sel = TypeVar('train_sel') train_sel = TypeVar('train_sel')
entry_sel = TypeVar('entry_sel') entry_sel = TypeVar('entry_sel')
index_exp = TypeVar('index_exp') index_exp = TypeVar('index_exp')
...@@ -137,7 +138,7 @@ class ReductionMethod(list): ...@@ -137,7 +138,7 @@ class ReductionMethod(list):
assert all([x is None or isinstance(x, int) for x in chunks]) assert all([x is None or isinstance(x, int) for x in chunks])
self._emit('rechunk-keys', source_glob, key_glob, chunks) self._emit('rechunk-keys', source_glob, key_glob, chunks)
def partial_copy( def subslice_keys(
self, self,
source_glob: str, source_glob: str,
key_glob: str, key_glob: str,
...@@ -146,7 +147,18 @@ class ReductionMethod(list): ...@@ -146,7 +147,18 @@ class ReductionMethod(list):
assert isinstance(source_glob, str) assert isinstance(source_glob, str)
assert isinstance(key_glob, str) assert isinstance(key_glob, str)
assert is_index_expression(region) assert is_index_expression(region)
self._emit('partial-copy', source_glob, key_glob, region) self._emit('subslice-keys', source_glob, key_glob, region)
def compress_keys(
self,
source_glob: str,
key_glob: str,
level: int = 1
):
assert isinstance(source_glob, str)
assert isinstance(key_glob, str)
assert isinstance(level, int)
self._emit('compress-keys', source_glob, key_glob, level)
@staticmethod @staticmethod
def arguments(ap): def arguments(ap):
......
...@@ -11,6 +11,7 @@ from extra_data import by_id ...@@ -11,6 +11,7 @@ from extra_data import by_id
from extra_data.read_machinery import select_train_ids from extra_data.read_machinery import select_train_ids
from exdf.write import SourceDataWriter from exdf.write import SourceDataWriter
from ..write.datafile import write_compressed_frames
class ReduceWriter(SourceDataWriter): class ReduceWriter(SourceDataWriter):
...@@ -56,7 +57,8 @@ class ReduceWriter(SourceDataWriter): ...@@ -56,7 +57,8 @@ class ReduceWriter(SourceDataWriter):
self._custom_xtdf_counts = {} # source -> ndarray self._custom_xtdf_counts = {} # source -> ndarray
self._custom_entry_masks = {} # source -> dict(train_id -> mask) self._custom_entry_masks = {} # source -> dict(train_id -> mask)
self._rechunked_keys = {} # (source, key) -> chunks self._rechunked_keys = {} # (source, key) -> chunks
self._partial_copies = {} # (source, key) -> list(<regions>) self._subsliced_keys = {} # (source, key) -> list(<regions>)
self._compressed_keys = {} # (source, key) -> level
# TODO: Raise error if rechunking is overwritten! # TODO: Raise error if rechunking is overwritten!
# TODO: make partial copies a list of slices! # TODO: make partial copies a list of slices!
...@@ -152,7 +154,7 @@ class ReduceWriter(SourceDataWriter): ...@@ -152,7 +154,7 @@ class ReduceWriter(SourceDataWriter):
self._rechunked_keys[(source, key)] = chunking self._rechunked_keys[(source, key)] = chunking
for source_glob, key_glob, region in self._filter_ops('partial-copy'): for source_glob, key_glob, region in self._filter_ops('subslice-keys'):
for source in fnmatch.filter(self._sources, source_glob): for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source) self._touched_sources.add(source)
...@@ -160,9 +162,23 @@ class ReduceWriter(SourceDataWriter): ...@@ -160,9 +162,23 @@ class ReduceWriter(SourceDataWriter):
source, set(self._data[source].keys())) source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob): for key in fnmatch.filter(keys, key_glob):
self._partial_copies.setdefault((source, key), []).append( self._subsliced_keys.setdefault((source, key), []).append(
region) region)
for source_glob, key_glob, level in self._filter_ops('compress-keys'):
for source in fnmatch.filter(self._sources, source_glob):
self._touched_sources.add(source)
keys = self._custom_keys.get(
source, set(self._data[source].keys()))
for key in fnmatch.filter(keys, key_glob):
self._compressed_keys[source, key] = level
if (self._rechunked_keys.keys() & self._compressed_keys.keys()):
raise ValueError('keys may not be affected by both compress-keys '
'and rechunk-keys operations')
if self._scope == 'sources': if self._scope == 'sources':
self._sources = sorted( self._sources = sorted(
self._touched_sources.intersection(self._sources)) self._touched_sources.intersection(self._sources))
...@@ -339,8 +355,13 @@ class ReduceWriter(SourceDataWriter): ...@@ -339,8 +355,13 @@ class ReduceWriter(SourceDataWriter):
def with_attrs(self): def with_attrs(self):
return self._version >= Version('1.3') return self._version >= Version('1.3')
def chunk_instrument_data(self, source, key, orig_chunks): def create_instrument_key(self, source, key, orig_dset, kwargs):
try: # Keys are guaranteed to never use both custom chunking and
# compression.
if (source, key) in self._rechunked_keys:
orig_chunks = kwargs['chunks']
chunks = list(self._rechunked_keys[source, key]) chunks = list(self._rechunked_keys[source, key])
assert len(chunks) == len(orig_chunks) assert len(chunks) == len(orig_chunks)
...@@ -352,9 +373,18 @@ class ReduceWriter(SourceDataWriter): ...@@ -352,9 +373,18 @@ class ReduceWriter(SourceDataWriter):
chunks[chunks.index(-1)] = \ chunks[chunks.index(-1)] = \
np.prod(orig_chunks) // -np.prod(chunks) np.prod(orig_chunks) // -np.prod(chunks)
return tuple(chunks) kwargs['chunks'] = tuple(chunks)
except KeyError:
return orig_chunks elif (source, key) in self._compressed_keys or orig_dset.compression:
# TODO: Maintain more of existing properties, for now it is
# forced to use gzip and (1, *entry) chunking.
kwargs['chunks'] = (1,) + kwargs['shape'][1:]
kwargs['shuffle'] = True
kwargs['compression'] = 'gzip'
kwargs['compression_opts'] = self._compressed_keys.setdefault(
(source, key), orig_dset.compression_opts)
return kwargs
def mask_instrument_data(self, source, index_group, train_ids, counts): def mask_instrument_data(self, source, index_group, train_ids, counts):
if source in self._custom_xtdf_masks and index_group == 'image': if source in self._custom_xtdf_masks and index_group == 'image':
...@@ -393,11 +423,14 @@ class ReduceWriter(SourceDataWriter): ...@@ -393,11 +423,14 @@ class ReduceWriter(SourceDataWriter):
for i, train_id in enumerate(train_ids): for i, train_id in enumerate(train_ids):
data[i] = custom_counts.get(train_id, data[i]) data[i] = custom_counts.get(train_id, data[i])
try: if (source, key) in self._subsliced_keys:
regions = self._partial_copies[source, key] for region in self._subsliced_keys[source, key]:
except KeyError:
dest[:] = data
else:
for region in regions:
sel = (np.s_[:], *region) sel = (np.s_[:], *region)
dest[sel] = data[sel] dest[sel] = data[sel]
elif (source, key) in self._compressed_keys:
write_compressed_frames(
data, dest, self._compressed_keys[source, key], 8)
else:
dest[:] = data
...@@ -17,6 +17,31 @@ import numpy as np ...@@ -17,6 +17,31 @@ import numpy as np
import h5py import h5py
def write_compressed_frames(arr, dset, level=1, comp_threads=1):
"""Compress gain/mask frames in multiple threads, and save their data
This is significantly faster than letting HDF5 do the compression
in a single thread.
"""
import zlib
from multiprocessing.pool import ThreadPool
def _compress_frame(idx):
# Equivalent to the HDF5 'shuffle' filter: transpose bytes for better
# compression.
shuffled = np.ascontiguousarray(
arr[idx].view(np.uint8).reshape((-1, arr.itemsize)).transpose()
)
return idx, zlib.compress(shuffled, level=level)
with ThreadPool(comp_threads) as pool:
for i, compressed in pool.imap(_compress_frame, range(len(arr))):
# Each frame is 1 complete chunk
chunk_start = (i,) + (0,) * (dset.ndim - 1)
dset.id.write_direct_chunk(chunk_start, compressed)
def get_pulse_offsets(pulses_per_train): def get_pulse_offsets(pulses_per_train):
"""Compute pulse offsets from pulse counts. """Compute pulse offsets from pulse counts.
...@@ -459,6 +484,7 @@ class ControlSource(Source): ...@@ -459,6 +484,7 @@ class ControlSource(Source):
timestamp for the corresponding value in the RUN timestamp for the corresponding value in the RUN
section. The first entry for the train values is used if section. The first entry for the train values is used if
omitted. No run key is created if exactly False. omitted. No run key is created if exactly False.
attrs (dict, optional): Attributes to add to this key.
Returns: Returns:
None None
...@@ -589,6 +615,7 @@ class InstrumentSource(Source): ...@@ -589,6 +615,7 @@ class InstrumentSource(Source):
slashes. slashes.
data (array_like, optional): Key data to initialize the data (array_like, optional): Key data to initialize the
dataset to. dataset to.
attrs (dict, optional): Attributes to add to this key.
kwargs: Any additional keyword arguments are passed to kwargs: Any additional keyword arguments are passed to
create_dataset. create_dataset.
...@@ -620,6 +647,7 @@ class InstrumentSource(Source): ...@@ -620,6 +647,7 @@ class InstrumentSource(Source):
key (str): Source key, dots are automatically replaced by key (str): Source key, dots are automatically replaced by
slashes. slashes.
data (np.ndarray): Key data.ss data (np.ndarray): Key data.ss
attrs (dict, optional): Attributes to add to this key.
comp_threads (int, optional): Number of threads to use for comp_threads (int, optional): Number of threads to use for
compression, 8 by default. compression, 8 by default.
...@@ -627,15 +655,18 @@ class InstrumentSource(Source): ...@@ -627,15 +655,18 @@ class InstrumentSource(Source):
(h5py.Dataset) Created dataset (h5py.Dataset) Created dataset
""" """
dset = self.create_key(
key, shape=data.shape, chunks=((1,) + data.shape[1:]),
dtype=data.dtype, shuffle=True,
compression='gzip', compression_opts=1)
key = escape_key(key) key = escape_key(key)
if not self.key_pattern.match(key): if not self.key_pattern.match(key):
raise ValueError(f'invalid key format, must satisfy ' raise ValueError(f'invalid key format, must satisfy '
f'{self.key_pattern.pattern}') f'{self.key_pattern.pattern}')
from cal_tools.tools import write_compressed_frames write_compressed_frames(data, dset, level=1, comp_threads=comp_threads)
dset = write_compressed_frames(data, self, key,
comp_threads=comp_threads)
if attrs is not None: if attrs is not None:
dset.attrs.update(attrs) dset.attrs.update(attrs)
......
...@@ -36,19 +36,21 @@ class SourceDataWriter: ...@@ -36,19 +36,21 @@ class SourceDataWriter:
"""Determine whether to write key attributes.""" """Determine whether to write key attributes."""
return True return True
def chunk_instrument_data(self, source, key, orig_chunks): def create_instrument_key(self, source, key, orig_dset, kwargs):
"""Determine chunk size for INSTRUMENT key. """Determine creation arguments for INSTRUMENT key.
Args: Args:
source (str): Source name. source (str): Source name.
key (str): Key name. key (str): Key name.
orig_chunks (tuple of int): Chunk size as found in input. orig_dset (h5py.Dataset): Original dataset sample.
kwargs (dict of Any): Keyword arguments passed to
h5py.Group.create_dataset.
Returns: Returns:
(tuple of int): Chunk size to use for output. (dict of Any): Chunk size to use for output.
""" """
return orig_chunks return kwargs
def mask_instrument_data(self, source, index_group, train_ids, counts): def mask_instrument_data(self, source, index_group, train_ids, counts):
"""Mask INSTRUMENT data. """Mask INSTRUMENT data.
...@@ -249,14 +251,13 @@ class SourceDataWriter: ...@@ -249,14 +251,13 @@ class SourceDataWriter:
kd = sd[key] kd = sd[key]
shape = (num_entries, *kd.entry_shape) shape = (num_entries, *kd.entry_shape)
chunks = self.chunk_instrument_data( orig_dset = kd.files[0].file[kd.hdf5_data_path]
sd.source, key, kwargs = {
kd.files[0].file[kd.hdf5_data_path].chunks) 'shape': shape, 'maxshape': (None,) + shape[1:],
'chunks': orig_dset.chunks, 'dtype': kd.dtype,
h5source.create_key( 'attrs': attrs.pop(key, None)}
key, shape=shape, maxshape=(None,) + shape[1:], h5source.create_key(key, **self.create_instrument_key(
chunks=chunks, dtype=kd.dtype, sd.source, key, orig_dset, kwargs))
attrs=attrs.pop(key, None))
# Update tableSize to the correct number of records. # Update tableSize to the correct number of records.
h5source[index_group].attrs['tableSize'] = num_entries h5source[index_group].attrs['tableSize'] = num_entries
...@@ -429,6 +430,7 @@ def get_key_attributes(sd): ...@@ -429,6 +430,7 @@ def get_key_attributes(sd):
return source_attrs return source_attrs
def iter_index_group_keys(keys, index_group): def iter_index_group_keys(keys, index_group):
for key in keys: for key in keys:
if key[:key.index('.')] == index_group: if key[:key.index('.')] == index_group:
...@@ -445,11 +447,11 @@ def mask_index(g, counts, masks_by_train): ...@@ -445,11 +447,11 @@ def mask_index(g, counts, masks_by_train):
# Modify INDEX entry if necessary. # Modify INDEX entry if necessary.
if full_mask.sum() != num_entries: if full_mask.sum() != num_entries:
g.create_dataset( g.create_dataset(
f'original/first', data=get_pulse_offsets(counts)) 'original/first', data=get_pulse_offsets(counts))
g.create_dataset( g.create_dataset(
f'original/count', data=counts) 'original/count', data=counts)
g.create_dataset( g.create_dataset(
f'original/position', 'original/position',
data=np.concatenate([np.flatnonzero(mask) data=np.concatenate([np.flatnonzero(mask)
for mask in masks_by_train])) for mask in masks_by_train]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment