Skip to content
Snippets Groups Projects
Commit e8cb2bf9 authored by Egor Sobolev's avatar Egor Sobolev Committed by spbonc
Browse files

Add computation of autocorrelation in different boundary conditions

parent b6c9a2be
No related branches found
No related tags found
1 merge request!70Autocorrelation addon
import numpy as np import numpy as np
from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT, STRING_ELEMENT
from .base_addon import BaseCorrectionAddon from .base_addon import BaseCorrectionAddon
from .. import shmem_utils from .. import shmem_utils
def block(b):
return np.concatenate([np.concatenate(row, axis=-1) for row in b], axis=-2)
def autocorr2_fft(f, mode='full', backend='scipy'):
"""Computes 2D autocorrelation function for real input.
This is equivalent to
fftconvolve(f, np.flip(f, axis=(-2, -1)), mode)
"""
nx, ny = f.shape[-2:]
from scipy.fft import next_fast_len
fshape = [next_fast_len(sz, True) for sz in (2 * nx - 1, 2 * ny - 1)]
r = fft.rfft2(f, fshape)
r = fft.irfft2((r * r.conj()).real, fshape)
if mode == "same":
ux = (nx + 1) // 2
lx = (1 - nx) // 2
uy = (ny + 1) // 2
ly = (1 - ny) // 2
elif mode == "full":
ux = nx
lx = (1 - nx)
uy = ny
ly = (1 - ny)
else:
raise ValueError("acceptable mode flags are 'same', or 'full'")
a11 = r[..., :ux, :uy]
a12 = r[..., :ux, ly:]
a21 = r[..., lx:, :uy]
a22 = r[..., lx:, ly:]
return block([[a22, a21], [a12, a11]])
def autocorr2symm_fft(f, backend='scipy'):
"""Computes 2D autocorrelation function for real input and
symmetrical boundary conditions.
"""
r = fft.dctn(f, 3, axes=(-2, -1))
r = fft.idctn(r * r, 3, axes=(-2, -1))
return r
def autocorr2wrap_fft(f, backend='scipy'):
"""Computes 2D autocorrelation function for real input and
circular boundary conditions.
"""
s = f.shape[-2:]
r = fft.rfft2(f, s)
r = fft.irfft2((r * r.conj()).real, s)
r = fft.fftshift(r)
return r
AUTOCORR_FUN = {
"symm": autocorr2symm_fft,
"wrap": autocorr2wrap_fft,
"fill": lambda f: autocorr2_fft(f, "same"),
}
class Autocorrelation(BaseCorrectionAddon): class Autocorrelation(BaseCorrectionAddon):
_device = None # will be set to host device *after* init _device = None # will be set to host device *after* init
@staticmethod @staticmethod
def extend_device_schema(schema, prefix): def extend_device_schema(schema, prefix):
"""Will be given the device schema where everything should be put under a
prefix. This prefix should be something like 'addons.nodeName', is given to this
class by the base correction device, and the root node of prefix will already be
in place (as will an 'enable' flag under it)."""
( (
STRING_ELEMENT(schema)
.key(f"{prefix}.boundary")
.tags("managed")
.options(",".join(AUTOCORR_FUN.keys()))
.assignmentOptional()
.defaultValue("symm")
.reconfigurable()
.commit(),
DOUBLE_ELEMENT(schema) DOUBLE_ELEMENT(schema)
.key(f"{prefix}.intenstityPerPhoton") .key(f"{prefix}.intenstityPerPhoton")
.tags("managed") .tags("managed")
...@@ -27,8 +97,6 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -27,8 +97,6 @@ class Autocorrelation(BaseCorrectionAddon):
@staticmethod @staticmethod
def extend_output_schema(schema): def extend_output_schema(schema):
"""Will be given the regular output schema, can create arbitrary nodes in there
and add properties to it."""
( (
NDARRAY_ELEMENT(schema) NDARRAY_ELEMENT(schema)
.key("image.autocorr") .key("image.autocorr")
...@@ -37,11 +105,11 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -37,11 +105,11 @@ class Autocorrelation(BaseCorrectionAddon):
) )
def __init__(self, config): def __init__(self, config):
"""Will be given the node from extend_device_schema, no prefix needed here"""
global cupy global cupy
import cupy import cupy
self._intensity_per_photon = config["intenstityPerPhoton"] self._intensity_per_photon = config["intenstityPerPhoton"]
self._autocorr = AUTOCORR_FUN[config["boundary"]]
self._shmem_buffer = None self._shmem_buffer = None
...@@ -50,11 +118,15 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -50,11 +118,15 @@ class Autocorrelation(BaseCorrectionAddon):
super().__del__() super().__del__()
def _initialization(self): def _initialization(self):
global fft
if self._device.kernel_runner._gpu_based:
import cupyx.scipy.fft as fft
else:
import scipy.fft as fft
self._update_buffer() self._update_buffer()
def post_correction(self, processed_data, cell_table, pulse_table, output_hash): def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
"""Called directly after correction has happened. Processed data will still be
on GPU if the correction device is generally running in GPU mode."""
nf, nx, ny = processed_data.shape nf, nx, ny = processed_data.shape
if nf != self._nframe: if nf != self._nframe:
self._update_buffer() self._update_buffer()
...@@ -62,23 +134,13 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -62,23 +134,13 @@ class Autocorrelation(BaseCorrectionAddon):
data = np.around(processed_data / self._intensity_per_photon) data = np.around(processed_data / self._intensity_per_photon)
data[np.isnan(data) | (data < 0.0)] = 0.0 data[np.isnan(data) | (data < 0.0)] = 0.0
nf, nx, ny = data.shape autocorr = self._autocorr(data)
g = np.fft.rfft2(data, s=(2 * nx, 2 * ny))
autocorr = np.fft.irfft2((g * g.conj()).real, s=(nx, ny))
# !!! with fftconvolve
# from cupyx.scipy.signal import fftconvolve
# autocorr = fftconvolve(
# data, data[:, ::-1, ::-1], 'same', axes=(-2, -1))
buffer_handle, buffer_array = self._shmem_buffer.next_slot() buffer_handle, buffer_array = self._shmem_buffer.next_slot()
autocorr.get(out=buffer_array[:]) if hasattr(autocorr, 'get'):
autocorr.get(out=buffer_array[:])
# !!! alternatively send the squared Fourier transform else:
# !!! shape (nf, 2 * nx, ny + 1) buffer_array[:] = autocorr
#
# gg = (g * g.conj()).real
# buffer_handle, buffer_array = self._shmem_buffer.next_slot()
# gg.get(out=buffer_array[:])
if self._device._use_shmem_handles: if self._device._use_shmem_handles:
output_hash.set("image.autocorr", buffer_handle) output_hash.set("image.autocorr", buffer_handle)
...@@ -96,12 +158,10 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -96,12 +158,10 @@ class Autocorrelation(BaseCorrectionAddon):
# pass # pass
def reconfigure(self, changed_config): def reconfigure(self, changed_config):
"""Will be given the node from extend_device_schema, no prefix needed here. Note
that only the changed properties will be in this update hash; see peakfinder9
for an example of caching and reinstanting complex kernels using many parameters
when any change."""
if changed_config.has("intenstityPerPhoton"): if changed_config.has("intenstityPerPhoton"):
self._intensity_per_photon = changed_config["intenstityPerPhoton"] self._intensity_per_photon = changed_config["intenstityPerPhoton"]
if changed_config.has("boundary"):
self._autocorr = AUTOCORR_FUN[changed_config["boundary"]]
def _update_buffer(self): def _update_buffer(self):
nf, _, nx, ny = self._device.input_data_shape nf, _, nx, ny = self._device.input_data_shape
...@@ -121,4 +181,3 @@ class Autocorrelation(BaseCorrectionAddon): ...@@ -121,4 +181,3 @@ class Autocorrelation(BaseCorrectionAddon):
self._shmem_buffer.cuda_pin() self._shmem_buffer.cuda_pin()
else: else:
self._shmem_buffer.change_shape(shape) self._shmem_buffer.change_shape(shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment