diff --git a/DEPENDS b/DEPENDS
index 6dbf7d2159b1530f96346ea641ac0f5116aacd47..53d4298919b3f4496601016e71e66b28d9ae9d4f 100644
--- a/DEPENDS
+++ b/DEPENDS
@@ -1,4 +1,4 @@
 TrainMatcher, 2.4.5
 calibrationClient, 11.3.0
 calibration/geometryDevices, 0.0.2
-calibration/calngUtils, 0.0.3
+calibration/calngUtils, 0.0.4
diff --git a/docs/devices.md b/docs/devices.md
index bbdaa9b88e13fd1b68c96996869217ce82327f70..242caf3a832a3a10f6ddef6b47c4563ed2c5d348 100644
--- a/docs/devices.md
+++ b/docs/devices.md
@@ -50,15 +50,13 @@ Therefore, settings related to how this is done should typically be managed via
 The parameters seen in the "Preview" box on the manager overview scene control:
 
 - How to select which frame to preview (for non-negative indices)
-    - See [Index selection mode](schemas/BaseCorrection.md#preview.selectionMode)
 	- Frame can be extracted directly from image ndarray (`frame` mode), by looking up corresponding frame in cell table (`cell` mode), or by looking up in pulse table (`pulse` mode)
     - For XTDF detectors, the mapping tables are typically `image.cellId` and `image.pulseId`.
 	  Which selection mode makes sense depends on detector, veto pattern, and experimental setup.
     - If the specified cell or pulse is not found in the respective table, a warning is issued and the first frame is sent instead.
 - Which frame or statistic to send
-    - See [Index for preview](schemas/BaseCorrection.md#preview.index)
     - If the index is non-negative, a single frame is sliced from the image data.
-	  How this frame is found depends on the the [index selection mode](schemas/BaseCorrection.md#preview.selectionMode).
+	  How this frame is found depends on the the index selection mode (previous point).
     - If the index is negative, a statistic is computed across all frames for a summary preview.
       Note that this is done individually per pixel and that `NaN` values are ignored.
       Options are:
diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py
index 8de76153d77e6b7ebf8386bbba8867fb7e81a532..4ae0ed05eee7663944351f709690c56b975a2458 100644
--- a/src/calng/CalibrationManager.py
+++ b/src/calng/CalibrationManager.py
@@ -18,7 +18,7 @@ import re
 from tornado.httpclient import AsyncHTTPClient, HTTPError
 from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
 
-from calngUtils import scene_utils
+from calngUtils.scene_utils import recursive_subschema_scene
 from karabo.middlelayer import (
     KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable,
     Slot, Node, Type, Schema, ProxyFactory,
@@ -158,8 +158,8 @@ class PreviewLayerRow(Configurable):
         defaultValue='',
         accessMode=AccessMode.RECONFIGURABLE)
 
-    outputPipeline = String(
-        displayedName='Output pipeline',
+    previewName = String(
+        displayedName='Preview name',
         defaultValue='',
         accessMode=AccessMode.RECONFIGURABLE)
 
@@ -307,7 +307,7 @@ class CalibrationManager(DeviceClientBase, Device):
                 prefix = name[len('browse_schema:'):]
             else:
                 prefix = 'managedKeys'
-            scene_data = scene_utils.recursive_subschema_scene(
+            scene_data = recursive_subschema_scene(
                 self.deviceId,
                 self.getDeviceSchema(),
                 prefix,
@@ -382,12 +382,6 @@ class CalibrationManager(DeviceClientBase, Device):
         accessMode=AccessMode.RECONFIGURABLE,
         assignment=Assignment.MANDATORY)
 
-    previewLayers = VectorHash(
-        displayedName='Preview layers',
-        rows=PreviewLayerRow,
-        accessMode=AccessMode.RECONFIGURABLE,
-        assignment=Assignment.MANDATORY)
-
     @VectorHash(
         displayedName='Device servers',
         description='WARNING: It is strongly recommended to perform a full '
@@ -402,6 +396,21 @@ class CalibrationManager(DeviceClientBase, Device):
         # Switch to UNKNOWN state to suggest operator to restart pipeline.
         self.state = State.UNKNOWN
 
+    previewServer = String(
+        displayedName='Preview device server',
+        description='The server (from "Device servers" list) to use for '
+                    'preview assemblers',
+        accessMode=AccessMode.RECONFIGURABLE,
+        assignment=Assignment.MANDATORY)
+
+    previewLayers = VectorString(
+        displayedName='Preview layers',
+        description='List of previews (like raw, corrected, and anything else '
+                    'detector-specific) to show. Automatically populated based '
+                    'on correction device schema, just listed for debugging.',
+        defaultValue=[],
+        accessMode=AccessMode.READONLY)
+
     geometryDevice = String(
         displayedName='Geometry device',
         description='Device ID for a geometry device defining the detector '
@@ -598,6 +607,10 @@ class CalibrationManager(DeviceClientBase, Device):
         # Inject schema for configuration of managed devices.
         await self._inject_managed_keys()
 
+        # Populate preview layers from correction device schema
+        self.previewLayers = list(self._correction_device_schema
+                                  .hash["preview"].getKeys())
+
         # Populate the device ID sets with what's out there right now.
         await self._check_topology()
 
@@ -1262,12 +1275,9 @@ class CalibrationManager(DeviceClientBase, Device):
         # Servers by group and layer.
         server_by_group = {group: server for group, server, _, _, _,
                            in self.moduleGroups.value}
-        server_by_layer = {layer: server for layer, _, server
-                           in self.previewLayers.value}
-
-        all_req_servers = set(server_by_group.values()).union(
-            server_by_layer.values())
+        preview_server = self.previewServer.value
 
+        all_req_servers = set(server_by_group.values()) | {preview_server}
         if all_req_servers != up_servers:
             return self._set_error('One or more device servers are not '
                                    'listed in the device servers '
@@ -1352,9 +1362,9 @@ class CalibrationManager(DeviceClientBase, Device):
             ))
 
         # Instantiate preview layer assemblers.
-        for layer, output_pipeline, server in self.previewLayers.value:
+        for preview_name in self.previewLayers.value:
             sources = [Hash('select', True,
-                            'source', f'{device_id}:{output_pipeline}')
+                            'source', f'{device_id}:preview.{preview_name}.output')
                        for (_, device_id)
                        in correct_device_id_by_module.items()]
             config = Hash('sources', sources,
@@ -1367,9 +1377,10 @@ class CalibrationManager(DeviceClientBase, Device):
                     config[remote_key] = value.value
 
             awaitables.append(self._instantiate_device(
-                server,
+                preview_server,
                 self._class_ids['assembler'],
-                self._device_id_templates['assembler'].format(layer=layer),
+                self._device_id_templates['assembler'].format(
+                    layer=preview_name.upper()),
                 config
             ))
 
diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py
index 1f3fb09d0b6c795f4c14903f5e5ae227c56c49a3..b4f6c06d94be9aa7b1ceea21673951818068ed77 100644
--- a/src/calng/DetectorAssembler.py
+++ b/src/calng/DetectorAssembler.py
@@ -27,7 +27,8 @@ from karabo.bound import (
 )
 from TrainMatcher import TrainMatcher
 
-from . import preview_utils, scenes, schemas
+from . import scenes, schemas
+from .preview_utils import PreviewFriend, PreviewSpec
 from ._version import version as deviceVersion
 
 
@@ -73,8 +74,12 @@ def module_index_schema():
 @device_utils.with_unsafe_get
 @KARABO_CLASSINFO("DetectorAssembler", deviceVersion)
 class DetectorAssembler(TrainMatcher.TrainMatcher):
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec("assembled", frame_reduction=False, flip_ss=True, flip_fs=True)
+    ]
+
+    @classmethod
+    def expectedParameters(cls, expected):
         (
             OVERWRITE_ELEMENT(expected)
             .key("availableScenes")
@@ -178,18 +183,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             .reconfigurable()
             .commit(),
         )
-        preview_utils.PreviewFriend.add_schema(expected, "preview", create_node=True)
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.flipSS")
-            .setNewDefaultValue(True)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.flipFS")
-            .setNewDefaultValue(True)
-            .commit(),
-        )
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
 
     def __init__(self, conf):
         super().__init__(conf)
@@ -204,7 +198,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
     def initialization(self):
         super().initialization()
 
-        self._preview_friend = preview_utils.PreviewFriend(self)
+        self._preview_friend = PreviewFriend(self, self._preview_outputs, "preview")
         self._image_data_path = self.get("imageDataPath")
         self._image_mask_path = self.get("imageMaskPath")
         self._geometry = None
@@ -469,8 +463,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             self.zmq_output.update()
 
         self._preview_friend.write_outputs(
-            my_timestamp,
-            np.ma.masked_array(data=assembled_data, mask=assembled_mask),
+            np.ma.masked_array(data=assembled_data, mask=assembled_mask)
         )
 
         self._processing_time_tracker.update(default_timer() - ts_start)
@@ -495,7 +488,8 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
         ):
             self._need_to_update_source_index_mapping = True
 
-        self._preview_friend.reconfigure(conf)
+        if conf.has("preview"):
+            self._preview_friend.reconfigure(conf["preview"])
 
     def postReconfigure(self):
         if self._need_to_update_source_index_mapping:
diff --git a/src/calng/LpdminiSplitter.py b/src/calng/LpdminiSplitter.py
index 5c32d679e305c8586d40e93457721789602586c5..d2fee86678176a520cab2fe521940942734cc710 100644
--- a/src/calng/LpdminiSplitter.py
+++ b/src/calng/LpdminiSplitter.py
@@ -105,7 +105,7 @@ class LpdminiSplitter(PythonDevice):
             (
                 OUTPUT_CHANNEL(expected)
                 .key(channel_name)
-                .dataSchema(schemas.xtdf_output_schema())
+                .dataSchema(schemas.xtdf_output_schema(use_shmem_handles=True))
                 .commit(),
             )
 
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 414b334211ecda2b7dfaa08b83c8a1477a0d7f35..a11f9b969a696211cc95742115b93dc76139eeef 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -1,7 +1,7 @@
 import concurrent.futures
-import enum
+import contextlib
 import functools
-import gc
+import itertools
 import math
 import pathlib
 import threading
@@ -13,17 +13,16 @@ import numpy as np
 from geometryDevices import utils as geom_utils
 from calngUtils import (
     device as device_utils,
-    misc,
     scene_utils,
     shmem_utils,
     timing,
     trackers,
 )
+from calngUtils.misc import ChainHash
 from karabo.bound import (
     BOOL_ELEMENT,
     DOUBLE_ELEMENT,
     INPUT_CHANNEL,
-    INT32_ELEMENT,
     KARABO_CLASSINFO,
     NODE_ELEMENT,
     OUTPUT_CHANNEL,
@@ -45,96 +44,33 @@ from karabo.bound import (
 )
 from karabo.common.api import KARABO_SCHEMA_DISPLAY_TYPE_SCENES as DT_SCENES
 
-from . import preview_utils, schemas, scenes, utils
+from . import scenes
+from .preview_utils import PreviewFriend
+from .utils import StateContext, WarningContextSystem, WarningLampType, subset_of_hash
 from ._version import version as deviceVersion
 
 PROCESSING_STATE_TIMEOUT = 10
 
 
-class FramefilterSpecType(enum.Enum):
-    NONE = "none"
-    RANGE = "range"
-    COMMASEPARATED = "commaseparated"
-
-
-class WarningLampType(enum.Enum):
-    FRAME_FILTER = enum.auto()
-    MEMORY_CELL_RANGE = enum.auto()
-    CONSTANT_OPERATING_PARAMETERS = enum.auto()
-    PREVIEW_SETTINGS = enum.auto()
-    CORRECTION_RUNNER = enum.auto()
-    OUTPUT_BUFFER = enum.auto()
-    GPU_MEMORY = enum.auto()
-    CALCAT_CONNECTION = enum.auto()
-    EMPTY_HASH = enum.auto()
-    MISC_INPUT_DATA = enum.auto()
-    TRAIN_ID = enum.auto()
-    TIMESERVER_CONNECTION = enum.auto()
+class TrainFromTheFutureException(BaseException):
+    pass
 
 
+@device_utils.with_config_overlay
 @device_utils.with_unsafe_get
 @KARABO_CLASSINFO("BaseCorrection", deviceVersion)
 class BaseCorrection(PythonDevice):
     _available_addons = []  # classes, filled by add_addon_nodes using entry_points
     _base_output_schema = None  # subclass must set constructor
     _constant_enum_class = None  # subclass must set
-    _correction_flag_class = None  # subclass must set (ex.: dssc_gpu.CorrectionFlags)
     _correction_steps = None  # subclass must set
     _kernel_runner_class = None  # subclass must set (ex.: dssc_gpu.DsscGpuRunner)
-    _kernel_runner_init_args = {}  # optional extra args for runner
     _image_data_path = "image.data"  # customize for *some* subclasses
     _cell_table_path = "image.cellId"
     _pulse_table_path = "image.pulseId"
     _warn_memory_cell_range = True  # can be disabled for some detectors
     _cuda_pin_buffers = False
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        """Subclass must define how to process constants into correction maps and store
-        into appropriate buffers in (GPU or main) memory."""
-        # note: aim to refactor this away as kernel runner handles loading logic
-        raise NotImplementedError()
-
-    def _successfully_loaded_constant_to_runner(self, constant):
-        for field_name in self._constant_to_correction_names[constant]:
-            key = f"corrections.{field_name}.available"
-            if not self.unsafe_get(key):
-                self.set(key, True)
-
-        self._update_correction_flags()
-        self.log_status_info(f"Done loading {constant.name} to runner")
-
-    @property
-    def input_data_shape(self):
-        """Subclass must define expected input data shape in terms of dataFormat.{
-        memoryCells,pixelsX,pixelsY} and any other axes."""
-        raise NotImplementedError()
-
-    @property
-    def output_data_shape(self):
-        """Shape of corrected image data sent on dataOutput. Depends on data format
-        parameters (optionally including frame filter)."""
-        axis_lengths = {
-            "x": self.unsafe_get("dataFormat.pixelsX"),
-            "y": self.unsafe_get("dataFormat.pixelsY"),
-            "f": self.unsafe_get("dataFormat.filteredFrames"),
-        }
-        return tuple(
-            axis_lengths[axis] for axis in self.unsafe_get("dataFormat.outputAxisOrder")
-        )
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-    ):
-        """Subclass must define data processing (presumably using the kernel runner).
-        Will be called by input_handler, which will take care of some common checks and
-        extracting the parameters given to process_data."""
-        raise NotImplementedError()
+    _use_shmem_handles = True
 
     @staticmethod
     def expectedParameters(expected):
@@ -171,57 +107,6 @@ class BaseCorrection(PythonDevice):
             .defaultValue([])
             .commit(),
 
-            NODE_ELEMENT(expected)
-            .key("frameFilter")
-            .displayedName("Frame filter")
-            .description(
-                "The frame filter - if set - slices the input data. Frames not in the "
-                "filter will be discarded before any processing happens and will not "
-                "get to dataOutput or preview. Note that this filter goes by frame "
-                "index rather than cell ID or pulse ID; set accordingly. Handle with "
-                "care - an invalid filter can prevent all processing. How the filter "
-                "is specified depends on frameFilter.type. See frameFilter.current to "
-                "inspect the currently set frame filter array (if any)."
-            )
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("frameFilter.type")
-            .tags("managed")
-            .displayedName("Type")
-            .description(
-                "Controls how frameFilter.spec is used. The default value of 'none' "
-                "means that no filter is set (regardless of frameFilter.spec). "
-                "'arange' allows between one and three integers separated by ',' which "
-                "are parsed and passed directly to numpy.arange. 'commaseparated' "
-                "reads a list of integers separated by commas."
-            )
-            .options(",".join(spectype.value for spectype in FramefilterSpecType))
-            .assignmentOptional()
-            .defaultValue("none")
-            .reconfigurable()
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("frameFilter.spec")
-            .tags("managed")
-            .displayedName("Specification")
-            .assignmentOptional()
-            .defaultValue("")
-            .reconfigurable()
-            .commit(),
-
-            VECTOR_UINT32_ELEMENT(expected)
-            .key("frameFilter.current")
-            .displayedName("Current filter")
-            .description(
-                "This read-only value is used to display the contents of the current "
-                "frame filter. An empty array means no filtering is done."
-            )
-            .readOnly()
-            .initialValue([])
-            .commit(),
-
             UINT32_ELEMENT(expected)
             .key("outputShmemBufferSize")
             .tags("managed")
@@ -285,14 +170,6 @@ class BaseCorrection(PythonDevice):
             .displayedName("Data format (in/out)")
             .commit(),
 
-            STRING_ELEMENT(expected)
-            .key("dataFormat.inputImageDtype")
-            .displayedName("Input image data dtype")
-            .description("The (numpy) dtype to expect for incoming image data.")
-            .readOnly()
-            .initialValue("uint16")
-            .commit(),
-
             STRING_ELEMENT(expected)
             .key("dataFormat.outputImageDtype")
             .tags("managed")
@@ -304,41 +181,16 @@ class BaseCorrection(PythonDevice):
                 "causes truncation rather than rounding."
             )
             # TODO: consider adding rounding / binning for integer output
-            .options("float16,float32,uint16")
+            .options("float32,float16,uint16")
             .assignmentOptional()
             .defaultValue("float32")
             .reconfigurable()
             .commit(),
 
-            # important: determines shape of data as going into correction
             UINT32_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .displayedName("Pixels x")
-            .description("Number of pixels of image data along X axis")
-            .assignmentOptional()
-            .defaultValue(512)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .displayedName("Pixels y")
-            .description("Number of pixels of image data along Y axis")
-            .assignmentOptional()
-            .defaultValue(128)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.frames")
+            .key("dataFormat.inputFrames")
             .displayedName("Frames")
-            .description("Number of image frames per train in incoming data")
-            .assignmentOptional()
-            .defaultValue(1)  # subclass will want to set a default value
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.filteredFrames")
-            .displayedName("Frames after filter")
-            .description("Number of frames left after applying frame filter")
+            .description("Number of frames on input")
             .readOnly()
             .initialValue(0)
             .commit(),
@@ -349,13 +201,15 @@ class BaseCorrection(PythonDevice):
             .displayedName("Output axis order")
             .description(
                 "Axes of main data output can be reordered after correction. Axis "
-                "order is specified as string consisting of 'x', 'y', and 'f', with "
-                "the latter indicating the image frame. The default value of 'fxy' "
-                "puts pixels on the fast axes."
+                "order is specified as string consisting of 'f' (frames), 'ss' (slow "
+                "scan), and 'ff' (fast scan). Anything but the default f-ss-fs order "
+                "implies reordering - this axis ordering is based on the data we get "
+                "from the detector (/ receiver), so how ss and fs maps to what may be "
+                "considered x and y is detector-dependent."
             )
-            .options("fxy,fyx,xfy,xyf,yfx,yxf")
+            .options("f-ss-fs,f-fs-ss,ss-f-fs,ss-fs-f,fs-f-ss,fs-ss-f")
             .assignmentOptional()
-            .defaultValue("fxy")
+            .defaultValue("f-ss-fs")
             .reconfigurable()
             .commit(),
 
@@ -363,9 +217,9 @@ class BaseCorrection(PythonDevice):
             .key("dataFormat.inputDataShape")
             .displayedName("Input data shape")
             .description(
-                "Image data shape in incoming data (from reader / DAQ). This value is "
-                "computed from pixelsX, pixelsY, and frames - this field just "
-                "shows what is currently expected."
+                "Image data shape in incoming data (from reader / DAQ). Updated based "
+                "on latest train processed. Note that axis order may look wonky due to "
+                "DAQ quirk."
             )
             .readOnly()
             .initialValue([])
@@ -375,9 +229,10 @@ class BaseCorrection(PythonDevice):
             .key("dataFormat.outputDataShape")
             .displayedName("Output data shape")
             .description(
-                "Image data shape for data output from this device. This value is "
-                "computed from pixelsX, pixelsY, and the size of the frame filter - "
-                "this field just shows what is currently expected."
+                "Image data shape for data output from this device. Takes into account "
+                "axis reordering, if applicable. Even without that, "
+                "workarounds.overrideInputAxisOrder likely makes this differ from "
+                "dataFormat.inputDataShape."
             )
             .readOnly()
             .initialValue([])
@@ -472,50 +327,6 @@ class BaseCorrection(PythonDevice):
             .commit(),
         )
 
-        (
-            NODE_ELEMENT(expected)
-            .key("preview")
-            .displayedName("Preview")
-            .commit(),
-
-            INT32_ELEMENT(expected)
-            .key("preview.index")
-            .tags("managed")
-            .displayedName("Index (or stat) for preview")
-            .description(
-                "If this value is ≥ 0, the corresponding index (frame, cell, or pulse) "
-                "will be sliced for the preview output. If this value is < 0, preview "
-                "will be one of the following stats: -1: max, -2: mean, -3: sum, -4: "
-                "stdev. These stats are computed across frames, ignoring NaN values."
-            )
-            .assignmentOptional()
-            .defaultValue(0)
-            .minInc(-4)
-            .reconfigurable()
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .tags("managed")
-            .displayedName("Index selection mode")
-            .description(
-                "The value of preview.index can be used in multiple ways, controlled "
-                "by this value. If this is set to 'frame', preview.index is sliced "
-                "directly from data. If 'cell' (or 'pulse') is selected, I will look "
-                "at cell (or pulse) table for the requested cell (or pulse ID). "
-                "Special (stat) index values <0 are not affected by this."
-            )
-            .options(
-                ",".join(
-                    spectype.value for spectype in utils.PreviewIndexSelectionMode
-                )
-            )
-            .assignmentOptional()
-            .defaultValue("frame")
-            .reconfigurable()
-            .commit(),
-        )
-
         # just measurements and counters to display
         (
             UINT64_ELEMENT(expected)
@@ -618,24 +429,8 @@ class BaseCorrection(PythonDevice):
             config["dataOutput.hostname"] = ib_ip
         super().__init__(config)
 
-        self.output_data_dtype = np.dtype(config["dataFormat.outputImageDtype"])
-
         self.sources = set(config.get("fastSources"))
 
-        self.kernel_runner = None  # must call _update_buffers to initialize
-        self._shmem_buffer = None  # ditto
-        self._use_shmem_handles = config.get("useShmemHandles")
-        self._preview_friend = None
-
-        self._correction_flag_enabled = self._correction_flag_class.NONE
-        self._correction_flag_preview = self._correction_flag_class.NONE
-        self._correction_applied_hash = Hash()
-        # note: does not handle one constant enabling multiple correction steps
-        # (do we need to generalize "decide if this correction step is available"?)
-        self._constant_to_correction_names = {}
-        for (name, _, constants) in self._correction_steps:
-            for constant in constants:
-                self._constant_to_correction_names.setdefault(constant, set()).add(name)
         self._buffer_lock = threading.Lock()
         self._last_processing_started = 0  # used for processing time and timeout
         self._last_train_id_processed = 0  # used to keep track (and as fallback)
@@ -657,43 +452,36 @@ class BaseCorrection(PythonDevice):
                         constant_data = getattr(self.calcat_friend, friend_fun)(
                             constant
                         )
-                        self._load_constant_to_runner(constant, constant_data)
+                        self.kernel_runner.load_constant(constant, constant_data)
                     except Exception as e:
                         warn(f"Failed to load {constant}: {e}")
-                    else:
-                        self._successfully_loaded_constant_to_runner(constant)
                 # note: consider if multi-constant buffers need special care
             self._lock_and_update(aux)
 
-        for constant in self._constant_enum_class:
-            self.KARABO_SLOT(
-                functools.partial(
-                    constant_override_fun,
-                    friend_fun="get_constant_version",
-                    constant=constant,
-                    preserve_fields=set(),
+        for constant, (friend_fun, preserved_fields, slot_name) in itertools.product(
+            self._constant_enum_class,
+            (
+                ("get_constant_version", set(), "loadMostRecent"),
+                (
+                    "get_constant_from_constant_version_id",
+                    {"constantVersionId"},
+                    "overrideConstantFromVersion",
                 ),
-                slotName=f"foundConstants_{constant.name}_loadMostRecent",
-                numArgs=0,
-            )
-            self.KARABO_SLOT(
-                functools.partial(
-                    constant_override_fun,
-                    friend_fun="get_constant_from_constant_version_id",
-                    constant=constant,
-                    preserve_fields={"constantVersionId"},
+                (
+                    "get_constant_from_file",
+                    {"dataFilePath", "dataSetName"},
+                    "overrideConstantFromFile",
                 ),
-                slotName=f"foundConstants_{constant.name}_overrideConstantFromVersion",
-                numArgs=0,
-            )
+            ),
+        ):
             self.KARABO_SLOT(
                 functools.partial(
                     constant_override_fun,
-                    friend_fun="get_constant_from_file",
-                    constant=constant,
-                    preserve_fields={"dataFilePath", "dataSetName"},
+                    friend_fun,
+                    constant,
+                    preserved_fields,
                 ),
-                slotName=f"foundConstants_{constant.name}_overrideConstantFromFile",
+                slotName=f"foundConstants_{constant.name}_{slot_name}",
                 numArgs=0,
             )
 
@@ -704,30 +492,27 @@ class BaseCorrection(PythonDevice):
 
     def _initialization(self):
         self.updateState(State.INIT)
-        self.warning_context = utils.WarningContextSystem(
+        self.warning_context = WarningContextSystem(
             self,
             on_success={
                 f"foundConstants.{constant.name}.state": "ON"
                 for constant in self._constant_enum_class
             },
         )
-        self._preview_friend = preview_utils.PreviewFriend(
-            self,
-            output_channels=self._preview_outputs,
+        self.log.DEBUG("Opening shmem buffer")
+        self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
+            self.get("outputShmemBufferSize") * 2**30,
+            # TODO: have it just allocate memory, then set_shape later per train
+            (1,),
+            np.float32,
+            self.getInstanceId() + ":dataOutput",
         )
-        self["availableScenes"] = self["availableScenes"] + [
-            f"preview:{channel}" for channel in self._preview_outputs
-        ]
-
-        self._geometry = None
-        if self.get("geometryDevice"):
-            self.signalSlotable.connect(
-                self.get("geometryDevice"),
-                "signalNewGeometry",
-                "",  # slot device ID (default: self)
-                "slotReceiveGeometry",
-            )
+        if self._cuda_pin_buffers:
+            self.log.DEBUG("Trying to pin the shmem buffer memory")
+            self._shmem_buffer.cuda_pin()
+        self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver()
 
+        # CalCat friend comes before kernel runner
         with self.warning_context(
             "deviceInternalsState", WarningLampType.CALCAT_CONNECTION
         ) as warn:
@@ -739,18 +524,21 @@ class BaseCorrection(PythonDevice):
                 warn(f"Failed to connect to CalCat: {e}")
                 # TODO: add raw fallback mode if CalCat fails (raw data still useful)
                 return
+        self.kernel_runner = self._kernel_runner_class(self)
+        self.preview_friend = PreviewFriend(self, self._preview_outputs)
+        # TODO: can be static OVERWRITE_ELEMENT in expectedParameters
+        self["availableScenes"] = self["availableScenes"] + [
+            f"preview:{spec.name}" for spec in self._preview_outputs
+        ]
 
-        with self.warning_context(
-            "processingState", WarningLampType.FRAME_FILTER
-        ) as warn:
-            try:
-                self._frame_filter = _parse_frame_filter(self._parameters)
-            except (ValueError, TypeError):
-                warn("Failed to parse initial frame filter, will not use")
-                self._frame_filter = None
-        # TODO: decide how persistent initial warning should be
-        self._update_correction_flags()
-        self._update_frame_filter()
+        self.geometry = None
+        if self.get("geometryDevice"):
+            self.signalSlotable.connect(
+                self.get("geometryDevice"),
+                "signalNewGeometry",
+                "",  # slot device ID (default: self)
+                "slotReceiveGeometry",
+            )
 
         self._buffered_status_update = Hash()
         self._processing_time_tracker = trackers.ExponentialMovingAverage(
@@ -766,7 +554,7 @@ class BaseCorrection(PythonDevice):
         )
         self._train_ratio_tracker = trackers.TrainRatioTracker()
 
-        self.KARABO_ON_INPUT("dataInput", self.input_handler)
+        self.KARABO_ON_DATA("dataInput", self.input_handler)
         self.KARABO_ON_EOS("dataInput", self.handle_eos)
 
         self._enabled_addons = [
@@ -776,9 +564,14 @@ class BaseCorrection(PythonDevice):
         ]
         for addon in self._enabled_addons:
             addon._device = self
-        if self._enabled_addons:
+        if (
+            (self.get("useShmemHandles") != self._use_shmem_handles)
+            or self._enabled_addons
+        ):
             schema_override = Schema()
-            output_schema_override = self._base_output_schema
+            output_schema_override = self._base_output_schema(
+                use_shmem_handles=self.get("useShmemHandles")
+            )
             for addon in self._enabled_addons:
                 addon.extend_output_schema(output_schema_override)
             (
@@ -794,7 +587,7 @@ class BaseCorrection(PythonDevice):
         self.updateState(State.ON)
 
     def __del__(self):
-        del self._shmem_buffer
+        del self.shmem_buffer
         super().__del__()
 
     def preReconfigure(self, config):
@@ -812,64 +605,31 @@ class BaseCorrection(PythonDevice):
                     timestamp = dateutil.parser.isoparse(ts_string)
                     config.set(ts_path, timestamp.isoformat())
 
-        if config.has("constantParameters.deviceMappingSnapshotAt"):
-            self.calcat_friend.flush_pdu_mapping()
+        with self.config_overlay(config):
+            if config.has("constantParameters.deviceMappingSnapshotAt"):
+                self.calcat_friend.flush_pdu_mapping()
 
-        # update device based on changes
-        if config.has("frameFilter"):
-            self._frame_filter = _parse_frame_filter(
-                misc.ChainHash(config, self._parameters)
-            )
-
-        self._prereconfigure_update_hash = config
-
-    def postReconfigure(self):
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            self.log_status_warn("postReconfigure without knowing update hash")
-            return
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("frameFilter"):
-            self._lock_and_update(self._update_frame_filter)
-        elif any(
-            update.has(shape_param)
-            for shape_param in (
-                "dataFormat.pixelsX",
-                "dataFormat.pixelsY",
-                "dataFormat.outputImageDtype",
-                "dataFormat.outputAxisOrder",
-                "dataFormat.frames",
-                "constantParameters.memoryCells",
-                "frameFilter",
-            )
-        ):
-            self._lock_and_update(self._update_buffers)
-
-        if any(
-            (
-                update.has(f"corrections.{field_name}.enable")
-                or update.has(f"corrections.{field_name}.preview")
-            )
-            for field_name, *_ in self._correction_steps
-        ):
-            self._update_correction_flags()
+            if (
+                runner_update := subset_of_hash(
+                    config, "corrections", "dataFormat", "constantParameters"
+                )
+            ):
+                self.kernel_runner.reconfigure(runner_update)
 
-        # TODO: only send subhash of reconfiguration (like with addons)
-        if update.has("preview"):
-            self._preview_friend.reconfigure(update)
+            if config.has("preview"):
+                self.preview_friend.reconfigure(config["preview"])
 
-        if update.has("addons"):
-            # note: can avoid iterating, but it's not that costly
-            for addon in self._enabled_addons:
-                full_path = f"addons.{addon.__class__.__name__}"
-                if update.has(full_path):
-                    addon.reconfigure(update[full_path])
+            if config.has("addons"):
+                # note: can avoid iterating, but it's not that costly
+                for addon in self._enabled_addons:
+                    full_path = f"addons.{addon.__class__.__name__}"
+                    if config.has(full_path):
+                        addon.reconfigure(config[full_path])
 
     def _lock_and_update(self, method, background=True):
         # TODO: securely handle errors (postReconfigure may succeed, device state not)
         def runner():
-            with self._buffer_lock, utils.StateContext(self, State.CHANGING):
+            with self._buffer_lock, StateContext(self, State.CHANGING):
                 method()
 
         if background:
@@ -896,11 +656,9 @@ class BaseCorrection(PythonDevice):
                     ) as warn:
                         try:
                             constant_data = future.result()
-                            self._load_constant_to_runner(constant, constant_data)
+                            self.kernel_runner.load_constant(constant, constant_data)
                         except Exception as e:
                             warn(f"Failed to load {constant}: {e}")
-                        else:
-                            self._successfully_loaded_constant_to_runner(constant)
         self._lock_and_update(aux)
 
     def flush_constants(self, constants=None, preserve_fields=None):
@@ -930,7 +688,6 @@ class BaseCorrection(PythonDevice):
                 self.set(f"corrections.{field_name}.available", False)
         self.calcat_friend.flush_constants(constants, preserve_fields)
         self._reload_constants_from_cache_to_runner(constants)
-        self._update_correction_flags()
 
     def log_status_info(self, msg):
         self.log.INFO(msg)
@@ -954,7 +711,7 @@ class BaseCorrection(PythonDevice):
             payload["data"] = scenes.correction_device_preview(
                 device_id=self.getInstanceId(),
                 schema=self.getFullSchema(),
-                preview_channel=channel_name,
+                name=channel_name,
             )
         elif name.startswith("browse_schema"):
             if ":" in name:
@@ -983,10 +740,35 @@ class BaseCorrection(PythonDevice):
     def slotReceiveGeometry(self, device_id, serialized_geometry):
         self.log.INFO(f"Received geometry from {device_id}")
         try:
-            self._geometry = geom_utils.deserialize_geometry(serialized_geometry)
+            self.geometry = geom_utils.deserialize_geometry(serialized_geometry)
         except Exception as e:
             self.log.WARN(f"Failed to deserialize geometry; {e}")
 
+    def _get_data_from_hash(self, data_hash):
+        """Will get image data, cell table, pulse table, and list of other arrays from
+        the data hash. Assumes XTDF (image.data, image.cellId, image.pulseId, and no
+        other data), non-XTDF subclass can override. Cell and pulse table can be
+        None."""
+
+        image_data = data_hash.get(self._image_data_path)
+        cell_table = data_hash[self._cell_table_path].ravel()
+        pulse_table = data_hash[self._pulse_table_path].ravel()
+        num_frames = cell_table.size
+        # DataAggregator typically tells us the wrong axis order
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            # TODO: check that frames are always axis 0
+            expected_shape = self.kernel_runner.expected_input_shape(
+                num_frames
+            )
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
+        return (
+            num_frames,
+            image_data,  # not ndarray (runner can do that)
+            cell_table,
+            pulse_table,
+        )
+
     def _write_output(self, data, old_metadata):
         """For dataOutput: reusing incoming data hash and setting source and timestamp
         to be same as input"""
@@ -994,103 +776,11 @@ class BaseCorrection(PythonDevice):
             old_metadata.get("source"),
             Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")),
         )
-        data["corrections"] = self._correction_applied_hash
 
         channel = self.signalSlotable.getOutputChannel("dataOutput")
         channel.write(data, metadata, copyAllData=False)
         channel.update(safeNDArray=True)
 
-    def _update_correction_flags(self):
-        """Based on constants loaded and settings, update bit mask flags for kernel"""
-        enabled = self._correction_flag_class.NONE
-        output = Hash()  # for informing downstream receivers what was applied
-        preview = self._correction_flag_class.NONE
-        for field_name, flag, constants in self._correction_steps:
-            output[field_name] = False
-            if self.get(f"corrections.{field_name}.available"):
-                if self.get(f"corrections.{field_name}.enable"):
-                    enabled |= flag
-                    output[field_name] = True
-                if self.get(f"corrections.{field_name}.preview"):
-                    preview |= flag
-        self._correction_flag_enabled = enabled
-        self._correction_flag_preview = preview
-        self._correction_applied_hash = output
-        self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
-        self.log.DEBUG(f"Corrections for preview: {str(preview)}")
-
-    def _update_frame_filter(self, update_buffers=True):
-        """Parse frameFilter string (if set) and update cached filter array. May update
-        dataFormat.filteredFrames - will therefore by default call _update_buffers
-        afterwards."""
-        # TODO: add some validation to preReconfigure
-        self.log.DEBUG("Updating frame filter")
-
-        if self._frame_filter is None:
-            self.set("dataFormat.filteredFrames", self.get("dataFormat.frames"))
-            self.set("frameFilter.current", [])
-        else:
-            self.set("dataFormat.filteredFrames", self._frame_filter.size)
-            self.set("frameFilter.current", list(map(int, self._frame_filter)))
-
-        with self.warning_context(
-            "deviceInternalsState", WarningLampType.FRAME_FILTER
-        ) as warn:
-            if self._frame_filter is not None and (
-                self._frame_filter.min() < 0
-                or self._frame_filter.max() >= self.get("dataFormat.frames")
-            ):
-                warn("Invalid frame filter set, expect exceptions!")
-                # note: dataFormat.frames could change from detector and
-                # that could make this warning invalid
-
-        if update_buffers:
-            self._update_buffers()
-
-    def _update_buffers(self):
-        """(Re)initialize buffers / kernel runner according to expected data shapes"""
-        self.log.INFO("Updating buffers according to data shapes")
-        # reflect the axis reordering in the expected output shape
-        self.set("dataFormat.inputDataShape", list(self.input_data_shape))
-        self.set("dataFormat.outputDataShape", list(self.output_data_shape))
-        self.log.INFO(f"Input shape: {self.input_data_shape}")
-        self.log.INFO(f"Output shape: {self.output_data_shape}")
-
-        if self._shmem_buffer is None:
-            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
-            memory_budget = self.get("outputShmemBufferSize") * 2**30
-            self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
-            self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
-                memory_budget,
-                self.output_data_shape,
-                self.output_data_dtype,
-                shmem_buffer_name,
-            )
-            self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver()
-            if self._cuda_pin_buffers:
-                self.log.INFO("Trying to pin the shmem buffer memory")
-                self._shmem_buffer.cuda_pin()
-            self.log.INFO("Done, shmem buffer is ready")
-        else:
-            self._shmem_buffer.change_shape(self.output_data_shape)
-
-        # give CuPy a chance to at least start memory cleanup before this
-        if self.kernel_runner is not None:
-            del self.kernel_runner
-            self.kernel_runner = None
-            gc.collect()
-
-        self.kernel_runner = self._kernel_runner_class(
-            self.get("dataFormat.pixelsX"),
-            self.get("dataFormat.pixelsY"),
-            self.get("dataFormat.filteredFrames"),
-            int(self.get("constantParameters.memoryCells")),
-            output_data_dtype=self.output_data_dtype,
-            **self._kernel_runner_init_args,
-        )
-
-        self._reload_constants_from_cache_to_runner()
-
     def _reload_constants_from_cache_to_runner(self, constants=None):
         # TODO: gracefully handle change in constantParameters.memoryCells
         if constants is None:
@@ -1110,13 +800,11 @@ class BaseCorrection(PythonDevice):
                 ) as warn:
                     try:
                         self.log_status_info(f"Reload constant {constant.name}")
-                        self._load_constant_to_runner(constant, data)
+                        self.kernel_runner.load_constant(constant, data)
                     except Exception as e:
                         warn(f"Failed to reload {constant.name}: {e}")
-                    else:
-                        self._successfully_loaded_constant_to_runner(constant)
 
-    def input_handler(self, input_channel):
+    def input_handler(self, data_hash, metadata):
         """Main handler for data input: Do a few simple checks to determine whether to
         even try processing. If yes, will pass data and information to process_data
         method provided by subclass."""
@@ -1134,161 +822,179 @@ class BaseCorrection(PythonDevice):
                 warn("Received data before correction device was ready")
                 return
 
-        all_metadata = input_channel.getMetaData()
-        for input_index in range(input_channel.size()):
-            with self.warning_context(
-                "inputDataState", WarningLampType.MISC_INPUT_DATA
-            ) as warn:
-                data_hash = input_channel.read(input_index)
-                metadata = all_metadata[input_index]
-                source = metadata.get("source")
+        with self.warning_context(
+            "inputDataState", WarningLampType.MISC_INPUT_DATA
+        ) as warn:
+            source = metadata.get("source")
 
-                if source not in self.sources:
-                    continue
-                elif not data_hash.has(self._image_data_path):
-                    warn("Ignoring hash without image node")
-                    continue
+            if source not in self.sources:
+                return
+            elif not data_hash.has(self._image_data_path):
+                warn("Ignoring hash without image node")
+                return
 
-            with self.warning_context(
-                "inputDataState",
-                WarningLampType.EMPTY_HASH,
-            ) as warn:
-                self._shmem_receiver.dereference_shmem_handles(data_hash)
-                try:
-                    image_data = np.asarray(data_hash.get(self._image_data_path))
-                    cell_table = (
-                        np.array(
-                            # explicit copy to avoid mysterious segfault
-                            data_hash.get(self._cell_table_path), copy=True
-                        ).ravel()
-                        if self._cell_table_path is not None
-                        else None
-                    )
-                    pulse_table = (
-                        np.array(
-                            # explicit copy to avoid mysterious segfault
-                            data_hash.get(self._pulse_table_path), copy=True
-                        ).ravel()
-                        if self._pulse_table_path is not None
-                        else None
-                    )
-                except RuntimeError as err:
-                    warn(
-                        "Failed to load image data; "
-                        f"probably empty hash from DAQ: {err}"
-                    )
-                    self._maybe_update_cell_and_pulse_tables(None, None)
-                    continue
+        with self.warning_context(
+            "inputDataState",
+            WarningLampType.EMPTY_HASH,
+        ) as warn:
+            self._shmem_receiver.dereference_shmem_handles(data_hash)
+            try:
+                (
+                    num_frames,
+                    image_data,
+                    cell_table,
+                    pulse_table,
+                    *additional_data
+                ) = self._get_data_from_hash(data_hash)
+            except RuntimeError as err:
+                warn(
+                    "Failed to load data from hash, probably empy hash from DAQ. "
+                    f"(Is detector sending data?)\n{err}"
+                )
+                self._maybe_update_cell_and_pulse_tables(None, None)
+                return
 
-            # no more common reasons to skip input, so go to processing
-            self._last_processing_started = default_timer()
-            if state is State.ON:
-                self.updateState(State.PROCESSING)
-                self.log_status_info("Processing data")
+        # no more common reasons to skip input, so go to processing
+        self._last_processing_started = default_timer()
+        if state is State.ON:
+            self.updateState(State.PROCESSING)
+            self.log_status_info("Processing data")
 
-            self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table)
-            timestamp = Timestamp.fromHashAttributes(
-                metadata.getAttributes("timestamp")
-            )
-            train_id = timestamp.getTrainId()
+        self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table)
+        timestamp = Timestamp.fromHashAttributes(
+            metadata.getAttributes("timestamp")
+        )
 
