Skip to content
Snippets Groups Projects

Refactor stacking for reuse and overlappability

Merged David Hammer requested to merge refactor-stacking into master
Files
2
+ 124
117
import collections
import concurrent.futures
import enum
import functools
import re
from karabo.bound import (
@@ -139,41 +141,42 @@ class StackingFriend:
)
def __init__(self, device, source_config, merge_config):
# (old source, new source, key) -> (index (maybe index_exp), axis)
self._source_stacking_indices = {}
self._source_stacking_sources = collections.defaultdict(list)
self._source_stacking_group_sizes = {}
# (new source name, key) -> {original sources used}
# list of (old source, key, new source)
self._source_stacking_parts = []
# (new source, key, method, group size, axis, missing) -> {original sources}
self._new_sources_inputs = collections.defaultdict(set)
self._key_stacking_sources = collections.defaultdict(list)
self._merge_config = None
self._source_config = None
self._device = device
self.reconfigure(merge_config, source_config)
self.reconfigure(source_config, merge_config)
def reconfigure(self, merge_config, source_config):
def reconfigure(self, source_config, merge_config):
print("merge_config", type(merge_config))
print("source_config", type(source_config))
if merge_config is not None:
self._merge_config = merge_config
if source_config is not None:
self._source_config = source_config
if merge_config is not None:
self._merge_config = merge_config
self._source_stacking_indices.clear()
self._source_stacking_sources.clear()
self._source_stacking_group_sizes.clear()
self._key_stacking_sources.clear()
self._source_stacking_parts.clear()
self._new_sources_inputs.clear()
self._key_stacking_sources.clear()
# not filtering by row["select"] to allow unselected sources to create gaps
source_names = [row["source"].partition("@")[0] for row in self._source_config]
source_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTISOURCE.name
if row["select"] and row["groupType"] == GroupType.MULTISOURCE.value
]
key_stacking_groups = [
row
for row in self._merge_config
if row["select"] and row["groupType"] == GroupType.MULTIKEY.name
if row["select"] and row["groupType"] == GroupType.MULTIKEY.value
]
for row in source_stacking_groups:
@@ -181,6 +184,7 @@ class StackingFriend:
merge_method = MergeMethod(row["mergeMethod"])
key = row["keyPattern"]
new_source = row["replacement"]
axis = row["axis"]
merge_sources = [
source for source in source_names if source_re.match(source)
]
@@ -189,11 +193,7 @@ class StackingFriend:
f"Group pattern {source_re} did not match any known sources"
)
continue
self._source_stacking_group_sizes[(new_source, key)] = len(merge_sources)
for i, source in enumerate(merge_sources):
self._source_stacking_sources[source].append(
(key, new_source, merge_method, row["axis"])
)
self._source_stacking_indices[(source, new_source, key)] = (
i
if merge_method is MergeMethod.STACK
@@ -204,7 +204,18 @@ class StackingFriend:
len(merge_sources),
)
]
)
), axis
self._new_sources_inputs[
(
new_source,
key,
merge_method,
len(merge_sources),
axis,
row["missingValue"],
)
].add(source)
self._source_stacking_parts.append((source, key, new_source))
for row in key_stacking_groups:
key_re = re.compile(row["keyPattern"])
@@ -214,25 +225,27 @@ class StackingFriend:
)
def process(self, sources, thread_pool=None):
stacking_data_shapes = {}
stacking_buffers = {}
new_source_map = collections.defaultdict(Hash)
merge_buffers = {}
new_source_map = {}
missing_value_defaults = {}
# prepare for source stacking where sources are present
source_set = set(sources.keys())
for (
new_source,
data_key,
merge_method,
group_size,
axis,
missing_value,
missing_value_str,
), original_sources in self._new_sources_inputs.items():
for present_source in source_set & original_sources:
data = sources[present_source].get(data_key)[0]
for source in original_sources:
if source not in sources:
continue
data_hash, timestamp = sources[source]
data = data_hash.get(data_key)
if data is None:
continue
# didn't hit continue => first source in group present and with data
if merge_method is MergeMethod.STACK:
expected_shape = utils.stacking_buffer_shape(
data.shape, group_size, axis=axis
@@ -241,119 +254,113 @@ class StackingFriend:
expected_shape = utils.interleaving_buffer_shape(
data.shape, group_size, axis=axis
)
stacking_buffer = np.empty(
expected_shape, dtype=data.dtype
)
stacking_buffers[(new_source, data_key)] = stacking_buffer
new_source_map[new_source][data_key] = stacking_buffer
merge_buffer = np.empty(expected_shape, dtype=data.dtype)
merge_buffers[(new_source, data_key)] = merge_buffer
if new_source not in new_source_map:
new_source_map[new_source] = (Hash(), timestamp)
new_source_map[new_source][0][data_key] = merge_buffer
try:
missing_value_defaults[(new_source, data_key)] = data.dtype.type(
missing_value
)
missing_value = data.dtype.type(missing_value_str)
except ValueError:
self._device.log.WARN(
f"Invalid missing data value for {new_source}.{data_key}, using 0"
f"Invalid missing data value for {new_source}.{data_key} "
f"(tried making a {data.dtype} out of "
f" '{missing_value_str}', using 0"
)
missing_value = 0
missing_value_defaults[(new_source, data_key)] = missing_value
# now we have set up the buffer for this new source, so break
break
else:
# in this case: no present_source (if any) had data_key
# in this case: no source (if any) had data_key
self._device.log.WARN(
f"No sources needed for {new_source}.{data_key} were present"
)
for (new_source, key), attr in stacking_data_shapes.items():
merge_data_shape, merge_method, axis, dtype, timestamp = attr
group_size = parent._source_stacking_group_sizes[(new_source, key)]
merge_buffer = self._stacking_buffers.get((new_source, key))
if merge_buffer is None or merge_buffer.shape != expected_shape:
merge_buffer = np.empty(shape=expected_shape, dtype=dtype)
self._stacking_buffers[(new_source, key)] = merge_buffer
for merge_index in self._fill_missed_data.get((new_source, key), []):
utils.set_on_axis(
merge_buffer,
0,
merge_index
if merge_method is MergeMethod.STACK
else np.index_exp[
slice(
merge_index,
None,
parent._source_stacking_group_sizes[(new_source, key)],
)
],
axis,
)
if new_source not in self.new_source_map:
self.new_source_map[new_source] = (Hash(), timestamp)
new_source_hash = self.new_source_map[new_source][0]
if not new_source_hash.has(key):
new_source_hash[key] = merge_buffer
# now actually do some work
fun = functools.partial(self._handle_source, ...)
fun = functools.partial(
self._handle_source_stacking_part,
merge_buffers,
missing_value_defaults,
sources,
)
if thread_pool is None:
for _ in map(fun, ...):
for _ in map(fun, self._source_stacking_parts):
pass
else:
concurrent.futures.wait(thread_pool.map(fun, ...))
concurrent.futures.wait(thread_pool.map(fun, self._source_stacking_parts))
# note: new source names for group stacking may not match existing source names
sources.update(new_source_map)
def _handle_expected_source(self, merge_buffers, missing_values, actual_sources, expected_source):
def _handle_source_stacking_part(
self,
merge_buffers,
missing_values,
actual_sources,
stacking_triple,
):
"""Helper function used in processing. Note that it should be called for each
source that was supposed to be there - so it can decide whether to move data in
for stacking, fill missing data in case a source is missing, or skip in case no
buffer was created (none of the necessary sources were present)."""
(original source, data key, new source) triple expected in the stacking scheme
regardless of whether that original source is present - so it can decide whether
to move data in for stacking, fill missing data in case a source is missing,
or skip in case no buffer was created (none of the necessary sources were
present)."""
if expected_source not in actual_sources:
if ex
for (
key,
new_source,
merge_method,
axis,
) in self._source_stacking_sources.get(source, ()):
if (new_source, key) in self._ignore_stacking:
continue
merge_data = data_hash.get(key)
merge_index = self._source_stacking_indices[(source, new_source, key)]
merge_buffer = self._stacking_buffers[(new_source, key)]
old_source, data_key, new_source = stacking_triple
if (new_source, data_key) not in merge_buffers:
# preparatory step did not create buffer
# (i.e. no sources from this group were present)
return
merge_buffer = merge_buffers[(new_source, data_key)]
if old_source in actual_sources:
data_hash = actual_sources[old_source][0]
else:
data_hash = None
if data_hash is None or (merge_data := data_hash.get(data_key)) is None:
merge_data = missing_values[(new_source, data_key)]
try:
utils.set_on_axis(
merge_buffer,
merge_data,
merge_index,
axis,
*self._source_stacking_indices[(old_source, new_source, data_key)],
)
if data_hash is not None:
data_hash.erase(data_key)
except Exception as ex:
self._device.log.WARN(
f"Failed to stack {data_key} from {old_source} into {new_source}: {ex}"
)
data_hash.erase(key)
# stack keys (multiple keys within this source)
for key_re, new_key, merge_method, axis in self._key_stacking_sources.get(
source, ()
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
keys = [key for key in data_hash.paths() if key_re.match(key)]
try:
# note: maybe we could reuse buffers here, too?
if merge_method is MergeMethod.STACK:
stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
else:
first = data_hash.get(keys[0])
stacked = np.empty(
shape=utils.stacking_buffer_shape(
first.shape, len(keys), axis=axis
),
dtype=first.dtype,
)
for i, key in enumerate(keys):
utils.set_on_axis(
stacked,
data_hash.get(key),
np.index_exp[slice(i, None, len(keys))],
axis,
)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
def _handle_key_stacking_entry(
self, source, data_hash, key_re, new_key, merge_method, axis
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
keys = [key for key in data_hash.paths() if key_re.match(key)]
try:
# note: maybe we could reuse buffers here, too?
if merge_method is MergeMethod.STACK:
stacked = np.stack([data_hash.get(key) for key in keys], axis=axis)
else:
for key in keys:
data_hash.erase(key)
data_hash[new_key] = stacked
first = data_hash.get(keys[0])
stacked = np.empty(
shape=utils.stacking_buffer_shape(
first.shape, len(keys), axis=axis
),
dtype=first.dtype,
)
for i, key in enumerate(keys):
utils.set_on_axis(
stacked,
data_hash.get(key),
np.index_exp[slice(i, None, len(keys))],
axis,
)
except Exception as e:
self.log.WARN(f"Failed to stack {key_re} for {source}: {e}")
else:
for key in keys:
data_hash.erase(key)
data_hash[new_key] = stacked
Loading