From e8cb2bf978123d529b6fd84cf18f1006e2dbfc8a Mon Sep 17 00:00:00 2001 From: Egor Sobolev <egor.sobolev@xfel.eu> Date: Sun, 17 Sep 2023 15:06:08 +0200 Subject: [PATCH] Add computation of autocorrelation in different boundary conditions --- .../correction_addons/autocorrelation.py | 119 +++++++++++++----- 1 file changed, 89 insertions(+), 30 deletions(-) diff --git a/src/calng/correction_addons/autocorrelation.py b/src/calng/correction_addons/autocorrelation.py index b32f11de..9cfbf7be 100644 --- a/src/calng/correction_addons/autocorrelation.py +++ b/src/calng/correction_addons/autocorrelation.py @@ -1,21 +1,91 @@ import numpy as np -from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT +from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT, STRING_ELEMENT from .base_addon import BaseCorrectionAddon from .. import shmem_utils +def block(b): + return np.concatenate([np.concatenate(row, axis=-1) for row in b], axis=-2) + + +def autocorr2_fft(f, mode='full', backend='scipy'): + """Computes 2D autocorrelation function for real input. + + This is equivalent to + fftconvolve(f, np.flip(f, axis=(-2, -1)), mode) + """ + nx, ny = f.shape[-2:] + + from scipy.fft import next_fast_len + fshape = [next_fast_len(sz, True) for sz in (2 * nx - 1, 2 * ny - 1)] + r = fft.rfft2(f, fshape) + r = fft.irfft2((r * r.conj()).real, fshape) + + if mode == "same": + ux = (nx + 1) // 2 + lx = (1 - nx) // 2 + uy = (ny + 1) // 2 + ly = (1 - ny) // 2 + elif mode == "full": + ux = nx + lx = (1 - nx) + uy = ny + ly = (1 - ny) + else: + raise ValueError("acceptable mode flags are 'same', or 'full'") + + a11 = r[..., :ux, :uy] + a12 = r[..., :ux, ly:] + a21 = r[..., lx:, :uy] + a22 = r[..., lx:, ly:] + + return block([[a22, a21], [a12, a11]]) + + +def autocorr2symm_fft(f, backend='scipy'): + """Computes 2D autocorrelation function for real input and + symmetrical boundary conditions. + """ + r = fft.dctn(f, 3, axes=(-2, -1)) + r = fft.idctn(r * r, 3, axes=(-2, -1)) + return r + + +def autocorr2wrap_fft(f, backend='scipy'): + """Computes 2D autocorrelation function for real input and + circular boundary conditions. + """ + s = f.shape[-2:] + r = fft.rfft2(f, s) + r = fft.irfft2((r * r.conj()).real, s) + r = fft.fftshift(r) + return r + + +AUTOCORR_FUN = { + "symm": autocorr2symm_fft, + "wrap": autocorr2wrap_fft, + "fill": lambda f: autocorr2_fft(f, "same"), +} + + class Autocorrelation(BaseCorrectionAddon): _device = None # will be set to host device *after* init @staticmethod def extend_device_schema(schema, prefix): - """Will be given the device schema where everything should be put under a - prefix. This prefix should be something like 'addons.nodeName', is given to this - class by the base correction device, and the root node of prefix will already be - in place (as will an 'enable' flag under it).""" ( + STRING_ELEMENT(schema) + .key(f"{prefix}.boundary") + .tags("managed") + .options(",".join(AUTOCORR_FUN.keys())) + .assignmentOptional() + .defaultValue("symm") + .reconfigurable() + .commit(), + DOUBLE_ELEMENT(schema) .key(f"{prefix}.intenstityPerPhoton") .tags("managed") @@ -27,8 +97,6 @@ class Autocorrelation(BaseCorrectionAddon): @staticmethod def extend_output_schema(schema): - """Will be given the regular output schema, can create arbitrary nodes in there - and add properties to it.""" ( NDARRAY_ELEMENT(schema) .key("image.autocorr") @@ -37,11 +105,11 @@ class Autocorrelation(BaseCorrectionAddon): ) def __init__(self, config): - """Will be given the node from extend_device_schema, no prefix needed here""" global cupy import cupy self._intensity_per_photon = config["intenstityPerPhoton"] + self._autocorr = AUTOCORR_FUN[config["boundary"]] self._shmem_buffer = None @@ -50,11 +118,15 @@ class Autocorrelation(BaseCorrectionAddon): super().__del__() def _initialization(self): + global fft + if self._device.kernel_runner._gpu_based: + import cupyx.scipy.fft as fft + else: + import scipy.fft as fft + self._update_buffer() def post_correction(self, processed_data, cell_table, pulse_table, output_hash): - """Called directly after correction has happened. Processed data will still be - on GPU if the correction device is generally running in GPU mode.""" nf, nx, ny = processed_data.shape if nf != self._nframe: self._update_buffer() @@ -62,23 +134,13 @@ class Autocorrelation(BaseCorrectionAddon): data = np.around(processed_data / self._intensity_per_photon) data[np.isnan(data) | (data < 0.0)] = 0.0 - nf, nx, ny = data.shape - g = np.fft.rfft2(data, s=(2 * nx, 2 * ny)) - autocorr = np.fft.irfft2((g * g.conj()).real, s=(nx, ny)) - # !!! with fftconvolve - # from cupyx.scipy.signal import fftconvolve - # autocorr = fftconvolve( - # data, data[:, ::-1, ::-1], 'same', axes=(-2, -1)) + autocorr = self._autocorr(data) buffer_handle, buffer_array = self._shmem_buffer.next_slot() - autocorr.get(out=buffer_array[:]) - - # !!! alternatively send the squared Fourier transform - # !!! shape (nf, 2 * nx, ny + 1) - # - # gg = (g * g.conj()).real - # buffer_handle, buffer_array = self._shmem_buffer.next_slot() - # gg.get(out=buffer_array[:]) + if hasattr(autocorr, 'get'): + autocorr.get(out=buffer_array[:]) + else: + buffer_array[:] = autocorr if self._device._use_shmem_handles: output_hash.set("image.autocorr", buffer_handle) @@ -96,12 +158,10 @@ class Autocorrelation(BaseCorrectionAddon): # pass def reconfigure(self, changed_config): - """Will be given the node from extend_device_schema, no prefix needed here. Note - that only the changed properties will be in this update hash; see peakfinder9 - for an example of caching and reinstanting complex kernels using many parameters - when any change.""" if changed_config.has("intenstityPerPhoton"): self._intensity_per_photon = changed_config["intenstityPerPhoton"] + if changed_config.has("boundary"): + self._autocorr = AUTOCORR_FUN[changed_config["boundary"]] def _update_buffer(self): nf, _, nx, ny = self._device.input_data_shape @@ -121,4 +181,3 @@ class Autocorrelation(BaseCorrectionAddon): self._shmem_buffer.cuda_pin() else: self._shmem_buffer.change_shape(shape) - -- GitLab