diff --git a/src/calng/LpdCorrection.py b/src/calng/LpdCorrection.py index 3a9ca72faf4e064a085ec4478a3a0d800218071b..b906d79528c4eea3bedaa67f37ac2972bfb60117 100644 --- a/src/calng/LpdCorrection.py +++ b/src/calng/LpdCorrection.py @@ -14,15 +14,6 @@ from ._version import version as deviceVersion from .base_correction import BaseCorrection, add_correction_step_schema -class CorrectionFlags(enum.IntFlag): - NONE = 0 - OFFSET = 1 - GAIN_AMP = 2 - REL_GAIN = 4 - FF_CORR = 8 - BPMASK = 16 - - class LpdConstants(enum.Enum): Offset = enum.auto() BadPixelsDark = enum.auto() @@ -32,6 +23,15 @@ class LpdConstants(enum.Enum): BadPixelsFF = enum.auto() +class CorrectionFlags(enum.IntFlag): + NONE = 0 + OFFSET = 1 + GAIN_AMP = 2 + REL_GAIN = 4 + FF_CORR = 8 + BPMASK = 16 + + class LpdGpuRunner(base_gpu.BaseGpuRunner): _kernel_source_filename = "lpd_gpu.cu" _corrected_axis_order = "cxy" @@ -91,43 +91,36 @@ class LpdGpuRunner(base_gpu.BaseGpuRunner): ) def load_constant(self, constant_type, constant_data): - if constant_type is LpdConstants.Offset: - self.offset_map_gpu.set( - np.transpose( - constant_data.astype(np.float32), - (2, 1, 0, 3), - ) - ) - elif constant_type in (LpdConstants.BadPixelsDark, LpdConstants.BadPixelsFF): + # constant type → transpose order + bad_pixel_loading = { + LpdConstants.BadPixelsDark: (2, 1, 0, 3), + LpdConstants.BadPixelsFF: (2, 0, 1, 3), + } + # constant type → transpose order, GPU buffer + other_constant_loading = { + LpdConstants.Offset: ((2, 1, 0, 3), self.offset_map_gpu), + LpdConstants.GainAmpMap: ((2, 1, 0, 3), self.gain_amp_map_gpu), + LpdConstants.FFMap: ((2, 0, 1, 3), self.flatfield_map_gpu), + LpdConstants.RelativeGain: ((2, 1, 0, 3), self.rel_gain_slopes_map_gpu), + } + if constant_type in bad_pixel_loading: self.bad_pixel_map_gpu |= cupy.asarray( np.transpose( constant_data, - (2, 1, 0, 3), + bad_pixel_loading[constant_type], ), dtype=np.uint32, ) - elif constant_type is LpdConstants.GainAmpMap: - self.gain_amp_map_gpu.set( - np.transpose( - constant_data.astype(np.float32), - (2, 1, 0, 3), - ) - ) - elif constant_type is LpdConstants.FFMap: - self.flatfield_map_gpu.set( - np.transpose( - constant_data.astype(np.float32), - (2, 1, 0, 3), - ) - ) - elif constant_type is LpdConstants.RelativeGain: - # TODO: figure out where to get - self.rel_gain_slopes_map_gpu.set( + elif constant_type in other_constant_loading: + transpose_order, gpu_buffer = other_constant_loading[constant_type] + gpu_buffer.set( np.transpose( constant_data.astype(np.float32), - (2, 1, 0, 3), + transpose_order, ) ) + else: + raise ValueError(f"Unhandled constant type {constant_type}") def _init_kernels(self): kernel_source = self._kernel_template.render(