diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py index 3c6f010c60261bfee8c9eef7fe7eb327d76d8abc..c445ae6a2358cb4812495a67bd2703a8dbb00201 100644 --- a/src/calng/agipd_gpu.py +++ b/src/calng/agipd_gpu.py @@ -145,6 +145,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): "constant_memory_cells": self.constant_memory_cells, "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), + "corr_enum": utils.enum_to_c_template(CorrectionFlags), } ) self.source_module = cupy.RawModule(code=kernel_source) diff --git a/src/calng/agipd_gpu_kernels.cpp b/src/calng/agipd_gpu_kernels.cpp index d11af9737db09bbcb6eaf8f62a0881c1388feefc..e5fc623de00e232fffc7325712c18c08e45b2a69 100644 --- a/src/calng/agipd_gpu_kernels.cpp +++ b/src/calng/agipd_gpu_kernels.cpp @@ -1,12 +1,7 @@ #include <cuda_fp16.h> #include <math_constants.h> -const unsigned char CORR_THRESHOLD = 1; -const unsigned char CORR_OFFSET = 2; -const unsigned char CORR_BLSHIFT = 4; -const unsigned char CORR_REL_GAIN_PC = 8; -const unsigned char CORR_REL_GAIN_XRAY = 16; -const unsigned char CORR_BPMASK = 32; +{{corr_enum}} extern "C" { /* @@ -82,7 +77,7 @@ extern "C" { if (map_cell < map_cells) { unsigned char gain = 0; - if (corr_flags & CORR_THRESHOLD) { + if (corr_flags & THRESHOLD) { const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold + map_cell * threshold_map_stride_cell + y * threshold_map_stride_y + @@ -110,20 +105,20 @@ extern "C" { y * gm_map_stride_y + x * gm_map_stride_x; - if ((corr_flags & CORR_BPMASK) && bad_pixel_map[gm_map_index]) { + if ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index]) { corrected = CUDART_NAN_F; } else { - if (corr_flags & CORR_OFFSET) { + if (corr_flags & OFFSET) { corrected -= offset_map[gm_map_index]; } // TODO: baseline shift - if (corr_flags & CORR_REL_GAIN_PC) { + if (corr_flags & REL_GAIN_PC) { corrected *= rel_gain_pc_map[gm_map_index]; if (gain == 1) { corrected += md_additional_offset[map_index]; } } - if (corr_flags & CORR_REL_GAIN_XRAY) { + if (corr_flags & REL_GAIN_XRAY) { // TODO //corrected *= rel_gain_xray_map[map_index]; } diff --git a/src/calng/utils.py b/src/calng/utils.py index 424e45de992a6f2e7599ed089d27f4ba72ec798c..9b9a3786acbc4d620a93f511ae7acd7f0e8dcbbf 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -31,6 +31,14 @@ def np_dtype_to_c_type(dtype): return _np_typechar_to_c_typestring[as_char] +def enum_to_c_template(enum_class): + res = [f"enum {enum_class.__name__} {{"] + for field in enum_class: + res.append(f"\t{field.name} = {field.value},") + res.append("};") + return "\n".join(res) + + def ceil_div(num, denom): return (num + denom - 1) // denom