-            # check time server connection
-            with self.warning_context(
-                "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION
-            ) as warn:
-                my_timestamp = self.getActualTimestamp()
-                my_train_id = my_timestamp.getTrainId()
-                self._input_delay_tracker.update(
-                    (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000
-                )
-                self._buffered_status_update.set(
-                    "performance.inputDelay", self._input_delay_tracker.get()
-                )
-                if my_train_id == 0:
-                    my_train_id = self._last_train_id_processed + 1
-                    warn(
-                        "Failed to get current train ID, using previously seen train "
-                        "ID for future train thresholding - if this persists, check "
-                        "connection to timeserver."
-                    )
+        self._check_train_id_and_time(timestamp)
+
+        if self._warn_memory_cell_range:
             with self.warning_context(
-                "inputDataState", WarningLampType.TRAIN_ID
+                "processingState", WarningLampType.MEMORY_CELL_RANGE
             ) as warn:
-                if train_id > (
-                    my_train_id
-                    + self.unsafe_get("workarounds.trainFromFutureThreshold")
+                if (
+                    self.unsafe_get("constantParameters.memoryCells")
+                    <= cell_table.max()
                 ):
-                    warn(
-                        f"Suspecting train from the future: 'now' is {my_train_id}, "
-                        f"received train ID {train_id}, dropping data"
-                    )
-                    continue
-                try:
-                    self._train_ratio_tracker.update(train_id)
-                    self._buffered_status_update.set(
-                        "performance.ratioOfRecentTrainsReceived",
-                        self._train_ratio_tracker.get(
-                            current_train=my_train_id,
-                            expected_delay=math.ceil(
-                                self.unsafe_get("performance.inputDelay") / 100
-                            ),
-                        )
-                        if my_train_id != 0
-                        else self._train_ratio_tracker.get(),
-                    )
-                except trackers.NonMonotonicTrainIdWarning as ex:
-                    warn(
-                        f"Train ratio tracker noticed issue with train ID: {ex}\n"
-                        f"For the record, I think now is: {my_train_id}"
-                    )
-                    self._train_ratio_tracker.reset()
-                    self._train_ratio_tracker.update(train_id)
+                    warn("Input cell IDs out of range of constants")
 
-            if self._warn_memory_cell_range:
-                with self.warning_context(
-                    "processingState", WarningLampType.MEMORY_CELL_RANGE
-                ) as warn:
-                    if (
-                        self.unsafe_get("constantParameters.memoryCells")
-                        <= cell_table.max()
-                    ):
-                        warn("Input cell IDs out of range of constants")
-
-            # TODO: avoid branching multiple times on cell_table
-            if cell_table is not None and cell_table.size != self.unsafe_get(
-                "dataFormat.frames"
-            ):
-                self.log_status_info(
-                    f"Updating new input shape to account for {cell_table.size} cells"
-                )
-                self.set("dataFormat.frames", cell_table.size)
-                self._lock_and_update(self._update_frame_filter, background=False)
-
-            # DataAggregator typically tells us the wrong axis order
-            if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-                expected_shape = self.input_data_shape
-                if expected_shape != image_data.shape:
-                    image_data.shape = expected_shape
-
-            with self._buffer_lock:
-                self.process_data(
-                    data_hash,
-                    metadata,
-                    source,
-                    train_id,
-                    image_data,
-                    cell_table,
-                    pulse_table,
+        # TODO: change ShmemCircularBuffer API so we don't have to poke like this
+        if (
+            (output_shape := self.kernel_runner.expected_output_shape(num_frames))
+            != self._shmem_buffer._buffer_ary.shape[1:]
+            or self.kernel_runner._output_dtype != self._shmem_buffer._buffer_ary.dtype
+        ):
+            self.set("dataFormat.outputDataShape", list(output_shape))
+            self.log.DEBUG("Updating shmem buffer shape / dtype")
+            self._shmem_buffer.change_shape(
+                output_shape, self.kernel_runner._output_dtype
+            )
+        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
+        with self.warning_context(
+            "processingState", WarningLampType.CORRECTION_RUNNER
+        ):
+            corrections, processed_buffer, previews = self.kernel_runner.correct(
+                image_data, cell_table, *additional_data
+            )
+            data_hash["corrections"] = corrections
+            for addon in self._enabled_addons:
+                addon.post_correction(
+                    processed_buffer, cell_table, pulse_table, data_hash
                 )
-            self._last_train_id_processed = train_id
-            self._buffered_status_update.set("trainId", train_id)
-            self._processing_time_tracker.update(
-                default_timer() - self._last_processing_started
+            self.kernel_runner.reshape(processed_buffer, out=buffer_array)
+
+        with self.warning_context(
+            "processingState", WarningLampType.PREVIEW_SETTINGS
+        ) as warn:
+            self.preview_friend.write_outputs(
+                *previews,
+                timestamp=timestamp,
+                cell_table=cell_table,
+                pulse_table=pulse_table,
+                warn_fun=warn,
+            )
+
+        for addon in self._enabled_addons:
+            addon.post_reshape(
+                buffer_array, cell_table, pulse_table, data_hash
+            )
+
+        if self.unsafe_get("useShmemHandles"):
+            data_hash.set(self._image_data_path, buffer_handle)
+            data_hash.set("calngShmemPaths", [self._image_data_path])
+        else:
+            data_hash.set(self._image_data_path, buffer_array)
+            data_hash.set("calngShmemPaths", [])
+
+        self._write_output(data_hash, metadata)
+        self._processing_time_tracker.update(
+            default_timer() - self._last_processing_started
+        )
+        self._buffered_status_update.set(
+            "performance.processingTime", self._processing_time_tracker.get() * 1000
+        )
+        self._rate_tracker.update()
+
+    def _please_send_me_cached_constants(self, constants, callback):
+        for constant in constants:
+            if constant in self.calcat_friend.cached_constants:
+                callback(constant, self.calcat_friend.cached_constants[constant])
+
+    def _check_train_id_and_time(self, timestamp):
+        train_id = timestamp.getTrainId()
+        # check time server connection
+        with self.warning_context(
+            "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION
+        ) as warn:
+            my_timestamp = self.getActualTimestamp()
+            my_train_id = my_timestamp.getTrainId()
+            self._input_delay_tracker.update(
+                (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000
             )
             self._buffered_status_update.set(
-                "performance.processingTime", self._processing_time_tracker.get() * 1000
+                "performance.inputDelay", self._input_delay_tracker.get()
             )
-            self._rate_tracker.update()
+            if my_train_id == 0:
+                my_train_id = self._last_train_id_processed + 1
+                warn(
+                    "Failed to get current train ID, using previously seen train "
+                    "ID for future train thresholding - if this persists, check "
+                    "connection to timeserver."
+                )
+        with self.warning_context(
+            "inputDataState", WarningLampType.TRAIN_ID
+        ) as warn:
+            if train_id > (
+                my_train_id
+                + self.unsafe_get("workarounds.trainFromFutureThreshold")
+            ):
+                warn(
+                    f"Suspecting train from the future: 'now' is {my_train_id}, "
+                    f"received train ID {train_id}, dropping data"
+                )
+                raise TrainFromTheFutureException()
+            try:
+                self._train_ratio_tracker.update(train_id)
+                self._buffered_status_update.set(
+                    "performance.ratioOfRecentTrainsReceived",
+                    self._train_ratio_tracker.get(
+                        current_train=my_train_id,
+                        expected_delay=math.ceil(
+                            self.unsafe_get("performance.inputDelay") / 100
+                        ),
+                    )
+                    if my_train_id != 0
+                    else self._train_ratio_tracker.get(),
+                )
+            except trackers.NonMonotonicTrainIdWarning as ex:
+                warn(
+                    f"Train ratio tracker noticed issue with train ID: {ex}\n"
+                    f"For the record, I think now is: {my_train_id}"
+                )
+                self._train_ratio_tracker.reset()
+                self._train_ratio_tracker.update(train_id)
+
+        self._last_train_id_processed = train_id
+        self._buffered_status_update.set("trainId", train_id)
 
     def _update_rate_and_state(self):
         if self.get("state") is State.PROCESSING:
@@ -1329,119 +1035,6 @@ class BaseCorrection(PythonDevice):
         self.signalEndOfStream("dataOutput")
 
 
-def add_correction_step_schema(schema, field_flag_constants_mapping):
-    """Using the fields in the provided mapping, will add nodes to schema
-
-    field_flag_mapping is assumed to be iterable of pairs where first entry in each
-    pair is the name of a correction step as it will appear in device schema (second
-    entry - typically an enum field - is ignored). For correction step, a node and some
-    booleans are added to the schema (all tagged as managed). Subclass can customize /
-    add additional keys under node later.
-
-    This method should be called in expectedParameters of subclass after the same for
-    BaseCorrection has been called. Would be nice to include in BaseCorrection instead,
-    but that is tricky: static method of superclass will need _correction_steps
-    of subclass or device server gets mad. A nice solution with classmethods would be
-    welcome.
-    """
-
-    for field_name, _, used_constants in field_flag_constants_mapping:
-        node_name = f"corrections.{field_name}"
-        (
-            NODE_ELEMENT(schema).key(node_name).commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.available")
-            .displayedName("Available")
-            .description(
-                "This boolean indicates whether the necessary constants have been "
-                "loaded for this correction step to be applied. Enabling the "
-                "correction will have no effect unless this is True."
-            )
-            .readOnly()
-            # some corrections available without constants
-            .initialValue(not used_constants or None in used_constants)
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.enable")
-            .tags("managed")
-            .displayedName("Enable")
-            .description(
-                "Controls whether to apply this correction step for main data "
-                "output - subject to availability."
-            )
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.preview")
-            .tags("managed")
-            .displayedName("Preview")
-            .description(
-                "Whether to apply this correction step for corrected preview "
-                "output - subject to availability."
-            )
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-
-
-def add_bad_pixel_config_node(schema, prefix="corrections.badPixels"):
-    (
-        STRING_ELEMENT(schema)
-        .key("corrections.badPixels.maskingValue")
-        .tags("managed")
-        .displayedName("Bad pixel masking value")
-        .description(
-            "Any pixels masked by the bad pixel mask will have their value replaced "
-            "with this. Note that this parameter is to be interpreted as a "
-            "numpy.float32; use 'nan' to get NaN value."
-        )
-        .assignmentOptional()
-        .defaultValue("nan")
-        .reconfigurable()
-        .commit(),
-
-        NODE_ELEMENT(schema)
-        .key("corrections.badPixels.subsetToUse")
-        .displayedName("Bad pixel flags to use")
-        .description(
-            "The booleans under this node allow for selecting a subset of bad pixel "
-            "types to take into account when doing bad pixel masking. Upon updating "
-            "these flags, the map used for bad pixel masking will be ANDed with this "
-            "selection. Turning disabled flags back on causes reloading of cached "
-            "constants."
-        )
-        .commit(),
-    )
-    for field in utils.BadPixelValues:
-        (
-            BOOL_ELEMENT(schema)
-            .key(f"corrections.badPixels.subsetToUse.{field.name}")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit()
-        )
-
-
-def add_preview_outputs(schema, channels):
-    for channel in channels:
-        (
-            OUTPUT_CHANNEL(schema)
-            .key(f"preview.{channel}")
-            .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True))
-            .commit(),
-        )
-    preview_utils.PreviewFriend.add_schema(schema, output_channels=channels)
-
-
 def add_addon_nodes(schema, device_class, prefix="addons"):
     det_name = device_class.__name__[:-len("Correction")].lower()
     device_class._available_addons = [
@@ -1467,28 +1060,3 @@ def add_addon_nodes(schema, device_class, prefix="addons"):
         addon_class.extend_device_schema(
             schema, f"{prefix}.{addon_class.__name__}"
         )
-
-
-def get_bad_pixel_field_selection(self):
-    selection = 0
-    for field in utils.BadPixelValues:
-        if self.get(f"corrections.badPixels.subsetToUse.{field.name}"):
-            selection |= field
-    return selection
-
-
-def _parse_frame_filter(config):
-    filter_type = FramefilterSpecType(config["frameFilter.type"])
-    filter_string = config["frameFilter.spec"]
-
-    if filter_type is FramefilterSpecType.NONE or filter_string.strip() == "":
-        return None
-    elif filter_type is FramefilterSpecType.RANGE:
-        # allow exceptions
-        numbers = tuple(int(part) for part in filter_string.split(","))
-        return np.arange(*numbers, dtype=np.uint16)
-    elif filter_type is FramefilterSpecType.COMMASEPARATED:
-        # np.fromstring is too lenient I think
-        return np.array([int(s) for s in filter_string.split(",")], dtype=np.uint16)
-    else:
-        raise TypeError(f"Unknown frame filter type {filter_type}")
diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py
index 5eaaf10155d23e50b73044da62cad62a335f97f6..e04ad559453a2f8ab954ae0b8b60e529733f94d8 100644
--- a/src/calng/base_kernel_runner.py
+++ b/src/calng/base_kernel_runner.py
@@ -1,19 +1,163 @@
 import enum
 import functools
-import operator
+import itertools
 import pathlib
 
-from calngUtils import misc as utils
+from karabo.bound import (
+    BOOL_ELEMENT,
+    NODE_ELEMENT,
+    STRING_ELEMENT,
+    Hash,
+)
+from .utils import (
+    BadPixelValues,
+    maybe_get,
+    subset_of_hash,
+)
+from calngUtils import misc
 import jinja2
-import numpy as np
 
 
 class BaseKernelRunner:
-    _gpu_based = True
     _xp = None  # subclass sets numpy or cupy
+    num_pixels_ss = None  # subclass must set
+    num_pixels_fs = None  # subclass must set
+
+    _bad_pixel_constants = None  # should be set by subclasses using bad pixels
+    bad_pixel_subset = 0xFFFFFFFF  # will be set in reconfigure
+
+    @classmethod
+    def add_schema(cls, schema):
+        """Using the steps in the provided mapping, will add nodes to schema
+
+        correction_steps is assumed to be iterable of pairs where first entry in each
+        pair is the name of a correction step as it will appear in device schema (second
+        entry - typically an enum field - is ignored). For correction step, a node and
+        some booleans are added to the schema (all tagged as managed). Subclass can
+        customize / add additional keys under node later.
+        """
+        (
+            NODE_ELEMENT(schema)
+            .key("corrections")
+            .commit(),
+        )
+
+        for field_name, _, used_constants in cls._correction_steps:
+            node_name = f"corrections.{field_name}"
+            (
+                NODE_ELEMENT(schema).key(node_name).commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.available")
+                .displayedName("Available")
+                .description(
+                    "This boolean indicates whether the necessary constants have been "
+                    "loaded for this correction step to be applied. Enabling the "
+                    "correction will have no effect unless this is True."
+                )
+                .readOnly()
+                # some corrections available without constants
+                .initialValue(not used_constants or None in used_constants)
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.enable")
+                .tags("managed")
+                .displayedName("Enable")
+                .description(
+                    "Controls whether to apply this correction step for main data "
+                    "output - subject to availability."
+                )
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.preview")
+                .tags("managed")
+                .displayedName("Preview")
+                .description(
+                    "Whether to apply this correction step for corrected preview "
+                    "output - subject to availability."
+                )
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit(),
+            )
+
+    @staticmethod
+    def add_bad_pixel_config(schema):
+        (
+            STRING_ELEMENT(schema)
+            .key("corrections.badPixels.maskingValue")
+            .tags("managed")
+            .displayedName("Bad pixel masking value")
+            .description(
+                "Any pixels masked by the bad pixel mask will have their value "
+                "replaced with this. Note that this parameter is to be interpreted as "
+                "a numpy.float32; use 'nan' to get NaN value."
+            )
+            .assignmentOptional()
+            .defaultValue("nan")
+            .reconfigurable()
+            .commit(),
+
+            NODE_ELEMENT(schema)
+            .key("corrections.badPixels.subsetToUse")
+            .displayedName("Bad pixel flags to use")
+            .description(
+                "The booleans under this node allow for selecting a subset of bad "
+                "pixel types to take into account when doing bad pixel masking. Upon "
+                "updating these flags, the map used for bad pixel masking will be "
+                "ANDed with this selection. Turning disabled flags back on causes "
+                "reloading of cached constants."
+            )
+            .commit(),
+        )
+        for field in BadPixelValues:
+            (
+                BOOL_ELEMENT(schema)
+                .key(f"corrections.badPixels.subsetToUse.{field.name}")
+                .tags("managed")
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit()
+            )
+
+    def __init__(self, device):
+        self._device = device
+        # for now, we depend on multiple keys from device hash, so get multiple nodes
+        config = subset_of_hash(
+            self._device._parameters,
+            "corrections",
+            "dataFormat",
+            "constantParameters",
+        )
+        self._constant_memory_cells = config["constantParameters.memoryCells"]
+        self._pre_init()
+        # note: does not handle one constant enabling multiple correction steps
+        # (do we need to generalize "decide if this correction step is available"?)
+        self._constant_to_correction_names = {}
+        for (name, _, constants) in self._correction_steps:
+            for constant in constants:
+                self._constant_to_correction_names.setdefault(constant, set()).add(name)
+        self._output_dtype = self._xp.float32
+        self._setup_constant_buffers()
+        self.reconfigure(config)
+        self._post_init()
 
     def _pre_init(self):
-        # can be used, for example, to import cupy and set as _xp at runtime
+        """Hook used to set up things which need to be in place before __init__ is
+        called. Note that __init__ will do reconfigure with full configuration, which
+        means creating constant buffers and such (__init__ already takes care to set
+        _constant_memory_cells so this can be used in _setup_constant_buffers).
+
+
+        See also _setup_constant_buffers
+        """
         pass
 
     def _post_init(self):
@@ -21,121 +165,230 @@ class BaseKernelRunner:
         # without overriding __init__
         pass
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
-        self._pre_init()
-        self.pixels_x = pixels_x
-        self.pixels_y = pixels_y
-        self.frames = frames
-        if constant_memory_cells == 0:
-            # if not set, guess same as input; may save one recompilation
-            self.constant_memory_cells = frames
-        else:
-            self.constant_memory_cells = constant_memory_cells
-        self.output_data_dtype = output_data_dtype
-        self._post_init()
+    def expected_input_shape(self, num_frames):
+        """Mostly used to inform the host device about what data shape should be.
+        Subclass should override as needed (ex. XTDF detectors with extra image/gain
+        axis)."""
+        return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
-
-    def correct(self, flags):
-        """Correct (already loaded) image data according to flags
-
-        Detector-specific subclass must define this method. It should assume that image
-        data, cell table, and other data (including constants) has already been loaded.
-        It should probably run some {CPU,GPU} kernel and output should go into
-        self.processed_data{,_gpu}.
+    def expected_output_shape(self, num_frames):
+        """Used to inform the host device about what to expect (for reshaping the shmem
+        buffer). Takes into account transpose setting; override _expected_output_shape
+        for subclasses which do funky shape things."""
+        base_shape = self._expected_output_shape(num_frames)
+        if self._output_transpose is None:
+            return base_shape
+        else:
+            return tuple(base_shape[i] for i in self._output_transpose)
 
-        Keep in mind that user only gets output from compute_preview or reshape
-        (either of these should come after correct).
+    def _expected_output_shape(self, num_frames):
+        return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
 
-        The submodules providing subclasses should have some IntFlag enums defining
-        which flags are available to pass along to the kernel. A zero flag should allow
-        the kernel to do no actual correction - but still copy the data between buffers
-        and cast it to desired output type.
+    def _correct(self, flags, image_data, cell_table, *rest):
+        """Subclass will implement the core correcton routine with this name. Observe
+        the parameter list: the "rest" part will always contain at least one output
+        buffer, but may have more entries such as additional raw input buffers and
+        multiple output buffers. See how it is called by correct."""
 
-        """
         raise NotImplementedError()
 
-    @property
-    def preview_data_views(self):
+    def _preview_data_views(self, *buffers):
         """Should provide tuple of views of data for different preview layers - as a
         minimum, raw and corrected.  View is expected to have axis order frame, slow
-        scan, fast scan with the latter two matching preview_shape (which is X and which
-        is Y may differ between detectors)."""
+        scan, fast scan with the latter being the final preview shape (which is X and
+        which is Y may differ between detectors). The buffers passed vary from subclass
+        to subclass; will be all inputs and outputs used in _correct. By default, that
+        is just what gets returned."""
+
+        return buffers
+
+    def _load_constant(self, constant_type, constant_data):
+        """Should update the appropriate internal buffers for constant type (maps to
+        "calibration" type in CalCat) with the data in constant_tdata. This may include
+        sanitization; data passed here has the axis order found in the stored constant
+        files."""
+        raise NotImplementedError()
 
+    def _setup_constant_buffers(self):
         raise NotImplementedError()
 
+    def _make_output_buffers(self, num_frames, flags):
+        """Should, based on number of frames passed (may in the future be extended to
+        include other information), create the necessary buffers to store corrected
+        output(s). By default, will create a single processed data buffer of float32
+        with the same shape as input image data, skipping any extra axes (like the
+        image/gain axis for XTDF). Subclass should override if it wants more outputs.
+        First buffer must be for corrected image data (passed to addons)"""
+        return [
+            self._xp.empty(
+                (num_frames, self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
+
+    @functools.cached_property
+    def _all_relevant_constants(self):
+        return set(
+            itertools.chain.from_iterable(
+                constants for (_, _, constants) in self._correction_steps
+            )
+        ) - {None}
+
+    def reconfigure(self, config):
+        if config.has("constantParameters.memoryCells"):
+            self._constant_memory_cells = config.get("constantParameters.memoryCells")
+            self._setup_constant_buffers()
+            self._device._please_send_me_cached_constants(
+                self._all_relevant_constants, self.load_constant
+            )
+
+        if config.has("dataFormat.outputImageDtype"):
+            self._output_dtype = self._xp.dtype(
+                config.get("dataFormat.outputImageDtype")
+            )
+
+        if config.has("dataFormat.outputAxisOrder"):
+            order = config["dataFormat.outputAxisOrder"]
+            if order == "f-ss-fs":
+                self._output_transpose = None
+            else:
+                self._output_transpose = misc.transpose_order(
+                    ("f", "ss", "fs"), order.split("-")
+                )
+
+        if config.has("corrections.badPixels.maskingValue"):
+            self.bad_pixel_mask_value = self._xp.float32(
+                config["corrections.badPixels.maskingValue"]
+            )
+
+        if config.has("corrections.badPixels.subsetToUse"):
+            # note: now just always reloading from cache for convenience
+            self._update_bad_pixel_subset()
+            self.flush_buffers(self._bad_pixel_constants)
+            self._device._please_send_me_cached_constants(
+                self._bad_pixel_constants, self.load_constant
+            )
+
+        # TODO: only trigger flag update if changed (cheap though)
+        self._update_correction_flags()
+
+    def reshape(self, processed_data, out=None):
+        """Move axes to desired output order"""
+        if self._output_transpose is None:
+            return maybe_get(processed_data, out)
+        else:
+            return maybe_get(processed_data.transpose(self._output_transpose), out)
+
     def flush_buffers(self, constants=None):
         """Optional reset internal buffers (implement in subclasses which need this)"""
         pass
 
-    def compute_previews(self, preview_index):
-        """Generate single slice or reduction previews for raw and corrected data and
-        any other layers, determined by self.preview_data_views
+    def _update_correction_flags(self):
+        enabled = self._correction_flag_class.NONE
+        output = Hash()  # for informing downstream receivers what was applied
+        preview = self._correction_flag_class.NONE
+        for field_name, flag, constants in self._correction_steps:
+            output[field_name] = False
+            if self._get(f"corrections.{field_name}.available"):
+                if self._get(f"corrections.{field_name}.enable"):
+                    enabled |= flag
+                    output[field_name] = True
+                if self._get(f"corrections.{field_name}.preview"):
+                    preview |= flag
+        self._correction_flag_enabled = enabled
+        self._correction_flag_preview = preview
+        self._correction_applied_hash = output
+        self._device.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
+        self._device.log.DEBUG(f"Corrections for preview: {str(preview)}")
+
+    def _update_bad_pixel_subset(self):
+        selection = 0
+        for field in BadPixelValues:
+            if self._get(f"corrections.badPixels.subsetToUse.{field.name}"):
+                selection |= field
+        self.bad_pixel_subset = selection
 
-        Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and
-        -4 for stdev (across cells).
+    def _get(self, key):
+        """Will look up key in host device's parameters.  Note that device calls
+        reconfigure during preReconfigure, but with new unmerged config overlaid
+        temporarily, making _get get the "new" state of things."""
+        return self._device.unsafe_get(key)
 
-        Note that preview_index is taken from data without checking cell table.
-        Caller has to figure out which index along memory cell dimension they
-        actually want to preview in case it needs to be a specific cell / pulse.
+    def _set(self, key, value):
+        """Helper function to host device's configuration. Note that this does *not*
+        call reconfigure, so either do so afterwards or make sure to update auxiliary
+        properties correctly yourself."""
+        self._device.set(key, value)
 
-        Will typically reuse data from corrected output buffer. Therefore,
-        correct(...) must have been called with the appropriate flags before
-        compute_preview(...).
+    @property
+    def _map_shape(self):
+        return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs)
+
+    def load_constant(self, constant, constant_data):
+        self._load_constant(constant, constant_data)
+        self._post_load_constant(constant)
+
+    def _post_load_constant(self, constant):
+        for field_name in self._constant_to_correction_names[constant]:
+            key = f"corrections.{field_name}.available"
+            if not self._get(key):
+                self._set(key, True)
+
+        self._update_correction_flags()
+        self._device.log_status_info(f"Done loading {constant.name} to runner")
+
+    def correct(self, image_data, cell_table, *additional_data):
+        """Correct image data, providing both full corrected data and previews
+
+        This method relies on the following auxiliary methods which may need to be
+        overridden / defined in subclass:
+
+        _make_output_buffers
+        _preview_data_views
+        _correct (must be implemented by subclass)
+
+        Look at signature for _correct; it is given image_data, cell_table,
+        *additional_data (anything the host device class' _get_data_from_hash gives us
+        in addition to image data, cell table, and pulse table), and *processed_buffers
+        which is the result of _make_output_buffers. So all these four methods must
+        agree on the relevant set of buffers for input and output.
+
+        The pulse_table is only used for preview index selection purposes as corrections
+        do not (yet, at any rate) take the pulse table into account.
         """
-        if preview_index < -4:
-            raise ValueError(f"No statistic with code {preview_index} defined")
-        elif preview_index >= self.frames:
-            raise ValueError(f"Memory cell index {preview_index} out of range")
-
-        if preview_index >= 0:
-            fun = operator.itemgetter(preview_index)
-        elif preview_index == -1:
-            # note: separate from next case because dtype not applicable here
-            fun = functools.partial(self._xp.nanmax, axis=0)
-        elif preview_index in (-2, -3, -4):
-            fun = functools.partial(
-                {
-                    -2: self._xp.nanmean,
-                    -3: self._xp.nansum,
-                    -4: self._xp.nanstd,
-                }[preview_index],
-                axis=0,
-                dtype=self._xp.float32,
+        # TODO: check if this is robust enough (also base_correction)
+        num_frames = image_data.shape[0]
+        processed_buffers = self._make_output_buffers(
+            num_frames, self._correction_flag_preview
+        )
+        image_data = self._xp.asarray(image_data)  # if cupy, will put on GPU
+        if cell_table is not None:
+            cell_table = self._xp.asarray(cell_table)
+        additional_data = [self._xp.asarray(data) for data in additional_data]
+        self._correct(
+            self._correction_flag_preview,
+            image_data,
+            cell_table,
+            *additional_data,
+            *processed_buffers,
+        )
+
+        preview_buffers = self._preview_data_views(
+            image_data, *additional_data, *processed_buffers
+        )
+
+        if self._correction_flag_preview != self._correction_flag_enabled:
+            processed_buffers = self._make_output_buffers(
+                num_frames, self._correction_flag_enabled
             )
-        # TODO: reuse output buffers
-        # TODO: integrate multithreading
-        res = (fun(in_buffer_view) for in_buffer_view in self.preview_data_views)
-        if self._gpu_based:
-            res = (buf.get() for buf in res)
-        return res
-
-    def reshape(self, output_order, out=None):
-        """Move axes to desired output order"""
-        # TODO: avoid copy
-        if output_order == self._corrected_axis_order:
-            self.reshaped_data = self.processed_data
-        else:
-            self.reshaped_data = self.processed_data.transpose(
-                utils.transpose_order(self._corrected_axis_order, output_order)
+            self._correct(
+                self._correction_flag_enabled,
+                image_data,
+                cell_table,
+                *additional_data,
+                *processed_buffers,
             )
-
-        if self._gpu_based:
-            return self.reshaped_data.get(out=out)
-        else:
-            if out is None:
-                return self.reshaped_data
-            else:
-                out[:] = self.reshaped_data
+        return self._correction_applied_hash, processed_buffers[0], preview_buffers
 
 
 kernel_dir = pathlib.Path(__file__).absolute().parent / "kernels"
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index d97f1fc4fc3738ef37bd94fa1bb3b8ee186b3bfb..6632b6ef2da64527d633b637d265125c7050583e 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -9,30 +9,36 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
+    UINT32_ELEMENT,
 )
 
 from .. import (
     base_calcat,
     base_correction,
     base_kernel_runner,
+
     schemas,
     utils,
 )
+from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
 class Constants(enum.Enum):
     ThresholdsDark = enum.auto()
     Offset = enum.auto()
+    SlopesCS = enum.auto()
     SlopesPC = enum.auto()
     SlopesFF = enum.auto()
     BadPixelsDark = enum.auto()
+    BadPixelsCS = enum.auto()
     BadPixelsPC = enum.auto()
     BadPixelsFF = enum.auto()
 
 
 bad_pixel_constants = {
     Constants.BadPixelsDark,
+    Constants.BadPixelsCS,
     Constants.BadPixelsPC,
     Constants.BadPixelsFF,
 }
@@ -51,146 +57,389 @@ class CorrectionFlags(enum.IntFlag):
     NONE = 0
     THRESHOLD = 1
     OFFSET = 2
-    BLSHIFT = 4
-    REL_GAIN_PC = 8
-    GAIN_XRAY = 16
-    BPMASK = 32
-    FORCE_MG_IF_BELOW = 64
-    FORCE_HG_IF_BELOW = 128
+    REL_GAIN_PC = 4
+    GAIN_XRAY = 8
+    BPMASK = 16
+    FORCE_MG_IF_BELOW = 32
+    FORCE_HG_IF_BELOW = 64
+    COMMON_MODE = 128
+
+
+correction_steps = (
+    # step name (used in schema), flag to enable for kernel, constants required
+    ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}),
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
+    (  # TODO: specify that /both/ constants are needed, not just one or the other
+        "forceMgIfBelow",
+        CorrectionFlags.FORCE_MG_IF_BELOW,
+        {Constants.ThresholdsDark, Constants.Offset},
+    ),
+    (
+        "forceHgIfBelow",
+        CorrectionFlags.FORCE_HG_IF_BELOW,
+        {Constants.ThresholdsDark, Constants.Offset},
+    ),
+    ("relGain", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesCS, Constants.SlopesPC}),
+    ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        bad_pixel_constants | {
+            None,  # means stay available even without constants loaded
+        },
+    ),
+    ("commonMode", CorrectionFlags.COMMON_MODE, {None}),
+)
 
 
 class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fxy"
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 512
+    num_pixels_fs = 128
 
-    @property
-    def input_shape(self):
-        return (self.frames, 2, self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        # will add node, enable / preview toggles
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
 
-    @property
-    def processed_shape(self):
-        return (self.frames, self.pixels_x, self.pixels_y)
+        # additional settings specific to AGIPD correction steps
+        # turn off some advanced corrections by default
+        for flag in ("enable", "preview"):
+            (
+                OVERWRITE_ELEMENT(schema)
+                .key(f"corrections.commonMode.{flag}")
+                .setNewDefaultValue(False)
+                .commit(),
+            )
+            for step in ("forceMgIfBelow", "forceHgIfBelow"):
+                (
+                    OVERWRITE_ELEMENT(schema)
+                    .key(f"corrections.{step}.{flag}")
+                    .setNewDefaultValue(False)
+                    .commit(),
+                )
+        (
+            FLOAT_ELEMENT(schema)
+            .key("corrections.forceMgIfBelow.hardThreshold")
+            .tags("managed")
+            .description(
+                "If enabled, any pixels assigned to low gain stage which would be "
+                "below this threshold after having medium gain offset subtracted will "
+                "be reassigned to medium gain."
+            )
+            .assignmentOptional()
+            .defaultValue(1000)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_x, self.pixels_y)
+            FLOAT_ELEMENT(schema)
+            .key("corrections.forceHgIfBelow.hardThreshold")
+            .tags("managed")
+            .description(
+                "Like forceMgIfBelow, but potentially reassigning from medium gain to "
+                "high gain based on threshold and pixel value minus low gain offset. "
+                "Applied after forceMgIfBelow, pixels can theoretically be reassigned "
+                "twice, from LG to MG and to HG."
+            )
+            .assignmentOptional()
+            .defaultValue(1000)
+            .reconfigurable()
+            .commit(),
+
+            STRING_ELEMENT(schema)
+            .key("corrections.relGain.sourceOfSlopes")
+            .tags("managed")
+            .displayedName("Kind of slopes")
+            .description(
+                "Slopes for relative gain correction can be derived from pulse "
+                "capacitor (PC) scans or current source scans (CS). Choose one!"
+            )
+            .assignmentOptional()
+            .defaultValue("PC")
+            .options("PC,CS")
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.relGain.adjustMgBaseline")
+            .tags("managed")
+            .displayedName("Adjust MG baseline")
+            .description(
+                "If set, an additional offset is applied to pixels in medium gain. "
+                "The value of the offset is either computed per pixel (and memory "
+                "cell) based on currently loaded relative gain slopes (PC or CS) or "
+                "it is set to a specific value when overrideMdAdditionalOffset is set."
+            )
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.relGain.overrideMdAdditionalOffset")
+            .tags("managed")
+            .displayedName("Override md_additional_offset")
+            .description(
+                "If set, the additional offset applied to medium gain pixels "
+                "(assuming adjustMgBaseLine is set) will use the value from "
+                "mdAdditionalOffset globally for. Otherwise, the additional offset "
+                "is derived from the relevant constants."
+            )
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.relGain.mdAdditionalOffset")
+            .tags("managed")
+            .displayedName("Value for md_additional_offset (if overriding)")
+            .description(
+                "Normally, md_additional_offset (part of relative gain correction) is "
+                "computed when loading SlopesPC. In case you want to use a different "
+                "value (global for all medium gain pixels), you can specify it here "
+                "and set corrections.overrideMdAdditionalOffset to True."
+            )
+            .assignmentOptional()
+            .defaultValue(0)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.gainXray.gGainValue")
+            .tags("managed")
+            .displayedName("G_gain_value")
+            .description(
+                "Newer X-ray gain correction constants are absolute. The default "
+                "G_gain_value of 1 means that output is expected to be in keV. If "
+                "this is not desired, one can here specify the mean X-ray gain value "
+                "over all modules to get ADU values out - operator must manually "
+                "find this mean value."
+            )
+            .assignmentOptional()
+            .defaultValue(1)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key("corrections.commonMode.iterations")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(4)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(0.15)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.commonMode.noisePeakRange")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(35)
+            .reconfigurable()
+            .commit(),
+        )
+
+    def reconfigure(self, config):
+        super().reconfigure(config)
+
+        if config.has("corrections.gainXray.gGainValue"):
+            self.g_gain_value = self._xp.float32(
+                config["corrections.gainXray.gGainValue"]
+            )
+
+        if config.has("corrections.forceHgIfBelow.hardThreshold"):
+            self.hg_hard_threshold = self._xp.float32(
+                config["corrections.forceHgIfBelow.hardThreshold"]
+            )
+
+        if config.has("corrections.forceMgIfBelow.hardThreshold"):
+            self.mg_hard_threshold = self._xp.float32(
+                config["corrections.forceMgIfBelow.hardThreshold"]
+            )
+
+        if (
+            config.has("corrections.relGain.overrideMdAdditionalOffset")
+            or config.has("corrections.relGain.mdAdditionalOffset")
+        ):
+            if (
+                self._get("corrections.relGain.overrideMdAdditionalOffset")
+                and self._get("corrections.relGain.adjustMgBaseLine")
+            ):
+                self.md_additional_offset.fill(
+                    self._xp.float32(
+                        self._get("corrections.relGain.mdAdditionalOffset")
+                    )
+                )
+            else:
+                self._device._please_send_me_cached_constants(
+                    {
+                        Constants.SlopesPC
+                        if self._get("corrections.relGain.sourceOfSlopes") == "PC"
+                        else Constants.SlopesCS
+                    },
+                    self.load_constant,
+                )
+
+        if config.has("corrections.relGain.sourceOfSlopes"):
+            self.flush_buffers(
+                {
+                    Constants.SlopesPC,
+                    Constants.SlopesCS,
+                    Constants.BadPixelsCS,
+                    Constants.BadPixelsPC,
+                }
+            )
+
+            # note: will try to reload all bad pixels, but
+            # _load_constant will reject PC/CS based on which we want
+            if self._get("corrections.relGain.sourceOfSlopes") == "PC":
+                to_get = {Constants.SlopesPC} | bad_pixel_constants
+            else:
+                to_get = {Constants.SlopesCS} | bad_pixel_constants
+
+            self._device._please_send_me_cached_constants(to_get, self.load_constant)
+
+        if config.has("corrections.commonMode"):
+            self.cm_min_frac = self._xp.float32(
+                self._get("corrections.commonMode.minFrac")
+            )
+            self.cm_noise_peak = self._xp.float32(
+                self._get("corrections.commonMode.noisePeakRange")
+            )
+            self.cm_iter = self._xp.uint16(
+                self._get("corrections.commonMode.iterations")
+            )
+
+        if config.has("constantParameters.gainMode"):
+            gain_mode = GainModes[config["constantParameters.gainMode"]]
+            if gain_mode is GainModes.ADAPTIVE_GAIN:
+                self.default_gain = self._xp.uint8(gain_mode)
+            else:
+                self.default_gain = self._xp.uint8(gain_mode - 1)
+
+    def expected_input_shape(self, num_frames):
+        return (
+            num_frames,
+            2,
+            self.num_pixels_ss,
+            self.num_pixels_fs,
+        )
 
     @property
-    def gm_map_shape(self):
-        return self.map_shape + (3,)  # for gain-mapped constants
+    def _gm_map_shape(self):
+        return self._map_shape + (3,)  # for gain-mapped constants
 
     @property
     def threshold_map_shape(self):
-        return self.map_shape + (2,)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        gain_mode=GainModes.ADAPTIVE_GAIN,
-        g_gain_value=1,
-        mg_hard_threshold=2000,
-        hg_hard_threshold=2000,
-        override_md_additional_offset=None,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.gain_mode = gain_mode
-        if self.gain_mode is GainModes.ADAPTIVE_GAIN:
-            self.default_gain = self._xp.uint8(gain_mode)
-        else:
-            self.default_gain = self._xp.uint8(gain_mode - 1)
-        self.gain_map = self._xp.empty(self.processed_shape, dtype=np.float32)
+        return self._map_shape + (2,)
 
+    def _setup_constant_buffers(self):
         # constants
         self.gain_thresholds = self._xp.empty(
             self.threshold_map_shape, dtype=np.float32
         )
-        self.offset_map = self._xp.empty(self.gm_map_shape, dtype=np.float32)
-        self.rel_gain_pc_map = self._xp.empty(self.gm_map_shape, dtype=np.float32)
-        # not gm_map_shape because it only applies to medium gain pixels
-        self.md_additional_offset = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_xray_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = self._xp.empty(self.gm_map_shape, dtype=np.uint32)
-        self.override_md_additional_offset(override_md_additional_offset)
-        self.set_bad_pixel_mask_value(bad_pixel_mask_value)
-        self.set_mg_hard_threshold(mg_hard_threshold)
-        self.set_hg_hard_threshold(hg_hard_threshold)
-        self.set_g_gain_value(g_gain_value)
+        self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.rel_gain_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        # not _gm_map_shape because it only applies to medium gain pixels
+        self.md_additional_offset = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.xray_gain_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32)
         self.flush_buffers(set(Constants))
