From 10fbab6fe6e45b03568a02864739b4301b038fc4 Mon Sep 17 00:00:00 2001 From: David Hammer <david.hammer@xfel.eu> Date: Thu, 16 Nov 2023 11:35:18 +0100 Subject: [PATCH] Avoid mangling shmem data when changing shape --- src/calng/shmem_utils.py | 13 ++++++++++--- src/tests/test_shmem_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 src/tests/test_shmem_utils.py diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py index e02e3fd0..d8a7c096 100644 --- a/src/calng/shmem_utils.py +++ b/src/calng/shmem_utils.py @@ -1,6 +1,8 @@ import numpy as np import posixshmem +from . import utils + def parse_shmem_handle(handle_string): buffer_name, dtype, shape, index = handle_string.split("$") @@ -78,12 +80,13 @@ class ShmemCircularBuffer: ) self._buffer_ary = None self._update_shape(array_shape, dtype) + self._next_slot_index = 0 self._cuda_pinned = False def _update_shape(self, array_shape, dtype): array_shape = tuple(array_shape) - array_bytes = np.dtype(dtype).itemsize * np.product(array_shape) - num_slots = self._shared_memory.size // array_bytes + self._array_bytes = np.dtype(dtype).itemsize * np.product(array_shape) + num_slots = self._shared_memory.size // self._array_bytes if num_slots == 0: raise ValueError("Array size exceeds size of allocated memory block") full_shape = (num_slots,) + array_shape @@ -98,16 +101,20 @@ class ShmemCircularBuffer: self.shmem_handle_template = ( f"{self.shmem_name}${np.dtype(dtype)}${shape_str}${{index}}" ) - self._next_slot_index = 0 def change_shape(self, array_shape, dtype=None): """Set new array shape to buffer. Note that the existing SharedMemory object is still used. Old data in there will be mangled and number of slots will depend upon new array shape and original memory budget. """ + old_array_bytes = self._array_bytes if dtype is None: dtype = self._buffer_ary.dtype self._update_shape(array_shape, dtype) + # continue from "next" (least recently touched) slot aligned to new array size + self._next_slot_index = ( + utils.ceil_div(old_array_bytes * self._next_slot_index, self._array_bytes) + ) % self.num_slots def cuda_pin(self): import cupy diff --git a/src/tests/test_shmem_utils.py b/src/tests/test_shmem_utils.py new file mode 100644 index 00000000..da077aff --- /dev/null +++ b/src/tests/test_shmem_utils.py @@ -0,0 +1,27 @@ +import numpy as np + +from calng import shmem_utils + + +def test_change_shape(): + my_buffer = shmem_utils.ShmemCircularBuffer( + 1024 * 4, + (2, 3), + np.uint32, + "test_shmem_buffer", + ) + handles = [] + for i in range(3): + handle, ary = my_buffer.next_slot() + ary.fill(i) + handles.append(handle) + my_buffer.change_shape((5, 7)) + for i in range(3, 5): + handle, ary = my_buffer.next_slot() + ary.fill(i) + handles.append(handle) + receiver = shmem_utils.ShmemCircularBufferReceiver() + # old handles don't get immediately mangled after change_shape + for i, handle in enumerate(handles): + ary = receiver.get(handle) + assert np.all(ary == i) -- GitLab