From c15bbc65d500362174d576b281f823d6a4395853 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Mon, 20 Mar 2023 19:26:22 +0100
Subject: [PATCH] Add rough peakfinder9 implementation

---
 setup.py                                   |   1 +
 src/calng/correction_addons/peakfinder9.py | 228 +++++++++++++++++++++
 src/calng/kernels/peakfinder9_gpu.cu       | 184 +++++++++++++++++
 3 files changed, 413 insertions(+)
 create mode 100644 src/calng/correction_addons/peakfinder9.py
 create mode 100644 src/calng/kernels/peakfinder9_gpu.cu

diff --git a/setup.py b/setup.py
index 0c624f95..cf1dc7f0 100644
--- a/setup.py
+++ b/setup.py
@@ -52,6 +52,7 @@ setup(
         "calng.correction_addon": [
             "IntegratedIntensity = calng.correction_addons.integrated_intensity:IntegratedIntensityAddon [agipd]",  # noqa
             "RandomFrames = calng.correction_addons.random_frames:RandomFramesAddon",
+            "Peakfinder9 = calng.correction_addons.peakfinder9:Peakfinder9Addon [agipd]",  # noqa
         ],
     },
     package_data={"": ["kernels/*"]},
diff --git a/src/calng/correction_addons/peakfinder9.py b/src/calng/correction_addons/peakfinder9.py
new file mode 100644
index 00000000..3c25ba30
--- /dev/null
+++ b/src/calng/correction_addons/peakfinder9.py
@@ -0,0 +1,228 @@
+import functools
+import pathlib
+
+from karabo.bound import (
+    FLOAT_ELEMENT,
+    NDARRAY_ELEMENT,
+    NODE_ELEMENT,
+    UINT32_ELEMENT,
+)
+from .base_addon import BaseCorrectionAddon
+from .. import utils
+
+
+class Peakfinder9Addon(BaseCorrectionAddon):
+    _node_name = "peakfinder9"
+
+    @staticmethod
+    def extend_device_schema(schema, managed_keys, prefix):
+        (
+            UINT32_ELEMENT(schema)
+            .key(f"{prefix}.windowRadius")
+            .assignmentOptional()
+            .defaultValue(2)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key(f"{prefix}.maxPeaks")
+            .assignmentOptional()
+            .defaultValue(500)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key(f"{prefix}.minPeakValueOverNeighbors")
+            .assignmentOptional()
+            .defaultValue(10)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key(f"{prefix}.minSnrMaxPixel")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key(f"{prefix}.minSnrPeakPixels")
+            .assignmentOptional()
+            .defaultValue(4)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key(f"{prefix}.minSnrWholePeak")
+            .assignmentOptional()
+            .defaultValue(6)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key(f"{prefix}.minSigma")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key(f"{prefix}.blockX")
+            .assignmentOptional()
+            .defaultValue(1)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key(f"{prefix}.blockY")
+            .assignmentOptional()
+            .defaultValue(1)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key(f"{prefix}.blockZ")
+            .assignmentOptional()
+            .defaultValue(64)
+            .reconfigurable()
+            .commit(),
+        )
+        managed_keys |= {
+            f"{prefix}.{key}"
+            for key in {
+                "windowRadius",
+                "maxPeaks",
+                "minPeakValueOverNeighbors",
+                "minSnrMaxPixel",
+                "minSnrPeakPixels",
+                "minSnrWholePeak",
+                "minSigma",
+            }
+        }
+
+    @staticmethod
+    def extend_output_schema(schema):
+        (
+            NODE_ELEMENT(schema)
+            .key("peakfinding")
+            .commit(),
+
+            NDARRAY_ELEMENT(schema)
+            .key("peakfinding.numPeaks")
+            .dtype("UINT32")
+            .commit(),
+
+            NDARRAY_ELEMENT(schema)
+            .key("peakfinding.peakX")
+            .dtype("FLOAT")
+            .commit(),
+
+            NDARRAY_ELEMENT(schema)
+            .key("peakfinding.peakY")
+            .dtype("FLOAT")
+            .commit(),
+
+            NDARRAY_ELEMENT(schema)
+            .key("peakfinding.peakIntensity")
+            .dtype("FLOAT")
+            .commit(),
+        )
+
+    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
+        # assumes processed data shape is frames, pixels, pixels
+        if self._input_shape != processed_data.shape:
+            try:
+                del self._peakfinding_parameters
+            except AttributeError:
+                pass
+            try:
+                del self._grid_and_block
+            except AttributeError:
+                pass
+            self._input_shape = processed_data.shape
+            self._rebuild_buffers()
+        kernel_params = self._peakfinding_parameters  # this will create buffers
+        self._peak_counts.fill(0)
+        self.kernel(
+            *self._grid_and_block,
+            (
+                *kernel_params,
+                processed_data.astype(cupy.float32, copy=False),
+                self._peak_counts,
+                self._peak_x,
+                self._peak_y,
+                self._peak_intensity,
+            ),
+        )
+        output_hash.set("peakfinding.numPeaks", self._peak_counts.get())
+        output_hash.set("peakfinding.peakX", self._peak_x.get())
+        output_hash.set("peakfinding.peakY", self._peak_y.get())
+        output_hash.set("peakfinding.peakIntensity", self._peak_intensity.get())
+
+    def reconfigure(self, changed_config):
+        self._config.merge(changed_config)
+        try:
+            del self._peakfinding_parameters
+        except AttributeError:
+            pass
+        if changed_config.has("maxPeaks"):
+            self._rebuild_buffers
+        if any(changed_config.has(key) for key in ("blockX", "blockY", "blockZ")):
+            try:
+                del self._grid_and_block
+            except AttributeError:
+                pass
+
+    def __init__(self, config):
+        global cupy
+        import cupy
+
+        self._config = config
+
+        _src_dir = pathlib.Path(__file__).absolute().parent.parent
+        with (_src_dir / "kernels" / "peakfinder9_gpu.cu").open("r") as fd:
+            self.kernel = cupy.RawKernel(
+                code=fd.read(),
+                name="pf9",
+                options=("--std=c++11",),
+                backend="nvcc",
+            )
+
+        self._input_shape = (0, 0, 0)  # frames, ss, fs
+
+    @functools.cached_property
+    def _grid_and_block(self):
+        # TODO: optimize
+        arbitrary_block_shape = (
+            self._config["blockX"],
+            self._config["blockY"],
+            self._config["blockZ"],
+        )
+        return (
+            utils.grid_to_cover_shape_with_blocks(
+                self._input_shape, arbitrary_block_shape
+            ),
+            arbitrary_block_shape,
+        )
+
+    @functools.cached_property
+    def _peakfinding_parameters(self):
+        return (
+            cupy.uint16(self._input_shape[0]),
+            cupy.uint16(self._input_shape[1]),
+            cupy.uint16(self._input_shape[2]),
+            cupy.uint16(self._config["windowRadius"]),
+            cupy.float32(self._config["minPeakValueOverNeighbors"]),
+            cupy.float32(self._config["minSnrMaxPixel"]),
+            cupy.float32(self._config["minSnrPeakPixels"]),
+            cupy.float32(self._config["minSnrWholePeak"]),
+            cupy.float32(self._config["minSigma"]),
+            cupy.uint32(self._config["maxPeaks"]),
+        )
+
+    def _rebuild_buffers(self):
+        output_shape = (self._input_shape[0], self._config["maxPeaks"])
+        self._peak_counts = cupy.zeros(self._input_shape[0], dtype=cupy.uint32)
+        self._peak_x = cupy.empty(output_shape, dtype=cupy.float32)
+        self._peak_y = cupy.empty(output_shape, dtype=cupy.float32)
+        self._peak_intensity = cupy.empty(output_shape, dtype=cupy.float32)
diff --git a/src/calng/kernels/peakfinder9_gpu.cu b/src/calng/kernels/peakfinder9_gpu.cu
new file mode 100644
index 00000000..ff65d2ea
--- /dev/null
+++ b/src/calng/kernels/peakfinder9_gpu.cu
@@ -0,0 +1,184 @@
+#include <cmath>
+#include <nvfunctional>
+
+class MaskedImageFrame {
+public:
+	const float* data_start;
+	const unsigned short num_rows;
+	const unsigned short num_cols;
+	__device__ MaskedImageFrame(const float* data_start, const unsigned short num_rows, const unsigned short num_cols):
+		data_start(data_start), num_rows(num_rows), num_cols(num_cols) {}
+	__device__ bool is_masked(const unsigned short i, const unsigned short j) {
+		return isnan(data_start[i * num_cols + j]);
+	}
+	__device__ float get(const unsigned short i, const unsigned short j) {
+		return data_start[i * num_cols + j];
+	}
+	__device__ float get(const unsigned short i, const unsigned short j, const float fallback) {
+		if (is_masked(i, j)) {
+			return fallback;
+		}
+		return get(i, j);
+	}
+	__device__ void maybe_call(const unsigned short i, const unsigned short j, const nvstd::function<void(unsigned short, unsigned short, float)> &fun) {
+		if (!is_masked(i, j)) {
+			fun(i, j, get(i, j));
+		}
+	}
+
+	/* Applies fun to all unmasked elements of frame lying on a ring (well, at least
+	   a circle in the infinity norm) with given center and radius.
+	   TODO: consider making a custom iterator, too
+	*/
+	__device__ void fun_ring(const unsigned short center_i,
+	                         const unsigned short center_j,
+	                         const unsigned short radius,
+	                         const nvstd::function<void(unsigned short, unsigned short, float)> &fun) {
+		if (radius==0) {
+			maybe_call(center_i, center_j, fun);
+			return;
+		}
+
+		// top row
+		for (short j=-radius; j<=radius; ++j) {
+			maybe_call(center_i - radius, center_j + j, fun);
+		}
+		// left / right side (not overlapping with top/bottom)
+		for (short i=-radius+1; i<radius; ++i) {
+			maybe_call(center_i + i, center_j - radius, fun);
+			maybe_call(center_i + i, center_j + radius, fun);
+		}
+		// bottom row
+		for (short j=-radius; j<=radius; ++j) {
+			maybe_call(center_i + radius, center_j + j, fun);
+		}
+	}
+};
+
+extern "C" __global__ void pf9(const unsigned short num_frames,
+                               const unsigned short num_rows,
+                               const unsigned short num_cols,
+                               const unsigned short window_radius,
+                               const float min_peak_over_border,
+                               const float min_snr_biggest_pixel,
+                               const float min_snr_peak_pixels,
+                               const float min_snr_whole_peak,
+                               const float min_sigma,
+                               const unsigned int max_peaks,
+                               const float* image,
+                               unsigned int* output_counts,
+                               float* output_x,
+                               float* output_y,
+                               float* output_intensity) {
+
+	// execution model: one thread handles one pixel in one frame
+	const unsigned short frame = blockDim.x * blockIdx.x + threadIdx.x;
+	const unsigned short row = blockDim.y * blockIdx.y + threadIdx.y;
+	const unsigned short col = blockDim.z * blockIdx.z + threadIdx.z;
+
+	if (frame >= num_frames ||
+	    row < window_radius || row >= num_rows - window_radius ||
+	    col < window_radius || col >= num_cols - window_radius) {
+		return;
+	}
+
+	// wrap thin helper class around for convenience
+	MaskedImageFrame masked_frame(image + frame * (num_rows * num_cols),
+	                              num_rows,
+	                              num_cols);
+
+	// candidate should not be masked
+	if (masked_frame.is_masked(row, col)) {
+		return;
+	}
+
+	float pixel_val = masked_frame.get(row, col);
+
+	// candidate should be greater than immediate neighbors
+	// with tie breaking: in case of equality, lowest row, col is candidate
+	if (pixel_val <= masked_frame.get(row-1, col-1, -INFINITY) ||
+	    pixel_val <= masked_frame.get(row-1, col, -INFINITY) ||
+	    pixel_val <= masked_frame.get(row-1, col+1, -INFINITY) ||
+	    pixel_val <= masked_frame.get(row, col-1, -INFINITY) ||
+	    pixel_val < masked_frame.get(row, col+1, -INFINITY) ||
+	    pixel_val < masked_frame.get(row+1, col, -INFINITY) ||
+	    pixel_val < masked_frame.get(row+1, col, -INFINITY) ||
+	    pixel_val < masked_frame.get(row+1, col, -INFINITY)) {
+		return;
+	}
+
+	// candidate should be greater than pixels on border
+	// note: original PF9 only checks three pixels per side around window
+	// (full border for window radius 2, less for higher)
+	{
+		float border_max = -INFINITY;
+		masked_frame.fun_ring(row, col, window_radius, [&] (unsigned short, unsigned short, float val) {
+			border_max = max(border_max, val);
+		});
+		if (pixel_val - min_peak_over_border <= border_max) {
+			return;
+		}
+	}
+
+	// candidate should have sufficient SNR over border
+	float mean = 0;
+	float sigma = 0;
+	{
+		// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+		unsigned short count = 0;
+		float M2 = 0;
+		masked_frame.fun_ring(row, col, window_radius, [&] (unsigned short, unsigned short, float val) {
+			++count;
+			float delta = val - mean;
+			mean += delta / static_cast<float>(count);
+			float delta2 = val - mean;
+			M2 += delta * (delta2);
+		});
+		if (count < 10) {
+			// should this be configurable? (check for validPixelCount < 10 is hardcoded in PF9)
+			return;
+		}
+		// TODO: form opinion on /(n-1) for dealing with biased sample stdev
+		sigma = max(min_sigma, sqrtf(M2 / (static_cast<float>(count) - 1)));
+
+		if (pixel_val <= mean + min_snr_biggest_pixel * sigma) {
+			return;
+		}
+	}
+
+	// whole peak should have sufficent SNR
+	float peak_weighted_row;
+	float peak_weighted_col;
+	float peak_total_mass = pixel_val;
+	{
+		/* TODO: more compact form */
+		float peak_weighted_row_nom = static_cast<float>(row) * pixel_val;
+		float peak_weighted_col_nom = static_cast<float>(col) * pixel_val;
+		const float peak_pixel_threshold = mean + min_snr_peak_pixels * sigma;
+		for (unsigned short layer=1; layer<=window_radius; ++layer) {
+			float total_mass_before = peak_total_mass;
+			masked_frame.fun_ring(row, col, layer, [&] (unsigned short i, unsigned short j, float val) {
+				if (val > peak_pixel_threshold) {
+					peak_total_mass += val;
+					peak_weighted_row_nom += val * static_cast<float>(i);
+					peak_weighted_col_nom += val * static_cast<float>(j);
+				}
+			});
+			// in case nothing was added, stop expanding
+			if (peak_total_mass == total_mass_before) {
+				break;
+			}
+		}
+		if (peak_total_mass <= mean + min_snr_whole_peak * sigma) {
+			return;
+		}
+		peak_weighted_row = peak_weighted_row_nom / peak_total_mass;
+		peak_weighted_col = peak_weighted_col_nom / peak_total_mass;
+	}
+
+	unsigned int output_index = atomicInc(output_counts + frame, max_peaks);
+	unsigned int output_pos = frame * max_peaks + output_index;
+	output_x[output_pos] = peak_weighted_row;
+	output_y[output_pos] = peak_weighted_col;
+	output_intensity[output_pos] = peak_total_mass;
+}
-- 
GitLab