From fc3b3be25560dca61e367ff432bdf65266a13629 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Tue, 18 Jan 2022 14:18:00 +0100
Subject: [PATCH] Fill in more mandatory methods

---
 src/calng/JungfrauCorrection.py   | 138 ++++++++++++++++++++++++++++--
 src/calng/kernels/jungfrau_gpu.cu |  11 +--
 2 files changed, 134 insertions(+), 15 deletions(-)

diff --git a/src/calng/JungfrauCorrection.py b/src/calng/JungfrauCorrection.py
index 1a0b20c9..bf7b9815 100644
--- a/src/calng/JungfrauCorrection.py
+++ b/src/calng/JungfrauCorrection.py
@@ -1,5 +1,4 @@
 import enum
-import timeit
 
 import cupy
 import numpy as np
@@ -7,15 +6,17 @@ from karabo.bound import (
     DOUBLE_ELEMENT,
     KARABO_CLASSINFO,
     OVERWRITE_ELEMENT,
+    STRING_ELEMENT,
     VECTOR_STRING_ELEMENT,
 )
-from karabo.common.states import State
 
 from . import base_gpu, calcat_utils, utils
 from ._version import version as deviceVersion
 from .base_correction import BaseCorrection, add_correction_step_schema
 
 
+_pretend_pulse_table = np.arange(16, dtype=np.uint8)
+
 class JungfrauConstants(enum.Enum):
     Offset10Hz = enum.auto()
     BadPixelsDark10Hz = enum.auto()
@@ -53,6 +54,7 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
         output_data_dtype=cupy.float32,
         bad_pixel_mask_value=cupy.nan,
         burst_mode=False,
+        gain_mode=JungfrauGainMode.dynamicgain,
     ):
         self.burst_mode = burst_mode
         self.input_shape = (memory_cells, pixels_y, pixels_x)
@@ -65,6 +67,9 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
             input_data_dtype,
             output_data_dtype,
         )
+        # TODO: avoid superclass creating cell table with wrong dtype first
+        self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=cupy.uint8)
+        self.input_gain_map_gpu = cupy.empty(self.input_shape, dtype=cupy.uint8)
         self.map_shape = self.input_shape + (3,)
         # is jungfrau stuff gain mapped?
         self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32)
@@ -84,9 +89,41 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
                 "burst_mode": self.burst_mode,
             }
         )
-        print(kernel_source)
+        for i, line in enumerate(kernel_source.split("\n")):
+            print(f"{i}: {line}")
         self.source_module = cupy.RawModule(code=kernel_source)
+        print("Got raw module")
         self.correction_kernel = self.source_module.get_function("correct")
+        print("Got kernel")
+
+    def _get_raw_for_preview(self):
+        return self.input_data_gpu.transpose(0, 2, 1)
+
+    def _get_corrected_for_preview(self):
+        return self.processed_data_gpu.transpose(0, 2, 1)
+
+    def load_data(self, image_data, input_gain_map, cell_table):
+        """Experiment: loading all three in one function as they are tied"""
+        self.input_data_gpu.set(image_data)
+        self.input_gain_map_gpu.set(input_gain_map)
+        self.cell_table_gpu.set(cell_table)
+
+    def correct(self, flags):
+        self.correction_kernel(
+            self.full_grid,
+            self.full_block,
+            (
+                self.input_data_gpu,
+                self.input_gain_map_gpu,
+                self.cell_table_gpu,
+                cupy.uint8(flags),
+                self.offset_map_gpu,
+                self.rel_gain_map_gpu,
+                self.bad_pixel_map_gpu,
+                self.bad_pixel_mask_value,
+                self.processed_data_gpu,
+            )
+        )
 
 
 class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend):
@@ -187,6 +224,10 @@ class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend):
         res["Integration Time"] = self._get_param("integrationTime")
         res["Sensor Temperature"] = self._get_param("sensorTemperature")
         res["Gain Setting"] = self._get_param("gainSetting")
+        gain_mode = JungfrauGainMode[self._get_param("gainMode")]
+        if gain_mode is not JungfrauGainMode.dynamicgain:
+            # TODO: figure out what to set
+            res["Gain mode"] = 1
         return res
 
 
