from collections import defaultdict from pathlib import Path import fnmatch import logging from packaging.version import Version import numpy as np from extra_data import by_id from extra_data.read_machinery import select_train_ids from exdf.write import SourceDataWriter class ReduceWriter(SourceDataWriter): log = logging.getLogger('exdf.data_reduction.ReduceWriter') def __init__(self, data, methods, scope, sequence_len=-1, version=None): self._data = data self._methods = methods self._scope = scope self._sequence_len = sequence_len metadata = self._data.run_metadata() input_version = Version(metadata.get('dataFormatVersion', '1.0')) if input_version < Version('1.0'): raise ValueError('Currently input files are required to be ' 'EXDF-v1.0+') if version == 'same': version = input_version else: self._version = Version(version) try: self.run_number = int(metadata['runNumber']) except KeyError: raise ValueError('runNumber dataset required in input METADATA') self._ops = sum(methods.values(), []) if not self._ops: self.log.warning('Sum of reduction methods yielded no operations ' 'to apply') self._sources = sorted(data.all_sources) self._touched_sources = set() # Only populated if trains/keys are selected/removed for sources. self._custom_keys = {} # source -> set(<keys>) self._custom_trains = {} # source -> list(<trains>) self._custom_xtdf_masks = {} # source -> dict(train_id -> mask) self._custom_xtdf_counts = {} # source -> ndarray self._custom_rows = {} # source -> dict(train_id -> mask) self._rechunked_keys = {} # (source, key) -> chunks self._partial_copies = {} # (source, key) -> list(<regions>) # TODO: Raise error if rechunking is overwritten! # TODO: make partial copies a list of slices! # Collect reductions resulting from operations. for source_glob, in self._filter_ops('remove-sources'): for source in fnmatch.filter(self._sources, source_glob): self._touched_sources.add(source) self._sources.remove(source) for source_glob, key_glob in self._filter_ops('remove-keys'): for source in fnmatch.filter(self._sources, source_glob): self._touched_sources.add(source) keys = self._custom_keys.setdefault( source, set(self._data[source].keys())) for key in fnmatch.filter(keys, key_glob): keys.remove(key) for source_glob, train_sel in self._filter_ops('select-trains'): for source in fnmatch.filter(self._sources, source_glob): self._touched_sources.add(source) train_ids = self._custom_trains.setdefault( source, list(self._data.train_ids)) self._custom_trains[source] = select_train_ids( train_ids, train_sel) for source_glob, index_group, train_sel, row_sel in self._filter_ops( 'select-rows' ): for source in fnmatch.filter(self._sources, source_glob): if index_group not in self._data[source].index_groups: raise ValueError(f'{index_group} not index group of ' f'{source}') self._touched_sources.add(source) self._custom_rows.setdefault((source, index_group), {}).update( self._get_row_masks(source, index_group, train_sel, row_sel)) for source_glob, train_sel, row_sel in self._filter_ops('select-xtdf'): for source in fnmatch.filter(self._sources, source_glob): if not source.endswith(':xtdf'): # Simply ignore matches without trailing :xtdf. continue if not self._is_xtdf_source(source): # Raise exception if essentials are missing. raise ValueError(f'{source} is not a valid XTDF source') self._touched_sources.add(source) self._custom_xtdf_masks.setdefault(source, {}).update( self._get_row_masks(source, 'image', train_sel, row_sel)) if ( {x[0] for x in self._custom_rows.keys()} & self._custom_xtdf_masks.keys() ): raise ValueError('source may not be affected by both select-rows ' 'and select-xtdf operations') for source_glob, key_glob, chunking in self._filter_ops( 'rechunk-keys' ): for source in fnmatch.filter(self._sources, source_glob): if not self._data[source].is_instrument: raise ValueError( f'rechunking keys only supported for instrument ' f'sources, but {source_glob} matches ' f'{self._data[source].section}/{source}') self._touched_sources.add(source) keys = self._custom_keys.get( source, set(self._data[source].keys())) for key in fnmatch.filter(keys, key_glob): old_chunking = self._rechunked_keys.setdefault( (source, key), chunking) if old_chunking != chunking: raise ValueError( f'reduction sequence yields conflicting chunks ' f'for {source}.{key}: {old_chunking}, {chunking}') self._rechunked_keys[(source, key)] = chunking for source_glob, key_glob, region in self._filter_ops('partial-copy'): for source in fnmatch.filter(self._sources, source_glob): self._touched_sources.add(source) keys = self._custom_keys.get( source, set(self._data[source].keys())) for key in fnmatch.filter(keys, key_glob): self._partial_copies.setdefault((source, key), []).append( region) if self._scope == 'sources': self._sources = sorted( self._touched_sources.intersection(self._sources)) elif self._scope == 'aggregators': touched_aggregators = {self._data[source].aggregator for source in self._touched_sources} self._sources = sorted( {source for source in self._sources if (self._data[source].aggregator in touched_aggregators)}) if not self._sources: raise ValueError('reduction sequence yields empty source ' 'selection') def _filter_ops(self, op): return [args[1:] for args in self._ops if args[0] == op] def _is_xtdf_source(self, source): return self._data[source].keys() > {'header.pulseCount', 'image.data'} def _get_row_masks(self, source, index_group, train_sel, row_sel): train_ids = select_train_ids( self._custom_trains.get(source, list(self._data.train_ids)), train_sel) counts = self._data[source].select_trains(by_id[train_ids]) \ .data_counts(index_group=index_group) masks = {} if isinstance(row_sel, slice): for train_id, count in counts.items(): if count > 0: masks[train_id] = np.zeros(count, dtype=bool) masks[train_id][row_sel] = True elif np.issubdtype(type(row_sel[0]), np.integer): max_row = max(row_sel) for train_id, count in counts.items(): if count == 0: continue elif max_row >= count: raise ValueError( f'row index exceeds data counts of train {train_id}') masks[train_id] = np.zeros(count, dtype=bool) masks[train_id][row_sel] = True elif np.issubdtype(type(row_sel[0]), bool): mask_len = len(row_sel) for train_id, count in counts.items(): if count == 0: continue elif mask_len != counts.get(train_id, 0): raise ValueError( f'mask length mismatch for train {train_id}') masks[train_id] = row_sel else: raise ValueError('unknown row mask format') return masks def write_collection(self, output_path): outp_data = self._data.select([(s, '*') for s in self._sources]) # Collect all items (combination of data category and # aggregator) and the sources they contain. sources_by_item = defaultdict(list) for source in self._sources: sd = outp_data[source] sources_by_item[(sd.data_category, sd.aggregator)].append(source) for (data_category, aggregator), sources in sources_by_item.items(): self.write_item( output_path, sources, f'{data_category}-{aggregator}', dict(data_category=data_category, aggregator=aggregator)) def write_collapsed(self, output_path): self.write_item(output_path, self._sources, 'COLLAPSED') def write_voview(self, output_path): raise NotImplementedError('voview output layout') def write_item(self, output_path, source_names, name, filename_fields={}): """Write sources to a single item.""" # Select output data down to what's in this item both in terms # of sources and trains (via require_any). item_data = self._data.select({ s: self._custom_keys[s] if s in self._custom_keys else set() for s in source_names }, require_any=True) # Switch to representation of SourceData objects for # per-source tracking of trains. item_sources = [item_data[source] for source in item_data.all_sources] # Tetermine input sequence length if no explicit value was given # for output. if self._sequence_len < 1: sequence_len = max({ len(sd._get_first_source_file().train_ids) for sd in item_sources }) else: sequence_len = self._sequence_len # Apply custom train selections, if any. for i, sd in enumerate(item_sources): train_sel = self._custom_trains.get(sd.source, None) if train_sel is not None: item_sources[i] = sd.select_trains(by_id[train_sel]) # Find the union of trains across all sources as total # trains for this item. item_train_ids = np.zeros(0, dtype=np.uint64) for sd in item_sources: item_train_ids = np.union1d( item_train_ids, sd.drop_empty_trains().train_ids) num_trains = len(item_train_ids) num_sequences = int(np.ceil(num_trains / sequence_len)) self.log.info( f'{name} containing {len(item_sources)} sources with {num_trains} ' f'trains over {num_sequences} sequences') for seq_no in range(num_sequences): seq_slice = np.s_[ (seq_no * sequence_len):((seq_no + 1) * sequence_len)] # Slice out the train IDs and timestamps for this sequence. seq_train_ids = item_train_ids[seq_slice] # Select item data down to what's in this sequence. seq_sources = [sd.select_trains(by_id[seq_train_ids]) for sd in item_sources] # Build explicit output path for this sequence. seq_path = Path(str(output_path).format( run=self.run_number, sequence=seq_no, **filename_fields)) self.log.debug(f'{seq_path.stem} containing {len(seq_sources)} ' f'sources with {len(seq_train_ids)} trains') self.write_sequence(seq_path, seq_sources, seq_no) # SourceDataWriter hooks. def write_base(self, f, sources, sequence): super().write_base(f, sources, sequence) # Add reduction-specific METADATA red_group = f.require_group('METADATA/reduction') for name, method in self._methods.items(): ops = np.array([ '\t'.join([str(x) for x in op[:]]).encode('ascii') for op in method ]) red_group.create_dataset(name, shape=len(method), data=ops,) def get_data_format_version(self): return str(self._version) def with_origin(self): return self._version >= Version('1.2') def with_attrs(self): return self._version >= Version('1.3') def chunk_instrument_data(self, source, key, orig_chunks): try: chunks = list(self._rechunked_keys[source, key]) assert len(chunks) == len(orig_chunks) for i, dim_len in enumerate(chunks): if dim_len is None: chunks[i] = orig_chunks[i] if -1 in chunks: chunks[chunks.index(-1)] = \ np.prod(orig_chunks) // -np.prod(chunks) return tuple(chunks) except KeyError: return orig_chunks def mask_instrument_data(self, source, index_group, train_ids, counts): if source in self._custom_xtdf_masks and index_group == 'image': custom_masks = self._custom_xtdf_masks[source] elif (source, index_group) in self._custom_rows: custom_masks = self._custom_rows[source, index_group] else: return # None efficiently selects all rows. masks = [] for train_id, count_all in zip(train_ids, counts): if train_id in custom_masks: mask = custom_masks[train_id] else: mask = np.ones(count_all, dtype=bool) masks.append(mask) if source in self._custom_xtdf_masks: # Sources are guaranteed to never use both XTDF and general # row slicing. In the XTDF case, the new data counts for the # image index group must be determined to be filled into # the respective header field. self._custom_xtdf_counts[source] = { train_id: mask.sum() for train_id, mask in zip(train_ids, masks) if mask.any()} return masks def copy_instrument_data(self, source, key, dest, train_ids, data): if source in self._custom_xtdf_counts and key == 'header.pulseCount': custom_counts = self._custom_xtdf_counts[source] for i, train_id in enumerate(train_ids): data[i] = custom_counts.get(train_id, data[i]) try: regions = self._partial_copies[source, key] except KeyError: dest[:] = data else: for region in regions: sel = (np.s_[:], *region) dest[sel] = data[sel]