diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py
index 8fddfae1a28a5e1c818d594be9f5145190de8031..b666dc468705a01ebb2d11bcb8f43f7b6c4b6e94 100644
--- a/src/calng/base_kernel_runner.py
+++ b/src/calng/base_kernel_runner.py
@@ -1,6 +1,7 @@
 import enum
 import functools
 import itertools
+import operator
 import pathlib
 
 import jinja2
@@ -11,6 +12,16 @@ from . import utils
 
 class BaseKernelRunner:
     _gpu_based = True
+    _xp = None  # subclass sets numpy or cupy
+
+    def _pre_init(self):
+        # can be used, for example, to import cupy and set as _xp at runtime
+        pass
+
+    def _post_init(self):
+        # can be used to set GPU / CPU-specific buffers in subclasses
+        # without overriding __init__
+        pass
 
     def __init__(
         self,
@@ -20,6 +31,7 @@ class BaseKernelRunner:
         constant_memory_cells,
         output_data_dtype=np.float32,
     ):
+        self._pre_init()
         self.pixels_x = pixels_x
         self.pixels_y = pixels_y
         self.frames = frames
@@ -29,6 +41,7 @@ class BaseKernelRunner:
         else:
             self.constant_memory_cells = constant_memory_cells
         self.output_data_dtype = output_data_dtype
+        self._post_init()
 
     @property
     def preview_shape(self):
@@ -87,22 +100,19 @@ class BaseKernelRunner:
             raise ValueError(f"Memory cell index {preview_index} out of range")
 
         if preview_index >= 0:
-
-            def fun(a):
-                return a[preview_index]
-
+            fun = operator.itemgetter(preview_index)
         elif preview_index == -1:
             # note: separate from next case because dtype not applicable here
             fun = functools.partial(self._xp.nanmax, axis=0)
         elif preview_index in (-2, -3, -4):
             fun = functools.partial(
                 {
-                    -2: np.nanmean,
-                    -3: np.nansum,
-                    -4: np.nanstd,
+                    -2: self._xp.nanmean,
+                    -3: self._xp.nansum,
+                    -4: self._xp.nanstd,
                 }[preview_index],
                 axis=0,
-                dtype=np.float32,
+                dtype=self._xp.float32,
             )
         # TODO: reuse output buffers
         # TODO: integrate multithreading
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index d745f0c8098d95afeed63b195853045e68905909..294fc447b26a0f2f39bcaa279153f2f116722ffc 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -91,11 +91,6 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
         hg_hard_threshold=2000,
         override_md_additional_offset=None,
     ):
-        self.gain_mode = gain_mode
-        if self.gain_mode is GainModes.ADAPTIVE_GAIN:
-            self.default_gain = self._xp.uint8(gain_mode)
-        else:
-            self.default_gain = self._xp.uint8(gain_mode - 1)
         super().__init__(
             pixels_x,
             pixels_y,
@@ -103,6 +98,11 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
             constant_memory_cells,
             output_data_dtype,
         )
+        self.gain_mode = gain_mode
+        if self.gain_mode is GainModes.ADAPTIVE_GAIN:
+            self.default_gain = self._xp.uint8(gain_mode)
+        else:
+            self.default_gain = self._xp.uint8(gain_mode - 1)
         self.gain_map = self._xp.empty(self.processed_shape, dtype=np.float32)
 
         # constants
@@ -291,31 +291,7 @@ class AgipdCpuRunner(AgipdBaseRunner):
             self.processed_data,
         )
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        gain_mode=GainModes.ADAPTIVE_GAIN,
-        g_gain_value=1,
-        mg_hard_threshold=2000,
-        hg_hard_threshold=2000,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-            bad_pixel_mask_value,
-            gain_mode,
-            g_gain_value,
-            mg_hard_threshold,
-            hg_hard_threshold,
-        )
+    def _post_init(self):
         self.input_data = None
         self.cell_table = None
         # NOTE: CPU kernel does not yet support anything other than float32
@@ -333,41 +309,19 @@ class AgipdCpuRunner(AgipdBaseRunner):
 class AgipdGpuRunner(AgipdBaseRunner):
     _gpu_based = True
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        gain_mode=GainModes.ADAPTIVE_GAIN,
-        g_gain_value=1,
-        mg_hard_threshold=2000,
-        hg_hard_threshold=2000,
-    ):
+    def _pre_init(self):
         import cupy as cp
 
         self._xp = cp
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-            bad_pixel_mask_value,
-            gain_mode,
-            g_gain_value,
-            mg_hard_threshold,
-            hg_hard_threshold,
-        )
-        self.input_data = cp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = cp.empty(frames, dtype=np.uint16)
+
+    def _post_init(self):
+        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
+        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
         self.block_shape = (1, 1, 64)
         self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks(
             self.processed_shape, self.block_shape
         )
-        self.correction_kernel = cp.RawModule(
+        self.correction_kernel = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("agipd_gpu.cu").render(
                 {
                     "pixels_x": self.pixels_x,
diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py
index 911329607a81d8bd55bb96d4282f711e05a9b686..7d22253b1cf1d8e0f81f2ed0ef2ab6628190b1b5 100644
--- a/src/calng/corrections/DsscCorrection.py
+++ b/src/calng/corrections/DsscCorrection.py
@@ -82,21 +82,7 @@ class DsscCpuRunner(DsscBaseRunner):
     _gpu_based = False
     _xp = np
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
+    def _post_init(self):
         self.input_data = None
         self.cell_table = None
         from ..kernels import dssc_cython
@@ -120,31 +106,19 @@ class DsscCpuRunner(DsscBaseRunner):
 class DsscGpuRunner(DsscBaseRunner):
     _gpu_based = True
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
+    def _pre_init(self):
         import cupy as cp
 
         self._xp = cp
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.input_data = cp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = cp.empty(self.frames, dtype=np.uint16)
+
+    def _post_init(self):
+        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
+        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
         self.block_shape = (1, 1, 64)
         self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks(
             self.input_shape, self.block_shape
         )
-        self.correction_kernel = cp.RawModule(
+        self.correction_kernel = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("dssc_gpu.cu").render(
                 {
                     "pixels_x": self.pixels_x,
diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py
index 68004d3b349c7bbe6f6d921c00a98d28a43ef7b5..79314a48eba76234f9651a897bd98847921b4205 100644
--- a/src/calng/corrections/JungfrauCorrection.py
+++ b/src/calng/corrections/JungfrauCorrection.py
@@ -161,33 +161,20 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
 class JungfrauGpuRunner(JungfrauBaseRunner):
     _gpu_based = True
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-    ):
+    def _pre_init(self):
         import cupy as cp
 
         self._xp = cp
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.input_data = cp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = cp.empty(self.frames, dtype=np.uint8)
+
+    def _post_init(self):
+        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
+        self.cell_table = self._xp.empty(self.frames, dtype=np.uint8)
         self.block_shape = (1, 1, 64)
         self.grid_shape = base_kernel_runner.grid_to_cover_shape_with_blocks(
             self.input_shape, self.block_shape
         )
 
-        source_module = cp.RawModule(
+        source_module = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("jungfrau_gpu.cu").render(
                 {
                     "pixels_x": self.pixels_x,
@@ -245,23 +232,7 @@ class JungfrauCpuRunner(JungfrauBaseRunner):
     _gpu_based = False
     _xp = np
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,  # TODO: configurable
-        bad_pixel_mask_value=np.nan,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-
+    def _post_init(self):
         # not actually allocating, will just point to incoming data
         self.input_data = None
         self.input_gain_stage = None