diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py index 8fddfae1a28a5e1c818d594be9f5145190de8031..b666dc468705a01ebb2d11bcb8f43f7b6c4b6e94 100644 --- a/src/calng/base_kernel_runner.py +++ b/src/calng/base_kernel_runner.py @@ -1,6 +1,7 @@ import enum import functools import itertools +import operator import pathlib import jinja2 @@ -11,6 +12,16 @@ from . import utils class BaseKernelRunner: _gpu_based = True + _xp = None # subclass sets numpy or cupy + + def _pre_init(self): + # can be used, for example, to import cupy and set as _xp at runtime + pass + + def _post_init(self): + # can be used to set GPU / CPU-specific buffers in subclasses + # without overriding __init__ + pass def __init__( self, @@ -20,6 +31,7 @@ class BaseKernelRunner: constant_memory_cells, output_data_dtype=np.float32, ): + self._pre_init() self.pixels_x = pixels_x self.pixels_y = pixels_y self.frames = frames @@ -29,6 +41,7 @@ class BaseKernelRunner: else: self.constant_memory_cells = constant_memory_cells self.output_data_dtype = output_data_dtype + self._post_init() @property def preview_shape(self): @@ -87,22 +100,19 @@ class BaseKernelRunner: raise ValueError(f"Memory cell index {preview_index} out of range") if preview_index >= 0: - - def fun(a): - return a[preview_index] - + fun = operator.itemgetter(preview_index) elif preview_index == -1: # note: separate from next case because dtype not applicable here fun = functools.partial(self._xp.nanmax, axis=0) elif preview_index in (-2, -3, -4): fun = functools.partial( { - -2: np.nanmean, - -3: np.nansum, - -4: np.nanstd, + -2: self._xp.nanmean, + -3: self._xp.nansum, + -4: self._xp.nanstd, }[preview_index], axis=0, - dtype=np.float32, + dtype=self._xp.float32, ) # TODO: reuse output buffers # TODO: integrate multithreading diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py index d745f0c8098d95afeed63b195853045e68905909..294fc447b26a0f2f39bcaa279153f2f116722ffc 100644 --- a/src/calng/corrections/AgipdCorrection.py +++ b/src/calng/corrections/AgipdCorrection.py @@ -91,11 +91,6 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner): hg_hard_threshold=2000, override_md_additional_offset=None, ): - self.gain_mode = gain_mode - if self.gain_mode is GainModes.ADAPTIVE_GAIN: - self.default_gain = self._xp.uint8(gain_mode) - else: - self.default_gain = self._xp.uint8(gain_mode - 1) super().__init__( pixels_x, pixels_y, @@ -103,6 +98,11 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner): constant_memory_cells, output_data_dtype, ) + self.gain_mode = gain_mode + if self.gain_mode is GainModes.ADAPTIVE_GAIN: + self.default_gain = self._xp.uint8(gain_mode) + else: + self.default_gain = self._xp.uint8(gain_mode - 1) self.gain_map = self._xp.empty(self.processed_shape, dtype=np.float32) # constants @@ -291,31 +291,7 @@ class AgipdCpuRunner(AgipdBaseRunner): self.processed_data, ) - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - gain_mode=GainModes.ADAPTIVE_GAIN, - g_gain_value=1, - mg_hard_threshold=2000, - hg_hard_threshold=2000, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - bad_pixel_mask_value, - gain_mode, - g_gain_value, - mg_hard_threshold, - hg_hard_threshold, - ) + def _post_init(self): self.input_data = None self.cell_table = None # NOTE: CPU kernel does not yet support anything other than float32 @@ -333,41 +309,19 @@ class AgipdCpuRunner(AgipdBaseRunner): class AgipdGpuRunner(AgipdBaseRunner): _gpu_based = True - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - gain_mode=GainModes.ADAPTIVE_GAIN, - g_gain_value=1, - mg_hard_threshold=2000, - hg_hard_threshold=2000, - ): + def _pre_init(self): import cupy as cp self._xp = cp - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - bad_pixel_mask_value, - gain_mode, - g_gain_value, - mg_hard_threshold, - hg_hard_threshold, - ) - self.input_data = cp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = cp.empty(frames, dtype=np.uint16) + + def _post_init(self): + self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) + self.cell_table = self._xp.empty(self.frames, dtype=np.uint16) self.block_shape = (1, 1, 64) self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks( self.processed_shape, self.block_shape ) - self.correction_kernel = cp.RawModule( + self.correction_kernel = self._xp.RawModule( code=base_kernel_runner.get_kernel_template("agipd_gpu.cu").render( { "pixels_x": self.pixels_x, diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py index 911329607a81d8bd55bb96d4282f711e05a9b686..7d22253b1cf1d8e0f81f2ed0ef2ab6628190b1b5 100644 --- a/src/calng/corrections/DsscCorrection.py +++ b/src/calng/corrections/DsscCorrection.py @@ -82,21 +82,7 @@ class DsscCpuRunner(DsscBaseRunner): _gpu_based = False _xp = np - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) + def _post_init(self): self.input_data = None self.cell_table = None from ..kernels import dssc_cython @@ -120,31 +106,19 @@ class DsscCpuRunner(DsscBaseRunner): class DsscGpuRunner(DsscBaseRunner): _gpu_based = True - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - ): + def _pre_init(self): import cupy as cp self._xp = cp - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - self.input_data = cp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = cp.empty(self.frames, dtype=np.uint16) + + def _post_init(self): + self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) + self.cell_table = self._xp.empty(self.frames, dtype=np.uint16) self.block_shape = (1, 1, 64) self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks( self.input_shape, self.block_shape ) - self.correction_kernel = cp.RawModule( + self.correction_kernel = self._xp.RawModule( code=base_kernel_runner.get_kernel_template("dssc_gpu.cu").render( { "pixels_x": self.pixels_x, diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py index 68004d3b349c7bbe6f6d921c00a98d28a43ef7b5..79314a48eba76234f9651a897bd98847921b4205 100644 --- a/src/calng/corrections/JungfrauCorrection.py +++ b/src/calng/corrections/JungfrauCorrection.py @@ -161,33 +161,20 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): class JungfrauGpuRunner(JungfrauBaseRunner): _gpu_based = True - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - ): + def _pre_init(self): import cupy as cp self._xp = cp - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - self.input_data = cp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = cp.empty(self.frames, dtype=np.uint8) + + def _post_init(self): + self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) + self.cell_table = self._xp.empty(self.frames, dtype=np.uint8) self.block_shape = (1, 1, 64) self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks( self.input_shape, self.block_shape ) - source_module = cp.RawModule( + source_module = self._xp.RawModule( code=base_kernel_runner.get_kernel_template("jungfrau_gpu.cu").render( { "pixels_x": self.pixels_x, @@ -245,23 +232,7 @@ class JungfrauCpuRunner(JungfrauBaseRunner): _gpu_based = False _xp = np - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, # TODO: configurable - bad_pixel_mask_value=np.nan, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - + def _post_init(self): # not actually allocating, will just point to incoming data self.input_data = None self.input_gain_stage = None