-        self.processed_data = self._xp.empty(self.processed_shape, output_data_dtype)
 
-    @property
-    def preview_data_views(self):
+    def _make_output_buffers(self, num_frames, flags):
+        output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+        return [
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # image
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # gain
+        ]
+
+    def _preview_data_views(self, raw_data, processed_data, gain_map):
         return (
-            self.input_data[:, 0],  # raw
-            self.processed_data,  # corrected
-            self.input_data[:, 1],  # raw gain
-            self.gain_map,  # digitized gain
+            raw_data[:, 0],  # raw
+            processed_data,  # corrected
+            raw_data[:, 1],  # raw gain
+            gain_map,  # digitized gain
         )
 
-    def load_constant(self, constant, constant_data):
+    def _load_constant(self, constant, constant_data):
         if constant is Constants.ThresholdsDark:
             # shape: y, x, memory cell, thresholds and gain values
             # note: the gain values are something like means used to derive thresholds
             self.gain_thresholds[:] = self._xp.asarray(
                 constant_data[..., :2], dtype=np.float32
-            ).transpose((2, 1, 0, 3))[:self.constant_memory_cells]
+            ).transpose((2, 1, 0, 3))[:self._constant_memory_cells]
         elif constant is Constants.Offset:
             # shape: y, x, memory cell, gain stage
             self.offset_map[:] = self._xp.asarray(
                 constant_data, dtype=np.float32
-            ).transpose((2, 1, 0, 3))[:self.constant_memory_cells]
+            ).transpose((2, 1, 0, 3))[:self._constant_memory_cells]
+        elif constant is Constants.SlopesCS:
+            if self._get("corrections.relGain.sourceOfSlopes") == "PC":
+                return
+            self.rel_gain_map.fill(1)
+            self.rel_gain_map[..., 1] = self.rel_gain_map[..., 0] * self._xp.asarray(
+                constant_data[ ..., :self._constant_memory_cells, 6]
+            ).T
+            self.rel_gain_map[..., 2] = self.rel_gain_map[..., 1] * self._xp.asarray(
+                constant_data[ ..., :self._constant_memory_cells, 7]
+            ).T
+            if self._get("corrections.relGain.adjustMgBaseline"):
+                self.md_additional_offset.fill(
+                    self._get("corrections.relGain.mdAdditionalOffset")
+                )
+            else:
+                self.md_additional_offset.fill(0)
         elif constant is Constants.SlopesPC:
-            # pc has funny shape (11, 352, 128, 512) from file
-            # this is (fi, memory cell, y, x)
-            # the following may contain NaNs, though...
+            if self._get("corrections.relGain.sourceOfSlopes") == "CS":
+                return
+            # pc has funny shape (11, 352, 128, 512) from file (fi, memory cell, y, x)
+            # assume newer variant (TODO: check) which is pre-sanitized
             hg_slope = constant_data[0]
             hg_intercept = constant_data[1]
             mg_slope = constant_data[3]
             mg_intercept = constant_data[4]
-            # TODO: remove sanitization (should happen in constant preparation notebook)
-            # from agipdlib.py: replace NaN with median (per memory cell)
-            # note: suffixes in agipdlib are "_m" and "_l", should probably be "_I"
-            for naughty_array in (hg_slope, hg_intercept, mg_slope, mg_intercept):
-                medians = np.nanmedian(naughty_array, axis=(1, 2))
-                nan_bool = np.isnan(naughty_array)
-                nan_cell, _, _ = np.where(nan_bool)
-                naughty_array[nan_bool] = medians[nan_cell]
-
-                too_low_bool = naughty_array < 0.8 * medians[:, np.newaxis, np.newaxis]
-                too_low_cell, _, _ = np.where(too_low_bool)
-                naughty_array[too_low_bool] = medians[too_low_cell]
-
-                too_high_bool = naughty_array > 1.2 * medians[:, np.newaxis, np.newaxis]
-                too_high_cell, _, _ = np.where(too_high_bool)
-                naughty_array[too_high_bool] = medians[too_high_cell]
+
+            hg_slope_median = np.nanmedian(hg_slope, axis=(1, 2))
+            mg_slope_median = np.nanmedian(mg_slope, axis=(1, 2))
 
             frac_hg_mg = hg_slope / mg_slope
             rel_gain_map = np.ones(
-                (3, self.constant_memory_cells, self.pixels_y, self.pixels_x),
+                (
+                    3,
+                    self._constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
+                ),
                 dtype=np.float32,
             )
             rel_gain_map[1] = rel_gain_map[0] * frac_hg_mg
             rel_gain_map[2] = rel_gain_map[1] * 4.48
-            self.rel_gain_pc_map[:] = self._xp.asarray(
+            self.rel_gain_map[:] = self._xp.asarray(
                 rel_gain_map.transpose((1, 3, 2, 0)), dtype=np.float32
-            )[:self.constant_memory_cells]
-            if self._md_additional_offset_value is None:
-                md_additional_offset = (
-                    hg_intercept - mg_intercept * frac_hg_mg
-                ).astype(np.float32)
-                self.md_additional_offset[:] = self._xp.asarray(
-                    md_additional_offset.transpose((0, 2, 1)), dtype=np.float32
-                )[:self.constant_memory_cells]
+            )[:self._constant_memory_cells]
+            if self._get("corrections.relGain.adjustMgBaseline"):
+                if self._get("corrections.relGain.overrideMdAdditionalOffset"):
+                    self.md_additional_offset.fill(
+                        self._get("corrections.relGain.mdAdditionalOffset")
+                    )
+                else:
+                    self.md_additional_offset[:] = self._xp.asarray(
+                        (
+                            hg_intercept - mg_intercept * frac_hg_mg
+                        ).astype(np.float32).transpose((0, 2, 1)), dtype=np.float32
+                    )[:self._constant_memory_cells]
+            else:
+                self.md_additional_offset.fill(0)
         elif constant is Constants.SlopesFF:
             # constant shape: y, x, memory cell
             if constant_data.shape[2] == 2:
@@ -199,35 +448,46 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
                 # note: we should not support this in online
                 constant_data = self._xp.broadcast_to(
                     constant_data[..., 0][..., np.newaxis],
-                    (self.pixels_y, self.pixels_x, self.constant_memory_cells),
+                    (
+                        self.num_pixels_fs,
+                        self.num_pixels_ss,
+                        self._constant_memory_cells,
+                    ),
                 )
-            self.rel_gain_xray_map[:] = self._xp.asarray(
+            self.xray_gain_map[:] = self._xp.asarray(
                 constant_data.transpose(), dtype=np.float32
-            )[:self.constant_memory_cells]
+            )[:self._constant_memory_cells]
         else:
             assert constant in bad_pixel_constants
+            if (
+                constant is Constants.BadPixelsCS
+                and self._get("corrections.relGain.sourceOfSlopes") == "PC"
+                or constant is Constants.BadPixelsPC
+                and self._get("corrections.relGain.sourceOfSlopes") == "CS"
+            ):
+                return
             # will simply OR with already loaded, does not take into account which ones
             constant_data = self._xp.asarray(constant_data, dtype=np.uint32)
             if len(constant_data.shape) == 3:
                 if constant_data.shape == (
-                    self.pixels_y,
-                    self.pixels_x,
-                    self.constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
+                    self._constant_memory_cells,
                 ):
                     # BadPixelsFF is not per gain stage - broadcasting along gain
                     constant_data = self._xp.broadcast_to(
                         constant_data.transpose()[..., np.newaxis],
-                        self.gm_map_shape,
+                        self._gm_map_shape,
                     )
                 elif constant_data.shape == (
-                    self.constant_memory_cells,
-                    self.pixels_y,
-                    self.pixels_x,
+                    self._constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
                 ):
                     # old BadPixelsPC have different axis order
                     constant_data = self._xp.broadcast_to(
                         constant_data.transpose((0, 2, 1))[..., np.newaxis],
-                        self.gm_map_shape,
+                        self._gm_map_shape,
                     )
                 else:
                     raise ValueError(
@@ -236,142 +496,126 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
             else:
                 # gain mapped constants seem consistent
                 constant_data = constant_data.transpose((2, 1, 0, 3))
-            self.bad_pixel_map |= constant_data[:self.constant_memory_cells]
-
-    def override_md_additional_offset(self, override_value):
-        self._md_additional_offset_value = override_value
-        if override_value is not None:
-            self.md_additional_offset.fill(override_value)
-
-    def set_g_gain_value(self, override_value):
-        self.g_gain_value = self._xp.float32(override_value)
-
-    def set_bad_pixel_mask_value(self, mask_value):
-        self.bad_pixel_mask_value = self._xp.float32(mask_value)
-
-    def set_mg_hard_threshold(self, value):
-        self.mg_hard_threshold = self._xp.float32(value)
-
-    def set_hg_hard_threshold(self, value):
-        self.hg_hard_threshold = self._xp.float32(value)
+            constant_data &= self._xp.uint32(self.bad_pixel_subset)
+            self.bad_pixel_map |= constant_data[:self._constant_memory_cells]
 
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
             self.offset_map.fill(0)
-        if Constants.SlopesPC in constants:
-            self.rel_gain_pc_map.fill(1)
+        if constants & {Constants.SlopesCS, Constants.SlopesPC}:
+            self.rel_gain_map.fill(1)
             self.md_additional_offset.fill(0)
         if Constants.SlopesFF:
-            self.rel_gain_xray_map.fill(1)
+            self.xray_gain_map.fill(1)
         if constants & bad_pixel_constants:
             self.bad_pixel_map.fill(0)
-            self.bad_pixel_map[
-                :, 64:512:64
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, 63:511:64
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+            if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE:
+                self._mask_asic_seams
+
+    def _mask_asic_seams(self):
+        self.bad_pixel_map[:, 64:512:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[:, 63:511:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
 
 
 class AgipdCpuRunner(AgipdBaseRunner):
     _xp = np
-    _gpu_based = False
 
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        if flags & CorrectionFlags.COMMON_MODE:
+            raise NotImplementedError("Common mode not available on CPU yet")
         self.correction_kernel(
-            self.input_data,
-            self.cell_table,
+            image_data,
+            cell_table,
             flags,
             self.default_gain,
             self.gain_thresholds,
             self.mg_hard_threshold,
             self.hg_hard_threshold,
             self.offset_map,
-            self.rel_gain_pc_map,
+            self.rel_gain_map,
             self.md_additional_offset,
-            self.rel_gain_xray_map,
+            self.xray_gain_map,
             self.g_gain_value,
             self.bad_pixel_map,
             self.bad_pixel_mask_value,
-            self.processed_data,
+            processed_data,
         )
 
-    def _post_init(self):
-        self.input_data = None
-        self.cell_table = None
-        # NOTE: CPU kernel does not yet support anything other than float32
-        self.processed_data = self._xp.empty(
-            self.processed_shape, dtype=self.output_data_dtype
-        )
+    def _pre_init(self):
         from ..kernels import agipd_cython
         self.correction_kernel = agipd_cython.correct
 
-    def load_data(self, image_data, cell_table):
-        self.input_data = image_data
-        self.cell_table = cell_table
-
 
 class AgipdGpuRunner(AgipdBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
-        import cupy as cp
-
-        self._xp = cp
+        import cupy
+        self._xp = cupy
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.processed_shape, self.block_shape
-        )
-        self.correction_kernel = self._xp.RawModule(
+        self.correction_kernel = self._xp.RawKernel(
             code=base_kernel_runner.get_kernel_template("agipd_gpu.cu").render(
-                {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
-                    "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
-                    ),
-                    "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                }
-            )
-        ).get_function("correct")
-
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
+                pixels_x=self.num_pixels_ss,
+                pixels_y=self.num_pixels_fs,
+                output_data_dtype=utils.np_dtype_to_c_type(self._output_dtype),
+                corr_enum=utils.enum_to_c_template(CorrectionFlags),
+            ),
+            name="correct",
+        )
+        self.cm_kernel = self._xp.RawKernel(
+            code=base_kernel_runner.get_kernel_template("common_gpu.cu").render(
+                ss_dim=self.num_pixels_ss,
+                fs_dim=self.num_pixels_fs,
+                asic_dim=64,
+            ),
+            name="common_mode_asic",
+        )
 
-    def correct(self, flags):
-        if flags & CorrectionFlags.BLSHIFT:
-            raise NotImplementedError("Baseline shift not implemented yet")
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            processed_data.shape, block_shape
+        )
 
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.cell_table,
-                self._xp.uint8(flags),
+                image_data,
+                cell_table,
+                flags,
+                num_frames,
+                self._constant_memory_cells,
                 self.default_gain,
                 self.gain_thresholds,
                 self.mg_hard_threshold,
                 self.hg_hard_threshold,
                 self.offset_map,
-                self.rel_gain_pc_map,
+                self.rel_gain_map,
                 self.md_additional_offset,
-                self.rel_gain_xray_map,
+                self.xray_gain_map,
                 self.g_gain_value,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self.gain_map,
-                self.processed_data,
+                gain_map,
+                processed_data,
             ),
         )
 
+        if flags & CorrectionFlags.COMMON_MODE:
+            # grid: one block per ss asic, fs asic, and frame
+            self.cm_kernel(
+                (2, 8, processed_data.shape[0]),
+                (64, 1, 1),
+                (
+                    processed_data,
+                    num_frames,
+                    self.cm_iter,
+                    self.cm_min_frac,
+                    self.cm_noise_peak,
+                ),
+            )
+
 
 class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
@@ -381,9 +625,11 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
         return {
             Constants.ThresholdsDark: self.dark_condition,
             Constants.Offset: self.dark_condition,
+            Constants.SlopesCS: self.dark_condition,
             Constants.SlopesPC: self.dark_condition,
             Constants.SlopesFF: self.illuminated_condition,
             Constants.BadPixelsDark: self.dark_condition,
+            Constants.BadPixelsCS: self.dark_condition,
             Constants.BadPixelsPC: self.dark_condition,
             Constants.BadPixelsFF: self.illuminated_condition,
         }
@@ -489,69 +735,34 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
 @KARABO_CLASSINFO("AgipdCorrection", deviceVersion)
 class AgipdCorrection(base_correction.BaseCorrection):
     # subclass *must* set these attributes
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        # step name (used in schema), flag to enable for kernel, constants required
-        ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}),
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
-        (  # TODO: specify that /both/ constants are needed, not just one or the other
-            "forceMgIfBelow",
-            CorrectionFlags.FORCE_MG_IF_BELOW,
-            {Constants.ThresholdsDark, Constants.Offset},
-        ),
-        (
-            "forceHgIfBelow",
-            CorrectionFlags.FORCE_HG_IF_BELOW,
-            {Constants.ThresholdsDark, Constants.Offset},
-        ),
-        ("relGainPc", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesPC}),
-        ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            bad_pixel_constants | {
-                None,  # means stay available even without constants loaded
-            },
-        ),
-    )
+    _base_output_schema = schemas.xtdf_output_schema
     _calcat_friend_class = AgipdCalcatFriend
     _constant_enum_class = Constants
+    _correction_steps = correction_steps
+
     _preview_outputs = [
-        "outputRaw",
-        "outputCorrected",
-        "outputRawGain",
-        "outputGainMap",
+        PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.CELL)
+        for name in ("raw", "corrected", "rawGain", "gainMap")
     ]
 
-    @staticmethod
-    def expectedParameters(expected):
+    @classmethod
+    def expectedParameters(cls, expected):
+        # this is not automatically done by superclass for complicated class reasons
+        cls._calcat_friend_class.add_schema(expected)
+        AgipdBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(AgipdCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(352)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("cell")
+            .dataSchema(
+                AgipdCorrection._base_output_schema(
+                    use_shmem_handles=cls._use_shmem_handles
+                )
+            )
             .commit(),
         )
 
