Skip to content
Snippets Groups Projects

Refactor stacking for reuse and overlappability

Merged David Hammer requested to merge refactor-stacking into master
Files
2
+ 400
0
import collections
import concurrent.futures
import enum
import re
from karabo.bound import (
BOOL_ELEMENT,
INT32_ELEMENT,
STRING_ELEMENT,
TABLE_ELEMENT,
Hash,
Schema,
State,
)
import numpy as np
from . import utils
class GroupType(enum.Enum):
MULTISOURCE = "sources" # same key stacked from multiple sources in new source
MULTIKEY = "keys" # multiple keys within each matched source is stacked in new key
class MergeMethod(enum.Enum):
STACK = "stack"
INTERLEAVE = "interleave"
def merge_schema():
schema = Schema()
(
BOOL_ELEMENT(schema)
.key("select")
.displayedName("Select")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("sourcePattern")
.displayedName("Source pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("keyPattern")
.displayedName("Key pattern")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("replacement")
.displayedName("Replacement")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("groupType")
.displayedName("Group type")
.options(",".join(option.value for option in GroupType))
.assignmentOptional()
.defaultValue(GroupType.MULTISOURCE.value)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("mergeMethod")
.displayedName("Merge method")
.options(",".join(option.value for option in MergeMethod))
.assignmentOptional()
.defaultValue(MergeMethod.STACK.value)
.reconfigurable()
.commit(),
INT32_ELEMENT(schema)
.key("axis")
.displayedName("Axis")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
STRING_ELEMENT(schema)
.key("missingValue")
.displayedName("Missing value default")
.description(
"If some sources are missing within one group in multi-source stacking*, "
"the corresponding parts of the resulting stacked array will be set to "
"this value. Note that if no sources are present, the array is not created "
"at all. This field is a string to allow special values like float nan / "
"inf; it is your responsibility to make sure that data types match (i.e. "
"if this is 'nan', the stacked data better be floats or doubles). *Missing "
"value handling is not yet implementedo for multi-key stacking."
)
.assignmentOptional()
.defaultValue("0")
.reconfigurable()
.commit(),
)
return schema
class StackingFriend:
@staticmethod
def add_schema(schema):
(
TABLE_ELEMENT(schema)
.key("merge")
.displayedName("Array stacking")
.allowedStates(State.PASSIVE)
.description(
"Specify which source(s) or key(s) to stack or interleave."
"When stacking sources, the 'Source pattern' is interpreted as a "
"regular expression and the 'Key pattern' is interpreted as an "
"ordinary string. From all sources matching the source pattern, the "
"data under this key (should be array with same dimensions across all "
"stacked sources) is stacked in the same order as the sources are "
"listed in 'Data sources' and the result is under the same key name in "
"a new source named by 'Replacement'. "
"When stacking keys, both the 'Source pattern' and the 'Key pattern' "
"are regular expressions. Within each source matching the source "
"pattern, all keys matching the key pattern are stacked and the result "
"is put under the key named by 'Replacement'. "
"While source stacking is optimized and can use thread pool, key "
"stacking will iterate over all paths in matched sources and naively "
"call np.stack for each key pattern. In either case, data that is used "
"for stacking is removed from its original location (e.g. key is "
"erased from hash). "
"In both cases, the data can alternatively be interleaved. This is "
"essentially equivalent to stacking except followed by a reshape such "
"that the output is shaped like concatenation."
)
.setColumns(merge_schema())
.assignmentOptional()
.defaultValue([])
.reconfigurable()
.commit(),
)
def __init__(self, device, source_config, merge_config):
# used during pre-processing to set up buffers
# (new source, key, method, group size, axis, missing) -> {original sources}
self._new_sources_inputs = collections.defaultdict(set)
# used for source stacking
# (old source, new source, key) -> (index (maybe index_exp), axis)
self._source_stacking_indices = {}
# list of (old source, key, new source)
self._source_stacking_parts = []
# used for key stacking
# list of entries used in _handle_key_stacking_entry
self._key_stacking_entries = []
self._merge_config = None
self._source_config = None
self._device = device
self.reconfigure(source_config, merge_config)
def reconfigure(self, source_config, merge_config):
print("merge_config", type(merge_config))
print("source_config", type(source_config))
if source_config is not None:
self._source_config = source_config
if merge_config is not None:
self._merge_config = merge_config
self._source_stacking_indices.clear()
self._source_stacking_parts.clear()
self._new_sources_inputs.clear()
self._key_stacking_entries.clear()
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self._source_config]
for row in self._merge_config:
if not row["select"]:
continue
source_re = re.compile(row["sourcePattern"])
group_type = GroupType(row["groupType"])
merge_method = MergeMethod(row["mergeMethod"])
axis = row["axis"]
merge_sources = [
source for source in source_names if source_re.match(source)
]
if not merge_sources:
self._device.log.WARN(
f"Group pattern {source_re} did not match any known sources"
)
continue
if group_type is GroupType.MULTISOURCE:
key = row["keyPattern"]
new_source = row["replacement"]
missing = row["missingValue"]
for i, source in enumerate(merge_sources):
self._source_stacking_indices[(source, new_source, key)] = (
i
if merge_method is MergeMethod.STACK
else np.index_exp[ # interleaving
slice(
i,
None,
len(merge_sources),
)
]
), axis
self._new_sources_inputs[
(
new_source,
key,
merge_method,
len(merge_sources),
axis,
missing,
)
].add(source)
self._source_stacking_parts.append((source, key, new_source))
else:
key_re = re.compile(row["keyPattern"])
new_key = row["replacement"]
for source in merge_sources:
self._key_stacking_entries.append(
(source, key_re, new_key, merge_method, axis)
)
def process(self, sources, thread_pool=None):
merge_buffers = {}
new_source_map = {}
missing_value_defaults = {}
# prepare for source stacking where sources are present
for (
new_source,
data_key,
merge_method,
group_size,
axis,
missing_value_str,
), original_sources in self._new_sources_inputs.items():
for source in original_sources:
if source not in sources:
continue
data_hash, timestamp = sources[source]
data = data_hash.get(data_key)
if not isinstance(data, np.ndarray):
continue
# didn't hit continue => first source in group present and with data
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
data.shape, group_size, axis=axis
)
else:
expected_shape = utils.interleaving_buffer_shape(
data.shape, group_size, axis=axis
)
merge_buffer = np.empty(expected_shape, dtype=data.dtype)
merge_buffers[(new_source, data_key)] = merge_buffer
if new_source not in new_source_map:
new_source_map[new_source] = (Hash(), timestamp)
new_source_map[new_source][0][data_key] = merge_buffer
try:
missing_value = data.dtype.type(missing_value_str)
except ValueError:
self._device.log.WARN(
f"Invalid missing data value for {new_source}.{data_key} "
f"(tried making a {data.dtype} out of "
f" '{missing_value_str}', using 0"
)
missing_value = 0
missing_value_defaults[(new_source, data_key)] = missing_value
# now we have set up the buffer for this new source, so break
break
else:
# in this case: no source (if any) had data_key
self._device.log.WARN(
f"No sources needed for {new_source}.{data_key} were present"
)
# now actually do some work
if thread_pool is None:
for thing in self._source_stacking_parts:
self._handle_source_stacking_part(
merge_buffers, missing_value_defaults, sources, *thing
)
for thing in self._key_stacking_entries:
self._handle_key_stacking_entry(sources, *thing)
else:
awaitables = []
for thing in self._source_stacking_parts:
awaitables.append(
thread_pool.submit(
self._handle_source_stacking_part,
merge_buffers,
missing_value_defaults,
sources,
*thing,
)
)
for thing in self._key_stacking_entries:
awaitables.append(
thread_pool.submit(self._handle_key_stacking_entry, sources, *thing)
)
concurrent.futures.wait(awaitables)
# note: new source names for group stacking may not match existing source names
sources.update(new_source_map)
def _handle_source_stacking_part(
self,
merge_buffers,
missing_values,
actual_sources,
old_source,
data_key,
new_source,
):
"""Helper function used in processing. Note that it should be called for each
(original source, data key, new source) triple expected in the stacking scheme
regardless of whether that original source is present - so it can decide whether
to move data in for stacking, fill missing data in case a source is missing,
or skip in case no buffer was created (none of the necessary sources were
present)."""
if (new_source, data_key) not in merge_buffers:
# preparatory step did not create buffer
# (i.e. no sources from this group were present)
return
merge_buffer = merge_buffers[(new_source, data_key)]
if old_source in actual_sources:
data_hash = actual_sources[old_source][0]
else:
data_hash = None
if data_hash is None or (merge_data := data_hash.get(data_key)) is None:
merge_data = missing_values[(new_source, data_key)]
merge_index, merge_axis = self._source_stacking_indices[
(old_source, new_source, data_key)
]
try:
utils.set_on_axis(merge_buffer, merge_data, merge_index, merge_axis)
if data_hash is not None:
data_hash.erase(data_key)
except Exception as ex:
self._device.log.WARN(
f"Failed to stack {data_key} from {old_source} into {new_source}: {ex}"
)
utils.set_on_axis(
merge_buffer,
missing_values[(new_source, data_key)],
merge_index,
merge_axis,
)
def _handle_key_stacking_entry(
self, sources, source, key_re, new_key, merge_method, axis
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
if source not in sources:
self._device.log.WARN(f"Source {source} not found for key stacking")
return
data_hash = sources[source][0]
keys = [key for key in data_hash.paths() if key_re.match(key)]
if not keys:
self._device.log.WARN(
f"Source {source} had no keys matching {key_re} for stacking"
)
return
try:
# note: maybe we could reuse buffers here, too?
if merge_method is MergeMethod.STACK:
stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
else:
first = data_hash.get(keys[0])
stacked = np.empty(
shape=utils.stacking_buffer_shape(
first.shape, len(keys), axis=axis
),
dtype=first.dtype,
)
for i, key in enumerate(keys):
utils.set_on_axis(
stacked,
data_hash.get(key),
np.index_exp[slice(i, None, len(keys))],
axis,
)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in keys:
data_hash.erase(key)
data_hash[new_key] = stacked
Loading