diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py index e02e3fd0b3bbf9a46172dc14735fd33e8c630eac..d8a7c0964a17f724cfe7e757e6d5adc495b962bb 100644 --- a/src/calng/shmem_utils.py +++ b/src/calng/shmem_utils.py @@ -1,6 +1,8 @@ import numpy as np import posixshmem +from . import utils + def parse_shmem_handle(handle_string): buffer_name, dtype, shape, index = handle_string.split("$") @@ -78,12 +80,13 @@ class ShmemCircularBuffer: ) self._buffer_ary = None self._update_shape(array_shape, dtype) + self._next_slot_index = 0 self._cuda_pinned = False def _update_shape(self, array_shape, dtype): array_shape = tuple(array_shape) - array_bytes = np.dtype(dtype).itemsize * np.product(array_shape) - num_slots = self._shared_memory.size // array_bytes + self._array_bytes = np.dtype(dtype).itemsize * np.product(array_shape) + num_slots = self._shared_memory.size // self._array_bytes if num_slots == 0: raise ValueError("Array size exceeds size of allocated memory block") full_shape = (num_slots,) + array_shape @@ -98,16 +101,20 @@ class ShmemCircularBuffer: self.shmem_handle_template = ( f"{self.shmem_name}${np.dtype(dtype)}${shape_str}${{index}}" ) - self._next_slot_index = 0 def change_shape(self, array_shape, dtype=None): """Set new array shape to buffer. Note that the existing SharedMemory object is still used. Old data in there will be mangled and number of slots will depend upon new array shape and original memory budget. """ + old_array_bytes = self._array_bytes if dtype is None: dtype = self._buffer_ary.dtype self._update_shape(array_shape, dtype) + # continue from "next" (least recently touched) slot aligned to new array size + self._next_slot_index = ( + utils.ceil_div(old_array_bytes * self._next_slot_index, self._array_bytes) + ) % self.num_slots def cuda_pin(self): import cupy diff --git a/src/tests/test_shmem_utils.py b/src/tests/test_shmem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da077aff07a1effae0a62124a8d216d7051067f6 --- /dev/null +++ b/src/tests/test_shmem_utils.py @@ -0,0 +1,27 @@ +import numpy as np + +from calng import shmem_utils + + +def test_change_shape(): + my_buffer = shmem_utils.ShmemCircularBuffer( + 1024 * 4, + (2, 3), + np.uint32, + "test_shmem_buffer", + ) + handles = [] + for i in range(3): + handle, ary = my_buffer.next_slot() + ary.fill(i) + handles.append(handle) + my_buffer.change_shape((5, 7)) + for i in range(3, 5): + handle, ary = my_buffer.next_slot() + ary.fill(i) + handles.append(handle) + receiver = shmem_utils.ShmemCircularBufferReceiver() + # old handles don't get immediately mangled after change_shape + for i, handle in enumerate(handles): + ary = receiver.get(handle) + assert np.all(ary == i)