-        # this is not automatically done by superclass for complicated class reasons
-        AgipdCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, AgipdCorrection)
-        base_correction.add_correction_step_schema(
-            expected,
-            AgipdCorrection._correction_steps,
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        base_correction.add_preview_outputs(expected, AgipdCorrection._preview_outputs)
         (
             # support both CPU and GPU kernels
             STRING_ELEMENT(expected)
@@ -569,130 +780,6 @@ class AgipdCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        # turn off the force MG / HG steps by default
-        for step in ("forceMgIfBelow", "forceHgIfBelow"):
-            for flag in ("enable", "preview"):
-                (
-                    OVERWRITE_ELEMENT(expected)
-                    .key(f"corrections.{step}.{flag}")
-                    .setNewDefaultValue(False)
-                    .commit(),
-                )
-        # additional settings specific to AGIPD correction steps
-        (
-            FLOAT_ELEMENT(expected)
-            .key("corrections.forceMgIfBelow.hardThreshold")
-            .tags("managed")
-            .description(
-                "If enabled, any pixels assigned to low gain stage which would be "
-                "below this threshold after having medium gain offset subtracted will "
-                "be reassigned to medium gain."
-            )
-            .assignmentOptional()
-            .defaultValue(1000)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.forceHgIfBelow.hardThreshold")
-            .tags("managed")
-            .description(
-                "Like forceMgIfBelow, but potentially reassigning from medium gain to "
-                "high gain based on threshold and pixel value minus low gain offset. "
-                "Applied after forceMgIfBelow, pixels can theoretically be reassigned "
-                "twice, from LG to MG and to HG."
-            )
-            .assignmentOptional()
-            .defaultValue(1000)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.relGainPc.overrideMdAdditionalOffset")
-            .tags("managed")
-            .displayedName("Override md_additional_offset")
-            .description(
-                "Toggling this on will use the value in the next field globally for "
-                "md_additional_offset. Note that the correction map on GPU gets "
-                "overwritten as long as this boolean is True, so reload constants "
-                "after turning off."
-            )
-            .assignmentOptional()
-            .defaultValue(False)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.relGainPc.mdAdditionalOffset")
-            .tags("managed")
-            .displayedName("Value for md_additional_offset (if overriding)")
-            .description(
-                "Normally, md_additional_offset (part of relative gain correction) is "
-                "computed when loading SlopesPC. In case you want to use a different "
-                "value (global for all medium gain pixels), you can specify it here "
-                "and set corrections.overrideMdAdditionalOffset to True."
-            )
-            .assignmentOptional()
-            .defaultValue(0)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.gainXray.gGainValue")
-            .tags("managed")
-            .displayedName("G_gain_value")
-            .description(
-                "Newer X-ray gain correction constants are absolute. The default "
-                "G_gain_value of 1 means that output is expected to be in keV. If "
-                "this is not desired, one can here specify the mean X-ray gain value "
-                "over all modules to get ADU values out - operator must manually "
-                "find this mean value."
-            )
-            .assignmentOptional()
-            .defaultValue(1)
-            .reconfigurable()
-            .commit(),
-        )
-
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            2,
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        # note: gain mode single sourced from constant retrieval node
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        self._has_updated_bad_pixel_selection = False
-
-    @property
-    def bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection)
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "gain_mode": self.gain_mode,
-            "bad_pixel_mask_value": self.bad_pixel_mask_value,
-            "g_gain_value": self.unsafe_get("corrections.gainXray.gGainValue"),
-            "mg_hard_threshold": self.unsafe_get(
-                "corrections.forceMgIfBelow.hardThreshold"
-            ),
-            "hg_hard_threshold": self.unsafe_get(
-                "corrections.forceHgIfBelow.hardThreshold"
-            ),
-        }
-
     @property
     def _kernel_runner_class(self):
         kernel_type = base_kernel_runner.KernelRunnerTypes[
@@ -702,159 +789,3 @@ class AgipdCorrection(base_correction.BaseCorrection):
             return AgipdCpuRunner
         else:
             return AgipdGpuRunner
-
-    @property
-    def gain_mode(self):
-        return GainModes[self.unsafe_get("constantParameters.gainMode")]
-
-    @property
-    def _override_md_additional_offset(self):
-        if self.unsafe_get("corrections.relGainPc.overrideMdAdditionalOffset"):
-            return self.unsafe_get("corrections.relGainPc.mdAdditionalOffset")
-        else:
-            return None
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        """Called by input_handler for each data hash. Should correct data, optionally
-        compute preview, write data output, and optionally write preview outputs."""
-        # original shape: frame, data/raw_gain, x, y
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            self.kernel_runner.load_data(image_data, cell_table.ravel())
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            self.log_status_warn(
-                f"For debugging - cell table: {type(cell_table)}, {cell_table}"
-            )
-            self.log_status_warn(
-                f"For debugging - data: {type(image_data)}, {image_data.shape}"
-            )
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        # first prepare previews (not affected by addons)
-        self.kernel_runner.correct(self._correction_flag_preview)
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_corrected,
-            preview_raw_gain,
-            preview_gain_map,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-        # TODO: start writing out previews asynchronously in the background
-
-        # then prepare full corrected data
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        for addon in self._enabled_addons:
-            addon.post_reshape(
-                buffer_array, cell_table, pulse_table, data_hash
-            )
-
-        # reusing input data hash for sending
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_raw_gain, preview_gain_map
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        if constant in bad_pixel_constants:
-            constant_data &= self._override_bad_pixel_flags
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def preReconfigure(self, config):
-        super().preReconfigure(config)
-        if config.has("corrections.badPixels.maskingValue"):
-            # only check if it is valid; postReconfigure will use it
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-
-    def postReconfigure(self):
-        super().postReconfigure()
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("constantParameters.gainMode"):
-            self.flush_constants()
-            self._update_buffers()
-
-        if update.has("corrections.forceMgIfBelow.hardThreshold"):
-            self.kernel_runner.set_mg_hard_threshold(
-                self.get("corrections.forceMgIfBelow.hardThreshold")
-            )
-
-        if update.has("corrections.forceHgIfBelow.hardThreshold"):
-            self.kernel_runner.set_hg_hard_threshold(
-                self.get("corrections.forceHgIfBelow.hardThreshold")
-            )
-
-        if update.has("corrections.gainXray.gGainValue"):
-            self.kernel_runner.set_g_gain_value(
-                self.get("corrections.gainXray.gGainValue")
-            )
-
-        if update.has("corrections.badPixels.maskingValue"):
-            self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value)
-
-        if update.has("corrections.badPixels.subsetToUse"):
-            self.log_status_info("Updating bad pixel maps based on subset specified")
-            # note: now just always reloading from cache for convenience
-            with self.calcat_friend.cached_constants_lock:
-                self.kernel_runner.flush_buffers(bad_pixel_constants)
-                for constant in bad_pixel_constants:
-                    if constant in self.calcat_friend.cached_constants:
-                        self._load_constant_to_runner(
-                            constant, self.calcat_friend.cached_constants[constant]
-                        )
diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py
index 9a6cc613b0360c93249d47aed841cdce85045869..3611b26721e42ac4f500a7940e3bea9489cba4d8 100644
--- a/src/calng/corrections/DsscCorrection.py
+++ b/src/calng/corrections/DsscCorrection.py
@@ -8,7 +8,14 @@ from karabo.bound import (
     STRING_ELEMENT,
 )
 
-from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils
+from .. import (
+    base_calcat,
+    base_correction,
+    base_kernel_runner,
+    schemas,
+    utils,
+)
+from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
@@ -21,131 +28,98 @@ class Constants(enum.Enum):
     Offset = enum.auto()
 
 
-class DsscBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fyx"
-    _xp = None
+correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),)
 
-    @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_y, self.pixels_x)
 
-    @property
-    def input_shape(self):
-        return (self.frames, self.pixels_y, self.pixels_x)
+class DsscBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 128
+    num_pixels_fs = 512
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.processed_data = self._xp.empty(
-            self.processed_shape, dtype=output_data_dtype
+    def expected_input_shape(self, num_frames):
+        return (
+            num_frames,
+            1,
+            self.num_pixels_ss,
+            self.num_pixels_fs,
         )
 
+    def _setup_constant_buffers(self):
+        self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32)
+
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
             self.offset_map.fill(0)
 
-    @property
-    def preview_data_views(self):
-        return (
-            self.input_data,
-            self.processed_data,
-        )
-
-    def load_offset_map(self, offset_map):
+    def _load_constant(self, constant_type, data):
+        assert constant_type is Constants.Offset
         # can have an extra dimension for some reason
-        if len(offset_map.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
-            offset_map = offset_map[..., 0]
+        if len(data.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
+            data = data[..., 0]
         # shape (now): x, y, memory cell
-        offset_map = self._xp.asarray(offset_map, dtype=np.float32).transpose()
-        self.offset_map[:] = offset_map[:self.constant_memory_cells]
+        data = self._xp.asarray(data, dtype=np.float32).transpose()
+        self.offset_map[:] = data[:self._constant_memory_cells]
+
+    def _preview_data_views(self, raw_data, processed_data):
+        return (
+            raw_data[:, 0],
+            processed_data,
+        )
 
 
 class DsscCpuRunner(DsscBaseRunner):
-    _gpu_based = False
     _xp = np
 
     def _post_init(self):
-        self.input_data = None
-        self.cell_table = None
         from ..kernels import dssc_cython
-
         self.correction_kernel = dssc_cython.correct
 
-    def load_data(self, image_data, cell_table):
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.cell_table = cell_table.astype(np.uint16, copy=False)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, output):
         self.correction_kernel(
-            self.input_data,
-            self.cell_table,
+            image_data,
+            cell_table,
             flags,
             self.offset_map,
-            self.processed_data,
+            output,
         )
 
 
 class DsscGpuRunner(DsscBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
-        import cupy as cp
-
-        self._xp = cp
+        import cupy
+        self._xp = cupy
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.input_shape, self.block_shape
-        )
         self.correction_kernel = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("dssc_gpu.cu").render(
                 {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
+                    "pixels_x": self.num_pixels_fs,
+                    "pixels_y": self.num_pixels_ss,
                     "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
+                        self._output_dtype
                     ),
                     "corr_enum": utils.enum_to_c_template(CorrectionFlags),
                 }
             )
         ).get_function("correct")
 
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, output):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(output.shape, block_shape)
+        image_data = image_data[:, 0]
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.cell_table,
-                np.uint8(flags),
+                image_data,
+                cell_table,
+                num_frames,
+                self._constant_memory_cells,
+                flags,
                 self.offset_map,
-                self.processed_data,
+                output,
             ),
         )
 
@@ -185,44 +159,30 @@ class DsscCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("DsscCorrection", deviceVersion)
 class DsscCorrection(base_correction.BaseCorrection):
-    # subclass *must* set these attributes
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),)
+    _base_output_schema = schemas.xtdf_output_schema
+    _correction_steps = correction_steps
     _calcat_friend_class = DsscCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.PULSE)
+        for name in ("raw", "corrected")
+    ]
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        DsscBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(DsscCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(400)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewDefaultValue("fyx")
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("pulse")
+            .dataSchema(
+                DsscCorrection._base_output_schema(cls._use_shmem_handles)
+            )
             .commit(),
         )
-        DsscCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, DsscCorrection)
-        base_correction.add_correction_step_schema(
-            expected,
-            DsscCorrection._correction_steps,
-        )
-        base_correction.add_preview_outputs(expected, DsscCorrection._preview_outputs)
         (
             # support both CPU and GPU kernels
             STRING_ELEMENT(expected)
@@ -240,15 +200,6 @@ class DsscCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-    @property
-    def input_data_shape(self):
-        return (
-            self.get("dataFormat.frames"),
-            1,
-            self.get("dataFormat.pixelsY"),
-            self.get("dataFormat.pixelsX"),
-        )
-
     @property
     def _kernel_runner_class(self):
         kernel_type = base_kernel_runner.KernelRunnerTypes[
@@ -258,80 +209,3 @@ class DsscCorrection(base_correction.BaseCorrection):
             return DsscCpuRunner
         else:
             return DsscGpuRunner
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            # drop the singleton dimension for kernel runner
-            self.kernel_runner.load_data(image_data[:, 0], cell_table.ravel())
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-            preview_raw, preview_corrected = self.kernel_runner.compute_previews(
-                preview_slice_index,
-            )
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(metadata, preview_raw, preview_corrected)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_offset_map(constant_data)
diff --git a/src/calng/corrections/Epix100Correction.py b/src/calng/corrections/Epix100Correction.py
index 5bd981bb7839ed41dff0135daccb47d1b2335e74..b541a72af5a96d70cd320140026284e0bfb40ace 100644
--- a/src/calng/corrections/Epix100Correction.py
+++ b/src/calng/corrections/Epix100Correction.py
@@ -1,6 +1,5 @@
 import concurrent.futures
 import enum
-import functools
 
 import numpy as np
 from karabo.bound import (
@@ -19,6 +18,7 @@ from .. import (
     utils,
 )
 from .._version import version as deviceVersion
+from ..preview_utils import PreviewFriend, PreviewSpec
 from ..kernels import common_cython
 
 
@@ -37,6 +37,14 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 2**4
 
 
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}),
+    ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}),
+    ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}),
+    ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}),
+)
+
+
 class Epix100CalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -138,84 +146,119 @@ class Epix100CalcatFriend(base_calcat.BaseCalcatFriend):
 
 
 class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "xy"
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _bad_pixel_constants = {Constants.BadPixelsDarkEPix100}
     _xp = np
-    _gpu_based = False
 
-    @property
-    def input_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    num_pixels_ss = 708
+    num_pixels_fs = 768
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.noiseSigma")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(0.25)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def map_shape(self):
-        return (self.pixels_x, self.pixels_y)
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableRow")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,  # will be 1, will be ignored
-        constant_memory_cells,
-        input_data_dtype=np.uint16,
-        output_data_dtype=np.float32,
-    ):
-        assert (
-            output_data_dtype == np.float32
-        ), "Alternative output types not supported yet"
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableCol")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableBlock")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
 
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
+    def reconfigure(self, config):
+        super().reconfigure(config)
+        if config.has("corrections.commonMode.noiseSigma"):
+            self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"])
+        if config.has("corrections.commonMode.minFrac"):
+            self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"])
+        if config.has("corrections.commonMode.enableRow"):
+            self._cm_row = config["corrections.commonMode.enableRow"]
+        if config.has("corrections.commonMode.enableCol"):
+            self._cm_col = config["corrections.commonMode.enableCol"]
+        if config.has("corrections.commonMode.enableBlock"):
+            self._cm_block = config["corrections.commonMode.enableBlock"]
+
+    def expected_input_data_shape(self, num_frames):
+        assert num_frames == 1
+        return (
+            self.num_pixels_ss,
+            self.num_pixels_fs,
         )
 
-        self.input_data = None
-        self.processed_data = np.empty(self.processed_shape, dtype=np.float32)
+    def _expected_output_shape(self, num_frames):
+        return (self.num_pixels_ss, self.num_pixels_fs)
 
-        self.offset_map = np.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.noise_map = np.empty(self.map_shape, dtype=np.float32)
+    @property
+    def _map_shape(self):
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        # ignore parameters
+        return [
+            self._xp.empty(
+                (self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
+
+    def _setup_constant_buffers(self):
+        assert (
+            self._output_dtype == np.float32
+        ), "Alternative output types not supported yet"
+
+        self.offset_map = np.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32)
+        self.noise_map = np.empty(self._map_shape, dtype=np.float32)
         # will do everything by quadrant
         self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
-        self._q_input_data = None
-        self._q_processed_data = utils.quadrant_views(self.processed_data)
         self._q_offset_map = utils.quadrant_views(self.offset_map)
         self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map)
         self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map)
         self._q_noise_map = utils.quadrant_views(self.noise_map)
 
-    def __del__(self):
-        self.thread_pool.shutdown()
-
-    @property
-    def preview_data_views(self):
-        # NOTE: apparently there are "calibration rows"
-        # geometry assumes these are cut out already
-        return (self.input_data[2:-2], self.processed_data[2:-2])
-
-    def load_data(self, image_data):
-        # should almost never squeeze, but ePix doesn't do burst mode, right?
-        self.input_data = image_data.astype(np.uint16, copy=False).squeeze()
-        self._q_input_data = utils.quadrant_views(self.input_data)
-
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.OffsetEPix100:
             self.offset_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.RelativeGainEPix100:
             self.rel_gain_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.BadPixelsDarkEPix100:
-            self.bad_pixel_map[:] = data.squeeze()
+            self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset)
         elif constant_type is Constants.NoiseEPix100:
             self.noise_map[:] = data.squeeze()
         else:
@@ -231,127 +274,115 @@ class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner):
         if Constants.NoiseEPix100 in constants:
             self.noise_map.fill(np.inf)
 
+    def _preview_data_views(self, raw_data, processed_data):
+        # NOTE: apparently there are "calibration rows"
+        # geometry assumes these are cut out already
+        return (raw_data[2:-2], processed_data[2:-2])
+
     def _correct_quadrant(
         self,
-        q,
         flags,
-        bad_pixel_mask_value,
-        cm_noise_sigma,
-        cm_min_frac,
-        cm_row,
-        cm_col,
-        cm_block,
+        input_data,
+        offset_map,
+        noise_map,
+        bad_pixel_map,
+        rel_gain_map,
+        output,
     ):
-        output = self._q_processed_data[q]
-        output[:] = self._q_input_data[q].astype(np.float32)
+        output[:] = input_data.astype(np.float32)
         if flags & CorrectionFlags.OFFSET:
-            output -= self._q_offset_map[q]
+            output -= offset_map
 
         if flags & CorrectionFlags.COMMONMODE:
             # per rectangular block that looks like something is going on
             cm_mask = (
-                (self._q_bad_pixel_map[q] != 0)
-                | (output > self._q_noise_map[q] * cm_noise_sigma)
+                (bad_pixel_map != 0)
+                | (output > noise_map * self._cm_sigma)
             ).astype(np.uint8, copy=False)
             masked = np.ma.masked_array(data=output, mask=cm_mask)
-            if cm_block:
+            if self._cm_block:
                 for block in np.hsplit(masked, 4):
-                    if block.count() < block.size * cm_min_frac:
+                    if block.count() < block.size * self._cm_minfrac:
                         continue
                     block.data[:] -= np.ma.median(block)
 
-            if cm_row:
-                common_cython.cm_fs(output, cm_mask, cm_noise_sigma, cm_min_frac)
+            if self._cm_row:
+                common_cython.cm_fs(output, cm_mask, self._cm_sigma, self._cm_minfrac)
 
-            if cm_col:
-                common_cython.cm_ss(output, cm_mask, cm_noise_sigma, cm_min_frac)
+            if self._cm_col:
+                common_cython.cm_ss(output, cm_mask, self._cm_sigma, self._cm_minfrac)
 
         if flags & CorrectionFlags.RELGAIN:
-            output *= self._q_rel_gain_map[q]
+            output *= rel_gain_map
 
         if flags & CorrectionFlags.BPMASK:
-            output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value
+            output[bad_pixel_map != 0] = self.bad_pixel_mask_value
+
+    def _correct(self, flags, image_data, cell_table, output):
+        # ignore cell table None
+        concurrent.futures.wait(
+            [
+                self.thread_pool.submit(
+                    self._correct_quadrant(flags, *parts)
+                )
+                for parts in zip(
+                        utils.quadrant_views(image_data),
+                        self._q_offset_map,
+                        self._q_noise_map,
+                        self._q_bad_pixel_map,
+                        self._q_rel_gain_map,
+                        utils.quadrant_views(output),
+                )
+            ]
+        )
 
-    def correct(
-        self,
-        flags,
-        bad_pixel_mask_value=np.nan,
-        cm_noise_sigma=5,
-        cm_min_frac=0.25,
-        cm_row=True,
-        cm_col=True,
-        cm_block=True,
-    ):
-        # NOTE: how to best clean up all these duplicated parameters?
-        for result in self.thread_pool.map(
-            functools.partial(
-                self._correct_quadrant,
-                flags=flags,
-                bad_pixel_mask_value=bad_pixel_mask_value,
-                cm_noise_sigma=cm_noise_sigma,
-                cm_min_frac=cm_min_frac,
-                cm_row=cm_row,
-                cm_col=cm_col,
-                cm_block=cm_block,
-            ),
-            range(4),
-        ):
-            pass  # just run through to await map
+    def __del__(self):
+        self.thread_pool.shutdown()
 
 
 @KARABO_CLASSINFO("Epix100Correction", deviceVersion)
 class Epix100Correction(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}),
-        ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}),
-        ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}),
-        ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}),
-    )
-    _image_data_path = "data.image.pixels"
+    _base_output_schema = schemas.jf_output_schema
     _kernel_runner_class = Epix100CpuRunner
-    _calcat_friend_class = Epix100CalcatFriend
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
+    _calcat_friend_class = Epix100CalcatFriend
+
+    _image_data_path = "data.image.pixels"
     _cell_table_path = None
     _pulse_table_path = None
-    _warn_memory_cell_range = False
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=False,
+        )
+        for name in ("raw", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(Epix100Correction._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(708)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(768)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewOptions("xy,yx")
-            .setNewDefaultValue("xy")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # TODO: disable preview selection mode
 
             OVERWRITE_ELEMENT(expected)
             .key("useShmemHandles")
-            .setNewDefaultValue(False)
+            .setNewDefaultValue(cls._use_shmem_handles)
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
@@ -360,112 +391,11 @@ class Epix100Correction(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, Epix100Correction._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            Epix100Correction._correction_steps,
-        )
-        (
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.noiseSigma")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(5)
-            .reconfigurable()
-            .commit(),
-
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.minFrac")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(0.25)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableRow")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableCol")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableBlock")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-        Epix100CalcatFriend.add_schema(expected)
-        # TODO: bad pixel node?
-
-    @property
-    def input_data_shape(self):
-        # TODO: check
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
         return (
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,  # will be None
-        pulse_table,  # ditto
-    ):
-        self.kernel_runner.load_data(image_data)
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        args_which_should_be_cached = dict(
-            cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"),
-            cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"),
-            cm_row=self.unsafe_get("corrections.commonMode.enableRow"),
-            cm_col=self.unsafe_get("corrections.commonMode.enableCol"),
-            cm_block=self.unsafe_get("corrections.commonMode.enableBlock"),
+            1,
+            image_data,
+            None,
+            None,
         )
-        self.kernel_runner.correct(
-            flags=self._correction_flag_enabled, **args_which_should_be_cached
-        )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(
-                flags=self._correction_flag_preview,
-                **args_which_should_be_cached,
-            )
-
-        if self._use_shmem_handles:
-            # TODO: consider better name for key...
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        # note: base class preview machinery assumes burst mode, shortcut it
-        self._preview_friend.write_outputs(
-            metadata, *self.kernel_runner.preview_data_views
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/corrections/Gotthard2Correction.py b/src/calng/corrections/Gotthard2Correction.py
index aca011ccdbdb16eaad0b1c747061164304968186..f8fa85bb6aa2684a9f33194bf7bdcbc77fd034ca 100644
--- a/src/calng/corrections/Gotthard2Correction.py
+++ b/src/calng/corrections/Gotthard2Correction.py
@@ -6,17 +6,18 @@ from karabo.bound import (
     KARABO_CLASSINFO,
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
-    Hash,
-    Timestamp,
 )
 
-from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils
+from .. import (
+    base_calcat,
+    base_correction,
+    base_kernel_runner,
+    schemas,
+)
+from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec
 from .._version import version as deviceVersion
 
 
-_pretend_pulse_table = np.arange(2720, dtype=np.uint8)
-
-
 class Constants(enum.Enum):
     LUTGotthard2 = enum.auto()
     OffsetGotthard2 = enum.auto()
@@ -25,7 +26,7 @@ class Constants(enum.Enum):
     BadPixelsFFGotthard2 = enum.auto()
 
 
-bp_constant_types = {
+bad_pixel_constants = {
     Constants.BadPixelsDarkGotthard2,
     Constants.BadPixelsFFGotthard2,
 }
@@ -39,106 +40,130 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 8
 
 
-class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
-    _gpu_based = False
-    _xp = np
+correction_steps = (
+    (
+        "lut",
+        CorrectionFlags.LUT,
+        {Constants.LUTGotthard2},
+    ),
+    (
+        "offset",
+        CorrectionFlags.OFFSET,
+        {Constants.OffsetGotthard2},
+    ),
+    (
+        "gain",
+        CorrectionFlags.GAIN,
+        {Constants.RelativeGainGotthard2},
+    ),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        {
+            Constants.BadPixelsDarkGotthard2,
+            Constants.BadPixelsFFGotthard2,
+        }
+    ),
+)
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        bad_pixel_subset=None,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
 
-        from ..kernels import gotthard2_cython
+def framesums_reduction_fun(data, cell_table, pulse_table, warn_fun):
+    return np.nansum(data, axis=1)
 
-        self.correction_kernel = gotthard2_cython.correct
-        self.input_shape = (frames, pixels_x)
-        self.processed_shape = self.input_shape
+
+class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _xp = np
+
+    num_pixels_ss = None  # 1d detector, should never access this attribute
+    num_pixels_fs = 1280
+
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+
+    def expected_input_shape(self, num_frames):
+        return (num_frames, self.num_pixels_fs)
+
+    def _expected_output_shape(self, num_frames):
+        return (num_frames, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        return [
+            self._xp.empty(
+                (num_frames, self.num_pixels_fs), dtype=np.float32)
+           ]
+
+    def _preview_data_views(self, raw_data, raw_gain, processed_data):
+        return [
+            raw_data,  # raw 1d
+            raw_gain,  # gain 1d
+            processed_data,  # corrected 1d
+            processed_data,  # frame sums
+            raw_data,  # raw streak (2d)
+            raw_gain,  # gain streak
+            processed_data,  # corrected streak
+        ]
+
+    def _setup_constant_buffers(self):
         # model: 2 buffers (corresponding to actual memory cells), 2720 frames
         # lut maps from uint12 to uint10 values
-        self.lut_shape = (2, 4096, pixels_x)
-        self.map_shape = (3, self.constant_memory_cells, self.pixels_x)
+        # TODO: override superclass properties instead
+        self.lut_shape = (2, 4096, self.num_pixels_fs)
+        self.map_shape = (3, self._constant_memory_cells, self.num_pixels_fs)
         self.lut = np.empty(self.lut_shape, dtype=np.uint16)
         self.offset_map = np.empty(self.map_shape, dtype=np.float32)
         self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
         self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-        self.flush_buffers()
-        self._bad_pixel_subset = bad_pixel_subset
+        self.flush_buffers(set(Constants))
 
-        self.input_data = None  # will just point to data coming in
-        self.input_gain_stage = None  # will just point to data coming in
-        self.processed_data = None  # will just point to buffer we're given
-
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.input_gain_stage, self.processed_data)
+    def _post_init(self):
+        from ..kernels import gotthard2_cython
+        self.correction_kernel = gotthard2_cython.correct
 
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.LUTGotthard2:
             self.lut[:] = np.transpose(data.astype(np.uint16, copy=False), (1, 2, 0))
         elif constant_type is Constants.OffsetGotthard2:
             self.offset_map[:] = np.transpose(data.astype(np.float32, copy=False))
         elif constant_type is Constants.RelativeGainGotthard2:
             self.rel_gain_map[:] = np.transpose(data.astype(np.float32, copy=False))
-        elif constant_type in bp_constant_types:
+        elif constant_type in bad_pixel_constants:
             # TODO: add the regular bad pixel subset configuration
-            data = data.astype(np.uint32, copy=False)
-            if self._bad_pixel_subset is not None:
-                data = data & self._bad_pixel_subset
-            self.bad_pixel_map |= np.transpose(data)
+            self.bad_pixel_map |= (np.transpose(data) & self.bad_pixel_subset)
         else:
             raise ValueError(f"What is this constant '{constant_type}'?")
 
-    def set_bad_pixel_subset(self, subset, apply_now=True):
-        self._bad_pixel_subset = subset
-        if apply_now:
-            self.bad_pixel_map &= self._bad_pixel_subset
-
-    def load_data(self, image_data, input_gain_stage):
-        """Experiment: loading both in one function as they are tied"""
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False)
-
-    def flush_buffers(self, constants=None):
-        # for now, always flushing all here...
-        default_lut = (
-            np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12
-        ).astype(np.uint16)
-        self.lut[:] = np.stack([np.stack([default_lut] * 2)] * self.pixels_x, axis=2)
-        self.offset_map.fill(0)
-        self.rel_gain_map.fill(1)
-        self.bad_pixel_map.fill(0)
-
-    def correct(self, flags, out=None):
-        if out is None:
-            out = np.empty(self.processed_shape, dtype=np.float32)
-
+    def flush_buffers(self, constants):
+        if Constants.LUTGotthard2 in constants:
+            default_lut = (
+                np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12
+            ).astype(np.uint16)
+            self.lut[:] = np.stack(
+                [np.stack([default_lut] * 2)] * self.num_pixels_fs, axis=2
+            )
+        if Constants.OffsetGotthard2 in constants:
+            self.offset_map.fill(0)
+        if Constants.RelativeGainGotthard2 in constants:
+            self.rel_gain_map.fill(1)
+        if bad_pixel_constants & constants:
+            self.bad_pixel_map.fill(0)
+
+    def _correct(self, flags, raw_data, cell_table, gain_stages, processed_data):
         self.correction_kernel(
-            self.input_data,
-            self.input_gain_stage,
-            np.uint8(flags),
+            raw_data,
+            gain_stages,
+            flags,
             self.lut,
             self.offset_map,
             self.rel_gain_map,
             self.bad_pixel_map,
             self.bad_pixel_mask_value,
-            out,
+            processed_data,
         )
-        self.processed_data = out
-        return out
 
 
 class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend):
@@ -227,236 +252,109 @@ class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("Gotthard2Correction", deviceVersion)
 class Gotthard2Correction(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema(use_shmem_handle=False)
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        (
-            "lut",
-            CorrectionFlags.LUT,
-            {Constants.LUTGotthard2},
-        ),
-        (
-            "offset",
-            CorrectionFlags.OFFSET,
-            {Constants.OffsetGotthard2},
-        ),
-        (
-            "gain",
-            CorrectionFlags.GAIN,
-            {Constants.RelativeGainGotthard2},
-        ),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDarkGotthard2,
-                Constants.BadPixelsFFGotthard2,
-            }
-        ),
-    )
+    _base_output_schema = schemas.jf_output_schema
     _kernel_runner_class = Gotthard2CpuRunner
     _calcat_friend_class = Gotthard2CalcatFriend
+    _correction_steps = correction_steps
     _constant_enum_class = Constants
+
     _image_data_path = "data.adc"
-    _cell_table_path = "data.memoryCell"
+    _cell_table_path = None
     _pulse_table_path = None
+
     _warn_memory_cell_range = False  # for now, receiver always writes 255
-    _preview_outputs = ["outputStreak", "outputGainStreak"]
     _cuda_pin_buffers = False
+    _use_shmem_handles = False
+
+    _preview_outputs = [
+        PreviewSpec(
+            "raw",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "gain",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "corrected",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "framesums",
+            dimensions=1,
+            frame_reduction=False,
+            frame_reduction_fun=framesums_reduction_fun,
+            wrap_in_imagedata=False,
+        ),
+        PreviewSpec(
+            "rawStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+        PreviewSpec(
+            "gainStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+        PreviewSpec(
+            "correctedStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+    ]
 
-    @staticmethod
-    def expectedParameters(expected):
-        super(Gotthard2Correction, Gotthard2Correction).expectedParameters(expected)
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(Gotthard2Correction._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1280)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(2720)  # note: actually just frames...
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
             .key("outputShmemBufferSize")
             .setNewDefaultValue(2)
             .commit(),
-        )
-
-        base_correction.add_preview_outputs(
-            expected, Gotthard2Correction._preview_outputs
-        )
-        for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
-            # add these "manually" as the automated bits wrap ImageData
-            (
-                OUTPUT_CHANNEL(expected)
-                .key(f"preview.{channel}")
-                .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=False))
-                .commit(),
-            )
-        base_correction.add_correction_step_schema(
-            expected,
-            Gotthard2Correction._correction_steps,
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        Gotthard2CalcatFriend.add_schema(expected)
 
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            OVERWRITE_ELEMENT(expected)
+            .key("useShmemHandles")
+            .setNewDefaultValue(cls._use_shmem_handles)
+            .commit(),
         )
 
-    @property
-    def output_data_shape(self):
+    def _get_data_from_hash(self, data_hash):
+        image_data = np.asarray(
+            data_hash.get(self._image_data_path)
+        ).astype(np.uint16, copy=False)
+        gain_data = np.asarray(
+            data_hash.get("data.gain")
+        ).astype(np.uint8, copy=False)
+        num_frames = image_data.shape[0]
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(num_frames)
+            image_data.shape = expected_shape
+            gain_data.shape = expected_shape
         return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            num_frames,
+            image_data,
+            None,  # no cell table
+            None,  # no pulse table
+            gain_data,
         )
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "bad_pixel_mask_value": self._bad_pixel_mask_value,
-            "bad_pixel_subset": self._bad_pixel_subset,
-        }
-
-    @property
-    def _bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _bad_pixel_subset = property(base_correction.get_bad_pixel_field_selection)
-
-    def postReconfigure(self):
-        super().postReconfigure()
-        update = self._prereconfigure_update_hash
-        if update.has("corrections.badPixels.subsetToUse"):
-            if any(
-                update.get(
-                    f"corrections.badPixels.subsetToUse.{field.name}", default=False
-                )
-                for field in utils.BadPixelValues
-            ):
-                self.log_status_info(
-                    "Some fields reenabled, reloading cached bad pixel constants"
-                )
-                self.kernel_runner.set_bad_pixel_subset(
-                    self._bad_pixel_subset, apply_now=False
-                )
-                with self.calcat_friend.cached_constants_lock:
-                    self.kernel_runner.flush_buffers(bp_constant_types)
-                    for constant_type in (
-                        bp_constant_types & self.calcat_friend.cached_constants.keys()
-                    ):
-                        self._load_constant_to_runner(
-                            constant_type,
-                            self.calcat_friend.cached_constants[constant_type],
-                        )
-            else:
-                # just narrowing the subset - no reload, just AND
-                self.kernel_runner.set_bad_pixel_subset(
-                    self._bad_pixel_subset, apply_now=True
-                )
-        if update.has("corrections.badPixels.maskingvalue"):
-            self.kernel_runner.bad_pixel_mask_value = self._bad_pixel_mask_value
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        # cell table currently not used for GOTTHARD2 (assume alternating)
-        gain_map = np.asarray(data_hash.get("data.gain"))
-        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-            gain_map.shape = self.input_data_shape
-        try:
-            self.kernel_runner.load_data(image_data, gain_map)
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled, out=buffer_array)
-
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                _pretend_pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_gain,
-            preview_corrected,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        frame_sums = np.nansum(buffer_array, axis=1)
-        timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
-        for channel, data in (
-            ("outputRaw", preview_raw),
-            ("outputGain", preview_gain),
-            ("outputCorrected", preview_corrected),
-            ("outputFrameSums", frame_sums),
-        ):
-            self.writeChannel(
-                f"preview.{channel}",
-                Hash(
-                    "image.data",
-                    data,
-                    "image.mask",
-                    (~np.isfinite(data)).astype(np.uint8),
-                ),
-                timestamp=timestamp,
-            )
-        self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py
index ab5af5f688fa74c32ab02195fffde302cf305ef9..c51cbe0865be0c9da5034f11785a70f219edf754 100644
--- a/src/calng/corrections/JungfrauCorrection.py
+++ b/src/calng/corrections/JungfrauCorrection.py
@@ -9,7 +9,6 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -19,12 +18,10 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec
 from .._version import version as deviceVersion
 
 
-_pretend_pulse_table = np.arange(16, dtype=np.uint8)
-
-
 class Constants(enum.Enum):
     Offset10Hz = enum.auto()
     BadPixelsDark10Hz = enum.auto()
@@ -57,66 +54,95 @@ class CorrectionFlags(enum.IntFlag):
     STRIXEL = 8
 
 
-class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fyx"
-    _xp = None  # subclass sets CuPy (import at runtime) or numpy
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}),
+    ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        {
+            Constants.BadPixelsDark10Hz,
+            Constants.BadPixelsFF10Hz,
+            None,
+        },
+    ),
+    ("strixel", CorrectionFlags.STRIXEL, set()),
+)
 
-    @property
-    def input_shape(self):
-        return self.frames, self.pixels_y, self.pixels_x
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 512
+    num_pixels_fs = 1024
 
     @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        config,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
+    def _map_shape(self):
+        return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs, 3)
+
+    def _setup_constant_buffers(self):
         # note: superclass creates cell table with wrong dtype
-        self.input_gain_stage = self._xp.empty(self.input_shape, dtype=np.uint8)
-        self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = self._xp.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-
-        self.output_dtype = output_data_dtype
-        self._processed_data_regular = self._xp.empty(
-            self.processed_shape, dtype=output_data_dtype
-        )
-        self._processed_data_strixel = None
+        self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._map_shape, dtype=np.uint32)
         self.flush_buffers(set(Constants))
-        self.correction_kernel_strixel = None
-        self.reconfigure(config)
+
+    def _make_output_buffers(self, num_frames, flags):
+        if flags & CorrectionFlags.STRIXEL:
+            return [
+                self._xp.empty(
+                    (num_frames,) + self._strixel_out_shape, dtype=self._output_dtype
+                )
+            ]
+        else:
+            return [
+                self._xp.empty(
+                    (num_frames, self.num_pixels_ss, self.num_pixels_fs),
+                    dtype=self._output_dtype,
+                )
+            ]
+
+    def _expected_output_shape(self, num_frames):
+        # TODO: think of how to unify with _make_output_buffers
+        if self._correction_flag_enabled & CorrectionFlags.STRIXEL:
+            return (num_frames,) + self._strixel_out_shape
+        else:
+            return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            OVERWRITE_ELEMENT(schema)
+            .key("corrections.strixel.enable")
+            .setNewDefaultValue(False)
+            .commit(),
+
+            OVERWRITE_ELEMENT(schema)
+            .key("corrections.strixel.preview")
+            .setNewDefaultValue(False)
+            .commit(),
+
+            STRING_ELEMENT(schema)
+            .key("corrections.strixel.type")
+            .description(
+                "Which kind of strixel layout is used for this module? cols_A0123 is "
+                "the first strixel layout deployed at HED and rows_A1256 is the one "
+                "later deployed at SCS."
+            )
+            .assignmentOptional()
+            .defaultValue("cols_A0123(HED-type)")
+            .options("cols_A0123(HED-type),rows_A1256(SCS-type)")
+            .reconfigurable()
+            .commit(),
+        )
 
     def reconfigure(self, config):
-        # note: regular bad pixel masking uses device property (TODO: unify)
-        if (mask_value := config.get("badPixels.maskingValue")) is not None:
-            self._bad_pixel_mask_value = self._xp.float32(mask_value)
-            # this is a functools.partial, can just update the captured parameter
-            if self.correction_kernel_strixel is not None:
-                self.correction_kernel_strixel.keywords[
-                    "missing"
-                ] = self._bad_pixel_mask_value
-
-        if (strixel_type := config.get("strixel.type")) is not None:
+        super().reconfigure(config)
+
+        if (strixel_type := config.get("corrections.strixel.type")) is not None:
             # drop the friendly parenthesized name
             strixel_type = strixel_type.partition("(")[0]
             strixel_package = np.load(
@@ -124,18 +150,17 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
                 / f"strixel_{strixel_type}-lut_mask.npz"
             )
             self._strixel_out_shape = tuple(strixel_package["frame_shape"])
-            self._processed_data_strixel = None
             # TODO: use bad pixel masking config here
             self.correction_kernel_strixel = functools.partial(
                 utils.apply_partial_lut,
                 lut=self._xp.asarray(strixel_package["lut"]),
                 mask=self._xp.asarray(strixel_package["mask"]),
-                missing=self._bad_pixel_mask_value,
+                missing=self.bad_pixel_mask_value,
             )
             # note: always masking unmapped pixels (not respecting NON_STANDARD_SIZE)
 
-    def load_constant(self, constant, constant_data):
-        if constant_data.shape[0] == self.pixels_x:
+    def _load_constant(self, constant, constant_data):
+        if constant_data.shape[0] == self.num_pixels_fs:
             constant_data = np.transpose(constant_data, (2, 1, 0, 3))
         else:
             constant_data = np.transpose(constant_data, (2, 0, 1, 3))
@@ -145,17 +170,14 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
         elif constant is Constants.RelativeGain10Hz:
             self.rel_gain_map[:] = self._xp.asarray(constant_data, dtype=np.float32)
         elif constant in bad_pixel_constants:
-            self.bad_pixel_map |= self._xp.asarray(constant_data, dtype=np.uint32)
+            self.bad_pixel_map |= self._xp.asarray(
+                constant_data, dtype=np.uint32
+            ) & self._xp.uint32(self.bad_pixel_subset)
         else:
             raise ValueError(f"Unexpected constant type {constant}")
 
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.processed_data, self.input_gain_stage)
-
-    @property
-    def burst_mode(self):
-        return self.frames > 1
+    def _preview_data_views(self, raw_data, gain_map, processed_data):
+        return (raw_data, gain_map, processed_data)
 
     def flush_buffers(self, constants):
         if Constants.Offset10Hz in constants:
@@ -167,106 +189,80 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
             Constants.BadPixelsFF10Hz,
         }:
             self.bad_pixel_map.fill(0)
-            self.bad_pixel_map[
-                :, :, 255:1023:256
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, :, 256:1024:256
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, [255, 256]
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+            if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE:
+                self._mask_asic_seams()
 
-    def override_bad_pixel_flags_to_use(self, override_value):
-        self.bad_pixel_map &= self._xp.uint32(override_value)
+    def _mask_asic_seams(self):
+        self.bad_pixel_map[
+            :, :, 255:1023:256
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[
+            :, :, 256:1024:256
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[
+            :, [255, 256]
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
 
 
 class JungfrauGpuRunner(JungfrauBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
         import cupy as cp
 
         self._xp = cp
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint8)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.input_shape, self.block_shape
-        )
-
         source_module = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("jungfrau_gpu.cu").render(
                 {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
                     "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
+                        self._output_dtype
                     ),
                     "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                    "burst_mode": self.burst_mode,
                 }
             )
         )
         self.correction_kernel = source_module.get_function("correct")
 
-    def load_data(self, image_data, input_gain_stage, cell_table):
-        self.input_data.set(image_data)
-        self.input_gain_stage.set(input_gain_stage)
-        if self.burst_mode:
-            self.cell_table.set(cell_table)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, gain_map, output):
+        num_frames = image_data.shape[0]
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            output.shape, block_shape
+        )
+        if flags & CorrectionFlags.STRIXEL:
+            # do "regular" correction into "regular" buffer, strixel transform to output
+            final_output = output
+            output = self._xp.empty_like(image_data, dtype=np.float32)
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.input_gain_stage,
-                self.cell_table,
+                image_data,
+                gain_map,
+                cell_table,
                 self._xp.uint8(flags),
+                self._xp.uint16(num_frames),
+                self._xp.uint16(self._constant_memory_cells),
+                self._xp.uint8(num_frames > 1),
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             ),
         )
         if flags & CorrectionFlags.STRIXEL:
-            if (
-                self._processed_data_strixel is None
-                or self.frames != self._processed_data_strixel.shape[0]
-            ):
-                self._processed_data_strixel = self._xp.empty(
-                    (self.frames,) + self._strixel_out_shape,
-                    dtype=self.output_dtype,
-                )
-            for pixel_frame, strixel_frame in zip(
-                self._processed_data_regular, self._processed_data_strixel
-            ):
+            for pixel_frame, strixel_frame in zip(output, final_output):
                 self.correction_kernel_strixel(
                     data=pixel_frame,
                     out=strixel_frame,
                 )
-            self.processed_data = self._processed_data_strixel
-        else:
-            self.processed_data = self._processed_data_regular
 
 
 class JungfrauCpuRunner(JungfrauBaseRunner):
-    _gpu_based = False
     _xp = np
 
-    def _post_init(self):
-        # not actually allocating, will just point to incoming data
-        self.input_data = None
-        self.input_gain_stage = None
-        self.input_cell_table = None
-
+    def _pre_init(self):
         # for computing previews faster
         self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
 
@@ -278,45 +274,39 @@ class JungfrauCpuRunner(JungfrauBaseRunner):
     def __del__(self):
         self.thread_pool.shutdown()
 
-    def load_data(self, image_data, input_gain_stage, cell_table):
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False)
-        self.input_cell_table = cell_table.astype(np.uint8, copy=False)
+    def _correct(self, flags, image_data, cell_table, gain_map, output):
+        if flags & CorrectionFlags.STRIXEL:
+            # do "regular" correction into "regular" buffer, strixel transform to output
+            final_output = output
+            output = self._xp.empty_like(image_data, dtype=np.float32)
 
-    def correct(self, flags):
-        if self.burst_mode:
+        # burst mode
+        if image_data.shape[0] > 1:
             self.correction_kernel_burst(
-                self.input_data,
-                self.input_gain_stage,
-                self.input_cell_table,
+                image_data,
+                gain_map,
+                cell_table,
                 flags,
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             )
         else:
             self.correction_kernel_single(
-                self.input_data,
-                self.input_gain_stage,
+                image_data,
+                gain_map,
                 flags,
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             )
 
         if flags & CorrectionFlags.STRIXEL:
-            if (
-                self._processed_data_strixel is None
-                or self.frames != self._processed_data_strixel.shape[0]
-            ):
-                self._processed_data_strixel = self._xp.empty(
-                    (self.frames,) + self._strixel_out_shape,
-                    dtype=self.output_dtype,
-                )
+            # now do strixel transformation into actual output buffer
             concurrent.futures.wait(
                 [
                     self.thread_pool.submit(
@@ -324,40 +314,9 @@ class JungfrauCpuRunner(JungfrauBaseRunner):
                         data=pixel_frame,
                         out=strixel_frame,
                     )
-                    for pixel_frame, strixel_frame in zip(
-                        self._processed_data_regular, self._processed_data_strixel
-                    )
+                    for pixel_frame, strixel_frame in zip(output, final_output)
                 ]
             )
-            self.processed_data = self._processed_data_strixel
-        else:
-            self.processed_data = self._processed_data_regular
-
-    def compute_previews(self, preview_index):
-        if preview_index < -4:
-            raise ValueError(f"No statistic with code {preview_index} defined")
-        elif preview_index >= self.frames:
-            raise ValueError(f"Memory cell index {preview_index} out of range")
-
-        if preview_index >= 0:
-
-            def fun(a):
-                return a[preview_index]
-
-        elif preview_index == -1:
-            # note: separate from next case because dtype not applicable here
-            fun = functools.partial(np.nanmax, axis=0)
-        elif preview_index in (-2, -3, -4):
-            fun = functools.partial(
-                {
-                    -2: np.nanmean,
-                    -3: np.nansum,
-                    -4: np.nanstd,
-                }[preview_index],
-                axis=0,
-                dtype=np.float32,
-            )
-        return self.thread_pool.map(fun, self.preview_data_views)
 
 
 class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend):
@@ -475,59 +434,40 @@ class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("JungfrauCorrection", deviceVersion)
 class JungfrauCorrection(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}),
-        ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDark10Hz,
-                Constants.BadPixelsFF10Hz,
-                None,
-            },
-        ),
-        ("strixel", CorrectionFlags.STRIXEL, set()),
-    )
+    _base_output_schema = schemas.jf_output_schema
+    _correction_steps = correction_steps
     _calcat_friend_class = JungfrauCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = [
-        "outputRaw",
-        "outputCorrected",
-        "outputGainMap",
-    ]
+
     _image_data_path = "data.adc"
     _cell_table_path = "data.memoryCell"
     _pulse_table_path = None
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=True,
+            valid_frame_selection_modes=(
+                FrameSelectionMode.FRAME,
+                FrameSelectionMode.CELL,
+            ),
+        ) for name in ("raw", "gain", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        JungfrauBaseRunner.add_schema(expected)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(JungfrauCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(512)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # JUNGFRAU data is small, can fit plenty of trains in here
@@ -535,6 +475,11 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             .key("outputShmemBufferSize")
             .setNewDefaultValue(2)
             .commit(),
+
+            OVERWRITE_ELEMENT(expected)
+            .key("useShmemHandles")
+            .setNewDefaultValue(cls._use_shmem_handles)
+            .commit(),
         )
         (
             # support both CPU and GPU kernels
@@ -553,53 +498,30 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, JungfrauCorrection._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            JungfrauCorrection._correction_steps,
-        )
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("corrections.strixel.enable")
-            .setNewDefaultValue(False)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("corrections.strixel.preview")
-            .setNewDefaultValue(False)
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("corrections.strixel.type")
-            .description(
-                "Which kind of strixel layout is used for this module? cols_A0123 is "
-                "the first strixel layout deployed at HED and rows_A1256 is the one "
-                "later deployed at SCS."
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
+        gain_data = np.asarray(data_hash.get("data.gain")).astype(np.uint8, copy=False)
+        # explicit np.asarray because types are special here
+        cell_table = np.asarray(
+            data_hash[self._cell_table_path]
+        ).astype(np.uint8, copy=False)
+        num_frames = cell_table.size
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(
+                num_frames
             )
-            .assignmentOptional()
-            .defaultValue("cols_A0123(HED-type)")
-            .options("cols_A0123(HED-type),rows_A1256(SCS-type)")
-            .reconfigurable()
-            .commit(),
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        JungfrauCalcatFriend.add_schema(expected)
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
+                gain_data.shape = expected_shape
 
-    @property
-    def input_data_shape(self):
         return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsY"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            num_frames,
+            image_data,  # not as ndarray as runner can do that
+            cell_table,
+            None,
+            gain_data,
         )
 
-    @property
-    def _warn_memory_cell_range(self):
-        # disable warning in normal operation as cell 15 is expected
-        return self.unsafe_get("dataFormat.frames") > 1
-
     @property
     def output_data_shape(self):
         if self._correction_flag_enabled & CorrectionFlags.STRIXEL:
@@ -623,134 +545,3 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             return JungfrauCpuRunner
         else:
             return JungfrauGpuRunner
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "bad_pixel_mask_value": self.bad_pixel_mask_value,
-            # temporary: will refactor base class to always pass config node
-            "config": self.get("corrections"),
-        }
-
-    @property
-    def bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection)
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue", default="nan"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        if config.get("useShmemHandles", default=True):
-            schema_override = Schema()
-            (
-                OUTPUT_CHANNEL(schema_override)
-                .key("dataOutput")
-                .dataSchema(schemas.jf_output_schema(use_shmem_handle=False))
-                .commit(),
-            )
-            self.updateSchema(schema_override)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        if len(cell_table.shape) == 0:
-            cell_table = cell_table[np.newaxis]
-        try:
-            gain_map = data_hash["data.gain"]
-            if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-                gain_map.shape = self.input_data_shape
-            self.kernel_runner.load_data(image_data, gain_map, cell_table)
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                _pretend_pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_corrected,
-            preview_gain_map,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        # reusing input data hash for sending
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            # TODO: use shmem for data.gain, too
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_gain_map
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        if constant in bad_pixel_constants:
-            constant_data &= self._override_bad_pixel_flags
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def preReconfigure(self, config):
-        super().preReconfigure(config)
-        if config.has("corrections"):
-            self.kernel_runner.reconfigure(config["corrections"])
-
-    def postReconfigure(self):
-        super().postReconfigure()
-
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            return
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("corrections.strixel.enable"):
-            self._lock_and_update(self._update_frame_filter)
-
-        if update.has("corrections.badPixels.subsetToUse"):
-            self.log_status_info("Updating bad pixel maps based on subset specified")
-            # note: now just always reloading from cache for convenience
-            with self.calcat_friend.cached_constants_lock:
-                self.kernel_runner.flush_buffers(bad_pixel_constants)
-                for constant in bad_pixel_constants:
-                    if constant in self.calcat_friend.cached_constants:
-                        self._load_constant_to_runner(
-                            constant, self.calcat_friend.cached_constants[constant]
-                        )
diff --git a/src/calng/corrections/LpdCorrection.py b/src/calng/corrections/LpdCorrection.py
index 98e81e1e5e97bcc19f0c84dfe9c0f65635815282..325f5b2193de7abe0190b99b8ee26ced09358283 100644
--- a/src/calng/corrections/LpdCorrection.py
+++ b/src/calng/corrections/LpdCorrection.py
@@ -8,7 +8,6 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -18,6 +17,7 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
@@ -30,6 +30,12 @@ class Constants(enum.Enum):
     BadPixelsFF = enum.auto()
 
 
+bad_pixel_constants = {
+    Constants.BadPixelsDark,
+    Constants.BadPixelsFF,
+}
+
+
 class CorrectionFlags(enum.IntFlag):
     NONE = 0
     OFFSET = 1
@@ -39,130 +45,85 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 16
 
 
-class LpdGpuRunner(base_kernel_runner.BaseKernelRunner):
-    _gpu_based = True
-    _corrected_axis_order = "fyx"
-
-    @property
-    def input_shape(self):
-        return (self.frames, 1, self.pixels_y, self.pixels_x)
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
+    ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}),
+    ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}),
+    ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}),
+    ("badPixels", CorrectionFlags.BPMASK, bad_pixel_constants),
+)
 
-    @property
-    def processed_shape(self):
-        return (self.frames, self.pixels_y, self.pixels_x)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-    ):
-        import cupy as cp
-        self._xp = cp
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.gain_map = cp.empty(self.processed_shape, dtype=np.float32)
-        self.input_data = cp.empty(self.input_shape, dtype=np.uint16)
-        self.processed_data = cp.empty(self.processed_shape, dtype=output_data_dtype)
-        self.cell_table = cp.empty(frames, dtype=np.uint16)
-
-        self.map_shape = (constant_memory_cells, pixels_y, pixels_x, 3)
-        self.offset_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.gain_amp_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_slopes_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.flatfield_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = cp.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-        self.flush_buffers(set(Constants))
 
-        self.correction_kernel = cp.RawModule(
-            code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render(
-                {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
-                    "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
-                    ),
-                    "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                }
-            )
-        ).get_function("correct")
+class LpdBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 256
+    num_pixels_fs = 256
 
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.processed_shape, self.block_shape
-        )
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
 
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
+    def expected_input_shape(self, num_frames):
+        return (num_frames, 1, self.num_pixels_ss, self.num_pixels_fs)
 
     @property
-    def preview_data_views(self):
+    def _gm_map_shape(self):
+        return self._map_shape + (3,)  # for gain-mapped constants
+
+    def _setup_constant_buffers(self):
+        self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.gain_amp_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.rel_gain_slopes_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.flatfield_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32)
+        self.flush_buffers(set(Constants))
+
+    def _make_output_buffers(self, num_frames, flags):
+        output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+        return [
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # image
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # gain
+        ]
+
+    def _preview_data_views(self, raw_data, processed_data, gain_map):
         # TODO: always split off gain from raw to avoid messing up preview?
         return (
-            self.input_data[:, 0],  # raw
-            self.processed_data,  # corrected
-            self.gain_map,  # gain (split from raw)
+            raw_data[:, 0],  # raw
+            processed_data,  # corrected
+            gain_map,  # gain (split from raw)
         )
 
-    def correct(self, flags):
-        self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
-            (
-                self.input_data,
-                self.cell_table,
-                self._xp.uint8(flags),
-                self.offset_map,
-                self.gain_amp_map,
-                self.rel_gain_slopes_map,
-                self.flatfield_map,
-                self.bad_pixel_map,
-                self.bad_pixel_mask_value,
-                self.gain_map,
-                self.processed_data,
-            ),
-        )
-
-    def load_constant(self, constant_type, constant_data):
-        # constant type → transpose order
-        bad_pixel_loading = {
-            Constants.BadPixelsDark: (2, 1, 0, 3),
-            Constants.BadPixelsFF: (2, 0, 1, 3),
-        }
+    def _load_constant(self, constant_type, constant_data):
         # constant type → transpose order, GPU buffer
-        other_constant_loading = {
+        transpose_order, my_buffer = {
+            Constants.BadPixelsDark: ((2, 1, 0, 3), self.bad_pixel_map),
+            Constants.BadPixelsFF: ((2, 0, 1, 3), self.bad_pixel_map),
             Constants.Offset: ((2, 1, 0, 3), self.offset_map),
             Constants.GainAmpMap: ((2, 0, 1, 3), self.gain_amp_map),
             Constants.FFMap: ((2, 0, 1, 3), self.flatfield_map),
             Constants.RelativeGain: ((2, 0, 1, 3), self.rel_gain_slopes_map),
-        }
-        if constant_type in bad_pixel_loading:
-            self.bad_pixel_map |= self._xp.asarray(
-                constant_data.transpose(bad_pixel_loading[constant_type]),
-                dtype=np.uint32,
-            )[:self.constant_memory_cells]
-        elif constant_type in other_constant_loading:
-            transpose_order, gpu_buffer = other_constant_loading[constant_type]
-            gpu_buffer.set(
-                np.transpose(
-                    constant_data.astype(np.float32),
-                    transpose_order,
-                )[:self.constant_memory_cells]
+        }[constant_type]
+
+        if constant_type in bad_pixel_constants:
+            my_buffer |= (
+                self._xp.asarray(
+                    constant_data.transpose(transpose_order),
+                    dtype=np.uint32,
+                )[:self._constant_memory_cells]
+                & self._xp.uint32(self.bad_pixel_subset)
             )
         else:
-            raise ValueError(f"Unhandled constant type {constant_type}")
+            my_buffer[:] = (
+                self._xp.asarray(
+                    self._xp.transpose(
+                        constant_data.astype(np.float32),
+                        transpose_order,
+                    )[:self._constant_memory_cells]
+                )
+            )
 
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
@@ -173,10 +134,79 @@ class LpdGpuRunner(base_kernel_runner.BaseKernelRunner):
             self.rel_gain_slopes_map.fill(1)
         if Constants.FFMap in constants:
             self.flatfield_map.fill(1)
-        if constants & {Constants.BadPixelsDark, Constants.BadPixelsFF}:
+        if constants & bad_pixel_constants:
             self.bad_pixel_map.fill(0)
 
 
+class LpdGpuRunner(LpdBaseRunner):
+    def _pre_init(self):
+        import cupy
+        self._xp = cupy
+
+    def _post_init(self):
+        self.correction_kernel = self._xp.RawModule(
+            code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render(
+                {
+                    "ss_dim": self.num_pixels_ss,
+                    "fs_dim": self.num_pixels_fs,
+                    "output_data_dtype": utils.np_dtype_to_c_type(
+                        self._output_dtype
+                    ),
+                    "corr_enum": utils.enum_to_c_template(self._correction_flag_class),
+                }
+            )
+        ).get_function("correct")
+
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            processed_data.shape, block_shape
+        )
+        self.correction_kernel(
+            grid_shape,
+            block_shape,
+            (
+                image_data,
+                cell_table,
+                flags,
+                num_frames,
+                self._constant_memory_cells,
+                self.offset_map,
+                self.gain_amp_map,
+                self.rel_gain_slopes_map,
+                self.flatfield_map,
+                self.bad_pixel_map,
+                self.bad_pixel_mask_value,
+                gain_map,
+                processed_data,
+            ),
+        )
+
+
+class LpdCpuRunner(LpdBaseRunner):
+    _xp = np
+
+    def _post_init(self):
+        from ..kernels import lpd_cython
+        self.correction_kernel = lpd_cython.correct
+
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        self.correction_kernel(
+            image_data,
+            cell_table,
+            flags,
+            self.offset_map,
+            self.gain_amp_map,
+            self.rel_gain_slopes_map,
+            self.flatfield_map,
+            self.bad_pixel_map,
+            self.bad_pixel_mask_value,
+            gain_map,
+            processed_data,
+        )
+
+
 class LpdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -294,216 +324,25 @@ class LpdCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("LpdCorrection", deviceVersion)
 class LpdCorrection(base_correction.BaseCorrection):
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
-        ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}),
-        ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}),
-        ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDark,
-                Constants.BadPixelsFF,
-            }
-        ),
-    )
+    _base_output_schema = schemas.xtdf_output_schema
+    _correction_steps = correction_steps
     _kernel_runner_class = LpdGpuRunner
     _calcat_friend_class = LpdCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected", "outputGainMap"]
 
-    @staticmethod
-    def expectedParameters(expected):
-        (
-            OUTPUT_CHANNEL(expected)
-            .key("dataOutput")
-            .dataSchema(LpdCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(256)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(256)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(512)
-            .commit(),
-
-            # TODO: determine ideal default
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
-            .commit(),
-        )
+    _preview_outputs = [PreviewSpec(name) for name in ("raw", "corrected", "gainMap")]
 
-        # TODO: add bad pixel config node
+    @classmethod
+    def expectedParameters(cls, expected):
         LpdCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, LpdCorrection)
-        base_correction.add_correction_step_schema(
-            expected, LpdCorrection._correction_steps
-        )
-        base_correction.add_preview_outputs(expected, LpdCorrection._preview_outputs)
-
-        # additional settings for correction steps
+        LpdBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
-            STRING_ELEMENT(expected)
-            .key("corrections.badPixels.maskingValue")
-            .tags("managed")
-            .displayedName("Bad pixel masking value")
-            .description(
-                "Any pixels masked by the bad pixel mask will have their value "
-                "replaced with this. Note that this parameter is to be interpreted as "
-                "a numpy.float32; use 'nan' to get NaN value."
+            OUTPUT_CHANNEL(expected)
+            .key("dataOutput")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
             )
-            .assignmentOptional()
-            .defaultValue("nan")
-            .reconfigurable()
             .commit(),
         )
-
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            1,
-            self.unsafe_get("dataFormat.pixelsY"),
-            self.unsafe_get("dataFormat.pixelsX"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        if config.get("useShmemHandles", default=False):
-            def aux():
-                schema_override = Schema()
-                (
-                    OUTPUT_CHANNEL(schema_override)
-                    .key("dataOutput")
-                    .dataSchema(schemas.xtdf_output_schema(use_shmem_handle=False))
-                    .commit(),
-                )
-
-            self.registerInitialFunction(aux)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        with self.warning_context(
-            "processingState",
-            base_correction.WarningLampType.CONSTANT_OPERATING_PARAMETERS,
-        ) as warn:
-            if (
-                cell_table_string := utils.cell_table_to_string(cell_table)
-            ) != self.unsafe_get(
-                "constantParameters.memoryCellOrder"
-            ) and self.unsafe_get(
-                "constantParameters.useMemoryCellOrder"
-            ):
-                warn(f"Cell order does not match input; input: {cell_table_string}")
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            # drop the singleton dimension for kernel runner
-            self.kernel_runner.load_data(image_data, cell_table)
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-            (
-                preview_raw,
-                preview_corrected,
-                preview_gain_map,
-            ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_gain_map
-        )
-
-    def preReconfigure(self, config):
-        # TODO: DRY (taken from AGIPD device)
-        super().preReconfigure(config)
-        if config.has("corrections.badPixels.maskingValue"):
-            # only check if it is valid (let raise exception)
-            # if valid, postReconfigure will use it
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-
-    def postReconfigure(self):
-        super().postReconfigure()
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            return
-
-        update = self._prereconfigure_update_hash
-        if update.has("corrections.badPixels.maskingValue"):
-            self.kernel_runner.bad_pixel_mask_value = self.bad_pixel_mask_value
diff --git a/src/calng/corrections/LpdminiCorrection.py b/src/calng/corrections/LpdminiCorrection.py
index f43650ad78577b9b9c0be8d85ae62116017f7302..4a5d5c24582427c429b923a73bf69bdf373a1998 100644
--- a/src/calng/corrections/LpdminiCorrection.py
+++ b/src/calng/corrections/LpdminiCorrection.py
@@ -6,13 +6,13 @@ from karabo.bound import (
 )
 
 from .._version import version as deviceVersion
-from ..base_correction import add_correction_step_schema
 from . import LpdCorrection
 
 
 class LpdminiGpuRunner(LpdCorrection.LpdGpuRunner):
+    num_pixels_ss = 32
+
     def load_constant(self, constant_type, constant_data):
-        print(f"Given: {constant_type} with shape {constant_data.shape}")
         # constant type → transpose order
         constant_buffer_map = {
             LpdCorrection.Constants.Offset: self.offset_map,
@@ -74,20 +74,3 @@ class LpdminiCalcatFriend(LpdCorrection.LpdCalcatFriend):
 class LpdminiCorrection(LpdCorrection.LpdCorrection):
     _calcat_friend_class = LpdminiCalcatFriend
     _kernel_runner_class = LpdminiGpuRunner
-
-    @classmethod
-    def expectedParameters(cls, expected):
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(32)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(512)
-            .commit(),
-        )
-        cls._calcat_friend_class.add_schema(expected)
-        # warning: this is redundant, but needed for now to get managed keys working
-        add_correction_step_schema(expected, cls._correction_steps)
diff --git a/src/calng/corrections/PnccdCorrection.py b/src/calng/corrections/PnccdCorrection.py
index 0b941e2ec9fd244ade05ad2a7cc90963a477e82a..0156e2001508255245b8ca66490f07f01b759ec0 100644
--- a/src/calng/corrections/PnccdCorrection.py
+++ b/src/calng/corrections/PnccdCorrection.py
@@ -1,6 +1,5 @@
 import concurrent.futures
 import enum
-import functools
 
 import numpy as np
 from karabo.bound import (
@@ -9,7 +8,6 @@ from karabo.bound import (
     KARABO_CLASSINFO,
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -19,6 +17,7 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 from ..kernels import common_cython
 
@@ -39,6 +38,14 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 2**4
 
 
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}),
+    ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}),
+    ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}),
+    ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}),
+)
+
+
 class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -51,30 +58,30 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
             Constants.RelativeGainCCD: self.illuminated_condition,
         }
 
-    @staticmethod
-    def add_schema(schema):
-        super(PnccdCalcatFriend, PnccdCalcatFriend).add_schema(schema, "pnCCD-Type")
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema, "pnCCD-Type")
 
         # set some defaults for common parameters
         (
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.memoryCells")
-            .setNewDefaultValue(1)
+            .key("constantParameters.pixelsX")
+            .setNewDefaultValue(1024)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.biasVoltage")
-            .setNewDefaultValue(300)
+            .key("constantParameters.pixelsY")
+            .setNewDefaultValue(1024)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.pixelsX")
-            .setNewDefaultValue(1024)
+            .key("constantParameters.memoryCells")
+            .setNewDefaultValue(1)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.pixelsY")
-            .setNewDefaultValue(1024)
+            .key("constantParameters.biasVoltage")
+            .setNewDefaultValue(300)
             .commit(),
         )
 
@@ -148,82 +155,104 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
 
 
 class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "xy"
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _bad_pixel_constants = {Constants.BadPixelsDarkCCD}
     _xp = np
-    _gpu_based = False
 
-    @property
-    def input_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    num_pixels_ss = 1024
+    num_pixels_fs = 1024
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.noiseSigma")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .assignmentOptional()
+            .defaultValue(0.25)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableRow")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableCol")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
+
+    def reconfigure(self, config):
+        super().reconfigure(config)
+        if config.has("corrections.commonMode.noiseSigma"):
+            self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"])
+        if config.has("corrections.commonMode.minFrac"):
+            self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"])
+        if config.has("corrections.commonMode.enableRow"):
+            self._cm_row = config["corrections.commonMode.enableRow"]
+        if config.has("corrections.commonMode.enableCol"):
+            self._cm_col = config["corrections.commonMode.enableCol"]
+
+    def expected_input_shape(self, num_frames):
+        assert num_frames == 1, "pnCCD not expected to have multiple frames"
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _expected_output_shape(self, num_frames):
+        return (self.num_pixels_ss, self.num_pixels_fs)
 
     @property
