From e8cb2bf978123d529b6fd84cf18f1006e2dbfc8a Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Sun, 17 Sep 2023 15:06:08 +0200
Subject: [PATCH] Add computation of autocorrelation in different boundary
 conditions

---
 .../correction_addons/autocorrelation.py      | 119 +++++++++++++-----
 1 file changed, 89 insertions(+), 30 deletions(-)

diff --git a/src/calng/correction_addons/autocorrelation.py b/src/calng/correction_addons/autocorrelation.py
index b32f11de..9cfbf7be 100644
--- a/src/calng/correction_addons/autocorrelation.py
+++ b/src/calng/correction_addons/autocorrelation.py
@@ -1,21 +1,91 @@
 import numpy as np
 
-from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT
+from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT, STRING_ELEMENT
 
 from .base_addon import BaseCorrectionAddon
 from .. import shmem_utils
 
 
+def block(b):
+    return np.concatenate([np.concatenate(row, axis=-1) for row in b], axis=-2)
+
+
+def autocorr2_fft(f, mode='full', backend='scipy'):
+    """Computes 2D autocorrelation function for real input.
+
+        This is equivalent to
+        fftconvolve(f, np.flip(f, axis=(-2, -1)), mode)
+    """
+    nx, ny = f.shape[-2:]
+
+    from scipy.fft import next_fast_len
+    fshape = [next_fast_len(sz, True) for sz in (2 * nx - 1, 2 * ny - 1)]
+    r = fft.rfft2(f, fshape)
+    r = fft.irfft2((r * r.conj()).real, fshape)
+
+    if mode == "same":
+        ux = (nx + 1) // 2
+        lx = (1 - nx) // 2
+        uy = (ny + 1) // 2
+        ly = (1 - ny) // 2
+    elif mode == "full":
+        ux = nx
+        lx = (1 - nx)
+        uy = ny
+        ly = (1 - ny)
+    else:
+        raise ValueError("acceptable mode flags are 'same', or 'full'")
+
+    a11 = r[..., :ux, :uy]
+    a12 = r[..., :ux, ly:]
+    a21 = r[..., lx:, :uy]
+    a22 = r[..., lx:, ly:]
+
+    return block([[a22, a21], [a12, a11]])
+
+
+def autocorr2symm_fft(f, backend='scipy'):
+    """Computes 2D autocorrelation function for real input and
+       symmetrical boundary conditions.
+    """
+    r = fft.dctn(f, 3, axes=(-2, -1))
+    r = fft.idctn(r * r, 3, axes=(-2, -1))
+    return r
+
+
+def autocorr2wrap_fft(f, backend='scipy'):
+    """Computes 2D autocorrelation function for real input and
+       circular boundary conditions.
+    """
+    s = f.shape[-2:]
+    r = fft.rfft2(f, s)
+    r = fft.irfft2((r * r.conj()).real, s)
+    r = fft.fftshift(r)
+    return r
+
+
+AUTOCORR_FUN = {
+    "symm": autocorr2symm_fft,
+    "wrap": autocorr2wrap_fft,
+    "fill": lambda f: autocorr2_fft(f, "same"),
+}
+
+
 class Autocorrelation(BaseCorrectionAddon):
     _device = None  # will be set to host device *after* init
 
     @staticmethod
     def extend_device_schema(schema, prefix):
-        """Will be given the device schema where everything should be put under a
-        prefix. This prefix should be something like 'addons.nodeName', is given to this
-        class by the base correction device, and the root node of prefix will already be
-        in place (as will an 'enable' flag under it)."""
         (
+            STRING_ELEMENT(schema)
+            .key(f"{prefix}.boundary")
+            .tags("managed")
+            .options(",".join(AUTOCORR_FUN.keys()))
+            .assignmentOptional()
+            .defaultValue("symm")
+            .reconfigurable()
+            .commit(),
+
             DOUBLE_ELEMENT(schema)
             .key(f"{prefix}.intenstityPerPhoton")
             .tags("managed")
@@ -27,8 +97,6 @@ class Autocorrelation(BaseCorrectionAddon):
 
     @staticmethod
     def extend_output_schema(schema):
-        """Will be given the regular output schema, can create arbitrary nodes in there
-        and add properties to it."""
         (
             NDARRAY_ELEMENT(schema)
             .key("image.autocorr")
@@ -37,11 +105,11 @@ class Autocorrelation(BaseCorrectionAddon):
         )
 
     def __init__(self, config):
-        """Will be given the node from extend_device_schema, no prefix needed here"""
         global cupy
         import cupy
 
         self._intensity_per_photon = config["intenstityPerPhoton"]
+        self._autocorr = AUTOCORR_FUN[config["boundary"]]
 
         self._shmem_buffer = None
 
@@ -50,11 +118,15 @@ class Autocorrelation(BaseCorrectionAddon):
         super().__del__()
 
     def _initialization(self):
+        global fft
+        if self._device.kernel_runner._gpu_based:
+            import cupyx.scipy.fft as fft
+        else:
+            import scipy.fft as fft
+
         self._update_buffer()
 
     def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
-        """Called directly after correction has happened. Processed data will still be
-        on GPU if the correction device is generally running in GPU mode."""
         nf, nx, ny = processed_data.shape
         if nf != self._nframe:
             self._update_buffer()
@@ -62,23 +134,13 @@ class Autocorrelation(BaseCorrectionAddon):
         data = np.around(processed_data / self._intensity_per_photon)
         data[np.isnan(data) | (data < 0.0)] = 0.0
 
-        nf, nx, ny = data.shape
-        g = np.fft.rfft2(data, s=(2 * nx, 2 * ny))
-        autocorr = np.fft.irfft2((g * g.conj()).real, s=(nx, ny))
-        # !!! with fftconvolve
-        # from cupyx.scipy.signal import fftconvolve
-        # autocorr = fftconvolve(
-        #     data, data[:, ::-1, ::-1], 'same', axes=(-2, -1))
+        autocorr = self._autocorr(data)
 
         buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        autocorr.get(out=buffer_array[:])
-
-        # !!! alternatively send the squared Fourier transform
-        # !!! shape (nf, 2 * nx, ny + 1)
-        #
-        # gg = (g * g.conj()).real
-        # buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        # gg.get(out=buffer_array[:])
+        if hasattr(autocorr, 'get'):
+            autocorr.get(out=buffer_array[:])
+        else:
+            buffer_array[:] = autocorr
 
         if self._device._use_shmem_handles:
             output_hash.set("image.autocorr", buffer_handle)
@@ -96,12 +158,10 @@ class Autocorrelation(BaseCorrectionAddon):
     #     pass
 
     def reconfigure(self, changed_config):
-        """Will be given the node from extend_device_schema, no prefix needed here. Note
-        that only the changed properties will be in this update hash; see peakfinder9
-        for an example of caching and reinstanting complex kernels using many parameters
-        when any change."""
         if changed_config.has("intenstityPerPhoton"):
             self._intensity_per_photon = changed_config["intenstityPerPhoton"]
+        if changed_config.has("boundary"):
+            self._autocorr = AUTOCORR_FUN[changed_config["boundary"]]
 
     def _update_buffer(self):
         nf, _, nx, ny = self._device.input_data_shape
@@ -121,4 +181,3 @@ class Autocorrelation(BaseCorrectionAddon):
                 self._shmem_buffer.cuda_pin()
         else:
             self._shmem_buffer.change_shape(shape)
-
-- 
GitLab