From e1e80999a21f5e7fae24491852eb3daff0774e45 Mon Sep 17 00:00:00 2001 From: Philipp Schmidt <philipp.schmidt@xfel.eu> Date: Sat, 13 May 2023 18:12:57 +0200 Subject: [PATCH] Replace FileWriter by DataFile API --- src/exdf_tools/exdf.py | 629 +++++++++++++++++++++++++++++++++++++++ src/exdf_tools/reduce.py | 289 ++++++++++++------ 2 files changed, 823 insertions(+), 95 deletions(-) create mode 100644 src/exdf_tools/exdf.py diff --git a/src/exdf_tools/exdf.py b/src/exdf_tools/exdf.py new file mode 100644 index 0000000..89a8684 --- /dev/null +++ b/src/exdf_tools/exdf.py @@ -0,0 +1,629 @@ + + +from datetime import datetime +from itertools import chain +from numbers import Integral +from pathlib import Path +import re + +import numpy as np +import h5py + + +def get_pulse_offsets(pulses_per_train): + """Compute pulse offsets from pulse counts. + + Given an array of number of pulses per train (INDEX/<source>/count), + computes the offsets (INDEX/<source>/first) for the first pulse of a + train inthe data array. + + Args: + pulses_per_train (array_like): Pulse count per train. + + Returns: + (array_like) Offet of first pulse for each train. + """ + + pulse_offsets = np.zeros_like(pulses_per_train) + np.cumsum(pulses_per_train[:-1], out=pulse_offsets[1:]) + + return pulse_offsets + + +def sequence_trains(train_ids, trains_per_sequence=256): + """Iterate over sequences for a list of trains. + + For pulse-resolved data, sequence_pulses may be used instead. + + Args: + train_ids (array_like): Train IDs to sequence. + trains_per_sequence (int, optional): Number of trains + per sequence, 256 by default. + + Yields: + (int, slice) Current sequence ID, train mask. + """ + + num_trains = len(train_ids) + + for seq_id, start in enumerate(range(0, num_trains, trains_per_sequence)): + train_mask = slice( + *np.s_[start:start+trains_per_sequence].indices(num_trains)) + yield seq_id, train_mask + + +def sequence_pulses(train_ids, pulses_per_train=1, pulse_offsets=None, + trains_per_sequence=256): + """Split trains into sequences. + + Args: + train_ids (array_like): Train IDs to sequence. + pulses_per_train (int or array_like, optional): Pulse count per + train. If scalar, it is assumed to be constant for all + trains. If omitted, it is 1 by default. + pulse_offsets (array_like, optional): Offsets for the first + pulse in each train, computed from pulses_per_train if + omitted. + trains_per_sequence (int, optional): Number of trains + per sequence, 256 by default. + + Yields: + (int, array_like, array_like) + Current sequence ID, train mask, pulse mask. + """ + + if isinstance(pulses_per_train, Integral): + pulses_per_train = np.full_like(train_ids, pulses_per_train, + dtype=np.uint64) + + if pulse_offsets is None: + pulse_offsets = get_pulse_offsets(pulses_per_train) + + for seq_id, train_mask in sequence_trains(train_ids, trains_per_sequence): + start = train_mask.start + stop = train_mask.stop - 1 + pulse_mask = np.s_[ + pulse_offsets[start]:pulse_offsets[stop]+pulses_per_train[stop]] + + yield seq_id, train_mask, pulse_mask + + +def escape_key(key): + """Escapes a key name from Karabo to HDF notation.""" + return key.replace('.', '/') + + +class CustomIndexer: + __slots__ = ('parent',) + + + def __init__(self, parent): + self.parent = parent + + +class SourceIndexer(CustomIndexer): + def __getitem__(self, source): + if ':' in source: + root = 'INSTRUMENT' + cls = InstrumentSource + else: + root = 'CONTROL' + cls = ControlSource + + return cls(self.parent[f'{root}/{source}'].id, source) + + +class KeyIndexer(CustomIndexer): + def __getitem__(self, key): + return self.parent[escape_key(key)] + + +class File(h5py.File): + """European XFEL HDF5 data file. + + This class extends the h5py.File with methods specific to writing + data in the European XFEL file format. The internal state does not + depend on using any of these methods, and the underlying file may be + manipulated by any of the regular h5py methods, too. + + Please refer to + https://extra-data.readthedocs.io/en/latest/data_format.html for + details of the file format. + """ + + filename_format = '{prefix}-R{run:04d}-{aggregator}-S{sequence:05d}.h5' + aggregator_pattern = re.compile(r'^\w{2,}\d{2}$') + instrument_source_pattern = re.compile(r'^[\w\/-]+:[\w.]+$') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__control_sources = set() + self.__instrument_sources = set() + self.__run = 0 + self.__sequence = 0 + + @classmethod + def from_details(cls, folder, aggregator, run, sequence, prefix='CORR', + mode='w', *args, **kwargs): + """Open or create a file based on European XFEL details. + + This methis is a wrapper to construct the filename based its + components. + + Args: + folder (Path, str): Parent location for this file. + aggregator (str): Name of the data aggregator, must satisfy + DataFile.aggregator_pattern. + run (int): Run number. + sequence (int): Sequence number. + prefix (str, optional): First filename component, 'CORR' by + default. + args, kwargs: Any additional arguments are passed on to + h5py.File + + + Returns: + (DataFile) Opened file object. + """ + + if not isinstance(folder, Path): + folder = Path(folder) + + if not cls.aggregator_pattern.match(aggregator): + raise ValueError(f'invalid aggregator format, must satisfy ' + f'{cls.aggregator_pattern.pattern}') + + filename = cls.filename_format.format( + prefix=prefix, aggregator=aggregator, run=run, sequence=sequence) + + self = cls((folder / filename).resolve(), mode, *args, **kwargs) + self.__run = run + self.__sequence = sequence + + return self + + @property + def source(self): + return SourceIndexer(self) + + def create_index(self, train_ids, timestamps=None, flags=None, + origins=None, from_file=None): + """Create global INDEX datasets. + + These datasets are agnostic of any source and describe the + trains contained in this file. + + Args: + train_ids (array_like): Train IDs contained in this file. + timestamps (array_like, optional): Timestamp of each train, + 0 if omitted. + flags (array_like, optional): Whether the time server is the + initial origin of each train, 1 if omitted. + origins (array_like, optional): Which source is the initial + origin of each train, -1 (time server) if omitted. + from_file (str, Path or extra_data.FileAccess, optional): + Existing data file to take timestamps, flags and origins + information from if present. + + Returns: + None + """ + + if from_file is not None: + from extra_data import FileAccess + + if not isinstance(from_file, FileAccess): + from_file = FileAccess(from_file) + + sel_trains = np.isin(from_file.train_ids, train_ids) + + if 'INDEX/timestamp' in from_file.file: + timestamps = from_file.file['INDEX/timestamp'][sel_trains] + + flags = from_file.validity_flag[sel_trains] + + if 'INDEX/origin' in from_file.file: + origins = from_file.file['INDEX/origin'][sel_trains] + + self.create_dataset('INDEX/trainId', data=train_ids, dtype=np.uint64) + + if timestamps is None: + timestamps = np.zeros_like(train_ids, dtype=np.uint64) + elif len(timestamps) != len(train_ids): + raise ValueError('timestamps and train_ids must be same length') + + self.create_dataset('INDEX/timestamp', data=timestamps, + dtype=np.uint64) + + if flags is None: + flags = np.ones_like(train_ids, dtype=np.int32) + elif len(flags) != len(train_ids): + raise ValueError('flags and train_ids must be same length') + + self.create_dataset('INDEX/flag', data=flags, dtype=np.int32) + + if origins is None: + origins = np.full_like(train_ids, -1, dtype=np.int32) + elif len(origins) != len(train_ids): + raise ValueError('origins and train_ids must be same length') + + self.create_dataset('INDEX/origin', data=origins, dtype=np.int32) + + def create_control_source(self, source): + """Create group for a control source ("slow data"). + + Control sources created via this method are not required to be + passed to create_metadata() again. + + Args: + source (str): Karabo device ID. + + Returns: + (ControlSource) Created group in CONTROL. + """ + + self.__control_sources.add(source) + return ControlSource(self.create_group(f'CONTROL/{source}').id, source) + + def create_instrument_source(self, source): + """Create group for an instrument source ("fast data"). + + Instrument sources created via this method are not required to be + passed to create_metadata() again. + + Args: + source (str): Karabp pipeline path, must satisfy + DataFile.instrument_source_pattern. + + Returns: + (InstrumentSource) Created group in INSTRUMENT. + """ + + if not self.instrument_source_pattern.match(source): + raise ValueError(f'invalid source format, must satisfy ' + f'{self.instrument_source_pattern.pattern}') + + self.__instrument_sources.add(source) + return InstrumentSource(self.create_group(f'INSTRUMENT/{source}').id, + source) + + def create_metadata(self, like=None, *, + creation_date=None, update_date=None, proposal=0, + run=0, sequence=None, daq_library='1.x', + karabo_framework='2.x', control_sources=(), + instrument_channels=()): + """Create METADATA datasets. + + Args: + like (DataCollection or SourceData, optional): Take + proposal, run, daq_library, karabo_framework from an + EXtra-data data collection, overwriting any of these + arguments passed. + creation_date (datetime, optional): Creation date and time, + now if omitted. + update_date (datetime, optional): Update date and time, + now if omitted. + proposal (int, optional): Proposal number, 0 if omitted and + no DataCollection passed. + run (int, optional): Run number, 0 if omitted, no + DataCollection is passed or object not created via + from_details. + sequence (int, optional): Sequence number, 0 if omitted and + object not created via from_details. + daq_library (str, optional): daqLibrary field, '1.x' if + omitted and no DataCollection passed. + karabo_framework (str, optional): karaboFramework field, + '2.x' if omitted and no DataCollection is passed. + control_sources (Iterable, optional): Control sources in + this file, sources created via create_control_source are + automatically included. + instrument_channels (Iterable, optional): Instrument + channels (source and first component of data hash) in + this file, channels created via create_instrument_source + are automatically included. + + Returns: + None + """ + + if like is not None: + metadata = like.run_metadata() + proposal = metadata.get('proposalNumber', proposal) + run = metadata.get('runNumber', run) + daq_library = metadata.get('daqLibrary', daq_library) + karabo_framework = metadata.get('karaboFramework', + karabo_framework) + + else: + if run is None: + run = self.__run + + if sequence is None: + sequence = self.__sequence + + if creation_date is None: + creation_date = datetime.now() + + if update_date is None: + update_date = creation_date + + md_group = self.require_group('METADATA') + md_group.create_dataset( + 'creationDate', shape=(1,), + data=creation_date.strftime('%Y%m%dT%H%M%SZ').encode('ascii')) + md_group.create_dataset('daqLibrary', shape=(1,), + data=daq_library.encode('ascii')) + md_group.create_dataset('dataFormatVersion', shape=(1,), data=b'1.2') + + # Start with the known and specified control sources + sources = {name: 'CONTROL' + for name in chain(self.__control_sources, control_sources)} + + # Add in the specified instrument data channels. + sources.update({full_channel: 'INSTRUMENT' + for full_channel in instrument_channels}) + + # Add in those already in the file, if not already passed. + sources.update({f'{name}/{channel}': 'INSTRUMENT' + for name in self.__instrument_sources + for channel in self[f'INSTRUMENT/{name}']}) + + source_names = sorted(sources.keys()) + data_sources_shape = (len(sources),) + md_group.create_dataset('dataSources/dataSourceId', + shape=data_sources_shape, + data=[f'{sources[name]}/{name}'.encode('ascii') + for name in source_names]) + md_group.create_dataset('dataSources/deviceId', + shape=data_sources_shape, + data=[name.encode('ascii') + for name in source_names]) + md_group.create_dataset('dataSources/root', shape=data_sources_shape, + data=[sources[name].encode('ascii') + for name in source_names]) + + md_group.create_dataset( + 'karaboFramework', shape=(1,), + data=karabo_framework.encode('ascii')) + md_group.create_dataset( + 'proposalNumber', shape=(1,), dtype=np.uint32, data=proposal) + md_group.create_dataset( + 'runNumber', shape=(1,), dtype=np.uint32, data=run) + md_group.create_dataset( + 'sequenceNumber', shape=(1,), dtype=np.uint32, data=sequence) + md_group.create_dataset( + 'updateDate', shape=(1,), + data=update_date.strftime('%Y%m%dT%H%M%SZ').encode('ascii')) + + +class Source(h5py.Group): + @property + def key(self): + return KeyIndexer(self) + + +class ControlSource(Source): + """Group for a control source ("slow data"). + + This class extends h5py.Group with methods specific to writing data + of a control source in the European XFEL file format. The internal + state does not depend on using any of these methods, and the + underlying file may be manipulated by any of the regular h5py + methods, too. + """ + + ascii_dt = h5py.string_dtype('ascii') + + def __init__(self, group_id, source): + super().__init__(group_id) + + self.__source = source + self.__run_group = self.file.require_group(f'RUN/{source}') + + def get_run_group(self): + return self.__run_group + + def get_index_group(self): + return self.file.require_group(f'INDEX/{self.__source}') + + def create_key(self, key, values, timestamps=None, run_entry=None): + """Create datasets for a key varying each train. + + Args: + key (str): Source key, dots are automatically replaced by + slashes. + values (array_like): Source values for each train. + timestamps (array_like, optional): Timestamps for each + source value, 0 if omitted. + run_entry (tuple of array_like, optional): Value and + timestamp for the corresponding value in the RUN + section. The first entry for the train values is used if + omitted. No run key is created if exactly False. + + Returns: + None + """ + + key = escape_key(key) + + if timestamps is None: + timestamps = np.zeros_like(values, dtype=np.uint64) + elif len(values) != len(timestamps): + raise ValueError('values and timestamp must be the same length') + + self.create_dataset(f'{key}/value', data=values) + self.create_dataset(f'{key}/timestamp', data=timestamps) + + if run_entry is False: + return + elif run_entry is None: + run_entry = (values[0], timestamps[0]) + + self.create_run_key(key, *run_entry) + + def create_run_key(self, key, value, timestamp=None): + """Create datasets for a key constant over a run. + + Args: + key (str): Source key, dots are automatically replaced by + slashes. + value (Any): Key value. + timestamp (int, optional): Timestamp of the value, + 0 if omitted. + + Returns: + None + """ + + # TODO: Some types/shapes are still not fully correct here. + + key = escape_key(key) + + if timestamp is None: + timestamp = 0 + + if isinstance(value, list): + shape = (1, len(value)) + + try: + dtype = type(value[0]) + except IndexError: + # Assume empty lists are string-typed. + dtype = self.ascii_dt + elif isinstance(value, np.ndarray): + shape = value.shape + dtype = value.dtype + else: + shape = (1,) + dtype = type(value) + + if dtype is str: + dtype = self.ascii_dt + + self.__run_group.create_dataset( + f'{key}/value', data=value, shape=shape, dtype=dtype) + self.__run_group.create_dataset( + f'{key}/timestamp', data=timestamp, shape=(1,), dtype=np.uint64) + + def create_index(self, num_trains): + """Create source-specific INDEX datasets. + + Depending on whether this source has train-varying data or not, + different count/first datasets are written. + + Args: + num_trains (int): Total number of trains in this file. + + Returns: + None + """ + + if len(self) > 0: + count_func = np.ones + first_func = np.arange + else: + count_func = np.zeros + first_func = np.zeros + + index_group = self.get_index_group() + index_group.create_dataset( + 'count', data=count_func(num_trains, dtype=np.uint64)) + index_group.create_dataset( + 'first', data=first_func(num_trains, dtype=np.uint64)) + + +class InstrumentSource(Source): + """Group for an instrument source ("fast data"). + + This class extends h5py.Group with methods specific to writing data + of a control source in the European XFEL file format. The internal + state does not depend on using any of these methods, and the + underlying file may be manipulated by any of the regular h5py + methods, too. + """ + + key_pattern = re.compile(r'^\w+\/[\w\/]+$') + + def __init__(self, group_id, source): + super().__init__(group_id) + + self.__source = source + + def get_index_group(self, channel): + return self.file.require_group(f'INDEX/{self.__source}/{channel}') + + def create_key(self, key, data=None, **kwargs): + """Create dataset for a key. + + Args: + key (str): Source key, dots are automatically replaced by + slashes. + data (array_like, optional): Key data to initialize the + dataset to. + kwargs: Any additional keyword arguments are passed to + create_dataset. + + Returns: + (h5py.Dataset) Created dataset + """ + + key = escape_key(key) + + if not self.key_pattern.match(key): + raise ValueError(f'invalid key format, must satisfy ' + f'{self.key_pattern.pattern}') + + return self.create_dataset(key, data=data, **kwargs) + + def create_compressed_key(self, key, data, comp_threads=8): + """Create a compressed dataset for a key. + + This method makes use of lower-level access in h5py to compress + the data separately in multiple threads and write it directly to + file rather than go through HDF's compression filters. + + Args: + key (str): Source key, dots are automatically replaced by + slashes. + data (np.ndarray): Key data.ss + comp_threads (int, optional): Number of threads to use for + compression, 8 by default. + + Returns: + (h5py.Dataset) Created dataset + """ + + key = escape_key(key) + + if not self.key_pattern.match(key): + raise ValueError(f'invalid key format, must satisfy ' + f'{self.key_pattern.pattern}') + + from cal_tools.tools import write_compressed_frames + return write_compressed_frames(data, self, key, + comp_threads=comp_threads) + + def create_index(self, *args, **channels): + """Create source-specific INDEX datasets. + + Instrument data is indexed by channel, which is the first + component in its key. If channels have already been created, the + index may be applied to all channels by passing them as a + positional argument. + """ + + if not channels: + try: + count = int(args[0]) + except IndexError: + raise ValueError('positional arguments required if no ' + 'explicit channels are passed') from None + # Allow ValueError to propagate directly. + + channels = {channel: count for channel in self} + + for channel, count in channels.items(): + index_group = self.get_index_group(channel) + index_group.create_dataset('count', data=count, dtype=np.uint64) + index_group.create_dataset( + 'first', data=get_pulse_offsets(count), dtype=np.uint64) diff --git a/src/exdf_tools/reduce.py b/src/exdf_tools/reduce.py index 365212c..85dfca7 100644 --- a/src/exdf_tools/reduce.py +++ b/src/exdf_tools/reduce.py @@ -9,11 +9,15 @@ import logging import re import sys +import numpy as np + from pkg_resources import iter_entry_points from extra_data import RunDirectory, open_run, by_id from extra_data.writer import FileWriter, VirtualFileWriter from extra_data.read_machinery import select_train_ids +from . import exdf + exdf_filename_pattern = re.compile( r'([A-Z]+)-R(\d{4})-(\w+)-S(\d{5}).h5') @@ -26,7 +30,7 @@ def get_source_origin(sourcedata): except AttributeError: files = sourcedata.files - storage_classes = set() + data_category = set() run_numbers = set() aggregators = set() @@ -34,17 +38,47 @@ def get_source_origin(sourcedata): m = exdf_filename_pattern.match(Path(f.filename).name) if m: - storage_classes.add(m[1]) + data_category.add(m[1]) run_numbers.add(int(m[2])) aggregators.add(m[3]) - result_sets = [storage_classes, run_numbers, aggregators] + result_sets = [data_category, run_numbers, aggregators] if any([len(x) != 1 for x in result_sets]): - raise ValueError('different or no classes or aggregators recognized ' - 'from source filenames') + raise ValueError('none or multiple sets of traits recognized ' + 'in source filenames, inconsistent input dataset?') + + return data_category.pop(), run_numbers.pop(), aggregators.pop() + + +def sourcedata_drop_empty_trains(sd): + if sd._is_control: + train_sel = sd[next(iter(sd.keys()))].drop_empty_trains().train_ids + else: + sample_keys = dict(zip( + [key[:key.find('.')] for key in sd.keys()], sd.keys())) + + train_sel = np.zeros(0, dtype=np.uint64) + + for key in sample_keys.values(): + train_sel = np.union1d( + train_sel, sd[key].drop_empty_trains().train_ids) + + return sd.select_trains(by_id[train_sel]) + + +def sourcedata_data_counts(sd): + if sd._is_control: + return sd[next(iter(sd.keys()))].data_counts(labelled=False) + else: + sample_keys = dict(zip( + [key[:key.find('.')] for key in sd.keys()], sd.keys())) - return storage_classes.pop(), run_numbers.pop(), aggregators.pop() + data_counts = { + prefix: sd[key].data_counts(labelled=False) + for prefix, key in sample_keys.items() + } + return data_counts class ReductionSequence: @@ -62,7 +96,7 @@ class ReductionSequence: self._sources = sorted(data.all_sources) self._touched_sources = set() - # Only populated if trains/keys are removed for select sources. + # Only populated if trains/keys are selected/removed for sources. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) self._partial_copies = {} # (source, key) -> (region, chunks) @@ -76,7 +110,7 @@ class ReductionSequence: for source_glob, in self._filter_ops('removeSources'): for source in fnmatch.filter(self._sources, source_glob): self._touched_sources.add(source) - self._sources.remove(source) + self._sources.remove(source) for source_glob, key_glob in self._filter_ops('removeKeys'): for source in fnmatch.filter(self._sources, source_glob): @@ -121,19 +155,18 @@ class ReductionSequence: def write_collection(self, output_path, sequence_len=512): outp_data = self._data.select([(s, '*') for s in self._sources]) - origin_sources = defaultdict(list) - print(len(self._data.train_ids), len(outp_data.train_ids)) - return + # An origin is a pair of (storage_class, aggregator) + origin_sources = defaultdict(list) run_numbers = set() - # Collect all origins (combination of storage class and aggregator) - # and the sources they contain. + # Collect all origins (combination of data category and + # aggregator) and the sources they contain. for source in self._sources: - storage_class, run_no, aggregator = get_source_origin( - outp_data[source]) - origin_sources[(storage_class, aggregator)].append(source) + data_category, run_no, aggregator = get_source_origin( + outp_data[source]) # TODO: Use support in EXtra-data + origin_sources[(data_category, aggregator)].append(source) run_numbers.add(run_no) if len(run_numbers) != 1: @@ -142,47 +175,58 @@ class ReductionSequence: run_no = run_numbers.pop() - for (storage_class, aggregator), sources in origin_sources.items(): - origin_tids = sorted(set(outp_data.train_ids).intersection(*[ - tids for s, tids in self._custom_trains.items() - if s in sources])) + for (data_category, aggregator), sources in origin_sources.items(): + # Build source & key selection for this origin. origin_sel = { s: self._custom_keys[s] if s in self._custom_keys else set() for s in sources} - # Select output data down to what's in this origin. + # Select output data down to what's in this origin both in + # terms of sources and trains (via require_any). origin_data = outp_data \ - .select_trains(by_id[origin_tids]) \ - .select(origin_sel) - - num_trains = len(origin_data.train_ids) - num_sequences = int(num_trains / sequence_len) + 1 + .select(origin_sel, require_any=True) + + # Switch to representation of SourceData objects for + # per-source tracking of trains. + origin_sources = {source: origin_data[source] + for source in origin_data.all_sources} + + # Apply custom train selections. + for source, train_sel in self._custom_trains.items(): + if source in origin_sources: + origin_sources[source] = origin_sources[source] \ + .select_trains(by_id[train_sel]) + + # Find the union of trains across all sources as total + # trains for this origin. + origin_train_ids = np.zeros(0, dtype=np.uint64) + for sd in origin_sources.values(): + origin_train_ids = np.union1d( + origin_train_ids, + sourcedata_drop_empty_trains(sd).train_ids) + + num_trains = len(origin_train_ids) + num_sequences = int(np.ceil(num_trains / sequence_len)) self.log.info( - f'{storage_class}-{aggregator}: ' - f'{len(origin_data.all_sources)} sources with {num_trains} ' - f'trains over {num_sequences} sequences') + f'{data_category}-{aggregator}: {len(origin_sources)} sources ' + f'with {num_trains} trains over {num_sequences} sequences') for seq_no in range(num_sequences): - train_slice = slice( - seq_no * sequence_len, - (seq_no + 1) * sequence_len) + # Slice out the train IDs for this sequence. + seq_train_ids = origin_train_ids[ + (seq_no * sequence_len):((seq_no + 1) * sequence_len)] # Select origin data down to what's in this file. - file_data = origin_data.select_trains(train_slice) - - # file_data now can contain trains for which no source has - # any data. .select(require_all) Unfortunately does the - # opposite of what we want - - file_path = output_path / '{}-R{:04d}-{}-S{:06d}.h5'.format( - storage_class, run_no, aggregator, seq_no) + seq_sources = {source: sd.select_trains(by_id[seq_train_ids]) + for source, sd in origin_sources.items()} + seq_path = self._write_sequence( + output_path, aggregator, run_no, seq_no, data_category, + seq_sources, seq_train_ids) self.log.debug( - f'{file_path.stem}: {len(file_data.all_sources)} sources ' - f'with {len(file_data.train_ids)} trains') - fw = ReductionFileWriter(file_path, file_data, self) - fw.write() + f'{seq_path.stem}: {len(seq_sources)} sources with ' + f'{len(seq_train_ids)} trains') def write_voview(self, output_path): raise NotImplementedError('voview output layout') @@ -190,6 +234,106 @@ class ReductionSequence: def write_collapsed(self, output_path, sequence_len): raise NotImplementedError('collapsed collection output layout') + def _write_sequence(self, folder, aggregator, run, sequence, data_category, + sources, train_ids=None): + """Write collection of SourceData to file.""" + + if train_ids is None: + pass # Figure out as union of source_train_ids + + sd_example = next(iter(sources.values())) + + # Build sources and indices. + control_indices = {} + instrument_indices = {} + + for source, sd in sources.items(): + if sd._is_control: + control_indices[sd.source] = sourcedata_data_counts(sd) + else: + instrument_indices[sd.source] = sourcedata_data_counts(sd) + + with exdf.File.from_details(folder, aggregator, run, sequence, + prefix=data_category, mode='w') as f: + path = Path(f.filename) + + # TODO: Handle old files without extended METADATA? + # TODO: Handle timestamps, origin + # TODO: Attributes + + # Create METADATA section. + f.create_metadata( + like=sd_example, + control_sources=control_indices.keys(), + instrument_channels=[ + f'{source}/{channel}' + for source, channels in instrument_indices.items() + for channel in channels.keys()]) + + # Create INDEX section. + f.create_index(train_ids) + + for source, counts in control_indices.items(): + control_src = f.create_control_source(source) + control_src.create_index(len(train_ids)) + + for source, channel_counts in instrument_indices.items(): + instrument_src = f.create_instrument_source(source) + instrument_src.create_index(**channel_counts) + + # Create CONTROL and RUN sections. + for source in control_indices.keys(): + exdf_source = f.source[source] + sd = sources[source] + + for key, value in sd.run_values().items(): + exdf_source.create_run_key(key, value) + + for key in sd.keys(False): + exdf_source.create_key(key, + sd[f'{key}.value'].ndarray(), + sd[f'{key}.timestamp'].ndarray(), + run_entry=False) + + # Create INSTRUMENT datasets. + for source in instrument_indices.keys(): + exdf_source = f.source[source] + sd = sources[source] + + for key in sd.keys(): + kd = sd[key] + + shape = (kd.data_counts(labelled=False).sum(), + *kd.entry_shape) + try: + _, chunks = self._partial_copies[source, key] + except KeyError: + chunks = kd.files[0].file[kd.hdf5_data_path].chunks + + exdf_source.create_key( + key, shape=shape, maxshape=(None,) + shape[1:], + chunks=chunks, dtype=kd.dtype) + + # Copy INSTRUMENT data. + for source in instrument_indices.keys(): + exdf_source = f.source[source] + sd = sources[source] + + for key in sd.keys(): + # TODO: Copy by chunk / file if too large + + data = sd[key].ndarray() + + try: + region, _ = self._partial_copies[source, key] + except KeyError: + exdf_source.key[key][:] = data + else: + full_region = (np.s_[:], *region) + exdf_source.key[key][full_region] = data[full_region] + + return path + class ReductionMethod: log = logging.getLogger('ReductionMethod') @@ -234,54 +378,6 @@ class ReductionMethod: self.log.debug(f'Emitted {self._instructions[-1]}') -class ReductionFileWriter(FileWriter): - def __init__(self, path, data, sequence): - super().__init__(path, data) - self.sequence = sequence - - def prepare_source(self, source): - for key in sorted(self.data.keys_for_source(source)): - path = f"{self._section(source)}/{source}/{key.replace('.', '/')}" - nentries = self._guess_number_of_storing_entries(source, key) - src_ds1 = self.data[source].files[0].file[path] - - try: - chunking_kwargs = dict( - chunks=self.sequence._partial_copies[source, key][1]) - except KeyError: - chunking_kwargs = dict() - - self.file.create_dataset_like( - path, src_ds1, shape=(nentries,) + src_ds1.shape[1:], - # Corrected detector data has maxshape==shape, but if any max - # dim is smaller than the chunk size, h5py complains. Making - # the first dimension unlimited avoids this. - maxshape=(None,) + src_ds1.shape[1:], - **chunking_kwargs - ) - if source in self.data.instrument_sources: - self.data_sources.add(f"INSTRUMENT/{source}/{key.partition('.')[0]}") - - if source not in self.data.instrument_sources: - self.data_sources.add(f"CONTROL/{source}") - - def copy_dataset(self, source, key): - path = f"{self._section(source)}/{source}/{key.replace('.', '/')}" - - try: - region = self.sequence._partial_copies[source, key][0] - except KeyError: - a = self.data.get_array(source, key) - self.file[path][:] = a.values - else: - a = self.data.get_array(source, key) - full_region = (slice(None, None, None), *region) - - self.file[path][full_region] = a.values[full_region] - - self._make_index(source, key, a.coords['trainId'].values) - - proposal_run_input_pattern = re.compile( r'^p(\d{4}|9\d{5}):r(\d{1,3}):*([a-z]*)$') @@ -362,6 +458,9 @@ def main(argv=None): help='export reduction operations to a script file' ) + # TODO: Whether to use suspect trains or not + # TODO: Whether to drop entirely empty sources + args = ap.parse_args(argv) if args.verbose: @@ -383,14 +482,14 @@ def main(argv=None): inp_data = open_run(proposal=m[1], run=m[2], data=m[3] or 'raw') log.info(f'Found proposal run at ' f'{Path(inp_data.files[0].filename).parent}') - + else: inp_data = RunDirectory(args.input) log.info(f'Opened data with {len(inp_data.all_sources)} sources with ' f'{len(inp_data.train_ids)} trains across {len(inp_data.files)} ' f'files') - + methods = {} for ep in iter_entry_points('exdf_tools.reduction_method'): @@ -445,7 +544,7 @@ def main(argv=None): log.debug(f'Writing collapsed collection to {args.output}') seq.write_collapsed(args.output, args.output_sequence_len) - + if __name__ == '__main__': main() -- GitLab