-    def map_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    def _map_shape(self):
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        # ignore parameters
+        return [
+            self._xp.empty(
+                (self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,  # will be 1, will be ignored
-        constant_memory_cells,
-        input_data_dtype=np.uint16,
-        output_data_dtype=np.float32,
-    ):
+    def _post_init(self):
+        self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
+
+    def _setup_constant_buffers(self):
         assert (
-            output_data_dtype == np.float32
+            self._output_dtype == np.float32
         ), "Alternative output types not supported yet"
 
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-
-        self.input_data = None
-        self.processed_data = np.empty(self.processed_shape, dtype=np.float32)
-
-        self.offset_map = np.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.noise_map = np.empty(self.map_shape, dtype=np.float32)
+        self.offset_map = np.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32)
+        self.noise_map = np.empty(self._map_shape, dtype=np.float32)
         # will do everything by quadrant
-        self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
-        self._q_input_data = None
-        self._q_processed_data = utils.quadrant_views(self.processed_data)
         self._q_offset_map = utils.quadrant_views(self.offset_map)
         self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map)
         self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map)
         self._q_noise_map = utils.quadrant_views(self.noise_map)
 
-    def __del__(self):
-        self.thread_pool.shutdown()
-
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.processed_data)
-
-    def load_data(self, image_data):
-        # should almost never squeeze, but ePix doesn't do burst mode, right?
-        self.input_data = image_data.astype(np.uint16, copy=False).squeeze()
-        self._q_input_data = utils.quadrant_views(self.input_data)
-
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.OffsetCCD:
             self.offset_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.RelativeGainCCD:
             self.rel_gain_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.BadPixelsDarkCCD:
-            self.bad_pixel_map[:] = data.squeeze()
+            self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset)
         elif constant_type is Constants.NoiseCCD:
             self.noise_map[:] = data.squeeze()
         else:
@@ -241,123 +270,116 @@ class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner):
 
     def _correct_quadrant(
         self,
-        q,
         flags,
-        bad_pixel_mask_value,
-        cm_noise_sigma,
-        cm_min_frac,
-        cm_row,
-        cm_col,
+        input_data,
+        offset_map,
+        noise_map,
+        bad_pixel_map,
+        rel_gain_map,
+        output,
     ):
-        output = self._q_processed_data[q]
-        output[:] = self._q_input_data[q].astype(np.float32)
+        output[:] = input_data.astype(np.float32)
         if flags & CorrectionFlags.OFFSET:
-            output -= self._q_offset_map[q]
+            output -= offset_map
 
         if flags & CorrectionFlags.COMMONMODE:
             cm_mask = (
-                (self._q_bad_pixel_map[q] != 0)
-                | (output > self._q_noise_map[q] * cm_noise_sigma)
+                (bad_pixel_map != 0)
+                | (output > noise_map * self._cm_sigma)
             ).astype(np.uint8, copy=False)
-            if cm_row:
+            if self._cm_row:
                 common_cython.cm_fs(
                     output,
                     cm_mask,
-                    cm_noise_sigma,
-                    cm_min_frac,
+                    self._cm_sigma,
+                    self._cm_minfrac,
                 )
 
-            if cm_col:
+            if self._cm_col:
                 common_cython.cm_ss(
                     output,
                     cm_mask,
-                    cm_noise_sigma,
-                    cm_min_frac,
+                    self._cm_sigma,
+                    self._cm_minfrac,
                 )
 
         if flags & CorrectionFlags.RELGAIN:
-            output *= self._q_rel_gain_map[q]
+            output *= rel_gain_map
 
         if flags & CorrectionFlags.BPMASK:
-            output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value
+            output[bad_pixel_map != 0] = self.bad_pixel_mask_value
+
+    def _correct(self, flags, image_data, cell_table, output):
+        # cell_table is going to be None for now, just ignore
+        concurrent.futures.wait(
+            [
+                self.thread_pool.submit(
+                    self._correct_quadrant(flags, *parts)
+                )
+                for parts in zip(
+                        utils.quadrant_views(image_data),
+                        self._q_offset_map,
+                        self._q_noise_map,
+                        self._q_bad_pixel_map,
+                        self._q_rel_gain_map,
+                        utils.quadrant_views(output),
+                )
+            ]
+        )
 
-    def correct(
-        self,
-        flags,
-        bad_pixel_mask_value=np.nan,
-        cm_noise_sigma=5,
-        cm_min_frac=0.25,
-        cm_row=True,
-        cm_col=True,
-    ):
-        # NOTE: how to best clean up all these duplicated parameters?
-        for result in self.thread_pool.map(
-            functools.partial(
-                self._correct_quadrant,
-                flags=flags,
-                bad_pixel_mask_value=bad_pixel_mask_value,
-                cm_noise_sigma=cm_noise_sigma,
-                cm_min_frac=cm_min_frac,
-                cm_row=cm_row,
-                cm_col=cm_col,
-            ),
-            range(4),
-        ):
-            pass  # just run through to await map
+    def __del__(self):
+        self.thread_pool.shutdown()
+
+    def _preview_data_views(self, raw_data, processed_data):
+        return [
+            raw_data[np.newaxis],
+            processed_data[np.newaxis],
+        ]
 
 
 @KARABO_CLASSINFO("PnccdCorrection", deviceVersion)
 class PnccdCorrection(base_correction.BaseCorrection):
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}),
-        ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}),
-        ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}),
-        ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}),
-    )
-    _image_data_path = "data.image"
+    _base_output_schema = schemas.pnccd_output_schema
     _kernel_runner_class = PnccdCpuRunner
     _calcat_friend_class = PnccdCalcatFriend
+    _correction_steps = correction_steps
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
+
+    _image_data_path = "data.image"
     _cell_table_path = None
     _pulse_table_path = None
-    _warn_memory_cell_range = False
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=False,
+        )
+        for name in ("raw", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _cuda_pin_buffers = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False))
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewOptions("xy,yx")
-            .setNewDefaultValue("xy")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # TODO: disable preview selection mode
 
             OVERWRITE_ELEMENT(expected)
             .key("useShmemHandles")
-            .setNewDefaultValue(False)
+            .setNewDefaultValue(cls._use_shmem_handles)
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
@@ -366,113 +388,11 @@ class PnccdCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, PnccdCorrection._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            PnccdCorrection._correction_steps,
-        )
-        (
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.noiseSigma")
-            .assignmentOptional()
-            .defaultValue(5)
-            .reconfigurable()
-            .commit(),
-
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.minFrac")
-            .assignmentOptional()
-            .defaultValue(0.25)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableRow")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableCol")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-        PnccdCalcatFriend.add_schema(expected)
-        # TODO: bad pixel node?
-
-    @property
-    def input_data_shape(self):
-        # TODO: check
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
         return (
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        if config.get("useShmemHandles", default=False):
-            def aux():
-                schema_override = Schema()
-                (
-                    OUTPUT_CHANNEL(schema_override)
-                    .key("dataOutput")
-                    .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False))
-                    .commit(),
-                )
-                self.updateSchema(schema_override)
-
-            self.registerInitialFunction(aux)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,  # will be None
-        pulse_table,  # ditto
-    ):
-        self.kernel_runner.load_data(image_data)
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        args_which_should_be_cached = dict(
-            cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"),
-            cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"),
-            cm_row=self.unsafe_get("corrections.commonMode.enableRow"),
-            cm_col=self.unsafe_get("corrections.commonMode.enableCol"),
+            1,
+            image_data,
+            None,
+            None,
         )
-        self.kernel_runner.correct(
-            flags=self._correction_flag_enabled, **args_which_should_be_cached
-        )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(
-                flags=self._correction_flag_preview,
-                **args_which_should_be_cached,
-            )
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        # note: base class preview machinery assumes burst mode, shortcut it
-        self._preview_friend.write_outputs(
-            metadata, *self.kernel_runner.preview_data_views
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/kernels/agipd_gpu.cu b/src/calng/kernels/agipd_gpu.cu
index ac5fc7435f82ae8889a9e6389352b884b5364f58..f638af17e5e8b23463effae43bc89887aeb4096b 100644
--- a/src/calng/kernels/agipd_gpu.cu
+++ b/src/calng/kernels/agipd_gpu.cu
@@ -13,6 +13,8 @@ extern "C" {
 	__global__ void correct(const unsigned short* data,
 	                        const unsigned short* cell_table,
 	                        const unsigned char corr_flags,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_cells,
 	                        // default_gain can be 0, 1, or 2, and is relevant for fixed gain mode (no THRESHOLD)
 	                        const unsigned char default_gain,
 	                        const float* threshold_map,
@@ -27,71 +29,69 @@ extern "C" {
 	                        const float bad_pixel_mask_value,
 	                        float* gain_map, // TODO: more compact yet plottable representation
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 512;
+		const size_t fs_dim = 128;
 
-		const size_t cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t x = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t y = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t frame = blockIdx.x * blockDim.x + threadIdx.x;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (cell >= input_frames || y >= Y || x >= X) {
+		if (frame >= input_frames || fs >= fs_dim || ss >= ss_dim) {
 			return;
 		}
 
-		// data shape: memory cell, data/raw_gain (dim size 2), x, y
-		const size_t data_stride_y = 1;
-		const size_t data_stride_x = Y * data_stride_y;
-		const size_t data_stride_raw_gain = X * data_stride_x;
-		const size_t data_stride_cell = 2 * data_stride_raw_gain;
-		const size_t data_index = cell * data_stride_cell +
+		// data shape: frame, data/raw_gain (dim size 2), ss, fs
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_raw_gain = ss_dim * data_stride_ss;
+		const size_t data_stride_frame = 2 * data_stride_raw_gain;
+		const size_t data_index = frame * data_stride_frame +
 			0 * data_stride_raw_gain +
-			y * data_stride_y +
-			x * data_stride_x;
-		const size_t raw_gain_index = cell * data_stride_cell +
+			fs * data_stride_fs +
+			ss * data_stride_ss;
+		const size_t raw_gain_index = frame * data_stride_frame +
 			1 * data_stride_raw_gain +
-			y * data_stride_y +
-			x * data_stride_x;
+			fs * data_stride_fs +
+			ss * data_stride_ss;
 		float res = (float)data[data_index];
 		const float raw_gain_val = (float)data[raw_gain_index];
 
-		const size_t output_stride_y = 1;
-		const size_t output_stride_x = output_stride_y * Y;
-		const size_t output_stride_cell = output_stride_x * X;
-		const size_t output_index = cell * output_stride_cell + x * output_stride_x + y * output_stride_y;
+		const size_t output_stride_fs = 1;
+		const size_t output_stride_ss = output_stride_fs * fs_dim;
+		const size_t output_stride_frame = output_stride_ss * ss_dim;
+		const size_t output_index = frame * output_stride_frame + ss * output_stride_ss + fs * output_stride_fs;
 
-		// per-pixel only constant: cell, x, y
-		const size_t map_stride_y = 1;
-		const size_t map_stride_x = Y * map_stride_y;
-		const size_t map_stride_cell = X * map_stride_x;
+		// per-pixel only constant: cell, ss, fs
+		const size_t map_stride_fs = 1;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 
-		// threshold constant shape: cell, x, y, threshold (dim size 2)
+		// threshold constant shape: cell, ss, fs, threshold (dim size 2)
 		const size_t threshold_map_stride_threshold = 1;
-		const size_t threshold_map_stride_y = 2 * threshold_map_stride_threshold;
-		const size_t threshold_map_stride_x = Y * threshold_map_stride_y;
-		const size_t threshold_map_stride_cell = X * threshold_map_stride_x;
+		const size_t threshold_map_stride_fs = 2 * threshold_map_stride_threshold;
+		const size_t threshold_map_stride_ss = fs_dim * threshold_map_stride_fs;
+		const size_t threshold_map_stride_cell = ss_dim * threshold_map_stride_ss;
 
-		// gain mapped constant shape: cell, x, y, gain_level (dim size 3)
+		// gain mapped constant shape: cell, ss, fs, gain_level (dim size 3)
 		const size_t gm_map_stride_gain = 1;
-		const size_t gm_map_stride_y = 3 * gm_map_stride_gain;
-		const size_t gm_map_stride_x = Y * gm_map_stride_y;
-		const size_t gm_map_stride_cell = X * gm_map_stride_x;
-		// note: assuming all maps have same shape (in terms of cells / x / y)
+		const size_t gm_map_stride_fs = 3 * gm_map_stride_gain;
+		const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs;
+		const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss;
+		// note: assuming all maps have same shape (in terms of cells / ss / fs)
 
-		const size_t map_cell = cell_table[cell];
+		const size_t map_cell = cell_table[frame];
 
 		if (map_cell < map_cells) {
 			unsigned char gain = default_gain;
 			if (corr_flags & THRESHOLD) {
 				const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold +
 				                                        map_cell * threshold_map_stride_cell +
-				                                        y * threshold_map_stride_y +
-				                                        x * threshold_map_stride_x];
+				                                        fs * threshold_map_stride_fs +
+				                                        ss * threshold_map_stride_ss];
 				const float threshold_1 = threshold_map[1 * threshold_map_stride_threshold +
 				                                        map_cell * threshold_map_stride_cell +
-				                                        y * threshold_map_stride_y +
-				                                        x * threshold_map_stride_x];
+				                                        fs * threshold_map_stride_fs +
+				                                        ss * threshold_map_stride_ss];
 				// could consider making this const using ternaries / tiny function
 				if (raw_gain_val <= threshold_0) {
 					gain = 0;
@@ -103,8 +103,8 @@ extern "C" {
 			}
 
 			const size_t gm_map_index_without_gain = map_cell * gm_map_stride_cell +
-				y * gm_map_stride_y +
-				x * gm_map_stride_x;
+				fs * gm_map_stride_fs +
+				ss * gm_map_stride_ss;
 			if ((corr_flags & FORCE_MG_IF_BELOW) && (gain == 2) &&
 			    (res - offset_map[gm_map_index_without_gain + 1 * gm_map_stride_gain] < mg_hard_threshold)) {
 				gain = 1;
@@ -116,8 +116,8 @@ extern "C" {
 			gain_map[output_index] = (float)gain;
 
 			const size_t map_index = map_cell * map_stride_cell +
-				y * map_stride_y +
-				x * map_stride_x;
+				fs * map_stride_fs +
+				ss * map_stride_ss;
 
 			const size_t gm_map_index = gm_map_index_without_gain + gain * gm_map_stride_gain;
 
diff --git a/src/calng/kernels/common_gpu.cu b/src/calng/kernels/common_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..631ab7dd5d1cb966ee736bc39c8cae57ef6fd443
--- /dev/null
+++ b/src/calng/kernels/common_gpu.cu
@@ -0,0 +1,60 @@
+// https://excalidraw.com/#json=mgBWynet5WUbpgfSd0pzC,TMhzt6L-x5yShNCHRGKLvg
+extern "C" __global__ void common_mode_asic(
+    float* const data,
+    const unsigned short n_frames,
+    const unsigned short num_iter,
+    const float min_dark_fraction,
+    const float noise_peak_range
+) {
+    const int frame = blockIdx.z;
+    const int ss_asic = blockIdx.y;
+    const int fs_asic = blockIdx.x;
+    const int thread_in_asic = threadIdx.x;
+    assert(blockDim.x == {{asic_dim}});
+
+    __shared__ float shared_sum[{{asic_dim}}];
+    __shared__ unsigned int shared_count[{{asic_dim}}];
+
+    const unsigned int min_dark_pixels = static_cast<int>(min_dark_fraction * static_cast<float>({{asic_dim}} * {{asic_dim}}));
+    float* const asic_start = data + frame * {{ss_dim}} * {{fs_dim}} + ss_asic * {{asic_dim}} * {{fs_dim}} + fs_asic * {{asic_dim}}; 
+
+    for (int iter=0; iter<num_iter; ++iter) {
+        // accumulate (over rows) to shared
+        unsigned int my_count = 0;
+        float my_sum = 0;
+        for (int row=0; row<{{asic_dim}}; ++row) {
+            const float pixel = asic_start[row * {{fs_dim}} + thread_in_asic];
+            if (!isfinite(pixel) || fabs(pixel) > noise_peak_range) {
+                continue;
+            }
+            my_sum += pixel;
+            ++my_count;
+        }
+        shared_sum[thread_in_asic] = my_sum;
+        shared_count[thread_in_asic] = my_count;
+        __syncthreads();
+
+        // reduce (per ASIC) in shared
+        int active_threads = {{asic_dim}};
+        while (active_threads > 1) {
+            active_threads /= 2;
+            const int other = thread_in_asic + active_threads;
+            if (thread_in_asic < active_threads) {
+                shared_sum[thread_in_asic] += shared_sum[other];
+                shared_count[thread_in_asic] += shared_count[other];
+            }
+            __syncthreads();
+        }
+
+        // now, each ASIC result should be in first index of shared by ASIC
+        const unsigned int dark_count = shared_count[0];
+        if (dark_count < min_dark_pixels) {
+            return;
+        }
+        const float dark_sum = shared_sum[0];
+        const float dark_mean = dark_sum / static_cast<float>(dark_count);
+        for (int row=0; row<{{asic_dim}}; ++row) {
+	        asic_start[row * {{fs_dim}} + thread_in_asic] -= dark_mean;
+        }
+    }
+}
diff --git a/src/calng/kernels/dssc_cpu.pyx b/src/calng/kernels/dssc_cpu.pyx
index 04803d3f0a28a0a36721504c5e97177f609639b8..613ef30755c3b09dae5cac23cc6435e1a5e0b82d 100644
--- a/src/calng/kernels/dssc_cpu.pyx
+++ b/src/calng/kernels/dssc_cpu.pyx
@@ -8,7 +8,7 @@ cdef unsigned char OFFSET = 1
 from cython.parallel import prange
 
 def correct(
-    unsigned short[:, :, :] image_data,
+    unsigned short[:, :, :, :] image_data,
     unsigned short[:] cell_table,
     unsigned char flags,
     float[:, :, :] offset_map,
@@ -19,13 +19,13 @@ def correct(
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
         if map_cell >= offset_map.shape[0]:
-            for x in range(image_data.shape[1]):
-                for y in range(image_data.shape[2]):
-                    output[frame, x, y] = <float>image_data[frame, x, y]
+            for x in range(image_data.shape[2]):
+                for y in range(image_data.shape[3]):
+                    output[frame, x, y] = <float>image_data[frame, 0, x, y]
             continue
-        for x in range(image_data.shape[1]):
-            for y in range(image_data.shape[2]):
-                res = image_data[frame, x, y]
+        for x in range(image_data.shape[2]):
+            for y in range(image_data.shape[3]):
+                res = image_data[frame, 0, x, y]
                 if flags & OFFSET:
                     res = res - offset_map[map_cell, x, y]
                 output[frame, x, y] = res
diff --git a/src/calng/kernels/dssc_gpu.cu b/src/calng/kernels/dssc_gpu.cu
index 4e21b02b4ade95779e1a31754c0ba8a54be280ae..6ebe7b9ae5f04941a175dc05b4266eabf36220ad 100644
--- a/src/calng/kernels/dssc_gpu.cu
+++ b/src/calng/kernels/dssc_gpu.cu
@@ -8,41 +8,41 @@ extern "C" {
 	  Take cell_table into account when getting correction values
 	  Converting to float while correcting
 	  Converting to output dtype at the end
-	  Shape of input data: memory cell, 1, y, x
-	  Shape of offset constant: x, y, memory cell
+	  Shape of input data: memory cell, 1, ss, fs
+	  Shape of offset constant: fs, ss, memory cell
 	*/
-	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x
+	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs
 	                        const unsigned short* cell_table,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
 	                        const unsigned char corr_flags,
 	                        const float* offset_map,
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 128;
+		const size_t fs_dim = 512;
 
 		const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (memory_cell >= input_frames || y >= Y || x >= X) {
+		if (memory_cell >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
 		// note: strides differ from numpy strides because unit here is sizeof(...), not byte
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_cell = Y * data_stride_y;
-		const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_cell = ss_dim * data_stride_ss;
+		const size_t data_index = memory_cell * data_stride_cell + ss * data_stride_ss + fs * data_stride_fs;
 		float res = (float)data[data_index];
 
-		const size_t map_stride_x = 1;
-		const size_t map_stride_y = X * map_stride_x;
-		const size_t map_stride_cell = Y * map_stride_y;
+		const size_t map_stride_fs = 1;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 		const size_t map_cell = cell_table[memory_cell];
 
 		if (map_cell < map_memory_cells) {
-			const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x;
+			const size_t map_index = map_cell * map_stride_cell + ss * map_stride_ss + fs * map_stride_fs;
 			if (corr_flags & OFFSET) {
 				res -= offset_map[map_index];
 			}
diff --git a/src/calng/kernels/jungfrau_cpu.pyx b/src/calng/kernels/jungfrau_cpu.pyx
index 2a200063879007b54f92c762378863c0d3e71c22..5ff3148c779dc3395fb0472b444c0f08abad0157 100644
--- a/src/calng/kernels/jungfrau_cpu.pyx
+++ b/src/calng/kernels/jungfrau_cpu.pyx
@@ -24,20 +24,20 @@ def correct_burst(
     float badpixel_fill_value,
     float[:, :, :] output,
 ):
-    cdef int frame, map_cell, x, y
+    cdef int frame, map_cell, fs, ss
     cdef unsigned char gain
     cdef float corrected
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
         if map_cell >= offset_map.shape[0]:
-            for y in range(image_data.shape[1]):
-                for x in range(image_data.shape[2]):
-                    output[frame, y, x] = <float>image_data[frame, y, x]
+            for ss in range(image_data.shape[1]):
+                for fs in range(image_data.shape[2]):
+                    output[frame, ss, fs] = <float>image_data[frame, ss, fs]
             continue
-        for y in range(image_data.shape[1]):
-            for x in range(image_data.shape[2]):
-                corrected = image_data[frame, y, x]
-                gain = gain_stage[frame, y, x]
+        for ss in range(image_data.shape[1]):
+            for fs in range(image_data.shape[2]):
+                corrected = image_data[frame, ss, fs]
+                gain = gain_stage[frame, ss, fs]
                 # legal values: 0, 1, or 3
                 if gain == 2:
                     corrected = badpixel_fill_value
@@ -45,14 +45,14 @@ def correct_burst(
                     if gain == 3:
                         gain = 2
 
-                    if (flags & BPMASK) and badpixel_mask[map_cell, y, x, gain] != 0:
+                    if (flags & BPMASK) and badpixel_mask[map_cell, ss, fs, gain] != 0:
                         corrected = badpixel_fill_value
                     else:
                         if (flags & OFFSET):
-                            corrected = corrected - offset_map[map_cell, y, x, gain]
+                            corrected = corrected - offset_map[map_cell, ss, fs, gain]
                         if (flags & REL_GAIN):
-                            corrected = corrected / relgain_map[map_cell, y, x, gain]
-                output[frame, y, x] = corrected
+                            corrected = corrected / relgain_map[map_cell, ss, fs, gain]
+                output[frame, ss, fs] = corrected
 
 
 def correct_single(
@@ -66,13 +66,13 @@ def correct_single(
     float[:, :, :] output,
 ):
     # ignore "cell table", constant pretends cell 0
-    cdef int x, y
+    cdef int fs, ss
     cdef unsigned char gain
     cdef float corrected
-    for y in range(image_data.shape[1]):
-        for x in range(image_data.shape[2]):
-            corrected = image_data[0, y, x]
-            gain = gain_stage[0, y, x]
+    for ss in range(image_data.shape[1]):
+        for fs in range(image_data.shape[2]):
+            corrected = image_data[0, ss, fs]
+            gain = gain_stage[0, ss, fs]
             # legal values: 0, 1, or 3
             if gain == 2:
                 corrected = badpixel_fill_value
@@ -80,11 +80,11 @@ def correct_single(
                 if gain == 3:
                     gain = 2
 
-                if (flags & BPMASK) and badpixel_mask[0, y, x, gain] != 0:
+                if (flags & BPMASK) and badpixel_mask[0, ss, fs, gain] != 0:
                     corrected = badpixel_fill_value
                 else:
                     if (flags & OFFSET):
-                        corrected = corrected - offset_map[0, y, x, gain]
+                        corrected = corrected - offset_map[0, ss, fs, gain]
                     if (flags & REL_GAIN):
-                        corrected = corrected / relgain_map[0, y, x, gain]
-            output[0, y, x] = corrected
+                        corrected = corrected / relgain_map[0, ss, fs, gain]
+            output[0, ss, fs] = corrected
diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu
index a4f148f1c49d172b31f91a6a1a4ae4ec13c61c49..57a2d2a24a6adeaaeb5e4b04032782fb5f83502a 100644
--- a/src/calng/kernels/jungfrau_gpu.cu
+++ b/src/calng/kernels/jungfrau_gpu.cu
@@ -7,50 +7,45 @@ extern "C" {
 	                        const unsigned char* gain_stage, // same shape
 	                        const unsigned char* cell_table,
 	                        const unsigned char corr_flags,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
+	                        const unsigned char burst_mode,
 	                        const float* offset_map,
 	                        const float* rel_gain_map,
 	                        const unsigned int* bad_pixel_map,
 	                        const float bad_pixel_mask_value,
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 512;
+		const size_t fs_dim = 1024;
 
 		const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (current_frame >= input_frames || y >= Y || x >= X) {
+		if (current_frame >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_frame = Y * data_stride_y;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_frame = ss_dim * data_stride_ss;
 		const size_t data_index = current_frame * data_stride_frame +
-			y * data_stride_y +
-			x * data_stride_x;
+			ss * data_stride_ss +
+			fs * data_stride_fs;
 		float res = (float)data[data_index];
 
 		// gain mapped constant shape: cell, y, x, gain_level (dim size 3)
 		// note: in fixed gain mode, constant still provides data for three stages
 		const size_t map_stride_gain = 1;
-		const size_t map_stride_x = 3 * map_stride_gain;
-		const size_t map_stride_y = X * map_stride_x;
-		const size_t map_stride_cell = Y * map_stride_y;
+		const size_t map_stride_fs = 3 * map_stride_gain;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 
 		// TODO: warn user about cell_table value of 255 in either mode
 		// note: cell table may contain 255 if data didn't arrive
-		{% if burst_mode %}
-		// burst mode: "cell 255" will get copied
-		// TODO: consider masking "cell 255"
-		const size_t map_cell = cell_table[current_frame];
-		{% else %}
+		// burst mode: "cell 255" will get copied (not corrected)
 		// single cell: "cell 255" will get "corrected"
-		const size_t map_cell = 0;
-		{% endif %}
-
+		const size_t map_cell = burst_mode ? cell_table[current_frame] : 0;
 		if (map_cell < map_memory_cells) {
 			unsigned char gain = gain_stage[data_index];
 			if (gain == 2) {
@@ -58,20 +53,24 @@ extern "C" {
 				res = bad_pixel_mask_value;
 			} else {
 				if (gain == 3) {
+					// but 3 means 2 in the constants
 					gain = 2;
 				}
-				const size_t map_index = map_cell * map_stride_cell +
-					y * map_stride_y +
-					x * map_stride_x +
-					gain * map_stride_gain;
-				if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) {
-					res = bad_pixel_mask_value;
-				} else {
-					if (corr_flags & OFFSET) {
-						res -= offset_map[map_index];
-					}
-					if (corr_flags & REL_GAIN) {
-						res /= rel_gain_map[map_index];
+				if (map_cell < map_memory_cells) {
+					// only correct if in range of constant data
+					const size_t map_index = map_cell * map_stride_cell +
+						ss * map_stride_ss +
+						fs * map_stride_fs +
+						gain * map_stride_gain;
+					if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) {
+						res = bad_pixel_mask_value;
+					} else {
+						if (corr_flags & OFFSET) {
+							res -= offset_map[map_index];
+						}
+						if (corr_flags & REL_GAIN) {
+							res /= rel_gain_map[map_index];
+						}
 					}
 				}
 			}
diff --git a/src/calng/kernels/lpd_cpu.pyx b/src/calng/kernels/lpd_cpu.pyx
index 0c3eb82d50c3e23d1414515fd967347ffe4a0876..3cd878fb847e7e1978620c63c3f8899dc3cec55f 100644
--- a/src/calng/kernels/lpd_cpu.pyx
+++ b/src/calng/kernels/lpd_cpu.pyx
@@ -12,6 +12,7 @@ cdef unsigned char BPMASK = 16
 from cython.parallel import prange
 from libc.math cimport isinf, isnan
 
+
 def correct(
     unsigned short[:, :, :, :] image_data,
     unsigned short[:] cell_table,
@@ -22,7 +23,7 @@ def correct(
     float[:, :, :, :] flatfield_map,
     unsigned[:, :, :, :] bad_pixel_map,
     float bad_pixel_mask_value,
-    # TODO: support spitting out gain map for preview purposes
+    float[:, :, :] gain_map,
     float[:, :, :] output
 ):
     cdef int frame, map_cell, ss, fs
@@ -31,25 +32,41 @@ def correct(
     cdef unsigned short raw_data_value
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
-        if map_cell >= offset_map.shape[0]:
-            for ss in range(image_data.shape[1]):
-                for fs in range(image_data.shape[2]):
-                    output[frame, ss, fs] = <float>image_data[frame, 0, ss, fs]
-            continue
-        for ss in range(image_data.shape[1]):
-            for fs in range(image_data.shape[3]):
-                raw_data_value = image_data[frame, 0, ss, fs]
-        gain_stage = (raw_data_value >> 12) & 0x0003
-        res = <float>(raw_data_value & 0x0fff)
-        if gain_stage > 2 or (flags & BPMASK and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0):
-            res = bad_pixel_mask_value
+        if map_cell < offset_map.shape[0]:
+            for ss in range(image_data.shape[2]):
+                for fs in range(image_data.shape[3]):
+                    raw_data_value = image_data[frame, 0, ss, fs]
+                    gain_stage = (raw_data_value >> 12) & 0x0003
+                    res = <float>(raw_data_value & 0x0fff)
+                    if (
+                        gain_stage > 2
+                        or (
+                            flags & BPMASK
+                            and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0
+                        )
+                    ):
+                        res = bad_pixel_mask_value
+                    else:
+                        if flags & OFFSET:
+                            res = res - offset_map[map_cell, ss, fs, gain_stage]
+                        if flags & GAIN_AMP:
+                            res = res * gain_amp_map[map_cell, ss, fs, gain_stage]
+                        if flags & FF_CORR:
+                            res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage]
+                        if res < -1e7 or res > 1e7 or isnan(res) or isinf(res):
+                            res = bad_pixel_mask_value
+                    output[frame, ss, fs] = res
+                    gain_map[frame, ss, fs] = <float>gain_stage
         else:
-            if flags & OFFSET:
-                res = res - offset_map[map_cell, ss, fs, gain_stage]
-            if flags & GAIN_AMP:
-                res = res * gain_amp_map[map_cell, ss, fs, gain_stage]
-            if flags & FF_CORR:
-                res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage]
-            if res < 1e-7 or res > 1e7 or isnan(res) or isinf(res):
-                res = bad_pixel_mask_value
-        output[frame, ss, fs] = res
+            for ss in range(image_data.shape[2]):
+                for fs in range(image_data.shape[3]):
+                    raw_data_value = image_data[frame, 0, ss, fs]
+                    gain_stage = (raw_data_value >> 12) & 0x0003
+                    res = <float>(raw_data_value & 0x0fff)
+                    if (
+                        gain_stage > 2
+                        or res < -1e7 or res > 1e7
+                    ):
+                        res = bad_pixel_mask_value
+                    output[frame, ss, fs] = res
+                    gain_map[frame, ss, fs] = <float>gain_stage
diff --git a/src/calng/kernels/lpd_gpu.cu b/src/calng/kernels/lpd_gpu.cu
index 4e98cc851e836d83edffff1b449087dce8c3f190..7e935a38963af4c27038e7f40ee75ca9e92dcfa8 100644
--- a/src/calng/kernels/lpd_gpu.cu
+++ b/src/calng/kernels/lpd_gpu.cu
@@ -3,10 +3,12 @@
 {{corr_enum}}
 
 extern "C" {
-	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x (I think)
+	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs
 	                        const unsigned short* cell_table,
 	                        const unsigned char corr_flags,
-	                        const float* offset_map, // shape: cell, y, x, gain
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
+	                        const float* offset_map, // shape: cell, ss, fs, gain
 	                        const float* gain_amp_map,
 	                        const float* rel_gain_slopes_map,
 	                        const float* flatfield_map,
@@ -14,39 +16,37 @@ extern "C" {
 	                        const float bad_pixel_mask_value,
 	                        float* gain_map, // similar to preview for AGIPD
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t fs_dim = {{fs_dim}};
+		const size_t ss_dim = {{ss_dim}};
 
-		const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t frame = blockIdx.x * blockDim.x + threadIdx.x;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (memory_cell >= input_frames || y >= Y || x >= X) {
+		if (frame >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_cell = Y * data_stride_y;
-		const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_frame = ss_dim * data_stride_ss;
+		const size_t data_index = frame * data_stride_frame + ss * data_stride_ss + fs * data_stride_fs;
 		const unsigned short raw_data_value = data[data_index];
 		const unsigned char gain = (raw_data_value >> 12) & 0x0003;
 		float corrected = (float)(raw_data_value & 0x0fff);
 		float gain_for_preview = (float)gain;
 
 		const size_t gm_map_stride_gain = 1;
-		const size_t gm_map_stride_x = 3 * gm_map_stride_gain;
-		const size_t gm_map_stride_y = X * gm_map_stride_x;
-		const size_t gm_map_stride_cell = Y * gm_map_stride_y;
+		const size_t gm_map_stride_fs = 3 * gm_map_stride_gain;
+		const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs;
+		const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss;
 
-		const size_t map_cell = cell_table[memory_cell];
+		const size_t map_cell = cell_table[frame];
 		if (map_cell < map_memory_cells) {
 			const size_t gm_map_index = gain * gm_map_stride_gain +
 				map_cell * gm_map_stride_cell +
-				y * gm_map_stride_y +
-				x * gm_map_stride_x;
+				ss * gm_map_stride_ss +
+				fs * gm_map_stride_fs;
 
 			if (gain > 2 || ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index])) {
 				// now also checking for illegal gain value
@@ -70,7 +70,13 @@ extern "C" {
 					corrected = bad_pixel_mask_value;
 				}
 			}
+		} else {
+			if (gain > 2 || corrected < -1e7 || corrected > 1e7) {
+				corrected = bad_pixel_mask_value;
+				gain_for_preview = bad_pixel_mask_value;
+			}
 		}
+
 		gain_map[data_index] = gain_for_preview;
 		{% if output_data_dtype == "half" %}
 		output[data_index] = __float2half(corrected);
diff --git a/src/calng/preview_utils.py b/src/calng/preview_utils.py
index 55e907423df5526e56a28a00bcf7eb4f88eb5671..12cb7e183a13133824a7964b852b84a524bf9e8b 100644
--- a/src/calng/preview_utils.py
+++ b/src/calng/preview_utils.py
@@ -1,81 +1,148 @@
+import enum
+import functools
+from dataclasses import dataclass
+from typing import Callable, List, Optional
+
 from karabo.bound import (
     BOOL_ELEMENT,
     FLOAT_ELEMENT,
+    INT32_ELEMENT,
     NODE_ELEMENT,
     OUTPUT_CHANNEL,
     STRING_ELEMENT,
     UINT32_ELEMENT,
-    ChannelMetaData,
     Dims,
     Encoding,
     Hash,
     ImageData,
+    Schema,
     Timestamp,
 )
 
 import numpy as np
 
-from . import schemas, utils
+from . import schemas
+from .utils import downsample_1d, downsample_2d, maybe_get
+
+
+class FrameSelectionMode(enum.Enum):
+    FRAME = "frame"
+    CELL = "cell"
+    PULSE = "pulse"
+
+
+@dataclass
+class PreviewSpec:
+    # used for output schema
+    name: str
+    channel_name: Optional[str] = None  # set after init
+    dimensions: int = 2
+    wrap_in_imagedata: bool = True
+
+    # reconfigurables in node schema
+    swap_axes: bool = False  # only appars in schema if dimensions > 1
+    flip_ss: bool = False  # only appars in schema if dimensions > 1
+    flip_fs: bool = False
+    nan_replacement: float = 0
+    downsampling_factor: int = 1
+    downsampling_function: Callable = np.nanmax
+    # if frame_reduction, frame_reduction_fun is made automatically from index and mode
+    # otherwise, leave None or provide custom
+    frame_reduction: bool = True
+    frame_reduction_selection_mode: FrameSelectionMode = FrameSelectionMode.FRAME
+    # not all detectors provide cell and / or pulse IDs
+    valid_frame_selection_modes: List[FrameSelectionMode] = tuple(FrameSelectionMode)
+    frame_reduction_index: int = 0
+    # either set explicitly or automatically with frame selection
+    frame_reduction_fun: Optional[Callable] = None
 
 
 class PreviewFriend:
     @staticmethod
     def add_schema(
-        schema, node_path="preview", output_channels=None, create_node=False
+        schema: Schema,
+        outputs: List[PreviewSpec],
+        node_name: str = "preview"
     ):
-        if output_channels is None:
-            output_channels = ["output"]
+        """Add the prerequisite schema configurables for all the outputs."""
+        (
+            NODE_ELEMENT(schema)
+            .key(node_name)
+            .displayedName("Preview")
+            .description(
+                "Output specifically intended for preview in Karabo GUI. Includes "
+                "some options for throttling and adjustments of the output data."
+            )
+            .commit(),
+        )
+
+        for spec in outputs:
+            PreviewFriend._add_subschema(schema, f"{node_name}.{spec.name}", spec)
+
+    @staticmethod
+    def _add_subschema(schema, node_path, spec):
+        # each preview gets its own node with config and channel
+        (
+            NODE_ELEMENT(schema)
+            .key(node_path)
+            .commit(),
+        )
 
-        if create_node:
+        if spec.dimensions > 1:
             (
-                NODE_ELEMENT(schema)
-                .key(node_path)
-                .displayedName("Preview")
+                BOOL_ELEMENT(schema)
+                .key(f"{node_path}.swapAxes")
+                .displayedName("Swap slow / fast scan")
                 .description(
-                    "Output specifically intended for preview in Karabo GUI. Includes "
-                    "some options for throttling and adjustments of the output data."
+                    "Swaps the two pixel axes around. Can be combined with "
+                    "flipping to rotate the image"
                 )
+                .assignmentOptional()
+                .defaultValue(spec.swap_axes)
+                .reconfigurable()
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_path}.flipSS")
+                .displayedName("Flip SS")
+                .description("Flip image data along slow scan axis.")
+                .assignmentOptional()
+                .defaultValue(spec.flip_ss)
+                .reconfigurable()
                 .commit(),
             )
-        (
-            BOOL_ELEMENT(schema)
-            .key(f"{node_path}.flipSS")
-            .displayedName("Flip SS")
-            .description("Flip image data along slow scan axis.")
-            .assignmentOptional()
-            .defaultValue(False)
-            .reconfigurable()
-            .commit(),
 
+        # configurable bits; these correspond to spec parts
+        (
             BOOL_ELEMENT(schema)
             .key(f"{node_path}.flipFS")
             .displayedName("Flip FS")
             .description("Flip image data along fast scan axis.")
             .assignmentOptional()
-            .defaultValue(False)
+            .defaultValue(spec.flip_fs)
             .reconfigurable()
             .commit(),
 
             UINT32_ELEMENT(schema)
             .key(f"{node_path}.downsamplingFactor")
-            .displayedName("Factor")
+            .displayedName("Downsampling factor")
             .description(
                 "If greater than 1, the preview image will be downsampled by this "
                 "factor before sending. This is mostly to save bandwidth in case GUI "
                 "updates start lagging."
             )
             .assignmentOptional()
-            .defaultValue(1)
+            .defaultValue(spec.downsampling_factor)
             .options("1,2,4,8")
             .reconfigurable()
             .commit(),
 
             STRING_ELEMENT(schema)
             .key(f"{node_path}.downsamplingFunction")
-            .displayedName("Function")
+            .displayedName("Downsampling function")
             .description("Reduction function used during downsampling.")
             .assignmentOptional()
-            .defaultValue("nanmax")
+            .defaultValue(spec.downsampling_function.__name__)
             .options("nanmax,nanmean,nanmin,nanmedian")
             .reconfigurable()
             .commit(),
@@ -91,97 +158,255 @@ class PreviewFriend:
                 "from the image data you want to see."
             )
             .assignmentOptional()
-            .defaultValue(0)
+            .defaultValue(spec.nan_replacement)
             .reconfigurable()
             .commit(),
         )
-        for channel in output_channels:
+
+        if spec.frame_reduction:
             (
-                OUTPUT_CHANNEL(schema)
-                .key(f"{node_path}.{channel}")
-                .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True))
-                .description("See description of parent node, 'preview'.")
+                STRING_ELEMENT(schema)
+                .key(f"{node_path}.selectionMode")
+                .tags("managed")
+                .displayedName("Index selection mode")
+                .description(
+                    "The value of preview.index can be used in multiple ways, "
+                    "controlled by this value. If this is set to 'frame', "
+                    "preview.index is sliced directly from data. If 'cell' "
+                    "(or 'pulse') is selected, I will look at cell (or pulse) table "
+                    "for the requested cell (or pulse ID). Special (stat) index values "
+                    "<0 are not affected by this."
+                )
+                .options(
+                    ",".join(
+                        selectionmode.value
+                        for selectionmode in spec.valid_frame_selection_modes
+                    )
+                )
+                .assignmentOptional()
+                .defaultValue("frame")
+                .reconfigurable()
+                .commit(),
+                INT32_ELEMENT(schema)
+                .key(f"{node_path}.index")
+                .tags("managed")
+                .displayedName("Index (or stat) for preview")
+                .description(
+                    "If this value is ≥ 0, the corresponding index (frame, cell, or "
+                    "pulse) will be sliced for the preview output. If this value is "
+                    "<0, preview will be one of the following stats: -1: max, "
+                    "-2: mean, -3: sum, -4: stdev. These stats are computed across "
+                    "frames, ignoring NaN values."
+                )
+                .assignmentOptional()
+                .defaultValue(spec.frame_reduction_index)
+                .minInc(-4)
+                .reconfigurable()
                 .commit(),
             )
 
-    def __init__(self, device, node_name="preview", output_channels=None):
-        if output_channels is None:
-            output_channels = ["output"]
-        self.output_channels = output_channels
-        self.device = device
-        self.dev_id = self.device.getInstanceId()
-        self.node_name = node_name
-        self.outputs = [
-            self.device.signalSlotable.getOutputChannel(f"{self.node_name}.{channel}")
-            for channel in self.output_channels
-        ]
-        self.reconfigure(device._parameters)
-
-    def write_outputs(self, timestamp, *datas, inplace=True):
+        (
+            OUTPUT_CHANNEL(schema)
+            .key(f"{node_path}.output")
+            .dataSchema(
+                schemas.preview_schema(
+                    wrap_image_in_imagedata=spec.wrap_in_imagedata
+                )
+            )
+            .description("See description of grandparent node, 'preview'.")
+            .commit(),
+        )
+
+    def __init__(self, device, outputs, node_name="preview"):
+        self._device = device
+        self.output_specs = outputs
+        for spec in self.output_specs:
+            spec.channel_name = f"{node_name}.{spec.name}.output"
+        self.reconfigure(self._device.get(node_name))
+
+    def write_outputs(
+        self,
+        *datas,
+        timestamp=None,
+        inplace=True,
+        cell_table=None,
+        pulse_table=None,
+        warn_fun=None,
+    ):
         """Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
         ImageData) and writes to output channels. Make sure datas length matches number
-        of channels!"""
+        of channels! If inplace, masking and such may write to provided buffers (will
+        otherwise copy)."""
         if isinstance(timestamp, Hash):
             timestamp = Timestamp.fromHashAttributes(
                 timestamp.getAttributes("timestamp")
             )
-        for data, output, channel_name in zip(
-            datas, self.outputs, self.output_channels
-        ):
-            if self.downsampling_factor > 1:
-                data = utils.downsample_2d(
-                    data,
-                    self.downsampling_factor,
-                    reduction_fun=self.downsampling_function,
-                )
+        if warn_fun is None:
+            warn_fun = self._device.log.WARN
+        for data, spec in zip(datas, self.output_specs):
+            if spec.frame_reduction_fun is not None:
+                data = spec.frame_reduction_fun(data, cell_table, pulse_table, warn_fun)
+            if spec.downsampling_factor > 1:
+                if spec.dimensions == 1:
+                    data = downsample_1d(
+                        data,
+                        spec.downsampling_factor,
+                        reduction_fun=spec.downsampling_function,
+                    )
+                else:
+                    data = downsample_2d(
+                        data,
+                        spec.downsampling_factor,
+                        reduction_fun=spec.downsampling_function,
+                    )
             elif not inplace:
                 data = data.copy()
-            if self.flip_ss:
-                data = np.flip(data, 0)
-            if self.flip_fs:
-                data = np.flip(data, 1)
+            if spec.flip_ss:
+                data = np.flip(data, -2)
+            if spec.flip_fs:
+                data = np.flip(data, -1)
             if isinstance(data, np.ma.MaskedArray):
                 data, mask = data.data, data.mask | ~np.isfinite(data)
             else:
                 mask = ~np.isfinite(data)
-            data[mask] = self.nan_replacement
+            data[mask] = spec.nan_replacement
             mask = mask.astype(np.uint8)
-            output_hash = Hash(
-                "image.data",
-                ImageData(
-                    data,
-                    Dims(*data.shape),
-                    Encoding.GRAY,
-                    bitsPerPixel=32,
-                ),
-                "image.mask",
-                ImageData(
-                    mask,
-                    Dims(*mask.shape),
-                    Encoding.GRAY,
-                    bitsPerPixel=8,
-                ),
-            )
-            output.write(
-                output_hash,
-                ChannelMetaData(
-                    f"{self.dev_id}:{self.node_name}.{channel_name}",
-                    timestamp,
-                ),
-                copyAllData=False,
+            if spec.wrap_in_imagedata:
+                output_hash = Hash(
+                    "image.data",
+                    ImageData(
+                        maybe_get(data),
+                        Dims(*data.shape),
+                        Encoding.GRAY,
+                        bitsPerPixel=32,
+                    ),
+                    "image.mask",
+                    ImageData(
+                        maybe_get(mask),
+                        Dims(*mask.shape),
+                        Encoding.GRAY,
+                        bitsPerPixel=8,
+                    ),
+                )
+            else:
+                output_hash = Hash(
+                    "image.data", maybe_get(data), "image.mask", maybe_get(mask)
+                )
+            self._device.writeChannel(
+                spec.channel_name, output_hash, timestamp=timestamp, safeNDArray=True
             )
-            output.update(safeNDArray=True)
 
     def reconfigure(self, conf):
-        if conf.has(f"{self.node_name}.downsamplingFunction"):
-            self.downsampling_function = getattr(
-                np, conf[f"{self.node_name}.downsamplingFunction"]
+        for spec in self.output_specs:
+            if not conf.has(spec.name):
+                continue
+            subconf = conf.get(spec.name)
+            if spec.frame_reduction and (
+                subconf.has("selectionMode") or subconf.has("index")
+            ):
+                if subconf.has("selectionMode"):
+                    spec.frame_reduction_selection_mode = FrameSelectionMode(
+                        subconf["selectionMode"]
+                    )
+                if subconf.has("index"):
+                    spec.frame_reduction_index = subconf["index"]
+                spec.frame_reduction_fun = _make_frame_reduction_fun(spec)
+            if subconf.has("downsamplingFunction"):
+                spec.downsampling_function = getattr(
+                    np, subconf["downsamplingFunction"]
+                )
+            if subconf.has("downsamplingFactor"):
+                spec.downsampling_factor = subconf["downsamplingFactor"]
+            if subconf.has("flipSS"):
+                spec.flip_ss = subconf["flipSS"]
+            if subconf.has("flipFS"):
+                spec.flip_fs = subconf["flipFS"]
+            if subconf.has("swapAxes"):
+                spec.flip_fs = subconf["flipFS"]
+            if subconf.has("replaceNanWith"):
+                spec.nan_replacement = subconf["replaceNanWith"]
+
+
+def _make_frame_reduction_fun(spec) -> Optional[Callable]:
+    """Will be called during configuration in case frame_reduction is True. Will
+    use frame_reduction_selection_mode and frame_reduction_index to create a
+    function to either slice or reduce input data."""
+    # note: in case of frame picking, signature includes cell and pulse table
+    if spec.frame_reduction_index < 0:
+        # so while those tables are irrelevant here, still include in wrapper
+        if spec.frame_reduction_index == -1:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanmax, axis=0)
+        elif spec.frame_reduction_index == -2:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanmean, axis=0, dtype=np.float32)
+        elif spec.frame_reduction_index == -3:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nansum, axis=0, dtype=np.float32)
+        elif spec.frame_reduction_index == -4:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanstd, axis=0, dtype=np.float32)
+
+        def fun(data, cell_table, pulse_table, warn_fun):
+            try:
+                return reduction_fun(data)
+            except Exception as ex:
+                warn_fun(f"Frame reduction error: {ex}")
+
+        return fun
+    else:
+        # here, we pick out a frame
+        if spec.frame_reduction_selection_mode is FrameSelectionMode.FRAME:
+            return _pick_preview_index_frame(spec.frame_reduction_index)
+        elif spec.frame_reduction_selection_mode is FrameSelectionMode.CELL:
+            return _pick_preview_index_cell(spec.frame_reduction_index)
+        else:
+            return _pick_preview_index_pulse(spec.frame_reduction_index)
+
+
+def _pick_preview_index_frame(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        if index >= data.shape[0]:
+            warn_fun(
+                f"Frame index {index} out of bounds for data with "
+                f"{data.shape[0]} frames, previewing frame 0"
+            )
+            return data[0]
+        return data[index]
+    return aux
+
+
+def _pick_preview_index_cell(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        found = np.nonzero(cell_table == index)[0]
+        if found.size == 0:
+            warn_fun(
+                f"Cell ID {index} not found in cell mapping, "
+                "previewing frame 0 instead"
+            )
+            return data[0]
+        if found.size > 1:
+            warn_fun(
+                f"Cell ID {index} not unique in cell mapping, "
+                f"previewing first occurrence (frame {found[0]}"
+            )
+        return data[found[0]]
+    return aux
+
+
+def _pick_preview_index_pulse(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        found = np.nonzero(pulse_table == index)[0]
+        if found.size == 0:
+            warn_fun(
+                f"Pulse ID {index} not found in pulse table, "
+                "previewing frame 0 instead"
+            )
+            return data[0]
+        if found.size > 1:
+            warn_fun(
+                f"Pulse ID {index} not unique in pulse table, "
+                f"previewing first occurrence (frame {found[0]})"
             )
-        if conf.has(f"{self.node_name}.downsamplingFactor"):
-            self.downsampling_factor = conf[f"{self.node_name}.downsamplingFactor"]
-        if conf.has(f"{self.node_name}.flipSS"):
-            self.flip_ss = conf[f"{self.node_name}.flipSS"]
-        if conf.has(f"{self.node_name}.flipFS"):
-            self.flip_fs = conf[f"{self.node_name}.flipFS"]
-        if conf.has(f"{self.node_name}.replaceNanWith"):
-            self.nan_replacement = conf[f"{self.node_name}.replaceNanWith"]
+        return data[found[0]]
+    return aux
diff --git a/src/calng/scenes.py b/src/calng/scenes.py
index 56ae3bfd2f1098efce69a38f8f677e0f3ddd8fa7..3eb66b0fdc9ba8e5ce507023ddaa1c7caededafe 100644
--- a/src/calng/scenes.py
+++ b/src/calng/scenes.py
@@ -34,6 +34,8 @@ from karabo.common.scenemodel.api import (
     WebLinkModel,
 )
 
+docs_url = "https://rtd.xfel.eu/docs/calng/en/latest"
+
 
 @titled("Found constants", width=6 * NARROW_INC)
 @boxed
@@ -140,7 +142,7 @@ class ManagerDeviceStatus(VerticalLayout):
             text="Documentation",
             width=7 * BASE_INC,
             height=BASE_INC,
-            target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#calibration-manager",
+            target=f"{docs_url}/devices/#calibration-manager",
         )
         self.children.extend(
             [
@@ -446,63 +448,77 @@ class RoiBox(VerticalLayout):
 @titled("Preview settings")
 @boxed
 class PreviewSettings(HorizontalLayout):
-    def __init__(self, device_id, schema_hash, node_name="preview", extras=None):
+    def __init__(self, device_id, schema_hash, node_name):
         super().__init__()
-        self.children.extend(
-            [
-                VerticalLayout(
-                    EditableRow(
-                        device_id,
-                        schema_hash,
-                        f"{node_name}.replaceNanWith",
-                        6,
-                        4,
-                    ),
-                    HorizontalLayout(
-                        EditableRow(
-                            device_id,
-                            schema_hash,
-                            f"{node_name}.flipSS",
-                            3,
-                            2,
-                        ),
-                        EditableRow(
-                            device_id,
-                            schema_hash,
-                            f"{node_name}.flipFS",
-                            3,
-                            2,
-                        ),
-                        padding=0,
-                    ),
-                ),
-                Vline(height=3 * BASE_INC),
-                VerticalLayout(
-                    LabelModel(
-                        text="Image downsampling",
-                        width=8 * BASE_INC,
-                        height=BASE_INC,
-                    ),
+        tweaks = VerticalLayout(
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.replaceNanWith",
+                6,
+                4,
+            ),
+        )
+        if schema_hash.has(f"{node_name}.flipSS"):
+            tweaks.children.append(
+                HorizontalLayout(
                     EditableRow(
                         device_id,
                         schema_hash,
-                        f"{node_name}.downsamplingFactor",
-                        5,
-                        5,
+                        f"{node_name}.flipSS",
+                        3,
+                        2,
                     ),
                     EditableRow(
                         device_id,
                         schema_hash,
-                        f"{node_name}.downsamplingFunction",
-                        5,
-                        5,
+                        f"{node_name}.flipFS",
+                        3,
+                        2,
                     ),
-                ),
+                    padding=0,
+                )
+            )
+        downsampling = VerticalLayout(
+            LabelModel(
+                text="Downsampling",
+                width=8 * BASE_INC,
+                height=BASE_INC,
+            ),
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.downsamplingFactor",
+                5,
+                5,
+            ),
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.downsamplingFunction",
+                5,
+                5,
+            ),
+        )
+        self.children.extend(
+            [
+                tweaks,
+                Vline(height=3 * BASE_INC),
+                downsampling,
             ]
         )
-        if extras is not None:
-            self.children.append(Vline(height=3 * BASE_INC))
-            self.children.extend(extras)
+        if schema_hash.has(f"{node_name}.index"):
+            self.children.extend(
+                [
+                    Vline(height=3 * BASE_INC),
+                    VerticalLayout(
+                        EditableRow(device_id, schema_hash, f"{node_name}.index", 8, 4),
+                        EditableRow(
+                            device_id, schema_hash, f"{node_name}.selectionMode", 8, 4
+                        ),
+                    ),
+                ]
+            )
 
 
 @titled("Histogram settings")
@@ -633,26 +649,23 @@ def correction_device_overview(device_id, schema):
     )
     return VerticalLayout(
         main_overview,
-        LabelModel(
-            text="Preview (corrected):",
-            width=20 * BASE_INC,
-            height=BASE_INC,
-        ),
-        PreviewDisplayArea(device_id, schema_hash, "preview.outputCorrected"),
         *(
             DeviceSceneLinkModel(
-                text=f"Preview: {channel}",
+                text=f"Preview: {preview_name}",
                 keys=[f"{device_id}.availableScenes"],
-                target=f"preview:preview.{channel}",
+                target=f"preview:{preview_name}",
                 target_window=SceneTargetWindow.Dialog,
                 width=16 * BASE_INC,
                 height=BASE_INC,
             )
-            for channel in schema_hash.get("preview").getKeys()
-            if schema_hash.hasAttribute(f"preview.{channel}", "classId")
-            and schema_hash.getAttribute(f"preview.{channel}", "classId")
-            == "OutputChannel"
+            for preview_name in schema_hash.get("preview").getKeys()
+        ),
+        LabelModel(
+            text="Preview (corrected):",
+            width=20 * BASE_INC,
+            height=BASE_INC,
         ),
+        PreviewDisplayArea(device_id, schema_hash, "preview.corrected.output"),
     )
 
 
@@ -666,25 +679,20 @@ def lpdmini_splitter_overview(device_id, schema):
 
 
 @scene_generator
-def correction_device_preview(device_id, schema, preview_channel):
+def correction_device_preview(device_id, schema, name):
     schema_hash = schema_to_hash(schema)
     return VerticalLayout(
         LabelModel(
-            text=f"Preview: {preview_channel}",
+            text=f"Preview: {name}",
             width=10 * BASE_INC,
             height=BASE_INC,
         ),
         PreviewSettings(
             device_id,
             schema_hash,
-            extras=[
-                VerticalLayout(
-                    EditableRow(device_id, schema_hash, "preview.index", 8, 4),
-                    EditableRow(device_id, schema_hash, "preview.selectionMode", 8, 4),
-                )
-            ],
+            f"preview.{name}",
         ),
-        PreviewDisplayArea(device_id, schema_hash, preview_channel),
+        PreviewDisplayArea(device_id, schema_hash, f"preview.{name}.output"),
     )
 
 
@@ -758,7 +766,8 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan
                     ),
                     DisplayCommandModel(
                         keys=[
-                            f"{device_id}.{prefix}.{constant}.overrideConstantFromVersion"
+                            f"{device_id}.{prefix}.{constant}"
+                            ".overrideConstantFromVersion"
                         ],
                         width=8 * BASE_INC,
                         height=BASE_INC,
@@ -829,67 +838,47 @@ def manager_device_overview(
     mds_hash = schema_to_hash(manager_device_schema)
     cds_hash = schema_to_hash(correction_device_schema)
 
-    data_throttling_children = [
-        LabelModel(
-            text="Frame filter",
-            width=11 * BASE_INC,
-            height=BASE_INC,
-        ),
-        EditableRow(
+    config_column = [
+        recursive_editable(
             manager_device_id,
             mds_hash,
-            "managedKeys.frameFilter.type",
-            7,
-            4,
+            "managedKeys.constantParameters",
         ),
-        EditableRow(
-            manager_device_id,
-            mds_hash,
-            "managedKeys.frameFilter.spec",
-            7,
-            4,
+        DisplayCommandModel(
+            keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"],
+            width=10 * BASE_INC,
+            height=BASE_INC,
         ),
     ]
 
-    if "managedKeys.daqTrainStride" in mds_hash:
-        # Only add DAQ train stride if present on the schema, may be
-        # disabled on the manager.
-        data_throttling_children.insert(
-            0,
-            EditableRow(
+    if "managedKeys.preview" in mds_hash:
+        config_column.append(
+            recursive_editable(
                 manager_device_id,
                 mds_hash,
-                "managedKeys.daqTrainStride",
-                7,
-                4,
+                "managedKeys.preview",
+                max_depth=2,
             ),
         )
 
-    return VerticalLayout(
-        HorizontalLayout(
-            ManagerDeviceStatus(manager_device_id),
-            VerticalLayout(
-                recursive_editable(
-                    manager_device_id,
-                    mds_hash,
-                    "managedKeys.constantParameters",
-                ),
-                DisplayCommandModel(
-                    keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"],
-                    width=10 * BASE_INC,
-                    height=BASE_INC,
-                ),
-                recursive_editable(
+    if "managedKeys.daqTrainStride" in mds_hash:
+        config_column.append(
+            titled("Data throttling")(boxed(VerticalLayout))(
+                EditableRow(
                     manager_device_id,
                     mds_hash,
-                    "managedKeys.preview",
-                    max_depth=2,
-                ),
-                titled("Data throttling")(boxed(VerticalLayout))(
-                    children=data_throttling_children,
-                    padding=0,
+                    "managedKeys.daqTrainStride",
+                    7,
+                    4,
                 ),
+                padding=0,
             ),
+        )
+
+    return VerticalLayout(
+        HorizontalLayout(
+            ManagerDeviceStatus(manager_device_id),
+            VerticalLayout(*config_column),
             recursive_editable(
                 manager_device_id,
                 mds_hash,
@@ -932,7 +921,7 @@ def manager_device_overview(
                             text="Documentation",
                             width=6 * BASE_INC,
                             height=BASE_INC,
-                            target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#correction-devices",
+                            target=f"{docs_url}/devices/#correction-devices",
                         ),
                         padding=0,
                     )
@@ -1043,12 +1032,12 @@ def detector_assembler_overview(device_id, schema):
     return VerticalLayout(
         HorizontalLayout(
             AssemblerDeviceStatus(device_id),
-            PreviewSettings(device_id, schema_hash),
+            PreviewSettings(device_id, schema_hash, "preview.assembled"),
         ),
         PreviewDisplayArea(
             device_id,
             schema_hash,
-            "preview.output",
+            "preview.assembled.output",
             data_width=40 * BASE_INC,
             data_height=40 * BASE_INC,
             mask_width=40 * BASE_INC,
diff --git a/src/calng/schemas.py b/src/calng/schemas.py
index 79524daade964745896b32a143491b8e8587a6d7..3b4820b0c7b90239bd70807a7d90cd250404d675 100644
--- a/src/calng/schemas.py
+++ b/src/calng/schemas.py
@@ -44,7 +44,7 @@ def preview_schema(wrap_image_in_imagedata=False):
     return res
 
 
-def xtdf_output_schema(use_shmem_handle=True):
+def xtdf_output_schema(use_shmem_handles):
     # TODO: trim output schema / adapt to specific detectors
     # currently: based on snapshot of actual output reusing AGIPD hash
     res = Schema()
@@ -178,7 +178,7 @@ def xtdf_output_schema(use_shmem_handle=True):
         .commit(),
     )
 
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("image.data")
@@ -194,7 +194,7 @@ def xtdf_output_schema(use_shmem_handle=True):
     return res
 
 
-def jf_output_schema(use_shmem_handle=True):
+def jf_output_schema(use_shmem_handles):
     res = Schema()
     (
         NODE_ELEMENT(res)
@@ -227,7 +227,7 @@ def jf_output_schema(use_shmem_handle=True):
         .readOnly()
         .commit(),
     )
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("data.adc")
@@ -243,7 +243,7 @@ def jf_output_schema(use_shmem_handle=True):
     return res
 
 
-def pnccd_output_schema(use_shmem_handle=True):
+def pnccd_output_schema(use_shmem_handles):
     res = Schema()
     (
         NODE_ELEMENT(res)
@@ -255,7 +255,7 @@ def pnccd_output_schema(use_shmem_handle=True):
         .readOnly()
         .commit(),
     )
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("data.image")
diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py
index 4d6e7008afa8a68ae44deabfef3d651cf2f7f910..8c81d1fe3c1a498cd2a7774e71fd31f04ae41818 100644
--- a/src/calng/stacking_utils.py
+++ b/src/calng/stacking_utils.py
@@ -166,8 +166,6 @@ class StackingFriend:
         self.reconfigure(source_config, merge_config)
 
     def reconfigure(self, source_config, merge_config):
-        print("merge_config", type(merge_config))
-        print("source_config", type(source_config))
         if source_config is not None:
             self._source_config = source_config
         if merge_config is not None:
diff --git a/src/calng/utils.py b/src/calng/utils.py
index fd7c6c284b4409e33c5ea9816eac8cfc17d6744c..c955d80b38d4b9f01cfce1c6e799fa691695512a 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -8,6 +8,21 @@ import numpy as np
 from calngUtils import misc
 
 
+class WarningLampType(enum.Enum):
+    FRAME_FILTER = enum.auto()
+    MEMORY_CELL_RANGE = enum.auto()
+    CONSTANT_OPERATING_PARAMETERS = enum.auto()
+    PREVIEW_SETTINGS = enum.auto()
+    CORRECTION_RUNNER = enum.auto()
+    OUTPUT_BUFFER = enum.auto()
+    GPU_MEMORY = enum.auto()
+    CALCAT_CONNECTION = enum.auto()
+    EMPTY_HASH = enum.auto()
+    MISC_INPUT_DATA = enum.auto()
+    TRAIN_ID = enum.auto()
+    TIMESERVER_CONNECTION = enum.auto()
+
+
 class WarningContextSystem:
     """Helper object to trigger warning lamps based on different warning types in
     contexts. Intended use: something that is checked multiple times and is good or bad
@@ -90,63 +105,6 @@ class WarningContextSystem:
         self.device.set("status", message)
 
 
-class PreviewIndexSelectionMode(enum.Enum):
-    FRAME = "frame"
-    CELL = "cell"
-    PULSE = "pulse"
-
-
-def pick_frame_index(selection_mode, index, cell_table, pulse_table):
-    """When selecting a single frame to preview, an obvious question is whether the
-    number the operator provides is a frame index, a cell ID, or a pulse ID. This
-    function allows any of the three, translating into frame index.
-
-    Indices below zero are special values and thus returned directly.
-
-    Returns: (frame index, cell ID, pulse ID), optional warning"""
-
-    if index < 0:
-        return (index, index, index), None
-
-    warning = None
-    selection_mode = PreviewIndexSelectionMode(selection_mode)
-
-    if selection_mode is PreviewIndexSelectionMode.FRAME:
-        if index < cell_table.size:
-            frame_index = index
-        else:
-            warning = (
-                f"Frame index {index} out of range for cell table of length "
-                f"{len(cell_table)}. Will use frame index 0 instead."
-            )
-            frame_index = 0
-    else:
-        if selection_mode is PreviewIndexSelectionMode.CELL:
-            index_table = cell_table
-        else:
-            index_table = pulse_table
-        found = np.where(index_table == index)[0]
-        if len(found) == 1:
-            frame_index = found[0]
-        elif len(found) == 0:
-            frame_index = 0
-            warning = (
-                f"{selection_mode.value.capitalize()} ID {index} not found in "
-                f"{selection_mode.name} table. Will use frame {frame_index}, which "
-                f"corresponds to {selection_mode.name} {index_table[0]} instead."
-            )
-        elif len(found) > 1:
-            warning = (
-                f"{selection_mode.value.capitalize()} ID {index} was not unique in "
-                f"{selection_mode} table. Will use first occurrence out of "
-                f"{len(found)} occurrences (frame {frame_index})."
-            )
-
-    cell_id = cell_table[frame_index]
-    pulse_id = pulse_table[frame_index]
-    return (frame_index, cell_id, pulse_id), warning
-
-
 _np_typechar_to_c_typestring = {
     "?": "bool",
     "B": "unsigned char",
@@ -234,22 +192,37 @@ class BadPixelValues(enum.IntFlag):
 def downsample_2d(arr, factor, reduction_fun=np.nanmax):
     """Generalization of downsampling from FemDataAssembler
 
-    Expects first two dimensions of arr to be multiple of 2 ** factor
-    Useful if you're sitting at home and ssh connection is slow to get full-resolution
-    previews."""
+    Expects last two dimensions of arr to be multiple of 2 ** factor.  Useful for
+    looking at detector previews over a slow connection (say, ssh)."""
 
     for i in range(factor // 2):
+        # downsample slow scan
         arr = reduction_fun(
             (
-                arr[:-1:2],
-                arr[1::2],
+                arr[..., :-1:2, :],
+                arr[..., 1::2, :],
             ),
             axis=0,
         )
+        # downsample fast scan
         arr = reduction_fun(
             (
-                arr[:, :-1:2],
-                arr[:, 1::2],
+                arr[..., :-1:2],
+                arr[..., 1::2],
+            ),
+            axis=0,
+        )
+    return arr
+
+
+def downsample_1d(arr, factor, reduction_fun=np.nanmax):
+    """Same as downsample_2d, but only 1d (applied to last axis of input)"""
+
+    for i in range(factor // 2):
+        arr = reduction_fun(
+            (
+                arr[..., :-1:2],
+                arr[..., 1::2],
             ),
             axis=0,
         )
@@ -289,3 +262,25 @@ def apply_partial_lut(data, lut, mask, out, missing=np.nan):
     tmp = out.ravel()
     tmp[~mask] = data.ravel()[lut]
     tmp[mask] = missing
+
+
+def subset_of_hash(input_hash, *keys):
+    # don't bother importing here, just construct empty
+    res = input_hash.__class__()
+    for key in keys:
+        if input_hash.has(key):
+            res.set(key, input_hash.get(key))
+    return res
+
+
+def maybe_get(a, out=None):
+    """For getting CuPy ndarray data to system memory - tries to behave like
+    cupy.ndarray.get with respect to the 'out' parameter"""
+    if hasattr(a, "get"):
+        return a.get(out=out)
+    else:
+        if out is None:
+            return a
+        else:
+            np.copyto(out, a)
+            return out
diff --git a/tests/common_setup.py b/tests/common_setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..fafe2cb7078a4a20b6ef4beb7d17b41d0c05d1b9
--- /dev/null
+++ b/tests/common_setup.py
@@ -0,0 +1,100 @@
+import contextlib
+import pathlib
+import threading
+
+import h5py
+import numpy as np
+from calngUtils import device as device_utils
+from calngUtils.misc import ChainHash
+from karabo.bound import Hash, Schema
+
+
+def maybe_get(b):
+    if hasattr(b, "get"):
+        return b.get()
+    else:
+        return b
+
+
+class DummyLogger:
+    DEBUG = print
+    INFO = print
+    WARN = print
+
+
+@device_utils.with_config_overlay
+@device_utils.with_unsafe_get
+class DummyCorrectionDevice:
+    log = DummyLogger()
+
+    def log_status_info(self, msg):
+        self.log.INFO(msg)
+
+    def log_status_warn(self, msg):
+        self.log.WARN(msg)
+
+    def get(self, key):
+        return self._parameters.get(key)
+
+    def set(self, key, value):
+        self._parameters[key] = value
+
+    def reconfigure(self, config):
+        with self.new_config_context(config):
+            self.runner.reconfigure(config)
+        self._parameters.merge(config)
+
+    def __init__(self, *friends):
+        self._parameters = Hash()
+        s = Schema()
+        self._temporary_config = []
+        self._temporary_config_lock = threading.RLock()
+        for friend in friends:
+            if isinstance(friend, dict):
+                friend_class = friend.pop("friend_class")
+                extra_args = friend
+            else:
+                friend_class = friend
+                extra_args = {}
+            friend_class.add_schema(s, **extra_args)
+        h = s.getParameterHash()
+        for path in h.getPaths():
+            if h.hasAttribute(path, "defaultValue"):
+                if "output" in path.split("."):
+                    continue
+                self._parameters[path] = h.getAttribute(path, "defaultValue")
+
+        self.warnings = {}
+
+    def _please_send_me_cached_constants(self, constants, callback):
+        pass
+
+    @contextlib.contextmanager
+    def warning_context(self, key, warn_type):
+        def aux(message):
+            print(message)
+            self.warnings[(key, warn_type)] = message
+
+        yield aux
+
+    @contextlib.contextmanager
+    def new_config_context(self, config):
+        with self._temporary_config_lock:
+            old_config = self._parameters
+            if isinstance(self._parameters, ChainHash):
+                # in case we need to support recursive case / configuration stacking
+                self._parameters = ChainHash(config, *old_config._hashes)
+            else:
+                self._parameters = ChainHash(config, old_config)
+            yield
+            self._parameters = old_config
+
+
+caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store")
+
+
+def get_constant_from_file(path_to_file, file_name, data_set_name):
+    """Intended to work with copy-pasting paths from CalCat. Specifially, the "Path to
+    File", "File Name", and "Data Set Name" fields displayed for a given CCV."""
+    with h5py.File(caldb_store / path_to_file / file_name, "r") as fd:
+        return np.array(fd[f"{data_set_name}/data"])
diff --git a/tests/test_agipd_kernels.py b/tests/test_agipd_kernels.py
index afbb2ca057dd3ec4b95aca4b2c16f235f59e9519..25021ae92bf2c221acec0f0ecf1a919a6659310a 100644
--- a/tests/test_agipd_kernels.py
+++ b/tests/test_agipd_kernels.py
@@ -1,19 +1,23 @@
-import h5py
 import numpy as np
-import pathlib
-
+import pytest
 from calng.corrections.AgipdCorrection import (
+    AgipdBaseRunner,
+    AgipdCalcatFriend,
     AgipdCpuRunner,
     AgipdGpuRunner,
     Constants,
     CorrectionFlags,
 )
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get
 
 output_dtype = np.float32
 corr_dtype = np.float32
 pixels_x = 512
 pixels_y = 128
 memory_cells = 352
+num_frames = memory_cells
 
 raw_data = np.random.randint(
     low=0, high=2000, size=(memory_cells, 2, pixels_x, pixels_y), dtype=np.uint16
@@ -23,29 +27,26 @@ raw_gain = raw_data[:, 1]
 cell_table = np.arange(memory_cells, dtype=np.uint16)
 np.random.shuffle(cell_table)
 
-caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store/xfel/cal")
-caldb_prefix = caldb_store / "agipd-type/agipd_siv1_agipdv11_m305"
-
-with h5py.File(caldb_prefix / "cal.1619543695.4679213.h5", "r") as fd:
-    thresholds = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0/data"])
-with h5py.File(caldb_prefix / "cal.1619543664.1545036.h5", "r") as fd:
-    offset_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/Offset/0/data"])
-with h5py.File(caldb_prefix / "cal.1615377705.8904035.h5", "r") as fd:
-    slopes_pc_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0/data"])
-
-kernel_runners = [
-    runner_class(
-        pixels_x,
-        pixels_y,
-        memory_cells,
-        constant_memory_cells=memory_cells,
-        output_data_dtype=output_dtype,
-    )
-    for runner_class in (AgipdCpuRunner, AgipdGpuRunner)
-]
+thresholds = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1619543695.4679213.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0",
+)
+offset_map = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1619543664.1545036.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/Offset/0",
+)
+slopes_pc_map = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1615377705.8904035.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0",
+)
 
+# first compute things naively
 
-def thresholding_cpu(data, cell_table, thresholds):
+
+def thresholding_naive(data, cell_table, thresholds):
     # get to memory_cell, x, y
     raw_gain = data[:, 1, ...].astype(corr_dtype, copy=False)
     # get to threshold, memory_cell, x, y
@@ -56,16 +57,13 @@ def thresholding_cpu(data, cell_table, thresholds):
     return res
 
 
-gain_map_cpu = thresholding_cpu(raw_data, cell_table, thresholds)
-
-
-def corr_offset_cpu(data, cell_table, gain_map, offset):
+def corr_offset_naive(data, cell_table, gain_map, offset):
     image_data = data[:, 0].astype(corr_dtype, copy=False)
     offset = np.transpose(offset)[:, cell_table]
     return (image_data - np.choose(gain_map, offset)).astype(output_dtype)
 
 
-def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
+def corr_rel_gain_pc_naive(data, cell_table, gain_map, slopes_pc):
     slopes_pc = slopes_pc.astype(np.float32, copy=False)
     pc_high_m = slopes_pc[0]
     pc_high_I = slopes_pc[1]
@@ -77,7 +75,7 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
     rel_gain_map[0] = 1  # rel xray gain can come after
     rel_gain_map[1] = rel_gain_map[0] * np.transpose(frac_high_med)
     rel_gain_map[2] = rel_gain_map[1] * 4.48
-    res = data[:, 0].astype(corr_dtype, copy=True)
+    res = data.astype(corr_dtype, copy=True)
     res *= np.choose(gain_map, np.transpose(rel_gain_map, (0, 3, 1, 2)))
     pixels_in_medium_gain = gain_map == 1
     res[pixels_in_medium_gain] += np.transpose(md_additional_offset, (0, 2, 1))[
@@ -86,35 +84,90 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
     return res
 
 
-def test_thresholding():
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.correct(np.uint8(CorrectionFlags.THRESHOLD))
-    assert np.allclose(kernel_runners[0].gain_map, gain_map_cpu)
-    assert np.allclose(kernel_runners[1].gain_map.get(), gain_map_cpu)
-
-
-def test_offset():
-    reference = corr_offset_cpu(raw_data, cell_table, gain_map_cpu, offset_map)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.load_constant(Constants.Offset, offset_map)
-        # have to do thresholding, otherwise all is treated as high gain
-        runner.correct(np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET))
-        res = runner.reshape(runner._corrected_axis_order)
-        assert np.allclose(res, reference), f"{runner.__class__}"
-
-
-def test_rel_gain_pc():
-    reference = corr_rel_gain_pc_cpu(raw_data, cell_table, gain_map_cpu, slopes_pc_map)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.load_constant(Constants.SlopesPC, slopes_pc_map)
-        runner.correct(
-            np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.REL_GAIN_PC)
+gain_map_naive = thresholding_naive(raw_data, cell_table, thresholds)
+offset_corrected_naive = corr_offset_naive(
+    raw_data, cell_table, gain_map_naive, offset_map
+)
+gain_corrected_naive = corr_rel_gain_pc_naive(
+    offset_corrected_naive, cell_table, gain_map_naive, slopes_pc_map
+)
+
+
+@pytest.fixture(params=[AgipdCpuRunner, AgipdGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(AgipdBaseRunner, AgipdCalcatFriend)
+    runner = request.param(device)
+    # TODO: also test ASIC seam masking
+    runner.reconfigure(
+        Hash(
+            "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE",
+            False,
         )
-        res = runner.reshape(runner._corrected_axis_order)
-        assert np.allclose(res, reference)
+    )
+    return runner
+
+
+# then start testing individually
+
+
+def test_thresholding(kernel_runner):
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    # AgipdBaseRunner's buffers don't depend on flags
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output_gain), gain_map_naive.astype(np.float32))
+
+
+def test_offset(kernel_runner):
+    # have to also do thresholding, otherwise all is treated as high gain
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output), offset_corrected_naive)
+
+
+def test_rel_gain_pc(kernel_runner):
+    # similarly, do previous steps first
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    kernel_runner.load_constant(Constants.SlopesPC, slopes_pc_map)
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(
+            CorrectionFlags.THRESHOLD
+            | CorrectionFlags.OFFSET
+            | CorrectionFlags.REL_GAIN_PC
+        ),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output), gain_corrected_naive)
+
+
+def test_with_preview(kernel_runner):
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    (
+        _,
+        processed_data,
+        (preview_raw, preview_corrected, preview_raw_gain, preview_gain_stage),
+    ) = kernel_runner.correct(raw_data, cell_table)
+    assert np.allclose(maybe_get(processed_data), offset_corrected_naive)
+    assert np.allclose(raw_data[:, 0], maybe_get(preview_raw))
+    assert np.allclose(offset_corrected_naive, maybe_get(preview_corrected))
+    assert np.allclose(raw_data[:, 1], maybe_get(preview_raw_gain))
diff --git a/tests/test_dssc_kernels.py b/tests/test_dssc_kernels.py
index 6ab37a8c0cadf40175b8a6acfede720d9651477a..f7df8ef9470460a28a57c2f85b60c3db1df3f3c5 100644
--- a/tests/test_dssc_kernels.py
+++ b/tests/test_dssc_kernels.py
@@ -1,29 +1,34 @@
 import numpy as np
 import pytest
-
 from calng.corrections.DsscCorrection import (
-    CorrectionFlags,
+    Constants,
+    DsscBaseRunner,
+    DsscCalcatFriend,
     DsscCpuRunner,
     DsscGpuRunner,
 )
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, maybe_get
 
 output_dtype = np.float32
 corr_dtype = np.float32
 pixels_x = 512
 pixels_y = 128
 memory_cells = 400
+num_frames = memory_cells
 offset_map = (
     np.random.random(size=(pixels_x, pixels_y, memory_cells)).astype(corr_dtype) * 20
 )
 cell_table = np.arange(memory_cells, dtype=np.uint16)
 np.random.shuffle(cell_table)
 raw_data = np.random.randint(
-    low=0, high=2000, size=(memory_cells, pixels_y, pixels_x), dtype=np.uint16
+    low=0, high=2000, size=(memory_cells, 1, pixels_y, pixels_x), dtype=np.uint16
 )
 
 
 # TODO: gather CPU implementations elsewhere
-def correct_cpu(data, cell_table, offset_map):
+def correct_naive(data, cell_table, offset_map):
     corr = np.squeeze(data).astype(corr_dtype, copy=True)
     safe_cell_bool = cell_table < offset_map.shape[-1]
     safe_cell_index = cell_table[safe_cell_bool]
@@ -31,127 +36,76 @@ def correct_cpu(data, cell_table, offset_map):
     return corr.astype(output_dtype, copy=False)
 
 
-corrected_data = correct_cpu(raw_data, cell_table, offset_map)
+corrected_data = correct_naive(raw_data, cell_table, offset_map)
 only_cast_data = np.squeeze(raw_data).astype(output_dtype)
 
 
-kernel_runners = [
-    runner_class(
-        pixels_x,
-        pixels_y,
-        memory_cells,
-        constant_memory_cells=memory_cells,
-        output_data_dtype=output_dtype,
-    )
-    for runner_class in (DsscCpuRunner, DsscGpuRunner)
-]
-
-
-def test_only_cast():
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.correct(CorrectionFlags.NONE)
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), only_cast_data)
-
-
-def test_correct():
-    for runner in kernel_runners:
-        runner.load_offset_map(offset_map)
-        runner.load_data(raw_data, cell_table)
-        runner.correct(CorrectionFlags.OFFSET)
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), corrected_data)
-
-
-def test_correct_oob_cells():
-    wild_cell_table = cell_table * 2
-    reference = correct_cpu(raw_data, wild_cell_table, offset_map)
-    for runner in kernel_runners:
-        runner.load_offset_map(offset_map)
-        # here, half the cell IDs will be out of bounds
-        runner.load_data(raw_data, wild_cell_table)
-        # should not crash
-        runner.correct(CorrectionFlags.OFFSET)
-        # should correct as much as possible
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), reference)
-
-
-def test_reshape():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        assert np.allclose(
-            runner.reshape(output_order="xyf"), corrected_data.transpose()
-        )
-
-
-def test_preview_slice():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(42)
-        assert np.allclose(
-            preview_raw,
-            raw_data[42].astype(np.float32),
-        )
-        assert np.allclose(
-            preview_corrected,
-            corrected_data[42].astype(np.float32),
-        )
+@pytest.fixture(params=[DsscCpuRunner, DsscGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(DsscBaseRunner, DsscCalcatFriend)
+    runner = request.param(device)
+    device.runner = runner
+    return runner
 
 
-def test_preview_max():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        # note: in case correction failed, still test this separately
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-1)
-        assert np.allclose(preview_raw, np.max(raw_data, axis=0).astype(np.float32))
-        assert np.allclose(
-            preview_corrected, np.max(corrected_data, axis=0).astype(np.float32)
-        )
-
-
-def test_preview_mean():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-2)
-        assert np.allclose(preview_raw, np.nanmean(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nanmean(corrected_data, axis=0, dtype=np.float32)
-        )
-
-
-def test_preview_sum():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-3)
-        assert np.allclose(preview_raw, np.nansum(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nansum(corrected_data, axis=0, dtype=np.float32)
+def test_only_cast(kernel_runner):
+    assert raw_data.shape == kernel_runner.expected_input_shape(num_frames)
+    # note: previews are just buffers, not yet reduced
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), only_cast_data)
+    assert np.allclose(maybe_get(preview_raw), only_cast_data)
+    assert np.allclose(maybe_get(preview_corrected), only_cast_data)
+
+
+def test_correct(kernel_runner):
+    for preview in ("raw", "corrected"):
+        kernel_runner._device.reconfigure(
+            Hash(
+                f"preview.{preview}.index",
+                0,
+                f"preview.{preview}.selectionMode",
+                "frame",
+            )
         )
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), corrected_data)
+    assert np.allclose(maybe_get(preview_raw), raw_data[:, 0])
+    assert np.allclose(maybe_get(preview_corrected), corrected_data)
 
-
-def test_preview_std():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-4)
-        assert np.allclose(preview_raw, np.nanstd(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nanstd(corrected_data, axis=0, dtype=np.float32)
+    kernel_runner._device.reconfigure(Hash("corrections.offset.enable", False))
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_corrected), corrected_data)
+
+    kernel_runner._device.reconfigure(
+        Hash(
+            "corrections.offset.enable",
+            True,
+            "corrections.offset.preview",
+            False,
         )
+    )
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), corrected_data)
+    assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_corrected), np.squeeze(raw_data))
 
 
-def test_preview_valid_index():
-    for runner in kernel_runners:
-        with pytest.raises(ValueError):
-            runner.compute_previews(-5)
-        with pytest.raises(ValueError):
-            runner.compute_previews(memory_cells)
+def test_correct_oob_cells(kernel_runner):
+    wild_cell_table = cell_table * 2
+    reference = correct_naive(raw_data, wild_cell_table, offset_map)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    _, processed_data, previews = kernel_runner.correct(
+        raw_data, wild_cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), reference)
diff --git a/tests/test_jungfrau_kernels.py b/tests/test_jungfrau_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f18794ce17fa0c4028e681256a09af49f5ec40
--- /dev/null
+++ b/tests/test_jungfrau_kernels.py
@@ -0,0 +1,108 @@
+import numpy as np
+import pytest
+from calng.corrections.JungfrauCorrection import (
+    Constants,
+    JungfrauBaseRunner,
+    JungfrauCalcatFriend,
+    JungfrauCpuRunner,
+    JungfrauGpuRunner,
+)
+from calng.utils import BadPixelValues
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get
+
+
+raw_data = np.rint(np.random.random((16, 512, 1024)) * 5000).astype(np.uint16)
+gain = np.random.choice(
+    np.uint8([0, 1, 3]), size=raw_data.shape
+)  # not testing gain value 2 (wrong) for now
+gain_as_index = gain.copy()
+gain_as_index[gain_as_index == 3] = 2
+constants = {
+    Constants.Offset10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1714364366.1478188.h5",
+        "/Jungfrau_M572/Offset10Hz/0",
+    ),
+    Constants.BadPixelsDark10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1714364369.7054155.h5",
+        "/Jungfrau_M572/BadPixelsDark10Hz/0",
+    ),
+    Constants.BadPixelsFF10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1711037632.1430287.h5",
+        "/Jungfrau_M572/BadPixelsFF10Hz/0",
+    ),
+    Constants.RelativeGain10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1711036758.5906427.h5",
+        "/Jungfrau_M572/RelativeGain10Hz/0",
+    ),
+}
+cell_table = np.arange(16, dtype=np.uint8)
+np.random.shuffle(cell_table)
+cast_data = raw_data.astype(np.float32)
+offset_corrected = cast_data - np.choose(
+    gain_as_index,
+    np.transpose(constants[Constants.Offset10Hz][:, :, cell_table], (3, 2, 0, 1)),
+)
+gain_corrected = offset_corrected / np.choose(
+    gain_as_index, constants[Constants.RelativeGain10Hz][:, :, cell_table].T
+)
+masked_dark = gain_corrected.copy()
+masked_dark[
+    np.choose(
+        gain_as_index,
+        np.transpose(constants[Constants.BadPixelsDark10Hz], (3, 2, 0, 1))[
+            :, cell_table
+        ],
+    )
+    != 0
+] = np.nan
+masked_ff = masked_dark.copy()
+masked_ff[
+    np.choose(
+        gain_as_index,
+        np.transpose(constants[Constants.BadPixelsFF10Hz], (3, 2, 1, 0))[:, cell_table],
+    )
+    != 0
+] = np.nan
+
+
+@pytest.fixture(params=[JungfrauCpuRunner, JungfrauGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(JungfrauBaseRunner, JungfrauCalcatFriend)
+    runner = request.param(device)
+    device.runner = runner
+    # TODO: test masking aroudn ASIC seams
+    # TODO: test single-frame mode
+    device.reconfigure(
+        Hash(
+            "constantParameters.memoryCells",
+            16,
+        )
+    )
+    device.reconfigure(
+        Hash(
+            "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE",
+            False,
+        )
+    )
+    return runner
+
+
+def test_cast(kernel_runner):
+    _, result, previews = kernel_runner.correct(raw_data, cell_table, gain)
+    result = maybe_get(result)
+    assert np.allclose(result, cast_data, equal_nan=True)
+
+
+def test_correct(kernel_runner):
+    for constant, constant_data in constants.items():
+        kernel_runner.load_constant(constant, constant_data)
+
+    _, result, previews = kernel_runner.correct(raw_data, cell_table, gain)
+    result = maybe_get(result)
+    assert np.allclose(result, masked_ff, equal_nan=True)
diff --git a/tests/test_lpd_kernels.py b/tests/test_lpd_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf51fb209a62d32bd7929c59af30464f7e2a8d2c
--- /dev/null
+++ b/tests/test_lpd_kernels.py
@@ -0,0 +1,98 @@
+import numpy as np
+import pytest
+from calng.corrections.LpdCorrection import (
+    Constants,
+    LpdBaseRunner,
+    LpdCalcatFriend,
+    LpdCpuRunner,
+    LpdGpuRunner,
+)
+
+from common_setup import DummyCorrectionDevice, maybe_get
+
+
+@pytest.fixture(params=[LpdCpuRunner, LpdGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return request.param(device)
+
+
+@pytest.fixture
+def cpu_runner():
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return LpdCpuRunner(device)
+
+
+@pytest.fixture
+def gpu_runner():
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return LpdGpuRunner(device)
+
+
+# generate some "raw data"
+raw_data_shape = (100, 1, 256, 256)
+# the image data part
+raw_image = (np.arange(np.product(raw_data_shape)).reshape(raw_data_shape)).astype(
+    np.uint16
+)
+# keeping in mind 12-bit ADC
+raw_image = np.bitwise_and(raw_image, 0x0FFF)
+raw_image_32 = raw_image.astype(np.float32)
+# the gain values (TODO: test with some being illegal)
+gain_data = (np.random.random(raw_data_shape) * 3).astype(np.uint16)
+assert np.array_equal(gain_data, np.bitwise_and(gain_data, 0x0003))
+gain_data_32 = gain_data.astype(np.float32)
+raw_data = np.bitwise_or(raw_image, np.left_shift(gain_data, 12))
+# just checking that we constructed something reasonable
+assert np.array_equal(gain_data, np.bitwise_and(np.right_shift(raw_data, 12), 0x0003))
+assert np.array_equal(raw_image, np.bitwise_and(raw_data, 0x0FFF))
+
+# generate some defects
+wrong_gain_value_indices = (np.random.random(raw_data_shape) < 0.005).astype(np.bool_)
+raw_data_with_some_wrong_gain = raw_data.copy()
+raw_data_with_some_wrong_gain[wrong_gain_value_indices] = np.bitwise_or(
+    raw_data_with_some_wrong_gain[wrong_gain_value_indices], 0x3000
+)
+
+# generate some constants
+dark_constant_shape = (256, 256, 512, 3)
+# okay, as slow and fast scan have same size, can't tell the difference here...
+funky_constant_shape = (256, 256, 512, 3)
+offset_constant = np.random.random(dark_constant_shape).astype(np.float32)
+gain_amp_constant = np.random.random(funky_constant_shape).astype(np.float32)
+bp_dark_constant = (np.random.random(dark_constant_shape) < 0.01).astype(np.uint32)
+
+cell_table = np.linspace(0, 512, 100).astype(np.uint16)
+np.random.shuffle(cell_table)
+
+
+def test_only_cast(kernel_runner):
+    _, processed, previews = kernel_runner.correct(raw_data, cell_table)
+    assert np.allclose(maybe_get(processed), raw_image_32[:, 0])
+    # TODO: make raw preview show image data (remove gain bits)
+    assert np.allclose(previews[0], raw_data[:, 0])  # raw: raw
+    assert np.allclose(previews[1], raw_image_32[:, 0])  # processed: raw
+    assert np.allclose(previews[2], gain_data_32[:, 0])  # gain: unchanged
+
+
+def test_only_cast_with_some_wrong_gain(kernel_runner):
+    (
+        _,
+        processed,
+        _,
+    ) = kernel_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    assert np.all(np.isnan(processed[wrong_gain_value_indices[:, 0]]))
+
+
+def test_correct(cpu_runner, gpu_runner):
+    # naive numpy version way too slow, just compare the two runners
+    # the CPU one uses kernel almost identical to pycalibration
+    cpu_runner.load_constant(Constants.Offset, offset_constant)
+    cpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant)
+    cpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant)
+    gpu_runner.load_constant(Constants.Offset, offset_constant)
+    gpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant)
+    gpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant)
+    _, processed_cpu, _ = cpu_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    _, processed_gpu, _ = gpu_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    assert np.allclose(processed_cpu, processed_gpu.get(), equal_nan=True)
diff --git a/tests/test_pnccd_kernels.py b/tests/test_pnccd_kernels.py
index 18f438fa7de2ce7cd77bb45245726689d2508563..35da5f926acec7e3a3bf80298025729f2b4922f2 100644
--- a/tests/test_pnccd_kernels.py
+++ b/tests/test_pnccd_kernels.py
@@ -1,50 +1,78 @@
+import itertools
+
 import numpy as np
+import pytest
 from calng import utils
-from calng.corrections.PnccdCorrection import Constants, CorrectionFlags, PnccdCpuRunner
+from calng.corrections.PnccdCorrection import (
+    Constants,
+    PnccdCalcatFriend,
+    PnccdCpuRunner,
+)
+from karabo.bound import Hash
 
+from common_setup import DummyCorrectionDevice
+
+
+@pytest.fixture
+def kernel_runner():
+    device = DummyCorrectionDevice(PnccdCpuRunner, PnccdCalcatFriend)
+    runner = PnccdCpuRunner(device)
+    device.runner = runner
+    return runner
 
-def test_common_mode():
-    def slow_common_mode(
-        data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs
-    ):
-        masked = np.ma.masked_array(
-            data=data,
-            mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma),
-        )
 
-        if cm_fs:
-            subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac
-            data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True)
+def slow_common_mode(
+    data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs
+):
+    masked = np.ma.masked_array(
+        data=data,
+        mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma),
+    )
 
-        if cm_ss:
-            subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac
-            data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0)
+    if cm_fs:
+        subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac
+        data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True)
 
+    if cm_ss:
+        subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac
+        data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0)
+
+
+def test_common_mode(kernel_runner):
     data_shape = (1024, 1024)
     data = (np.random.random(data_shape) * 1000).astype(np.uint16)
-    runner = PnccdCpuRunner(1024, 1024, 1, 1)
-    for noise_level in (0, 5, 100, 1000):
-        noise = (np.random.random(data_shape) * noise_level).astype(np.float32)
-        runner.load_constant(Constants.NoiseCCD, noise)
-        for bad_pixel_percentage in (0, 1, 20, 80, 100):
-            bad_pixels = (
-                np.random.random(data_shape) < (bad_pixel_percentage / 100)
-            ).astype(np.uint32)
-            runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixels)
-            runner.load_data(data.copy())
-            runner.correct(
-                CorrectionFlags.COMMONMODE,
-                cm_min_frac=0.25,
-                cm_noise_sigma=10,
-                cm_row=True,
-                cm_col=True,
-            )
-
-            slow_version = data.astype(np.float32, copy=True)
-            for q_data, q_noise, q_bad_pixels in zip(
-                utils.quadrant_views(slow_version),
-                utils.quadrant_views(noise),
-                utils.quadrant_views(bad_pixels),
-            ):
-                slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True)
-            assert np.allclose(runner.processed_data, slow_version, equal_nan=True)
+    noise_maps = [
+        (np.random.random(data_shape) * noise_level).astype(np.float32)
+        for noise_level in (0, 5, 100, 1000)
+    ]
+    bad_pixel_maps = [
+        (np.random.random(data_shape) < (bad_pixel_percentage / 100)).astype(np.uint32)
+        for bad_pixel_percentage in (0, 1, 20, 80, 100)
+    ]
+    kernel_runner._device.reconfigure(
+        Hash(
+            "corrections.commonMode.noiseSigma",
+            10,
+            "corrections.commonMode.minFrac",
+            0.25,
+            "corrections.commonMode.enableRow",
+            True,
+            "corrections.commonMode.enableCol",
+            True,
+            "corrections.badPixels.enable",
+            False,
+        )
+    )
+    for noise_map, bad_pixel_map in itertools.product(noise_maps, bad_pixel_maps):
+        kernel_runner.load_constant(Constants.NoiseCCD, noise_map)
+        kernel_runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixel_map)
+        _, processed_data, _ = kernel_runner.correct(data, None)
+
+        slow_version = data.astype(np.float32, copy=True)
+        for q_data, q_noise, q_bad_pixels in zip(
+            utils.quadrant_views(slow_version),
+            utils.quadrant_views(noise_map),
+            utils.quadrant_views(bad_pixel_map),
+        ):
+            slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True)
+        assert np.allclose(processed_data, slow_version, equal_nan=True)
diff --git a/tests/test_preview_utils.py b/tests/test_preview_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..043d7ac6704b4384fb18a2152201984ae64d0969
--- /dev/null
+++ b/tests/test_preview_utils.py
@@ -0,0 +1,160 @@
+import numpy as np
+import pytest
+from karabo.bound import Hash
+from calng.preview_utils import PreviewFriend, PreviewSpec
+
+from common_setup import DummyCorrectionDevice
+
+
+cell_table = np.array([0, 2, 1], dtype=np.uint16)
+pulse_table = np.array([1, 2, 0], dtype=np.uint16)
+raw_data = np.array(
+    [
+        [1.0, 2.0, 3.0],
+        [4.0, np.nan, 5.0],
+        [6.0, 7.0, 8.0],
+    ],
+    dtype=np.float32,
+)
+raw_without_nan = raw_data.copy()
+raw_without_nan[~np.isfinite(raw_data)] = 0
+
+
+@pytest.fixture
+def device():
+    output_specs = [
+        PreviewSpec(
+            "dummyPreview",
+            dimensions=1,
+            wrap_in_imagedata=False,
+            frame_reduction=True,
+        )
+    ]
+    device = DummyCorrectionDevice(
+        {"friend_class": PreviewFriend, "outputs": output_specs}
+    )
+
+    def writeChannel(name, hsh, **kwargs):
+        if not hasattr(device, "_written"):
+            device._written = []
+        device._written.append(("name", hsh))
+
+    device.writeChannel = writeChannel
+    device.preview_friend = PreviewFriend(device, output_specs)
+
+    return device
+
+
+def test_preview_frame(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            1,
+            "dummyPreview.selectionMode",
+            "frame",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    assert hasattr(device, "_written")
+    assert len(device._written) == 1
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[1])
+
+
+def test_preview_cell(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            1,
+            "dummyPreview.selectionMode",
+            "cell",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[2])
+
+
+def test_preview_pulse(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            2,
+            "dummyPreview.selectionMode",
+            "pulse",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[1])
+
+
+def test_preview_max(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -1,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], np.nanmax(raw_data, axis=0))
+
+
+def test_preview_mean(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -2,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nanmean(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+def test_preview_sum(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -3,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nansum(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+def test_preview_std(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -4,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nanstd(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+# TODO: also test alternative preview index selection modes