From c0610688b45efae051118fd700caaa5e8d1ad986 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 27 Oct 2021 15:54:05 +0200
Subject: [PATCH] Properly send gain map as main data / preview

---
 src/calng/AgipdCorrection.py    | 89 ++++++++++++++++++++++++++++++---
 src/calng/DsscCorrection.py     |  9 +++-
 src/calng/agipd_gpu.py          | 38 +++++++++++++-
 src/calng/agipd_gpu_kernels.cpp | 11 ++--
 src/calng/base_correction.py    |  7 +--
 5 files changed, 135 insertions(+), 19 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index d3f42390..df102741 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -8,7 +8,9 @@ from karabo.bound import (
     NODE_ELEMENT,
     STRING_ELEMENT,
     VECTOR_STRING_ELEMENT,
-    ImageData,
+    Schema,
+    NDARRAY_ELEMENT,
+    OUTPUT_CHANNEL,
 )
 from karabo.common.states import State
 
@@ -16,6 +18,7 @@ from ._version import version as deviceVersion
 from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags
 from .base_correction import BaseCorrection, add_correction_step_schema
 from .calcat_utils import AgipdCalcatFriend, AgipdConstants
+from . import shmem_utils
 
 
 @KARABO_CLASSINFO("AgipdCorrection", deviceVersion)
@@ -54,10 +57,10 @@ class AgipdCorrection(BaseCorrection):
             .displayedName("Send gain map on dataOutput")
             .assignmentOptional()
             .defaultValue(False)
-            .reconfigurable()
             .commit(),
         )
         AgipdCorrection._managed_keys.append("sendGainMap")
+        # TODO: make sendGainMap reconfigurable
 
         (
             STRING_ELEMENT(expected)
@@ -67,6 +70,17 @@ class AgipdCorrection(BaseCorrection):
             .defaultValue("")
             .commit()
         )
+        preview_schema = Schema()
+        (
+            NODE_ELEMENT(preview_schema).key("data").commit(),
+            NDARRAY_ELEMENT(preview_schema).key("data.adc").dtype("FLOAT").commit(),
+        )
+        (
+            OUTPUT_CHANNEL(expected)
+            .key("preview.outputGain")
+            .dataSchema(preview_schema)
+            .commit(),
+        )
 
         AgipdCalcatFriend.add_schema(expected, AgipdCorrection._managed_keys)
         # this is not automatically done by superclass for complicated class reasons
@@ -192,8 +206,10 @@ class AgipdCorrection(BaseCorrection):
             "gain_mode": self.gain_mode,
             "bad_pixel_mask_value": self.bad_pixel_mask_value,
             "g_gain_value": config.get("corrections.relGainXray.gGainValue"),
+            "output_gain_map": config.get("sendGainMap"),
         }
 
+        self._shmem_buffer_gain_map = None
         self._update_shapes()
 
         # configurability: overriding md_additional_offset
@@ -293,19 +309,51 @@ class AgipdCorrection(BaseCorrection):
                 preview_raw, preview_corrected = self.gpu_runner.compute_preview(
                     preview_slice_index
                 )
+                if self._schema_cache["sendGainMap"]:
+                    preview_gain = self.gpu_runner.compute_preview_gain(
+                        preview_slice_index
+                    )
 
+        data.set("image.data", buffer_handle)
         if self._schema_cache["sendGainMap"]:
-            data.set("image.gainMap", ImageData(self.gpu_runner.get_gain_map()))
+            buffer_handle, buffer_array = self._shmem_buffer_gain_map.next_slot()
+            self.gpu_runner.get_gain_map(
+                output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+                out=buffer_array,
+            )
+            data.set(
+                "image.gainMap",
+                buffer_handle,
+            )
+            data.set("calngShmemPaths", ["image.data", "image.gainMap"])
+        else:
+            data.set("calngShmemPaths", ["image.data"])
 
-        data.set("image.data", buffer_handle)
         data.set("image.cellId", cell_table[:, np.newaxis])
         data.set("image.pulseId", pulse_table[:, np.newaxis])
-        data.set("calngShmemPaths", ["image.data"])
+
         self._write_output(data, metadata)
         if do_generate_preview:
