From 251c5d5a76c933458655ffd94928abd5b7cf0456 Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Wed, 7 Dec 2022 15:38:10 +0100
Subject: [PATCH] Use indexing to access cells row by row and get rid of
 reshaping

---
 src/cal_tools/agipdalgs.pyx | 84 ++++++++++++++++++++-----------------
 src/cal_tools/agipdlib.py   | 30 ++++++-------
 2 files changed, 58 insertions(+), 56 deletions(-)

diff --git a/src/cal_tools/agipdalgs.pyx b/src/cal_tools/agipdalgs.pyx
index 649511df6..53766a906 100644
--- a/src/cal_tools/agipdalgs.pyx
+++ b/src/cal_tools/agipdalgs.pyx
@@ -147,63 +147,69 @@ def gain_choose(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[choices_t, ndim=
 
 @boundscheck(False)
 @wraparound(False)
-def sum_and_count_in_range_asic(cnp.ndarray[float, ndim=4] arr, float lower, float upper):
+def sum_and_count_in_range_asic(float[:, :, :] arr, long[:] ix,
+                                float lower, float upper):
     """
     Return the sum & count of values where lower <= x <= upper,
     across axes 2 & 3 (pixels within an ASIC, as reshaped by AGIPD correction code).
     Specialised function for performance.
     """
-
     cdef int i, j, k, l, m
     cdef float value
-    cdef cnp.ndarray[unsigned long long, ndim=2] count
-    cdef cnp.ndarray[double, ndim=2] sum_
-
+    
+    cdef int nfrm = ix.size
+    cdef int nx = arr.shape[1]
+    cdef int ny = arr.shape[2]
+    
     # Drop axes -2 & -1 (pixel dimensions within each ASIC)
-    out_shape = arr[:, :, 0, 0].shape
-    count = np.zeros(out_shape, dtype=np.uint64)
-    sum_ = np.zeros(out_shape, dtype=np.float64)
+    count_arr = np.zeros(nfrm, dtype=np.uint64)
+    sum_arr = np.zeros(nfrm, dtype=np.float64)
 
+    cdef unsigned long long[:] count = count_arr
+    cdef double[:] sum_ = sum_arr
     with nogil:
-        for i in range(arr.shape[0]):
-            for k in range(arr.shape[1]):
-                for l in range(arr.shape[2]):
-                    for m in range(arr.shape[3]):
-                        value = arr[i, k, l, m]
-                        if lower <= value <= upper:
-                            sum_[i, k] += value
-                            count[i, k] += 1
-    return sum_, count
+        for i in range(nfrm):
+            k = ix[i]
+            for l in range(nx):
+                for m in range(ny):
+                    value = arr[k, l, m]
+                    if lower <= value <= upper:
+                        sum_[i] += value
+                        count[i] += 1
+
+    return sum_arr, count_arr
 
 
 @boundscheck(False)
 @wraparound(False)
-def sum_and_count_in_range_cell(cnp.ndarray[float, ndim=4] arr, float lower, float upper):
+def sum_and_count_in_range_cell(float[:, :, :] arr, long[:] ix,
+                                float lower, float upper):
     """
     Return the sum & count of values where lower <= x <= upper,
     across axes 0 & 1 (memory cells in the same row, as reshaped by AGIPD correction code).
     Specialised function for performance.
     """
-
-    cdef int i, j, k, l, m,
+    cdef int i, j, k, l, m
     cdef float value
-    cdef cnp.ndarray[unsigned long long, ndim=2] count
-    cdef cnp.ndarray[double, ndim=2] sum_
-
-    # Drop axes 0 & 1
-    out_shape = arr[0, 0, :, :].shape
-    count = np.zeros(out_shape, dtype=np.uint64)
-    sum_ = np.zeros(out_shape, dtype=np.float64)
-
-
+    
+    cdef int nfrm = ix.size
+    cdef int nx = arr.shape[1]
+    cdef int ny = arr.shape[2]
+    
+    # Drop axes 0
+    count_arr = np.zeros([nx, ny], dtype=np.uint64)
+    sum_arr = np.zeros([nx, ny], dtype=np.float64)
+    
+    cdef unsigned long long[:, :] count = count_arr
+    cdef double[:, :] sum_ = sum_arr
     with nogil:
-        for i in range(arr.shape[0]):
-            for k in range(arr.shape[1]):
-                for l in range(arr.shape[2]):
-                    for m in range(arr.shape[3]):
-                        value = arr[i, k, l, m]
-                        if lower <= value <= upper:
-                            sum_[l, m] += value
-                            count[l, m] += 1
-
-    return sum_, count
+        for i in range(nfrm):
+            k = ix[i]
+            for l in range(nx):
+                for m in range(ny):
+                    value = arr[k, l, m]
+                    if lower <= value <= upper:
+                        sum_[l, m] += value
+                        count[l, m] += 1
+
+    return sum_arr, count_arr
diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index b0d9d2845..53385959d 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -647,39 +647,35 @@ class AgipdCorrections:
         if n_img == 0:
             return
         cell_id = self.shared_dict[i_proc]['cellId'][:n_img]
+        row_id = cell_id // 32
+        data = self.shared_dict[i_proc]['data'][:n_img]
+        data = data.reshape(-1, 8, 64, 2, 64)
 
-        data_arr = self.shared_dict[i_proc]['data'][:n_img]
-        data = data_arr.reshape(-1, 32, 8, 64, 2, 64)
-        row_id = cell_id[::32] // 32
-
-        xasic, yasic = asic % 8, asic // 8
+        asic_data = data[:, asic % 8, :, asic // 8, :]
+        # Loop over rows of cells
         for cell_row in range(11):
             irow = np.flatnonzero(row_id == cell_row)
             if not irow.size:
                 continue
-            asic_data = data[irow, :, xasic, :, yasic, :]
             # Loop over iterations
             for _ in range(n_itr):
                 # Cell common mode
                 cell_cm_sum, cell_cm_count = \
-                    calgs.sum_and_count_in_range_cell(asic_data, dark_min,
-                                                      dark_max)
+                    calgs.sum_and_count_in_range_cell(
+                        asic_data, irow, dark_min, dark_max)
                 cell_cm = cell_cm_sum / cell_cm_count
 
-                # TODO: check files with less 256 trains
-                cell_cm[cell_cm_count < fraction * 32 * asic_data.shape[0]] = 0
-                asic_data -= cell_cm[None, None, :, :]
+                cell_cm[cell_cm_count < fraction * irow.size] = 0
+                asic_data[irow] -= cell_cm[None, :, :]
 
                 # Asics common mode
                 asic_cm_sum, asic_cm_count = \
-                    calgs.sum_and_count_in_range_asic(asic_data, dark_min,
-                                                      dark_max)
+                    calgs.sum_and_count_in_range_asic(
+                        asic_data, irow, dark_min, dark_max)
                 asic_cm = asic_cm_sum / asic_cm_count
 
-                asic_cm[asic_cm_count < fraction * 64 * 64] = 0
-                asic_data -= asic_cm[:, :, None, None]
-
-            data[irow, :, xasic, :, yasic, :] = asic_data
+                asic_cm[asic_cm_count < fraction * 4096] = 0
+                asic_data[irow] -= asic_cm[:, None, None]
 
     def mask_zero_std(self, i_proc, cells):
         """
-- 
GitLab