Skip to content
Snippets Groups Projects
Commit 62f3038e authored by David Hammer's avatar David Hammer
Browse files

Merge branch 'use-multiprocessing-shared-memory' into 'master'

Use multiprocessing.shared_memory instead of posixshmem

See merge request !83
parents 487bf37c 77f74e79
No related branches found
No related tags found
1 merge request!83Use multiprocessing.shared_memory instead of posixshmem
......@@ -17,7 +17,4 @@ h5py:
extra-geom:
$(PYPI) extra_geom==1.11.0
posixshmem:
$(PROXIED) "git+https://github.com/European-XFEL/posixshmem"
calng: cupy jinja2 h5py extra-geom posixshmem
calng: cupy jinja2 h5py extra-geom
import multiprocessing.shared_memory
import numpy as np
import posixshmem
from . import utils
......@@ -14,37 +15,56 @@ def parse_shmem_handle(handle_string):
def open_shmem_from_handle(handle_string):
"""Conveniently open readonly SharedMemory with ndarray view from a handle."""
buffer_name, dtype, shape, _ = parse_shmem_handle(handle_string)
buffer_mem = posixshmem.SharedMemory(name=buffer_name, rw=False)
array = buffer_mem.ndarray(
shm_name, dtype, shape, _ = parse_shmem_handle(handle_string)
shm_mem = multiprocessing.shared_memory.SharedMemory(
name=buffer_name, create=False
)
array = np.ndarray(
shape=shape,
dtype=dtype,
buffer=shm_mem.buf,
)
return buffer_mem, array
class ShmemCircularBufferReceiver:
"""The receiving end of ShmemCircularBuffer. Will receive shmem handles and open
the corresponding buffers automatically when needed in `get`. For convenience,
includes `dereference_shmem_handles` for hashes."""
def __init__(self):
self._name_to_mem = {}
self._name_to_ary = {}
def __del__(self):
for mem in self._name_to_mem.values():
mem.close()
def get(self, handle_string):
name, dtype, shape, index = parse_shmem_handle(handle_string)
if name not in self._name_to_mem:
mem = posixshmem.SharedMemory(name=name, rw=False)
self._name_to_mem[name] = mem
ary = mem.ndarray(shape=shape, dtype=dtype)
self._name_to_ary[name] = ary
shm_name, dtype, shape, index = parse_shmem_handle(handle_string)
if shm_name not in self._name_to_mem:
mem = multiprocessing.shared_memory.SharedMemory(
name=shm_name, create=False
)
self._name_to_mem[shm_name] = mem
ary = np.ndarray(
shape=shape,
dtype=dtype,
buffer=mem.buf,
)
self._name_to_ary[shm_name] = ary
return ary[index]
ary = self._name_to_ary[name]
ary = self._name_to_ary[shm_name]
if ary.shape != shape or ary.dtype != dtype:
del ary
mem = self._name_to_mem[name]
ary = mem.ndarray(shape=shape, dtype=dtype)
self._name_to_ary[name] = ary
mem = self._name_to_mem[shm_name]
ary = np.ndarray(
shape=shape,
dtype=dtype,
buffer=mem.buf,
)
self._name_to_ary[shm_name] = ary
return ary[index]
......@@ -62,7 +82,7 @@ class ShmemCircularBufferReceiver:
class ShmemCircularBuffer:
"""Convenience wrapper around posixshmem-backed ndarray buffers
"""Convenience wrapper around shmem-backed ndarray buffers
The underlying memory will be opened as an ndarray with shape (buffer_size, ) +
array_shape where buffer_size is memory_budget // dtype * array size. Each call
......@@ -73,15 +93,33 @@ class ShmemCircularBuffer:
def __init__(self, memory_budget, array_shape, dtype, shmem_name):
# for portable use: name has leading slash and no other slashes
self.shmem_name = "/" + shmem_name.lstrip("/").replace("/", "_")
self._shared_memory = posixshmem.SharedMemory(
name=self.shmem_name,
size=memory_budget,
rw=True,
)
self._cuda_pinned = False
self._shared_memory = None
try:
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
name=self.shmem_name,
size=memory_budget,
create=True,
)
except FileExistsError:
# maybe device was restarted uncleanly and there's a lingering shmem "file"
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
name=self.shmem_name,
create=False,
)
# but may need to recreate if existing one is not suitable
if self._shared_memory.size != memory_budget:
self._shared_memory.close()
self._shared_memory.unlink()
# if it fails again, we're in real trouble, so not catching this
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
name=self.shmem_name,
size=memory_budget,
create=True,
)
self._buffer_ary = None
self._update_shape(array_shape, dtype)
self._next_slot_index = 0
self._cuda_pinned = False
def _update_shape(self, array_shape, dtype):
array_shape = tuple(array_shape)
......@@ -93,9 +131,10 @@ class ShmemCircularBuffer:
if self._buffer_ary is not None:
del self._buffer_ary
self._buffer_ary = self._shared_memory.ndarray(
self._buffer_ary = np.ndarray(
shape=full_shape,
dtype=dtype,
buffer=self._shared_memory.buf,
)
shape_str = ",".join(str(n) for n in full_shape)
self.shmem_handle_template = (
......@@ -125,11 +164,16 @@ class ShmemCircularBuffer:
)
def __del__(self):
if self._shared_memory is None:
return
if self._cuda_pinned:
import cupy
cupy.cuda.runtime.hostUnregister(self._memory_pointer)
del self._buffer_ary
self._shared_memory.close()
self._shared_memory.unlink()
del self._shared_memory
@property
......
import multiprocessing
import pathlib
import time
import numpy as np
import pytest
from karabo.bound import Hash
from calng import shmem_utils
def test_change_shape():
shm_fn = "test_shmem_buffer"
my_buffer = shmem_utils.ShmemCircularBuffer(
1024 * 4,
(2, 3),
np.uint32,
"test_shmem_buffer",
shm_fn,
)
handles = []
for i in range(3):
......@@ -25,3 +32,106 @@ def test_change_shape():
for i, handle in enumerate(handles):
ary = receiver.get(handle)
assert np.all(ary == i)
del my_buffer
assert not (pathlib.Path("/dev/shm") / shm_fn).exists()
def test_multiprocessing():
# note: test doesn't use hashes because they can't easily be pickled :(
# note: will ignore resource_tracker warning for now; might be CPython3.8 bug
# (we check that "file" is at least gone)
shm_fn = "test_multiproc_shmem_buffer"
num_messages = 10
def sender(handle_q, sent_q, barrier):
# handle_q: queue replacing channel communication
# sent_q: queue with original data for comparison
shm_buffer = shmem_utils.ShmemCircularBuffer(
memory_budget=1_000_000,
array_shape=(1, 10, 20),
dtype=np.float64,
shmem_name=shm_fn,
)
for i in range(num_messages):
# switches "number of frames" for every train
some_array = np.random.random(size=(i + 1, 10, 20))
# could consider adding a convenience "put" function to buffer
shm_buffer.change_shape(
array_shape=some_array.shape, dtype=some_array.dtype
)
handle, shm_array = shm_buffer.next_slot()
shm_array[:] = some_array
handle_q.put(handle)
sent_q.put(some_array)
handle_q.close()
sent_q.close()
# wait before dying so shmem buffer is not removed
barrier.wait(timeout=5)
def receiver(handle_q, received_q, barrier):
shm_recv = shmem_utils.ShmemCircularBufferReceiver()
for i in range(num_messages):
handle = handle_q.get(timeout=5)
data = shm_recv.get(handle)
received_q.put(data)
received_q.close()
barrier.wait(timeout=5)
barrier = multiprocessing.Barrier(3)
sent_q = multiprocessing.Queue()
received_q = multiprocessing.Queue()
handle_q = multiprocessing.Queue()
send_proc = multiprocessing.Process(target=sender, args=(handle_q, sent_q, barrier))
recv_proc = multiprocessing.Process(
target=receiver, args=(handle_q, received_q, barrier)
)
send_proc.start()
recv_proc.start()
for i in range(num_messages):
expected = sent_q.get(timeout=5)
got = received_q.get(timeout=5)
assert np.array_equal(expected, got)
assert (pathlib.Path("/dev/shm") / shm_fn).exists()
barrier.wait(timeout=5)
send_proc.join()
recv_proc.join()
assert not (pathlib.Path("/dev/shm") / shm_fn).exists()
def test_lingering_file():
# will kill a process ungracefully to mess up cleanup
shm_fn = "test_lingering_shmem_buffer"
def tragic_function(barrier):
shm_buffer = shmem_utils.ShmemCircularBuffer(
memory_budget=1_000_000,
array_shape=(1, 10, 20),
dtype=np.float64,
shmem_name=shm_fn,
)
barrier.wait(timeout=5)
# will not survive this sleep
time.sleep(10)
barrier = multiprocessing.Barrier(2)
poor_proc = multiprocessing.Process(target=tragic_function, args=(barrier,))
poor_proc.start()
barrier.wait(timeout=5)
assert (pathlib.Path("/dev/shm") / shm_fn).exists()
poor_proc.kill()
assert (pathlib.Path("/dev/shm") / shm_fn).exists()
# so now there's a lingering file; that's problem
# but we can just make a new one
shm_buffer = shmem_utils.ShmemCircularBuffer(
memory_budget=2_000_000,
array_shape=(1, 10, 20),
dtype=np.float64,
shmem_name=shm_fn,
)
assert shm_buffer._shared_memory.size == 2_000_000
# note: would be nice to prevent multiple writers on same memory...
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment