From 10fbab6fe6e45b03568a02864739b4301b038fc4 Mon Sep 17 00:00:00 2001
From: David Hammer <david.hammer@xfel.eu>
Date: Thu, 16 Nov 2023 11:35:18 +0100
Subject: [PATCH] Avoid mangling shmem data when changing shape

---
 src/calng/shmem_utils.py      | 13 ++++++++++---
 src/tests/test_shmem_utils.py | 27 +++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 3 deletions(-)
 create mode 100644 src/tests/test_shmem_utils.py

diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py
index e02e3fd0..d8a7c096 100644
--- a/src/calng/shmem_utils.py
+++ b/src/calng/shmem_utils.py
@@ -1,6 +1,8 @@
 import numpy as np
 import posixshmem
 
+from . import utils
+
 
 def parse_shmem_handle(handle_string):
     buffer_name, dtype, shape, index = handle_string.split("$")
@@ -78,12 +80,13 @@ class ShmemCircularBuffer:
         )
         self._buffer_ary = None
         self._update_shape(array_shape, dtype)
+        self._next_slot_index = 0
         self._cuda_pinned = False
 
     def _update_shape(self, array_shape, dtype):
         array_shape = tuple(array_shape)
-        array_bytes = np.dtype(dtype).itemsize * np.product(array_shape)
-        num_slots = self._shared_memory.size // array_bytes
+        self._array_bytes = np.dtype(dtype).itemsize * np.product(array_shape)
+        num_slots = self._shared_memory.size // self._array_bytes
         if num_slots == 0:
             raise ValueError("Array size exceeds size of allocated memory block")
         full_shape = (num_slots,) + array_shape
@@ -98,16 +101,20 @@ class ShmemCircularBuffer:
         self.shmem_handle_template = (
             f"{self.shmem_name}${np.dtype(dtype)}${shape_str}${{index}}"
         )
-        self._next_slot_index = 0
 
     def change_shape(self, array_shape, dtype=None):
         """Set new array shape to buffer. Note that the existing SharedMemory object is
         still used. Old data in there will be mangled and number of slots will depend
         upon new array shape and original memory budget.
         """
+        old_array_bytes = self._array_bytes
         if dtype is None:
             dtype = self._buffer_ary.dtype
         self._update_shape(array_shape, dtype)
+        # continue from "next" (least recently touched) slot aligned to new array size
+        self._next_slot_index = (
+            utils.ceil_div(old_array_bytes * self._next_slot_index, self._array_bytes)
+        ) % self.num_slots
 
     def cuda_pin(self):
         import cupy
diff --git a/src/tests/test_shmem_utils.py b/src/tests/test_shmem_utils.py
new file mode 100644
index 00000000..da077aff
--- /dev/null
+++ b/src/tests/test_shmem_utils.py
@@ -0,0 +1,27 @@
+import numpy as np
+
+from calng import shmem_utils
+
+
+def test_change_shape():
+    my_buffer = shmem_utils.ShmemCircularBuffer(
+        1024 * 4,
+        (2, 3),
+        np.uint32,
+        "test_shmem_buffer",
+    )
+    handles = []
+    for i in range(3):
+        handle, ary = my_buffer.next_slot()
+        ary.fill(i)
+        handles.append(handle)
+    my_buffer.change_shape((5, 7))
+    for i in range(3, 5):
+        handle, ary = my_buffer.next_slot()
+        ary.fill(i)
+        handles.append(handle)
+    receiver = shmem_utils.ShmemCircularBufferReceiver()
+    # old handles don't get immediately mangled after change_shape
+    for i, handle in enumerate(handles):
+        ary = receiver.get(handle)
+        assert np.all(ary == i)
-- 
GitLab