From cdacb78764e548383d8aa7f11888aba2c5215400 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Mon, 13 Sep 2021 12:32:48 +0200
Subject: [PATCH] Reuse correction flag enum definition

---
 src/calng/agipd_gpu.py          |  1 +
 src/calng/agipd_gpu_kernels.cpp | 17 ++++++-----------
 src/calng/utils.py              |  8 ++++++++
 3 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py
index 3c6f010c..c445ae6a 100644
--- a/src/calng/agipd_gpu.py
+++ b/src/calng/agipd_gpu.py
@@ -145,6 +145,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
                 "constant_memory_cells": self.constant_memory_cells,
                 "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype),
                 "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype),
+                "corr_enum": utils.enum_to_c_template(CorrectionFlags),
             }
         )
         self.source_module = cupy.RawModule(code=kernel_source)
diff --git a/src/calng/agipd_gpu_kernels.cpp b/src/calng/agipd_gpu_kernels.cpp
index d11af973..e5fc623d 100644
--- a/src/calng/agipd_gpu_kernels.cpp
+++ b/src/calng/agipd_gpu_kernels.cpp
@@ -1,12 +1,7 @@
 #include <cuda_fp16.h>
 #include <math_constants.h>
 
-const unsigned char CORR_THRESHOLD = 1;
-const unsigned char CORR_OFFSET = 2;
-const unsigned char CORR_BLSHIFT = 4;
-const unsigned char CORR_REL_GAIN_PC = 8;
-const unsigned char CORR_REL_GAIN_XRAY = 16;
-const unsigned char CORR_BPMASK = 32;
+{{corr_enum}}
 
 extern "C" {
 	/*
@@ -82,7 +77,7 @@ extern "C" {
 
 		if (map_cell < map_cells) {
 			unsigned char gain = 0;
-			if (corr_flags & CORR_THRESHOLD) {
+			if (corr_flags & THRESHOLD) {
 				const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold +
 														map_cell * threshold_map_stride_cell +
 														y * threshold_map_stride_y +
@@ -110,20 +105,20 @@ extern "C" {
 				y * gm_map_stride_y +
 				x * gm_map_stride_x;
 
-			if ((corr_flags & CORR_BPMASK) && bad_pixel_map[gm_map_index]) {
+			if ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index]) {
 				corrected = CUDART_NAN_F;
 			} else {
-				if (corr_flags & CORR_OFFSET) {
+				if (corr_flags & OFFSET) {
 					corrected -= offset_map[gm_map_index];
 				}
 				// TODO: baseline shift
-				if (corr_flags & CORR_REL_GAIN_PC) {
+				if (corr_flags & REL_GAIN_PC) {
 					corrected *= rel_gain_pc_map[gm_map_index];
 					if (gain == 1) {
 						corrected += md_additional_offset[map_index];
 					}
 				}
-				if (corr_flags & CORR_REL_GAIN_XRAY) {
+				if (corr_flags & REL_GAIN_XRAY) {
 					// TODO
 					//corrected *= rel_gain_xray_map[map_index];
 				}
diff --git a/src/calng/utils.py b/src/calng/utils.py
index 424e45de..9b9a3786 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -31,6 +31,14 @@ def np_dtype_to_c_type(dtype):
     return _np_typechar_to_c_typestring[as_char]
 
 
+def enum_to_c_template(enum_class):
+    res = [f"enum {enum_class.__name__} {{"]
+    for field in enum_class:
+        res.append(f"\t{field.name} = {field.value},")
+    res.append("};")
+    return "\n".join(res)
+
+
 def ceil_div(num, denom):
     return (num + denom - 1) // denom
 
-- 
GitLab