@@ -207,6 +248,16 @@ class JungfrauCorrection(BaseCorrection):
     def expectedParameters(expected):
         super(JungfrauCorrection, JungfrauCorrection).expectedParameters(expected)
         (
+            OVERWRITE_ELEMENT(expected)
+            .key("dataFormat.pixelsX")
+            .setNewDefaultValue(1024)
+            .commit(),
+
+            OVERWRITE_ELEMENT(expected)
+            .key("dataFormat.pixelsY")
+            .setNewDefaultValue(512)
+            .commit(),
+
             OVERWRITE_ELEMENT(expected)
             .key("dataFormat.memoryCells")
             .setNewDefaultValue(1)
@@ -238,15 +289,27 @@ class JungfrauCorrection(BaseCorrection):
     def input_data_shape(self):
         return (
             self._schema_cache["dataFormat.memoryCells"],
-            self._schema_cache["dataFormat.pixelsX"],
             self._schema_cache["dataFormat.pixelsY"],
+            self._schema_cache["dataFormat.pixelsX"],
         )
 
     def __init__(self, config):
         super().__init__(config)
         # TODO: gain mode as constant parameter and / or device configuration
         self.gain_mode = JungfrauGainMode[config.get("constantParameters.gainMode")]
-        # TODO: rest of this
+
+        try:
+            self.bad_pixel_mask_value = np.float32(
+                config.get("corrections.badPixels.maskingValue")
+            )
+        except ValueError:
+            self.bad_pixel_mask_value = np.float32("nan")
+
+        self._kernel_runner_init_args = {
+            "gain_mode": self.gain_mode,
+            "bad_pixel_mask_value": self.bad_pixel_mask_value,
+            "burst_mode": False, # TODO
+        }
 
     def process_data(
         self,
@@ -258,5 +321,66 @@ class JungfrauCorrection(BaseCorrection):
         cell_table,
         do_generate_preview,
     ):
-        # TODO
-        ...
+        if self._frame_filter is not None:
+            try:
+                cell_table = cell_table[self._frame_filter]
+                image_data = image_data[self._frame_filter]
+            except IndexError:
+                self.log_status_warn(
+                    "Failed to apply frame filter, please check that it is valid!"
+                )
+                return
+
+        try:
+            self.kernel_runner.load_data(
+                image_data, data_hash.get("data.gain"), cell_table
+            )
+        except ValueError as e:
+            self.log_status_warn(f"Failed to load data: {e}")
+            return
+        except Exception as e:
+            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
+
+        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
+        self.kernel_runner.correct(self._correction_flag_enabled)
+        self.kernel_runner.reshape(
+            output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+            out=buffer_array,
+        )
+
+        if do_generate_preview:
+            if self._correction_flag_enabled != self._correction_flag_preview:
+                self.kernel_runner.correct(self._correction_flag_preview)
+            (
+                preview_slice_index,
+                preview_cell,
+                preview_pulse,
+            ) = utils.pick_frame_index(
+                self._schema_cache["preview.selectionMode"],
+                self._schema_cache["preview.index"],
+                cell_table,
+                _pretend_pulse_table,
+                warn_func=self.log_status_warn,
+            )
+            preview_raw, preview_corrected = self.kernel_runner.compute_previews(
+                preview_slice_index
+            )
+
+        # reusing input data hash for sending
+        data_hash.set(self._image_data_path, buffer_handle)
+        data_hash.set("calngShmemPaths", [self._image_data_path])
+
+        data_hash.set(self._cell_table_path, cell_table)
+        data_hash.set("image.pulseId", pulse_table[:, np.newaxis])
+
+        self._write_output(data_hash, metadata)
+
+        if do_generate_preview:
+            self._write_combiner_previews(
+                (
+                    ("preview.outputRaw", preview_raw),
+                    ("preview.outputCorrected", preview_corrected),
+                ),
+                train_id,
+                source,
+            )
diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu
index 2c84628f..e203b6c0 100644
--- a/src/calng/kernels/jungfrau_gpu.cu
+++ b/src/calng/kernels/jungfrau_gpu.cu
@@ -3,18 +3,13 @@
 {{corr_enum}}
 
 extern "C" {
-	/*
-	  TODO
-	  Shape of input data: memory cell, y, x
-	  Shape of offset constant: x, y, memory cell
-	*/
 	__global__ void correct(const {{input_data_dtype}}* data, // shape: memory cell, y, x
 	                        const unsigned char* gain_stage, // same shape
 	                        const unsigned char* cell_table,
 	                        const unsigned char corr_flags,
 	                        const float* offset_map,
 	                        const float* rel_gain_map,
-	                        const unsigned int bad_pixel_map,
+	                        const unsigned int* bad_pixel_map,
 	                        const float bad_pixel_mask_value,
 	                        {{output_data_dtype}}* output) {
 		const size_t X = {{pixels_x}};
@@ -76,8 +71,8 @@ extern "C" {
 				if (corr_flags & OFFSET) {
 					res -= offset_map[map_index];
 				}
-				if (corr_flags & GAIN) {
-					res /= gain_map[map_index];
+				if (corr_flags & REL_GAIN) {
+					res /= rel_gain_map[map_index];
 				}
 			}
 		}
-- 
GitLab