-            self._write_combiner_preview(
-                preview_raw, preview_corrected, train_id, source
-            )
+            if self._schema_cache["sendGainMap"]:
+                self._write_combiner_previews(
+                    (
+                        ("preview.outputRaw", preview_raw),
+                        ("preview.outputCorrected", preview_corrected),
+                        ("preview.outputGain", preview_gain),
+                    ),
+                    train_id,
+                    source,
+                )
+                # TODO: DRY
+            else:
+                self._write_combiner_previews(
+                    (
+                        ("preview.outputRaw", preview_raw),
+                        ("preview.outputCorrected", preview_corrected),
+                    ),
+                    train_id,
+                    source,
+                )
 
         # update rate etc.
         self._buffered_status_update.set("trainId", train_id)
@@ -358,6 +406,31 @@ class AgipdCorrection(BaseCorrection):
         assert np.max(new_filter) < self.get("dataFormat.memoryCells")
         self.pulse_filter = new_filter
 
+    def _update_shapes(self):
+        super()._update_shapes()
+        # TODO: pack four pixels per byte
+        if self._schema_cache["sendGainMap"]:
+            if self._shmem_buffer_gain_map is None:
+                buffer_name = self.getInstanceId() + ":dataOutput-gain"
+                # try to match number of trains in image data buffer
+                memory_budget = (
+                    self.get("outputShmemBufferSize")
+                    * 2 ** 30
+                    // (
+                        np.dtype(self.output_data_dtype).itemsize
+                        // np.dtype(np.uint8).itemsize
+                    )
+                )
+                self.log.INFO(f"Opening new shmem buffer for gain: {buffer_name}")
+                self._shmem_buffer_gain_map = shmem_utils.ShmemCircularBuffer(
+                    memory_budget,
+                    self.output_data_shape,
+                    np.uint8,
+                    buffer_name,
+                )
+            else:
+                self._shmem_buffer_gain_map.change_shape(self.output_data_shape)
+
     def _update_bad_pixel_selection(self):
         selection = 0
         for field in BadPixelValues:
diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py
index 04916db1..d70f172d 100644
--- a/src/calng/DsscCorrection.py
+++ b/src/calng/DsscCorrection.py
@@ -144,8 +144,13 @@ class DsscCorrection(BaseCorrection):
         data.set("calngShmemPaths", ["image.data"])
         self._write_output(data, metadata)
         if do_generate_preview:
-            self._write_combiner_preview(
-                preview_raw, preview_corrected, train_id, source
+            self._write_combiner_previews(
+                (
+                    ("preview.outputRaw", preview_raw),
+                    ("preview.outputCorrected", preview_corrected),
+                ),
+                train_id,
+                source,
             )
 
         # update rate etc.
diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py
index 3b3802ab..5ae1aa9c 100644
--- a/src/calng/agipd_gpu.py
+++ b/src/calng/agipd_gpu.py
@@ -39,6 +39,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         bad_pixel_mask_value=cupy.nan,
         gain_mode=AgipdGainMode.ADAPTIVE_GAIN,
         g_gain_value=1,
+        output_gain_map=False,
     ):
         self.gain_mode = gain_mode
         if self.gain_mode is AgipdGainMode.ADAPTIVE_GAIN:
@@ -47,6 +48,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
             self.default_gain = cupy.uint8(gain_mode - 1)
         self.input_shape = (memory_cells, 2, pixels_x, pixels_y)
         self.processed_shape = (memory_cells, pixels_x, pixels_y)
+        self.output_gain_map = output_gain_map
         super().__init__(
             pixels_x,
             pixels_y,
@@ -55,7 +57,12 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
             input_data_dtype,
             output_data_dtype,
         )
-        self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=cupy.uint8)
+        if self.output_gain_map:
+            self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=cupy.uint8)
+            self.preview_gain = np.empty(self.preview_shape, dtype=np.float32)
+        else:
+            # TODO: don't even set in this case
+            self.gain_map_gpu = cupy.empty(0, dtype=cupy.uint8)
 
         self.map_shape = (self.constant_memory_cells, self.pixels_x, self.pixels_y)
         self.gm_map_shape = self.map_shape + (3,)  # for gain-mapped constants
@@ -75,6 +82,33 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
 
         self.update_block_size((1, 1, 64))
 
+    def compute_preview_gain(self, preview_index):
+        assert self.output_gain_map
+        # TODO: abstract most of this in base_gpu to DRY
+        if preview_index < -4:
+            raise ValueError(f"No statistic with code {preview_index} defined")
+        elif preview_index >= self.memory_cells:
+            raise ValueError(f"Memory cell index {preview_index} out of range")
+
+        if preview_index >= 0:
+            self.gain_map_gpu[preview_index].astype(np.float32).get(
+                out=self.preview_gain
+            )
+        elif preview_index == -1:
+            # TODO: confirm that max is pixel and not integrated intensity
+            # separate from next case because dtype not applicable here
+            cupy.max(self.gain_map_gpu, axis=0).astype(cupy.float32).get(
+                out=self.preview_gain
+            )
+        elif preview_index in (-2, -3, -4):
+            stat_fun = {-1: cupy.max, -2: cupy.mean, -3: cupy.sum, -4: cupy.std}[
+                preview_index
+            ]
+            stat_fun(self.gain_map_gpu, axis=0, dtype=cupy.float32).get(
+                out=self.preview_gain
+            )
+        return self.preview_gain
+
     def _preview_preprocess_raw(self):
         return self.input_data_gpu[:, 0]
 
@@ -247,6 +281,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         )
 
     def get_gain_map(self, output_order, out=None):
+        assert self.output_gain_map
         return cupy.ascontiguousarray(
             cupy.transpose(
                 self.gain_map_gpu,
@@ -264,6 +299,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
                 "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype),
                 "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype),
                 "corr_enum": utils.enum_to_c_template(CorrectionFlags),
+                "output_gain_map": self.output_gain_map,
             }
         )
         self.source_module = cupy.RawModule(code=kernel_source)
diff --git a/src/calng/agipd_gpu_kernels.cpp b/src/calng/agipd_gpu_kernels.cpp
index 7548ac40..0561e84b 100644
--- a/src/calng/agipd_gpu_kernels.cpp
+++ b/src/calng/agipd_gpu_kernels.cpp
@@ -127,21 +127,26 @@ extern "C" {
 					corrected = (corrected / rel_gain_xray_map[map_index]) * g_gain_value;
 				}
 			}
-
-			gain_map[output_index] = gain;
 			{% if output_data_dtype == "half" %}
 			output[output_index] = __float2half(corrected);
 			{% else %}
 			output[output_index] = ({{output_data_dtype}})corrected;
 			{% endif %}
+
+			{% if output_gain_map %}
+			gain_map[output_index] = gain;
+			{% endif %}
 		} else {
 			// TODO: decide what to do when we cannot threshold
-			gain_map[data_index] = 255;
 			{% if output_data_dtype == "half" %}
 			output[data_index] = __float2half(corrected);
 			{% else %}
 			output[data_index] = ({{output_data_dtype}})corrected;
 			{% endif %}
+
+			{% if output_gain_map %}
+			gain_map[data_index] = 255;
+			{% endif %}
 		}
 	}
 }
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index c99d82eb..9d393830 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -618,7 +618,7 @@ class BaseCorrection(PythonDevice):
         channel.write(data, metadata, False)
         channel.update()
 
-    def _write_combiner_preview(self, data_raw, data_corrected, train_id, source):
+    def _write_combiner_previews(self, channel_data_pairs, train_id, source):
         # TODO: take into account updated pulse table after pulse filter
         # TODO: send as ImageData (requires updated assembler)
         # TODO: allow sending *all* frames for commissioning (request: Jola)
@@ -629,10 +629,7 @@ class BaseCorrection(PythonDevice):
         # note: have to construct because setting .tid after init is broken
         timestamp = Timestamp(Epochstamp(), Trainstamp(train_id))
         metadata = ChannelMetaData(source, timestamp)
-        for channel_name, data in (
-            ("preview.outputRaw", data_raw),
-            ("preview.outputCorrected", data_corrected),
-        ):
+        for channel_name, data in channel_data_pairs:
             preview_hash.set("data.adc", data[..., np.newaxis])
             channel = self.signalSlotable.getOutputChannel(channel_name)
             channel.write(preview_hash, metadata, False)
-- 
GitLab