Skip to content
Snippets Groups Projects

Refactor stacking for reuse and overlappability

Merged David Hammer requested to merge refactor-stacking into master
+ 26
14
import concurrent.futures
from karabo.bound import Hash
import numpy as np
import pytest
@@ -23,7 +25,17 @@ class NotADevice:
datas = [np.arange(i * 100, i * 100 + 8).reshape(4, 2) for i in range(3)]
class TestSourceStacking:
class CommonTestFixtureGuy:
@pytest.fixture(params=[False, True])
def thread_pool(self, request):
if request.param:
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
yield pool
else:
yield None
class TestSourceStacking(CommonTestFixtureGuy):
@pytest.fixture
def friend(self):
source_table = [
@@ -56,8 +68,8 @@ class TestSourceStacking:
for i, data in enumerate(datas)
}
def test_simple(self, friend, sources):
friend.process(sources)
def test_simple(self, friend, sources, thread_pool):
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert stacked.shape == (4, 3, 2)
@@ -65,9 +77,9 @@ class TestSourceStacking:
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_missing_source(self, friend, sources):
def test_missing_source(self, friend, sources, thread_pool):
del sources["source0"]
friend.process(sources)
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert np.all(stacked[:, 0] == 0)
@@ -75,9 +87,9 @@ class TestSourceStacking:
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_missing_data(self, friend, sources):
def test_missing_data(self, friend, sources, thread_pool):
sources["source0"][0].erase("keyToStack")
friend.process(sources)
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
assert np.all(stacked[:, 0] == 0)
@@ -85,17 +97,17 @@ class TestSourceStacking:
assert np.array_equal(stacked[:, i], data, equal_nan=True)
assert not friend._device.warnings
def test_source_stacking_no_sources(self, friend, sources):
def test_source_stacking_no_sources(self, friend, sources, thread_pool):
sources = {}
friend.process(sources)
friend.process(sources, thread_pool)
assert friend._device.warnings
assert not sources
def test_source_stacking_erroneous_data(self, friend, sources):
def test_source_stacking_erroneous_data(self, friend, sources, thread_pool):
sources["source1"][0][
"keyToStack"
] = "and now for something completely different"
friend.process(sources)
friend.process(sources, thread_pool)
assert "newSource" in sources
stacked = sources["newSource"][0]["keyToStack"]
for i, data in enumerate(datas):
@@ -106,7 +118,7 @@ class TestSourceStacking:
assert friend._device.warnings
class TestKeyStacking:
class TestKeyStacking(CommonTestFixtureGuy):
@pytest.fixture
def friend(self):
source_table = [
@@ -138,8 +150,8 @@ class TestKeyStacking:
h[f"key{i}"] = data
return {"source": (h, None)}
def test_simple(self, friend, sources):
friend.process(sources)
def test_simple(self, friend, sources, thread_pool):
friend.process(sources, thread_pool)
assert sources["source"][0].has("newKey")
stacked = sources["source"][0]["newKey"]
for i, data in enumerate(datas):
Loading