From 47314d479672ba003251e4d92bdd8c74fcc8dd6d Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Tue, 10 May 2022 13:18:54 +0200
Subject: [PATCH] LPD: DRYing constant loading, swapping axes for FF constants

---
 src/calng/LpdCorrection.py | 65 +++++++++++++++++---------------------
 1 file changed, 29 insertions(+), 36 deletions(-)

diff --git a/src/calng/LpdCorrection.py b/src/calng/LpdCorrection.py
index 3a9ca72f..b906d795 100644
--- a/src/calng/LpdCorrection.py
+++ b/src/calng/LpdCorrection.py
@@ -14,15 +14,6 @@ from ._version import version as deviceVersion
 from .base_correction import BaseCorrection, add_correction_step_schema
 
 
-class CorrectionFlags(enum.IntFlag):
-    NONE = 0
-    OFFSET = 1
-    GAIN_AMP = 2
-    REL_GAIN = 4
-    FF_CORR = 8
-    BPMASK = 16
-
-
 class LpdConstants(enum.Enum):
     Offset = enum.auto()
     BadPixelsDark = enum.auto()
@@ -32,6 +23,15 @@ class LpdConstants(enum.Enum):
     BadPixelsFF = enum.auto()
 
 
+class CorrectionFlags(enum.IntFlag):
+    NONE = 0
+    OFFSET = 1
+    GAIN_AMP = 2
+    REL_GAIN = 4
+    FF_CORR = 8
+    BPMASK = 16
+
+
 class LpdGpuRunner(base_gpu.BaseGpuRunner):
     _kernel_source_filename = "lpd_gpu.cu"
     _corrected_axis_order = "cxy"
@@ -91,43 +91,36 @@ class LpdGpuRunner(base_gpu.BaseGpuRunner):
         )
 
     def load_constant(self, constant_type, constant_data):
-        if constant_type is LpdConstants.Offset:
-            self.offset_map_gpu.set(
-                np.transpose(
-                    constant_data.astype(np.float32),
-                    (2, 1, 0, 3),
-                )
-            )
-        elif constant_type in (LpdConstants.BadPixelsDark, LpdConstants.BadPixelsFF):
+        # constant type → transpose order
+        bad_pixel_loading = {
+            LpdConstants.BadPixelsDark: (2, 1, 0, 3),
+            LpdConstants.BadPixelsFF: (2, 0, 1, 3),
+        }
+        # constant type → transpose order, GPU buffer
+        other_constant_loading = {
+            LpdConstants.Offset: ((2, 1, 0, 3), self.offset_map_gpu),
+            LpdConstants.GainAmpMap: ((2, 1, 0, 3), self.gain_amp_map_gpu),
+            LpdConstants.FFMap: ((2, 0, 1, 3), self.flatfield_map_gpu),
+            LpdConstants.RelativeGain: ((2, 1, 0, 3), self.rel_gain_slopes_map_gpu),
+        }
+        if constant_type in bad_pixel_loading:
             self.bad_pixel_map_gpu |= cupy.asarray(
                 np.transpose(
                     constant_data,
-                    (2, 1, 0, 3),
+                    bad_pixel_loading[constant_type],
                 ),
                 dtype=np.uint32,
             )
-        elif constant_type is LpdConstants.GainAmpMap:
-            self.gain_amp_map_gpu.set(
-                np.transpose(
-                    constant_data.astype(np.float32),
-                    (2, 1, 0, 3),
-                )
-            )
-        elif constant_type is LpdConstants.FFMap:
-            self.flatfield_map_gpu.set(
-                np.transpose(
-                    constant_data.astype(np.float32),
-                    (2, 1, 0, 3),
-                )
-            )
-        elif constant_type is LpdConstants.RelativeGain:
-            # TODO: figure out where to get
-            self.rel_gain_slopes_map_gpu.set(
+        elif constant_type in other_constant_loading:
+            transpose_order, gpu_buffer = other_constant_loading[constant_type]
+            gpu_buffer.set(
                 np.transpose(
                     constant_data.astype(np.float32),
-                    (2, 1, 0, 3),
+                    transpose_order,
                 )
             )
+        else:
+            raise ValueError(f"Unhandled constant type {constant_type}")
 
     def _init_kernels(self):
         kernel_source = self._kernel_template.render(
-- 
GitLab