Skip to content
Snippets Groups Projects

Refactor stacking for reuse and overlappability

Merged David Hammer requested to merge refactor-stacking into master
+ 159
0
import concurrent.futures
from karabo.bound import Hash
import numpy as np
import pytest
from calng import stacking_utils
class NotALog:
def __init__(self, parent):
self.parent = parent
def WARN(self, s):
print(f"Warning: {s}")
self.parent.warnings.append(s)
class NotADevice:
def __init__(self):
self.log = NotALog(self)
self.warnings = []
datas = [np.arange(i * 100, i * 100 + 8).reshape(4, 2) for i in range(3)]
class CommonTestFixtureGuy:
@pytest.fixture(params=[False, True])
def thread_pool(self, request):
if request.param:
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
yield pool
else:
yield None
class TestSourceStacking(CommonTestFixtureGuy):
@pytest.fixture
def friend(self):
source_table = [
{
"select": True,
"source": f"source{i}@device{i}:channel",
}
for i in range(3)
]
merge_rules = [
{
"select": True,
"sourcePattern": "source\\d+",
"keyPattern": "keyToStack",
"replacement": "newSource",
"groupType": "sources",
"mergeMethod": "stack",
"axis": 1,
"missingValue": "0",
}
]
device = NotADevice()
return stacking_utils.StackingFriend(device, source_table, merge_rules)
@pytest.fixture
def sources(self):
return {
f"source{i}": (Hash("keyToStack", data), None)
for i, data in enumerate(datas)
}
def test_simple(self, friend, sources, thread_pool):
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert stacked.shape == (4, 3, 2)
for i, data in enumerate(datas):
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_missing_source(self, friend, sources, thread_pool):
del sources["source0"]
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert np.all(stacked[:, 0] == 0)
for i, data in enumerate(datas[1:], start=1):
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_missing_data(self, friend, sources, thread_pool):
sources["source0"][0].erase("keyToStack")
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert np.all(stacked[:, 0] == 0)
for i, data in enumerate(datas[1:], start=1):
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_source_stacking_no_sources(self, friend, sources, thread_pool):
sources = {}
friend.process(sources, thread_pool)
assert friend._device.warnings
assert not sources
def test_source_stacking_erroneous_data(self, friend, sources, thread_pool):
sources["source1"][0][
"keyToStack"
] = "and now for something completely different"
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
for i, data in enumerate(datas):
if i == 1:
assert np.all(stacked[:, i] == 0)
else:
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert friend._device.warnings
class TestKeyStacking(CommonTestFixtureGuy):
@pytest.fixture
def friend(self):
source_table = [
{
"select": True,
"source": "source@device:channel",
}
]
merge_rules = [
{
"select": True,
"sourcePattern": "source",
"keyPattern": "key\\d+",
"replacement": "newKey",
"groupType": "keys",
"mergeMethod": "stack",
"axis": 1,
"missingValue": "0",
}
]
device = NotADevice()
return stacking_utils.StackingFriend(device, source_table, merge_rules)
@pytest.fixture
def sources(self):
h = Hash()
for i, data in enumerate(datas):
h[f"key{i}"] = data
return {"source": (h, None)}
def test_simple(self, friend, sources, thread_pool):
friend.process(sources, thread_pool)
assert sources["source"][0].has("newKey")
stacked = sources["source"][0]["newKey"]
for i, data in enumerate(datas):
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
Loading