from collections import defaultdict from pathlib import Path import fnmatch import logging from packaging.version import Version import numpy as np from extra_data import by_id from extra_data.read_machinery import select_train_ids from exdf.write import SourceDataWriter from ..write.datafile import write_compressed_frames # Patch SourceData object. import h5py from extra_data.sourcedata import SourceData def _SourceData_get_index_group_sample(self, index_group): if self.is_control and not index_group: # Shortcut for CONTROL data. return self.one_key() if self.sel_keys is not None: for key in self.sel_keys: if key.startswith(index_group): return key def get_key(key, value): if isinstance(value, h5py.Dataset): return index_group + '.' + key.replace('/', '.') group = f'/INSTRUMENT/{self.source}/{index_group}' for f in self.files: return f.file[group].visititems(get_key) SourceData._get_index_group_sample = _SourceData_get_index_group_sample def apply_by_source(op_name): def op_decorator(op_func): def op_handler(self): assert isinstance(self, ReduceWriter) for source_glob, *args in self._filter_ops(op_name): for source in fnmatch.filter(self._sources, source_glob): if op_func(self, source, *args) is False: continue self._touched_sources.add(source) return op_handler return op_decorator def apply_by_key(op_name): def op_decorator(op_func): def op_handler(self): assert isinstance(self, ReduceWriter) for source_glob, key_glob, *args in self._filter_ops(op_name): for source in fnmatch.filter(self._sources, source_glob): keys = self._custom_keys.get( source, set(self._data[source].keys())) num_touched_keys = 0 for key in fnmatch.filter(keys, key_glob): if op_func(self, source, key, *args) is False: continue num_touched_keys += 1 if num_touched_keys > 0: self._touched_sources.add(source) return op_handler return op_decorator class ReduceInitError(RuntimeError): def __init__(self, msg): super().__init__(msg) ReduceWriter.log.error(msg) class ReduceWriter(SourceDataWriter): log = logging.getLogger('exdf.data_reduction.ReduceWriter') def __init__(self, data, methods, scope, sequence_len=-1, version=None): self._data = data self._methods = methods self._scope = scope self._sequence_len = sequence_len metadata = self._data.run_metadata() input_version = Version(metadata.get('dataFormatVersion', '1.0')) if input_version < Version('1.0'): raise ReduceInitError('Currently input files are required to be ' 'EXDF-v1.0+') elif input_version == Version('1.2') and data.control_sources: # Check for mislabeled EXDF-v1.3 files. ctrl_sd = data[next(iter(data.control_sources))] if ctrl_sd[ctrl_sd.one_key()].attributes(): input_version = Version('1.3') self.log.warning('Detected EXDF-v1.3 file with attributes ' 'mislabeled as EXDF-v1.2') self.log.debug(f'Input data EXDF version {input_version}') if version == 'same': version = input_version else: self._version = Version(version) try: self.run_number = int(metadata['runNumber']) except KeyError: raise ReduceInitError('runNumber dataset required to be present ' 'in input METADATA') self._ops = sum(methods.values(), []) if not self._ops: self.log.warning('Sum of reduction methods yielded no operations ' 'to apply') self._sources = sorted(data.all_sources) self._touched_sources = set() # Only populated for sources/keys that are modified. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) self._custom_xtdf_masks = {} # source -> dict(train_id -> mask) self._custom_xtdf_counts = {} # source -> ndarray self._custom_entry_masks = {} # source -> dict(train_id -> mask) self._rechunked_keys = {} # (source, key) -> chunks self._subsliced_keys = {} # (source, key) -> list(<regions>) self._compressed_keys = {} # (source, key) -> level # Collect reductions resulting from operations. # This is the most efficient order of operations to minimize # more expensive operations for source or trains that may not # end up being selected. self._handle_remove_sources() self._handle_remove_keys() self._handle_remove_trains() self._handle_select_trains() self._handle_select_entries() self._handle_select_xtdf() self._handle_rechunk_keys() self._handle_subslice_keys() self._handle_compress_keys() custom_entry_sources = {x[0] for x in self._custom_entry_masks.keys()} if custom_entry_sources & self._custom_xtdf_masks.keys(): raise ReduceInitError('Source may not be affected by both ' 'select-entries and select-xtdf operations') if self._rechunked_keys.keys() & self._compressed_keys.keys(): raise ReduceInitError('Key may not be affected by both ' 'compress-keys and rechunk-keys') if self._scope == 'sources': self._sources = sorted( self._touched_sources.intersection(self._sources)) elif self._scope == 'aggregators': touched_aggregators = {self._data[source].aggregator for source in self._touched_sources} self._sources = sorted( {source for source in self._sources if (self._data[source].aggregator in touched_aggregators)}) if not self._sources: raise ReduceInitError('Reduction operations and output scope ' 'yield an empty dataset') else: self.log.debug( f'Sources being modified: {sorted(self._touched_sources)}') self.log.debug(f'Sources included in output: {self._sources}') def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] def _get_entry_masks(self, source, index_group, train_sel, entry_sel): """Generate bool vectors from any accepted selection.""" train_ids = select_train_ids( self._custom_trains.get(source, list(self._data.train_ids)), train_sel) counts = self._data[source].select_trains(by_id[train_ids]) \ .data_counts(index_group=index_group) masks = {} if isinstance(entry_sel, slice): for train_id, count in counts.items(): if count > 0: masks[train_id] = np.zeros(count, dtype=bool) masks[train_id][entry_sel] = True elif np.issubdtype(type(entry_sel[0]), np.integer): max_entry = max(entry_sel) for train_id, count in counts.items(): if count == 0: continue elif max_entry >= count: raise ReduceInitError(f'Entry index exceeds data counts ' f'of train {train_id}') masks[train_id] = np.zeros(count, dtype=bool) masks[train_id][entry_sel] = True elif np.issubdtype(type(entry_sel[0]), bool): mask_len = len(entry_sel) for train_id, count in counts.items(): if count == 0: continue elif mask_len != counts.get(train_id, 0): raise ReduceInitError(f'Mask length mismatch for ' f'train {train_id}') masks[train_id] = entry_sel else: raise ReduceInitError('Unknown entry mask format') return masks # Public API def write_collection(self, output_path): outp_data = self._data.select([(s, '*') for s in self._sources]) # Collect all items (combination of data category and # aggregator) and the sources they contain. sources_by_item = defaultdict(list) for source in self._sources: sd = outp_data[source] sources_by_item[(sd.data_category, sd.aggregator)].append(source) for (data_category, aggregator), sources in sources_by_item.items(): self.write_item( output_path, sources, f'{data_category}-{aggregator}', dict(data_category=data_category, aggregator=aggregator)) def write_collapsed(self, output_path): self.write_item(output_path, self._sources, 'COLLAPSED') def write_voview(self, output_path): raise NotImplementedError('voview output layout') def write_item(self, output_path, source_names, name, filename_fields={}): """Write sources to a single item.""" # Select output data down to what's in this item both in terms # of sources and trains (via require_any). item_data = self._data.select({ s: self._custom_keys[s] if s in self._custom_keys else set() for s in source_names }, require_any=True) # Switch to representation of SourceData objects for # per-source tracking of trains. item_sources = [item_data[source] for source in item_data.all_sources] # Tetermine input sequence length if no explicit value was given # for output. if self._sequence_len < 1: sequence_len = max({ len(sd._get_first_source_file().train_ids) for sd in item_sources }) else: sequence_len = self._sequence_len # Apply custom train selections, if any. for i, sd in enumerate(item_sources): train_sel = self._custom_trains.get(sd.source, None) if train_sel is not None: item_sources[i] = sd.select_trains(by_id[train_sel]) # Find the union of trains across all sources as total # trains for this item. item_train_ids = np.zeros(0, dtype=np.uint64) for sd in item_sources: item_train_ids = np.union1d( item_train_ids, sd.drop_empty_trains().train_ids) num_trains = len(item_train_ids) num_sequences = int(np.ceil(num_trains / sequence_len)) self.log.info( f'{name} containing {len(item_sources)} sources with {num_trains} ' f'trains over {num_sequences} sequences') for seq_no in range(num_sequences): seq_slice = np.s_[ (seq_no * sequence_len):((seq_no + 1) * sequence_len)] # Slice out the train IDs and timestamps for this sequence. seq_train_ids = item_train_ids[seq_slice] # Select item data down to what's in this sequence. seq_sources = [sd.select_trains(by_id[seq_train_ids]) for sd in item_sources] # Build explicit output path for this sequence. seq_path = Path(str(output_path).format( run=self.run_number, sequence=seq_no, **filename_fields)) self.log.debug(f'{seq_path.stem} containing {len(seq_sources)} ' f'sources with {len(seq_train_ids)} trains') self.write_sequence(seq_path, seq_sources, seq_no) # SourceDataWriter hooks. def write_base(self, f, sources, sequence): super().write_base(f, sources, sequence) # Add reduction-specific METADATA red_group = f.require_group('METADATA/reduction') for name, method in self._methods.items(): ops = np.array([ '\t'.join([str(x) for x in op[:]]).encode('ascii') for op in method ]) red_group.create_dataset(name, shape=len(method), data=ops,) def get_data_format_version(self): return str(self._version) def with_origin(self): return self._version >= Version('1.2') def with_attrs(self): return self._version >= Version('1.3') def create_instrument_key(self, source, key, orig_dset, kwargs): # Keys are guaranteed to never use both custom chunking and # compression. sourcekey = source, key if sourcekey in self._rechunked_keys: orig_chunks = kwargs['chunks'] chunks = list(self._rechunked_keys[sourcekey]) assert len(chunks) == len(orig_chunks) for i, dim_len in enumerate(chunks): if dim_len is None: chunks[i] = orig_chunks[i] if -1 in chunks: chunks[chunks.index(-1)] = \ np.prod(orig_chunks) // -np.prod(chunks) kwargs['chunks'] = tuple(chunks) elif sourcekey in self._compressed_keys or orig_dset.compression: # TODO: Maintain more of existing properties, for now it is # forced to use gzip and (1, *entry) chunking. kwargs['chunks'] = (1,) + kwargs['shape'][1:] kwargs['shuffle'] = True kwargs['compression'] = 'gzip' kwargs['compression_opts'] = self._compressed_keys.setdefault( sourcekey, orig_dset.compression_opts) return kwargs def mask_instrument_data(self, source, index_group, train_ids, counts): if source in self._custom_xtdf_masks and index_group == 'image': custom_masks = self._custom_xtdf_masks[source] elif (source, index_group) in self._custom_entry_masks: custom_masks = self._custom_entry_masks[source, index_group] else: return # None efficiently selects all entries. masks = [] for train_id, count_all in zip(train_ids, counts): if train_id in custom_masks: mask = custom_masks[train_id] else: mask = np.ones(count_all, dtype=bool) masks.append(mask) if source in self._custom_xtdf_masks: # Sources are guaranteed to never use both XTDF and general # entry slicing. In the XTDF case, the new data counts for # the image index group must be determined to be filled into # the respective header field. self._custom_xtdf_counts[source] = { train_id: mask.sum() for train_id, mask in zip(train_ids, masks) if mask.any()} return masks def copy_instrument_data(self, source, key, dest, train_ids, data): if source in self._custom_xtdf_counts and key == 'header.pulseCount': custom_counts = self._custom_xtdf_counts[source] for i, train_id in enumerate(train_ids): data[i] = custom_counts.get(train_id, data[i]) if (source, key) in self._subsliced_keys: for region in self._subsliced_keys[source, key]: sel = (np.s_[:], *region) dest[sel] = data[sel] elif (source, key) in self._compressed_keys: write_compressed_frames( data, dest, self._compressed_keys[source, key], 8) else: dest[:] = data # Reduction operation handlers. @apply_by_source('remove-sources') def _handle_remove_sources(self, source): self._sources.remove(source) self.log.debug(f'Removing {source}') @apply_by_key('remove-keys') def _handle_remove_keys(self, source, key): if source not in self._custom_keys: self._custom_keys[source] = set(self._data[source].keys()) self._custom_keys[source].remove(key) self.log.debug(f'Removing {source}, {key}') @apply_by_source('select-trains') def _handle_select_trains(self, source, train_sel): self._custom_trains[source] = select_train_ids( self._custom_trains.setdefault(source, list(self._data.train_ids)), train_sel) self.log.debug(f'Selecting {len(self._custom_trains[source])} trains ' f'for {source}') @apply_by_source('remove-trains') def _handle_remove_trains(self, source, train_sel): trains = self._custom_trains.setdefault(source, list(self._data.train_ids)) trains_to_remove = select_train_ids(trains, train_sel) trains_to_keep = np.setdiff1d(trains, trains_to_remove) self._custom_trains[source] = trains_to_keep self.log.debug(f'Selecting {len(trains_to_keep)} trains ' f'for {source}') @apply_by_source('select-entries') def _handle_select_entries(self, source, idx_group, train_sel, entry_sel): if idx_group not in self._data[source].index_groups: raise ReduceInitError( f'{idx_group} not an index group of {source}') masks = self._custom_entry_masks.setdefault((source, idx_group), {}) new_masks = self._get_entry_masks( source, idx_group, train_sel, entry_sel) for train_id, train_mask in new_masks.items(): if train_id in masks: train_mask &= masks[train_id] masks[train_id] = train_mask self.log.debug(f'Applying entry selection to {source}, {idx_group}') @apply_by_source('select-xtdf') def _handle_select_xtdf(self, source, train_sel, entry_sel): if not source.endswith(':xtdf'): self.log.warning( f'Ignoring non-XTDF source {source} based on name') return False masks = self._custom_xtdf_masks.setdefault(source, {}) new_masks = self._get_entry_masks( source, 'image', train_sel, entry_sel) for train_id, train_mask in new_masks.items(): if train_id in masks: train_mask &= masks[train_id] masks[train_id] = train_mask self.log.debug(f'Applying XTDF selection to {source}') @apply_by_key('rechunk-keys') def _handle_rechunk_keys(self, source, key, chunking): if not self._data[source].is_instrument: # Ignore CONTROL sources. return False old_chunking = self._rechunked_keys.setdefault((source, key), chunking) if old_chunking != chunking: raise ReduceInitError( f'Reduction sequence yields conflicting chunks for ' f'{source}.{key}: {old_chunking}, {chunking}') self._rechunked_keys[(source, key)] = chunking self.log.debug(f'Rechunking {source}, {key} to {chunking}') @apply_by_key('subslice-keys') def _handle_subslice_keys(self, source, key, region): self._subsliced_keys.setdefault((source, key), []).append(region) self.log.debug(f'Subslicing {region} of {source}, {key}') @apply_by_key('compress-keys') def _handle_compress_keys(self, source, key, level): self._compressed_keys[source, key] = level self.log.debug(f'Compressing {source}, {key}')