From a0928f80b6de182dd249dacb223042ae1d0c57fb Mon Sep 17 00:00:00 2001
From: David Hammer <david.hammer@xfel.eu>
Date: Thu, 4 Jul 2024 14:40:33 +0200
Subject: [PATCH] Correction runners as friends, major refactor

Correction runners get more responsibility for handling reconfiguration subhashes. Also, they don't get killed and replaced, but can respond to more changes, including number of frames.
Previews are overhauled, with more code moved to preview_utils and each individual preview output (name of OutputChannels change) has its own preview reduction settings.
More tests added all around.
---
 DEPENDS                                      |    2 +-
 docs/devices.md                              |    4 +-
 src/calng/CalibrationManager.py              |   49 +-
 src/calng/DetectorAssembler.py               |   32 +-
 src/calng/LpdminiSplitter.py                 |    2 +-
 src/calng/base_correction.py                 | 1012 +++++-------------
 src/calng/base_kernel_runner.py              |  451 ++++++--
 src/calng/corrections/AgipdCorrection.py     |  981 ++++++++---------
 src/calng/corrections/DsscCorrection.py      |  272 ++---
 src/calng/corrections/Epix100Correction.py   |  422 +++-----
 src/calng/corrections/Gotthard2Correction.py |  478 ++++-----
 src/calng/corrections/JungfrauCorrection.py  |  587 ++++------
 src/calng/corrections/LpdCorrection.py       |  463 +++-----
 src/calng/corrections/LpdminiCorrection.py   |   21 +-
 src/calng/corrections/PnccdCorrection.py     |  426 +++-----
 src/calng/kernels/agipd_gpu.cu               |   90 +-
 src/calng/kernels/common_gpu.cu              |   60 ++
 src/calng/kernels/dssc_cpu.pyx               |   14 +-
 src/calng/kernels/dssc_gpu.cu                |   36 +-
 src/calng/kernels/jungfrau_cpu.pyx           |   42 +-
 src/calng/kernels/jungfrau_gpu.cu            |   69 +-
 src/calng/kernels/lpd_cpu.pyx                |   61 +-
 src/calng/kernels/lpd_gpu.cu                 |   46 +-
 src/calng/preview_utils.py                   |  415 +++++--
 src/calng/scenes.py                          |  225 ++--
 src/calng/schemas.py                         |   12 +-
 src/calng/stacking_utils.py                  |    2 -
 src/calng/utils.py                           |  123 +--
 tests/common_setup.py                        |  100 ++
 tests/test_agipd_kernels.py                  |  175 +--
 tests/test_dssc_kernels.py                   |  190 ++--
 tests/test_jungfrau_kernels.py               |  108 ++
 tests/test_lpd_kernels.py                    |   98 ++
 tests/test_pnccd_kernels.py                  |  110 +-
 tests/test_preview_utils.py                  |  160 +++
 35 files changed, 3559 insertions(+), 3779 deletions(-)
 create mode 100644 src/calng/kernels/common_gpu.cu
 create mode 100644 tests/common_setup.py
 create mode 100644 tests/test_jungfrau_kernels.py
 create mode 100644 tests/test_lpd_kernels.py
 create mode 100644 tests/test_preview_utils.py

diff --git a/DEPENDS b/DEPENDS
index 6dbf7d21..53d42989 100644
--- a/DEPENDS
+++ b/DEPENDS
@@ -1,4 +1,4 @@
 TrainMatcher, 2.4.5
 calibrationClient, 11.3.0
 calibration/geometryDevices, 0.0.2
-calibration/calngUtils, 0.0.3
+calibration/calngUtils, 0.0.4
diff --git a/docs/devices.md b/docs/devices.md
index bbdaa9b8..242caf3a 100644
--- a/docs/devices.md
+++ b/docs/devices.md
@@ -50,15 +50,13 @@ Therefore, settings related to how this is done should typically be managed via
 The parameters seen in the "Preview" box on the manager overview scene control:
 
 - How to select which frame to preview (for non-negative indices)
-    - See [Index selection mode](schemas/BaseCorrection.md#preview.selectionMode)
 	- Frame can be extracted directly from image ndarray (`frame` mode), by looking up corresponding frame in cell table (`cell` mode), or by looking up in pulse table (`pulse` mode)
     - For XTDF detectors, the mapping tables are typically `image.cellId` and `image.pulseId`.
 	  Which selection mode makes sense depends on detector, veto pattern, and experimental setup.
     - If the specified cell or pulse is not found in the respective table, a warning is issued and the first frame is sent instead.
 - Which frame or statistic to send
-    - See [Index for preview](schemas/BaseCorrection.md#preview.index)
     - If the index is non-negative, a single frame is sliced from the image data.
-	  How this frame is found depends on the the [index selection mode](schemas/BaseCorrection.md#preview.selectionMode).
+	  How this frame is found depends on the the index selection mode (previous point).
     - If the index is negative, a statistic is computed across all frames for a summary preview.
       Note that this is done individually per pixel and that `NaN` values are ignored.
       Options are:
diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py
index 8de76153..4ae0ed05 100644
--- a/src/calng/CalibrationManager.py
+++ b/src/calng/CalibrationManager.py
@@ -18,7 +18,7 @@ import re
 from tornado.httpclient import AsyncHTTPClient, HTTPError
 from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
 
-from calngUtils import scene_utils
+from calngUtils.scene_utils import recursive_subschema_scene
 from karabo.middlelayer import (
     KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable,
     Slot, Node, Type, Schema, ProxyFactory,
@@ -158,8 +158,8 @@ class PreviewLayerRow(Configurable):
         defaultValue='',
         accessMode=AccessMode.RECONFIGURABLE)
 
-    outputPipeline = String(
-        displayedName='Output pipeline',
+    previewName = String(
+        displayedName='Preview name',
         defaultValue='',
         accessMode=AccessMode.RECONFIGURABLE)
 
@@ -307,7 +307,7 @@ class CalibrationManager(DeviceClientBase, Device):
                 prefix = name[len('browse_schema:'):]
             else:
                 prefix = 'managedKeys'
-            scene_data = scene_utils.recursive_subschema_scene(
+            scene_data = recursive_subschema_scene(
                 self.deviceId,
                 self.getDeviceSchema(),
                 prefix,
@@ -382,12 +382,6 @@ class CalibrationManager(DeviceClientBase, Device):
         accessMode=AccessMode.RECONFIGURABLE,
         assignment=Assignment.MANDATORY)
 
-    previewLayers = VectorHash(
-        displayedName='Preview layers',
-        rows=PreviewLayerRow,
-        accessMode=AccessMode.RECONFIGURABLE,
-        assignment=Assignment.MANDATORY)
-
     @VectorHash(
         displayedName='Device servers',
         description='WARNING: It is strongly recommended to perform a full '
@@ -402,6 +396,21 @@ class CalibrationManager(DeviceClientBase, Device):
         # Switch to UNKNOWN state to suggest operator to restart pipeline.
         self.state = State.UNKNOWN
 
+    previewServer = String(
+        displayedName='Preview device server',
+        description='The server (from "Device servers" list) to use for '
+                    'preview assemblers',
+        accessMode=AccessMode.RECONFIGURABLE,
+        assignment=Assignment.MANDATORY)
+
+    previewLayers = VectorString(
+        displayedName='Preview layers',
+        description='List of previews (like raw, corrected, and anything else '
+                    'detector-specific) to show. Automatically populated based '
+                    'on correction device schema, just listed for debugging.',
+        defaultValue=[],
+        accessMode=AccessMode.READONLY)
+
     geometryDevice = String(
         displayedName='Geometry device',
         description='Device ID for a geometry device defining the detector '
@@ -598,6 +607,10 @@ class CalibrationManager(DeviceClientBase, Device):
         # Inject schema for configuration of managed devices.
         await self._inject_managed_keys()
 
+        # Populate preview layers from correction device schema
+        self.previewLayers = list(self._correction_device_schema
+                                  .hash["preview"].getKeys())
+
         # Populate the device ID sets with what's out there right now.
         await self._check_topology()
 
@@ -1262,12 +1275,9 @@ class CalibrationManager(DeviceClientBase, Device):
         # Servers by group and layer.
         server_by_group = {group: server for group, server, _, _, _,
                            in self.moduleGroups.value}
-        server_by_layer = {layer: server for layer, _, server
-                           in self.previewLayers.value}
-
-        all_req_servers = set(server_by_group.values()).union(
-            server_by_layer.values())
+        preview_server = self.previewServer.value
 
+        all_req_servers = set(server_by_group.values()) | {preview_server}
         if all_req_servers != up_servers:
             return self._set_error('One or more device servers are not '
                                    'listed in the device servers '
@@ -1352,9 +1362,9 @@ class CalibrationManager(DeviceClientBase, Device):
             ))
 
         # Instantiate preview layer assemblers.
-        for layer, output_pipeline, server in self.previewLayers.value:
+        for preview_name in self.previewLayers.value:
             sources = [Hash('select', True,
-                            'source', f'{device_id}:{output_pipeline}')
+                            'source', f'{device_id}:preview.{preview_name}.output')
                        for (_, device_id)
                        in correct_device_id_by_module.items()]
             config = Hash('sources', sources,
@@ -1367,9 +1377,10 @@ class CalibrationManager(DeviceClientBase, Device):
                     config[remote_key] = value.value
 
             awaitables.append(self._instantiate_device(
-                server,
+                preview_server,
                 self._class_ids['assembler'],
-                self._device_id_templates['assembler'].format(layer=layer),
+                self._device_id_templates['assembler'].format(
+                    layer=preview_name.upper()),
                 config
             ))
 
diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py
index 1f3fb09d..b4f6c06d 100644
--- a/src/calng/DetectorAssembler.py
+++ b/src/calng/DetectorAssembler.py
@@ -27,7 +27,8 @@ from karabo.bound import (
 )
 from TrainMatcher import TrainMatcher
 
-from . import preview_utils, scenes, schemas
+from . import scenes, schemas
+from .preview_utils import PreviewFriend, PreviewSpec
 from ._version import version as deviceVersion
 
 
@@ -73,8 +74,12 @@ def module_index_schema():
 @device_utils.with_unsafe_get
 @KARABO_CLASSINFO("DetectorAssembler", deviceVersion)
 class DetectorAssembler(TrainMatcher.TrainMatcher):
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec("assembled", frame_reduction=False, flip_ss=True, flip_fs=True)
+    ]
+
+    @classmethod
+    def expectedParameters(cls, expected):
         (
             OVERWRITE_ELEMENT(expected)
             .key("availableScenes")
@@ -178,18 +183,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             .reconfigurable()
             .commit(),
         )
-        preview_utils.PreviewFriend.add_schema(expected, "preview", create_node=True)
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.flipSS")
-            .setNewDefaultValue(True)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.flipFS")
-            .setNewDefaultValue(True)
-            .commit(),
-        )
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
 
     def __init__(self, conf):
         super().__init__(conf)
@@ -204,7 +198,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
     def initialization(self):
         super().initialization()
 
-        self._preview_friend = preview_utils.PreviewFriend(self)
+        self._preview_friend = PreviewFriend(self, self._preview_outputs, "preview")
         self._image_data_path = self.get("imageDataPath")
         self._image_mask_path = self.get("imageMaskPath")
         self._geometry = None
@@ -469,8 +463,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             self.zmq_output.update()
 
         self._preview_friend.write_outputs(
-            my_timestamp,
-            np.ma.masked_array(data=assembled_data, mask=assembled_mask),
+            np.ma.masked_array(data=assembled_data, mask=assembled_mask)
         )
 
         self._processing_time_tracker.update(default_timer() - ts_start)
@@ -495,7 +488,8 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
         ):
             self._need_to_update_source_index_mapping = True
 
-        self._preview_friend.reconfigure(conf)
+        if conf.has("preview"):
+            self._preview_friend.reconfigure(conf["preview"])
 
     def postReconfigure(self):
         if self._need_to_update_source_index_mapping:
diff --git a/src/calng/LpdminiSplitter.py b/src/calng/LpdminiSplitter.py
index 5c32d679..d2fee866 100644
--- a/src/calng/LpdminiSplitter.py
+++ b/src/calng/LpdminiSplitter.py
@@ -105,7 +105,7 @@ class LpdminiSplitter(PythonDevice):
             (
                 OUTPUT_CHANNEL(expected)
                 .key(channel_name)
-                .dataSchema(schemas.xtdf_output_schema())
+                .dataSchema(schemas.xtdf_output_schema(use_shmem_handles=True))
                 .commit(),
             )
 
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 414b3342..a11f9b96 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -1,7 +1,7 @@
 import concurrent.futures
-import enum
+import contextlib
 import functools
-import gc
+import itertools
 import math
 import pathlib
 import threading
@@ -13,17 +13,16 @@ import numpy as np
 from geometryDevices import utils as geom_utils
 from calngUtils import (
     device as device_utils,
-    misc,
     scene_utils,
     shmem_utils,
     timing,
     trackers,
 )
+from calngUtils.misc import ChainHash
 from karabo.bound import (
     BOOL_ELEMENT,
     DOUBLE_ELEMENT,
     INPUT_CHANNEL,
-    INT32_ELEMENT,
     KARABO_CLASSINFO,
     NODE_ELEMENT,
     OUTPUT_CHANNEL,
@@ -45,96 +44,33 @@ from karabo.bound import (
 )
 from karabo.common.api import KARABO_SCHEMA_DISPLAY_TYPE_SCENES as DT_SCENES
 
-from . import preview_utils, schemas, scenes, utils
+from . import scenes
+from .preview_utils import PreviewFriend
+from .utils import StateContext, WarningContextSystem, WarningLampType, subset_of_hash
 from ._version import version as deviceVersion
 
 PROCESSING_STATE_TIMEOUT = 10
 
 
-class FramefilterSpecType(enum.Enum):
-    NONE = "none"
-    RANGE = "range"
-    COMMASEPARATED = "commaseparated"
-
-
-class WarningLampType(enum.Enum):
-    FRAME_FILTER = enum.auto()
-    MEMORY_CELL_RANGE = enum.auto()
-    CONSTANT_OPERATING_PARAMETERS = enum.auto()
-    PREVIEW_SETTINGS = enum.auto()
-    CORRECTION_RUNNER = enum.auto()
-    OUTPUT_BUFFER = enum.auto()
-    GPU_MEMORY = enum.auto()
-    CALCAT_CONNECTION = enum.auto()
-    EMPTY_HASH = enum.auto()
-    MISC_INPUT_DATA = enum.auto()
-    TRAIN_ID = enum.auto()
-    TIMESERVER_CONNECTION = enum.auto()
+class TrainFromTheFutureException(BaseException):
+    pass
 
 
+@device_utils.with_config_overlay
 @device_utils.with_unsafe_get
 @KARABO_CLASSINFO("BaseCorrection", deviceVersion)
 class BaseCorrection(PythonDevice):
     _available_addons = []  # classes, filled by add_addon_nodes using entry_points
     _base_output_schema = None  # subclass must set constructor
     _constant_enum_class = None  # subclass must set
-    _correction_flag_class = None  # subclass must set (ex.: dssc_gpu.CorrectionFlags)
     _correction_steps = None  # subclass must set
     _kernel_runner_class = None  # subclass must set (ex.: dssc_gpu.DsscGpuRunner)
-    _kernel_runner_init_args = {}  # optional extra args for runner
     _image_data_path = "image.data"  # customize for *some* subclasses
     _cell_table_path = "image.cellId"
     _pulse_table_path = "image.pulseId"
     _warn_memory_cell_range = True  # can be disabled for some detectors
     _cuda_pin_buffers = False
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        """Subclass must define how to process constants into correction maps and store
-        into appropriate buffers in (GPU or main) memory."""
-        # note: aim to refactor this away as kernel runner handles loading logic
-        raise NotImplementedError()
-
-    def _successfully_loaded_constant_to_runner(self, constant):
-        for field_name in self._constant_to_correction_names[constant]:
-            key = f"corrections.{field_name}.available"
-            if not self.unsafe_get(key):
-                self.set(key, True)
-
-        self._update_correction_flags()
-        self.log_status_info(f"Done loading {constant.name} to runner")
-
-    @property
-    def input_data_shape(self):
-        """Subclass must define expected input data shape in terms of dataFormat.{
-        memoryCells,pixelsX,pixelsY} and any other axes."""
-        raise NotImplementedError()
-
-    @property
-    def output_data_shape(self):
-        """Shape of corrected image data sent on dataOutput. Depends on data format
-        parameters (optionally including frame filter)."""
-        axis_lengths = {
-            "x": self.unsafe_get("dataFormat.pixelsX"),
-            "y": self.unsafe_get("dataFormat.pixelsY"),
-            "f": self.unsafe_get("dataFormat.filteredFrames"),
-        }
-        return tuple(
-            axis_lengths[axis] for axis in self.unsafe_get("dataFormat.outputAxisOrder")
-        )
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-    ):
-        """Subclass must define data processing (presumably using the kernel runner).
-        Will be called by input_handler, which will take care of some common checks and
-        extracting the parameters given to process_data."""
-        raise NotImplementedError()
+    _use_shmem_handles = True
 
     @staticmethod
     def expectedParameters(expected):
@@ -171,57 +107,6 @@ class BaseCorrection(PythonDevice):
             .defaultValue([])
             .commit(),
 
-            NODE_ELEMENT(expected)
-            .key("frameFilter")
-            .displayedName("Frame filter")
-            .description(
-                "The frame filter - if set - slices the input data. Frames not in the "
-                "filter will be discarded before any processing happens and will not "
-                "get to dataOutput or preview. Note that this filter goes by frame "
-                "index rather than cell ID or pulse ID; set accordingly. Handle with "
-                "care - an invalid filter can prevent all processing. How the filter "
-                "is specified depends on frameFilter.type. See frameFilter.current to "
-                "inspect the currently set frame filter array (if any)."
-            )
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("frameFilter.type")
-            .tags("managed")
-            .displayedName("Type")
-            .description(
-                "Controls how frameFilter.spec is used. The default value of 'none' "
-                "means that no filter is set (regardless of frameFilter.spec). "
-                "'arange' allows between one and three integers separated by ',' which "
-                "are parsed and passed directly to numpy.arange. 'commaseparated' "
-                "reads a list of integers separated by commas."
-            )
-            .options(",".join(spectype.value for spectype in FramefilterSpecType))
-            .assignmentOptional()
-            .defaultValue("none")
-            .reconfigurable()
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("frameFilter.spec")
-            .tags("managed")
-            .displayedName("Specification")
-            .assignmentOptional()
-            .defaultValue("")
-            .reconfigurable()
-            .commit(),
-
-            VECTOR_UINT32_ELEMENT(expected)
-            .key("frameFilter.current")
-            .displayedName("Current filter")
-            .description(
-                "This read-only value is used to display the contents of the current "
-                "frame filter. An empty array means no filtering is done."
-            )
-            .readOnly()
-            .initialValue([])
-            .commit(),
-
             UINT32_ELEMENT(expected)
             .key("outputShmemBufferSize")
             .tags("managed")
@@ -285,14 +170,6 @@ class BaseCorrection(PythonDevice):
             .displayedName("Data format (in/out)")
             .commit(),
 
-            STRING_ELEMENT(expected)
-            .key("dataFormat.inputImageDtype")
-            .displayedName("Input image data dtype")
-            .description("The (numpy) dtype to expect for incoming image data.")
-            .readOnly()
-            .initialValue("uint16")
-            .commit(),
-
             STRING_ELEMENT(expected)
             .key("dataFormat.outputImageDtype")
             .tags("managed")
@@ -304,41 +181,16 @@ class BaseCorrection(PythonDevice):
                 "causes truncation rather than rounding."
             )
             # TODO: consider adding rounding / binning for integer output
-            .options("float16,float32,uint16")
+            .options("float32,float16,uint16")
             .assignmentOptional()
             .defaultValue("float32")
             .reconfigurable()
             .commit(),
 
-            # important: determines shape of data as going into correction
             UINT32_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .displayedName("Pixels x")
-            .description("Number of pixels of image data along X axis")
-            .assignmentOptional()
-            .defaultValue(512)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .displayedName("Pixels y")
-            .description("Number of pixels of image data along Y axis")
-            .assignmentOptional()
-            .defaultValue(128)
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.frames")
+            .key("dataFormat.inputFrames")
             .displayedName("Frames")
-            .description("Number of image frames per train in incoming data")
-            .assignmentOptional()
-            .defaultValue(1)  # subclass will want to set a default value
-            .commit(),
-
-            UINT32_ELEMENT(expected)
-            .key("dataFormat.filteredFrames")
-            .displayedName("Frames after filter")
-            .description("Number of frames left after applying frame filter")
+            .description("Number of frames on input")
             .readOnly()
             .initialValue(0)
             .commit(),
@@ -349,13 +201,15 @@ class BaseCorrection(PythonDevice):
             .displayedName("Output axis order")
             .description(
                 "Axes of main data output can be reordered after correction. Axis "
-                "order is specified as string consisting of 'x', 'y', and 'f', with "
-                "the latter indicating the image frame. The default value of 'fxy' "
-                "puts pixels on the fast axes."
+                "order is specified as string consisting of 'f' (frames), 'ss' (slow "
+                "scan), and 'ff' (fast scan). Anything but the default f-ss-fs order "
+                "implies reordering - this axis ordering is based on the data we get "
+                "from the detector (/ receiver), so how ss and fs maps to what may be "
+                "considered x and y is detector-dependent."
             )
-            .options("fxy,fyx,xfy,xyf,yfx,yxf")
+            .options("f-ss-fs,f-fs-ss,ss-f-fs,ss-fs-f,fs-f-ss,fs-ss-f")
             .assignmentOptional()
-            .defaultValue("fxy")
+            .defaultValue("f-ss-fs")
             .reconfigurable()
             .commit(),
 
@@ -363,9 +217,9 @@ class BaseCorrection(PythonDevice):
             .key("dataFormat.inputDataShape")
             .displayedName("Input data shape")
             .description(
-                "Image data shape in incoming data (from reader / DAQ). This value is "
-                "computed from pixelsX, pixelsY, and frames - this field just "
-                "shows what is currently expected."
+                "Image data shape in incoming data (from reader / DAQ). Updated based "
+                "on latest train processed. Note that axis order may look wonky due to "
+                "DAQ quirk."
             )
             .readOnly()
             .initialValue([])
@@ -375,9 +229,10 @@ class BaseCorrection(PythonDevice):
             .key("dataFormat.outputDataShape")
             .displayedName("Output data shape")
             .description(
-                "Image data shape for data output from this device. This value is "
-                "computed from pixelsX, pixelsY, and the size of the frame filter - "
-                "this field just shows what is currently expected."
+                "Image data shape for data output from this device. Takes into account "
+                "axis reordering, if applicable. Even without that, "
+                "workarounds.overrideInputAxisOrder likely makes this differ from "
+                "dataFormat.inputDataShape."
             )
             .readOnly()
             .initialValue([])
@@ -472,50 +327,6 @@ class BaseCorrection(PythonDevice):
             .commit(),
         )
 
-        (
-            NODE_ELEMENT(expected)
-            .key("preview")
-            .displayedName("Preview")
-            .commit(),
-
-            INT32_ELEMENT(expected)
-            .key("preview.index")
-            .tags("managed")
-            .displayedName("Index (or stat) for preview")
-            .description(
-                "If this value is ≥ 0, the corresponding index (frame, cell, or pulse) "
-                "will be sliced for the preview output. If this value is < 0, preview "
-                "will be one of the following stats: -1: max, -2: mean, -3: sum, -4: "
-                "stdev. These stats are computed across frames, ignoring NaN values."
-            )
-            .assignmentOptional()
-            .defaultValue(0)
-            .minInc(-4)
-            .reconfigurable()
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .tags("managed")
-            .displayedName("Index selection mode")
-            .description(
-                "The value of preview.index can be used in multiple ways, controlled "
-                "by this value. If this is set to 'frame', preview.index is sliced "
-                "directly from data. If 'cell' (or 'pulse') is selected, I will look "
-                "at cell (or pulse) table for the requested cell (or pulse ID). "
-                "Special (stat) index values <0 are not affected by this."
-            )
-            .options(
-                ",".join(
-                    spectype.value for spectype in utils.PreviewIndexSelectionMode
-                )
-            )
-            .assignmentOptional()
-            .defaultValue("frame")
-            .reconfigurable()
-            .commit(),
-        )
-
         # just measurements and counters to display
         (
             UINT64_ELEMENT(expected)
@@ -618,24 +429,8 @@ class BaseCorrection(PythonDevice):
             config["dataOutput.hostname"] = ib_ip
         super().__init__(config)
 
-        self.output_data_dtype = np.dtype(config["dataFormat.outputImageDtype"])
-
         self.sources = set(config.get("fastSources"))
 
-        self.kernel_runner = None  # must call _update_buffers to initialize
-        self._shmem_buffer = None  # ditto
-        self._use_shmem_handles = config.get("useShmemHandles")
-        self._preview_friend = None
-
-        self._correction_flag_enabled = self._correction_flag_class.NONE
-        self._correction_flag_preview = self._correction_flag_class.NONE
-        self._correction_applied_hash = Hash()
-        # note: does not handle one constant enabling multiple correction steps
-        # (do we need to generalize "decide if this correction step is available"?)
-        self._constant_to_correction_names = {}
-        for (name, _, constants) in self._correction_steps:
-            for constant in constants:
-                self._constant_to_correction_names.setdefault(constant, set()).add(name)
         self._buffer_lock = threading.Lock()
         self._last_processing_started = 0  # used for processing time and timeout
         self._last_train_id_processed = 0  # used to keep track (and as fallback)
@@ -657,43 +452,36 @@ class BaseCorrection(PythonDevice):
                         constant_data = getattr(self.calcat_friend, friend_fun)(
                             constant
                         )
-                        self._load_constant_to_runner(constant, constant_data)
+                        self.kernel_runner.load_constant(constant, constant_data)
                     except Exception as e:
                         warn(f"Failed to load {constant}: {e}")
-                    else:
-                        self._successfully_loaded_constant_to_runner(constant)
                 # note: consider if multi-constant buffers need special care
             self._lock_and_update(aux)
 
-        for constant in self._constant_enum_class:
-            self.KARABO_SLOT(
-                functools.partial(
-                    constant_override_fun,
-                    friend_fun="get_constant_version",
-                    constant=constant,
-                    preserve_fields=set(),
+        for constant, (friend_fun, preserved_fields, slot_name) in itertools.product(
+            self._constant_enum_class,
+            (
+                ("get_constant_version", set(), "loadMostRecent"),
+                (
+                    "get_constant_from_constant_version_id",
+                    {"constantVersionId"},
+                    "overrideConstantFromVersion",
                 ),
-                slotName=f"foundConstants_{constant.name}_loadMostRecent",
-                numArgs=0,
-            )
-            self.KARABO_SLOT(
-                functools.partial(
-                    constant_override_fun,
-                    friend_fun="get_constant_from_constant_version_id",
-                    constant=constant,
-                    preserve_fields={"constantVersionId"},
+                (
+                    "get_constant_from_file",
+                    {"dataFilePath", "dataSetName"},
+                    "overrideConstantFromFile",
                 ),
-                slotName=f"foundConstants_{constant.name}_overrideConstantFromVersion",
-                numArgs=0,
-            )
+            ),
+        ):
             self.KARABO_SLOT(
                 functools.partial(
                     constant_override_fun,
-                    friend_fun="get_constant_from_file",
-                    constant=constant,
-                    preserve_fields={"dataFilePath", "dataSetName"},
+                    friend_fun,
+                    constant,
+                    preserved_fields,
                 ),
-                slotName=f"foundConstants_{constant.name}_overrideConstantFromFile",
+                slotName=f"foundConstants_{constant.name}_{slot_name}",
                 numArgs=0,
             )
 
@@ -704,30 +492,27 @@ class BaseCorrection(PythonDevice):
 
     def _initialization(self):
         self.updateState(State.INIT)
-        self.warning_context = utils.WarningContextSystem(
+        self.warning_context = WarningContextSystem(
             self,
             on_success={
                 f"foundConstants.{constant.name}.state": "ON"
                 for constant in self._constant_enum_class
             },
         )
-        self._preview_friend = preview_utils.PreviewFriend(
-            self,
-            output_channels=self._preview_outputs,
+        self.log.DEBUG("Opening shmem buffer")
+        self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
+            self.get("outputShmemBufferSize") * 2**30,
+            # TODO: have it just allocate memory, then set_shape later per train
+            (1,),
+            np.float32,
+            self.getInstanceId() + ":dataOutput",
         )
-        self["availableScenes"] = self["availableScenes"] + [
-            f"preview:{channel}" for channel in self._preview_outputs
-        ]
-
-        self._geometry = None
-        if self.get("geometryDevice"):
-            self.signalSlotable.connect(
-                self.get("geometryDevice"),
-                "signalNewGeometry",
-                "",  # slot device ID (default: self)
-                "slotReceiveGeometry",
-            )
+        if self._cuda_pin_buffers:
+            self.log.DEBUG("Trying to pin the shmem buffer memory")
+            self._shmem_buffer.cuda_pin()
+        self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver()
 
+        # CalCat friend comes before kernel runner
         with self.warning_context(
             "deviceInternalsState", WarningLampType.CALCAT_CONNECTION
         ) as warn:
@@ -739,18 +524,21 @@ class BaseCorrection(PythonDevice):
                 warn(f"Failed to connect to CalCat: {e}")
                 # TODO: add raw fallback mode if CalCat fails (raw data still useful)
                 return
+        self.kernel_runner = self._kernel_runner_class(self)
+        self.preview_friend = PreviewFriend(self, self._preview_outputs)
+        # TODO: can be static OVERWRITE_ELEMENT in expectedParameters
+        self["availableScenes"] = self["availableScenes"] + [
+            f"preview:{spec.name}" for spec in self._preview_outputs
+        ]
 
-        with self.warning_context(
-            "processingState", WarningLampType.FRAME_FILTER
-        ) as warn:
-            try:
-                self._frame_filter = _parse_frame_filter(self._parameters)
-            except (ValueError, TypeError):
-                warn("Failed to parse initial frame filter, will not use")
-                self._frame_filter = None
-        # TODO: decide how persistent initial warning should be
-        self._update_correction_flags()
-        self._update_frame_filter()
+        self.geometry = None
+        if self.get("geometryDevice"):
+            self.signalSlotable.connect(
+                self.get("geometryDevice"),
+                "signalNewGeometry",
+                "",  # slot device ID (default: self)
+                "slotReceiveGeometry",
+            )
 
         self._buffered_status_update = Hash()
         self._processing_time_tracker = trackers.ExponentialMovingAverage(
@@ -766,7 +554,7 @@ class BaseCorrection(PythonDevice):
         )
         self._train_ratio_tracker = trackers.TrainRatioTracker()
 
-        self.KARABO_ON_INPUT("dataInput", self.input_handler)
+        self.KARABO_ON_DATA("dataInput", self.input_handler)
         self.KARABO_ON_EOS("dataInput", self.handle_eos)
 
         self._enabled_addons = [
@@ -776,9 +564,14 @@ class BaseCorrection(PythonDevice):
         ]
         for addon in self._enabled_addons:
             addon._device = self
-        if self._enabled_addons:
+        if (
+            (self.get("useShmemHandles") != self._use_shmem_handles)
+            or self._enabled_addons
+        ):
             schema_override = Schema()
-            output_schema_override = self._base_output_schema
+            output_schema_override = self._base_output_schema(
+                use_shmem_handles=self.get("useShmemHandles")
+            )
             for addon in self._enabled_addons:
                 addon.extend_output_schema(output_schema_override)
             (
@@ -794,7 +587,7 @@ class BaseCorrection(PythonDevice):
         self.updateState(State.ON)
 
     def __del__(self):
-        del self._shmem_buffer
+        del self.shmem_buffer
         super().__del__()
 
     def preReconfigure(self, config):
@@ -812,64 +605,31 @@ class BaseCorrection(PythonDevice):
                     timestamp = dateutil.parser.isoparse(ts_string)
                     config.set(ts_path, timestamp.isoformat())
 
-        if config.has("constantParameters.deviceMappingSnapshotAt"):
-            self.calcat_friend.flush_pdu_mapping()
+        with self.config_overlay(config):
+            if config.has("constantParameters.deviceMappingSnapshotAt"):
+                self.calcat_friend.flush_pdu_mapping()
 
-        # update device based on changes
-        if config.has("frameFilter"):
-            self._frame_filter = _parse_frame_filter(
-                misc.ChainHash(config, self._parameters)
-            )
-
-        self._prereconfigure_update_hash = config
-
-    def postReconfigure(self):
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            self.log_status_warn("postReconfigure without knowing update hash")
-            return
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("frameFilter"):
-            self._lock_and_update(self._update_frame_filter)
-        elif any(
-            update.has(shape_param)
-            for shape_param in (
-                "dataFormat.pixelsX",
-                "dataFormat.pixelsY",
-                "dataFormat.outputImageDtype",
-                "dataFormat.outputAxisOrder",
-                "dataFormat.frames",
-                "constantParameters.memoryCells",
-                "frameFilter",
-            )
-        ):
-            self._lock_and_update(self._update_buffers)
-
-        if any(
-            (
-                update.has(f"corrections.{field_name}.enable")
-                or update.has(f"corrections.{field_name}.preview")
-            )
-            for field_name, *_ in self._correction_steps
-        ):
-            self._update_correction_flags()
+            if (
+                runner_update := subset_of_hash(
+                    config, "corrections", "dataFormat", "constantParameters"
+                )
+            ):
+                self.kernel_runner.reconfigure(runner_update)
 
-        # TODO: only send subhash of reconfiguration (like with addons)
-        if update.has("preview"):
-            self._preview_friend.reconfigure(update)
+            if config.has("preview"):
+                self.preview_friend.reconfigure(config["preview"])
 
-        if update.has("addons"):
-            # note: can avoid iterating, but it's not that costly
-            for addon in self._enabled_addons:
-                full_path = f"addons.{addon.__class__.__name__}"
-                if update.has(full_path):
-                    addon.reconfigure(update[full_path])
+            if config.has("addons"):
+                # note: can avoid iterating, but it's not that costly
+                for addon in self._enabled_addons:
+                    full_path = f"addons.{addon.__class__.__name__}"
+                    if config.has(full_path):
+                        addon.reconfigure(config[full_path])
 
     def _lock_and_update(self, method, background=True):
         # TODO: securely handle errors (postReconfigure may succeed, device state not)
         def runner():
-            with self._buffer_lock, utils.StateContext(self, State.CHANGING):
+            with self._buffer_lock, StateContext(self, State.CHANGING):
                 method()
 
         if background:
@@ -896,11 +656,9 @@ class BaseCorrection(PythonDevice):
                     ) as warn:
                         try:
                             constant_data = future.result()
-                            self._load_constant_to_runner(constant, constant_data)
+                            self.kernel_runner.load_constant(constant, constant_data)
                         except Exception as e:
                             warn(f"Failed to load {constant}: {e}")
-                        else:
-                            self._successfully_loaded_constant_to_runner(constant)
         self._lock_and_update(aux)
 
     def flush_constants(self, constants=None, preserve_fields=None):
@@ -930,7 +688,6 @@ class BaseCorrection(PythonDevice):
                 self.set(f"corrections.{field_name}.available", False)
         self.calcat_friend.flush_constants(constants, preserve_fields)
         self._reload_constants_from_cache_to_runner(constants)
-        self._update_correction_flags()
 
     def log_status_info(self, msg):
         self.log.INFO(msg)
@@ -954,7 +711,7 @@ class BaseCorrection(PythonDevice):
             payload["data"] = scenes.correction_device_preview(
                 device_id=self.getInstanceId(),
                 schema=self.getFullSchema(),
-                preview_channel=channel_name,
+                name=channel_name,
             )
         elif name.startswith("browse_schema"):
             if ":" in name:
@@ -983,10 +740,35 @@ class BaseCorrection(PythonDevice):
     def slotReceiveGeometry(self, device_id, serialized_geometry):
         self.log.INFO(f"Received geometry from {device_id}")
         try:
-            self._geometry = geom_utils.deserialize_geometry(serialized_geometry)
+            self.geometry = geom_utils.deserialize_geometry(serialized_geometry)
         except Exception as e:
             self.log.WARN(f"Failed to deserialize geometry; {e}")
 
+    def _get_data_from_hash(self, data_hash):
+        """Will get image data, cell table, pulse table, and list of other arrays from
+        the data hash. Assumes XTDF (image.data, image.cellId, image.pulseId, and no
+        other data), non-XTDF subclass can override. Cell and pulse table can be
+        None."""
+
+        image_data = data_hash.get(self._image_data_path)
+        cell_table = data_hash[self._cell_table_path].ravel()
+        pulse_table = data_hash[self._pulse_table_path].ravel()
+        num_frames = cell_table.size
+        # DataAggregator typically tells us the wrong axis order
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            # TODO: check that frames are always axis 0
+            expected_shape = self.kernel_runner.expected_input_shape(
+                num_frames
+            )
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
+        return (
+            num_frames,
+            image_data,  # not ndarray (runner can do that)
+            cell_table,
+            pulse_table,
+        )
+
     def _write_output(self, data, old_metadata):
         """For dataOutput: reusing incoming data hash and setting source and timestamp
         to be same as input"""
@@ -994,103 +776,11 @@ class BaseCorrection(PythonDevice):
             old_metadata.get("source"),
             Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")),
         )
-        data["corrections"] = self._correction_applied_hash
 
         channel = self.signalSlotable.getOutputChannel("dataOutput")
         channel.write(data, metadata, copyAllData=False)
         channel.update(safeNDArray=True)
 
-    def _update_correction_flags(self):
-        """Based on constants loaded and settings, update bit mask flags for kernel"""
-        enabled = self._correction_flag_class.NONE
-        output = Hash()  # for informing downstream receivers what was applied
-        preview = self._correction_flag_class.NONE
-        for field_name, flag, constants in self._correction_steps:
-            output[field_name] = False
-            if self.get(f"corrections.{field_name}.available"):
-                if self.get(f"corrections.{field_name}.enable"):
-                    enabled |= flag
-                    output[field_name] = True
-                if self.get(f"corrections.{field_name}.preview"):
-                    preview |= flag
-        self._correction_flag_enabled = enabled
-        self._correction_flag_preview = preview
-        self._correction_applied_hash = output
-        self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
-        self.log.DEBUG(f"Corrections for preview: {str(preview)}")
-
-    def _update_frame_filter(self, update_buffers=True):
-        """Parse frameFilter string (if set) and update cached filter array. May update
-        dataFormat.filteredFrames - will therefore by default call _update_buffers
-        afterwards."""
-        # TODO: add some validation to preReconfigure
-        self.log.DEBUG("Updating frame filter")
-
-        if self._frame_filter is None:
-            self.set("dataFormat.filteredFrames", self.get("dataFormat.frames"))
-            self.set("frameFilter.current", [])
-        else:
-            self.set("dataFormat.filteredFrames", self._frame_filter.size)
-            self.set("frameFilter.current", list(map(int, self._frame_filter)))
-
-        with self.warning_context(
-            "deviceInternalsState", WarningLampType.FRAME_FILTER
-        ) as warn:
-            if self._frame_filter is not None and (
-                self._frame_filter.min() < 0
-                or self._frame_filter.max() >= self.get("dataFormat.frames")
-            ):
-                warn("Invalid frame filter set, expect exceptions!")
-                # note: dataFormat.frames could change from detector and
-                # that could make this warning invalid
-
-        if update_buffers:
-            self._update_buffers()
-
-    def _update_buffers(self):
-        """(Re)initialize buffers / kernel runner according to expected data shapes"""
-        self.log.INFO("Updating buffers according to data shapes")
-        # reflect the axis reordering in the expected output shape
-        self.set("dataFormat.inputDataShape", list(self.input_data_shape))
-        self.set("dataFormat.outputDataShape", list(self.output_data_shape))
-        self.log.INFO(f"Input shape: {self.input_data_shape}")
-        self.log.INFO(f"Output shape: {self.output_data_shape}")
-
-        if self._shmem_buffer is None:
-            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
-            memory_budget = self.get("outputShmemBufferSize") * 2**30
-            self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
-            self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
-                memory_budget,
-                self.output_data_shape,
-                self.output_data_dtype,
-                shmem_buffer_name,
-            )
-            self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver()
-            if self._cuda_pin_buffers:
-                self.log.INFO("Trying to pin the shmem buffer memory")
-                self._shmem_buffer.cuda_pin()
-            self.log.INFO("Done, shmem buffer is ready")
-        else:
-            self._shmem_buffer.change_shape(self.output_data_shape)
-
-        # give CuPy a chance to at least start memory cleanup before this
-        if self.kernel_runner is not None:
-            del self.kernel_runner
-            self.kernel_runner = None
-            gc.collect()
-
-        self.kernel_runner = self._kernel_runner_class(
-            self.get("dataFormat.pixelsX"),
-            self.get("dataFormat.pixelsY"),
-            self.get("dataFormat.filteredFrames"),
-            int(self.get("constantParameters.memoryCells")),
-            output_data_dtype=self.output_data_dtype,
-            **self._kernel_runner_init_args,
-        )
-
-        self._reload_constants_from_cache_to_runner()
-
     def _reload_constants_from_cache_to_runner(self, constants=None):
         # TODO: gracefully handle change in constantParameters.memoryCells
         if constants is None:
@@ -1110,13 +800,11 @@ class BaseCorrection(PythonDevice):
                 ) as warn:
                     try:
                         self.log_status_info(f"Reload constant {constant.name}")
-                        self._load_constant_to_runner(constant, data)
+                        self.kernel_runner.load_constant(constant, data)
                     except Exception as e:
                         warn(f"Failed to reload {constant.name}: {e}")
-                    else:
-                        self._successfully_loaded_constant_to_runner(constant)
 
-    def input_handler(self, input_channel):
+    def input_handler(self, data_hash, metadata):
         """Main handler for data input: Do a few simple checks to determine whether to
         even try processing. If yes, will pass data and information to process_data
         method provided by subclass."""
@@ -1134,161 +822,179 @@ class BaseCorrection(PythonDevice):
                 warn("Received data before correction device was ready")
                 return
 
-        all_metadata = input_channel.getMetaData()
-        for input_index in range(input_channel.size()):
-            with self.warning_context(
-                "inputDataState", WarningLampType.MISC_INPUT_DATA
-            ) as warn:
-                data_hash = input_channel.read(input_index)
-                metadata = all_metadata[input_index]
-                source = metadata.get("source")
+        with self.warning_context(
+            "inputDataState", WarningLampType.MISC_INPUT_DATA
+        ) as warn:
+            source = metadata.get("source")
 
-                if source not in self.sources:
-                    continue
-                elif not data_hash.has(self._image_data_path):
-                    warn("Ignoring hash without image node")
-                    continue
+            if source not in self.sources:
+                return
+            elif not data_hash.has(self._image_data_path):
+                warn("Ignoring hash without image node")
+                return
 
-            with self.warning_context(
-                "inputDataState",
-                WarningLampType.EMPTY_HASH,
-            ) as warn:
-                self._shmem_receiver.dereference_shmem_handles(data_hash)
-                try:
-                    image_data = np.asarray(data_hash.get(self._image_data_path))
-                    cell_table = (
-                        np.array(
-                            # explicit copy to avoid mysterious segfault
-                            data_hash.get(self._cell_table_path), copy=True
-                        ).ravel()
-                        if self._cell_table_path is not None
-                        else None
-                    )
-                    pulse_table = (
-                        np.array(
-                            # explicit copy to avoid mysterious segfault
-                            data_hash.get(self._pulse_table_path), copy=True
-                        ).ravel()
-                        if self._pulse_table_path is not None
-                        else None
-                    )
-                except RuntimeError as err:
-                    warn(
-                        "Failed to load image data; "
-                        f"probably empty hash from DAQ: {err}"
-                    )
-                    self._maybe_update_cell_and_pulse_tables(None, None)
-                    continue
+        with self.warning_context(
+            "inputDataState",
+            WarningLampType.EMPTY_HASH,
+        ) as warn:
+            self._shmem_receiver.dereference_shmem_handles(data_hash)
+            try:
+                (
+                    num_frames,
+                    image_data,
+                    cell_table,
+                    pulse_table,
+                    *additional_data
+                ) = self._get_data_from_hash(data_hash)
+            except RuntimeError as err:
+                warn(
+                    "Failed to load data from hash, probably empy hash from DAQ. "
+                    f"(Is detector sending data?)\n{err}"
+                )
+                self._maybe_update_cell_and_pulse_tables(None, None)
+                return
 
-            # no more common reasons to skip input, so go to processing
-            self._last_processing_started = default_timer()
-            if state is State.ON:
-                self.updateState(State.PROCESSING)
-                self.log_status_info("Processing data")
+        # no more common reasons to skip input, so go to processing
+        self._last_processing_started = default_timer()
+        if state is State.ON:
+            self.updateState(State.PROCESSING)
+            self.log_status_info("Processing data")
 
-            self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table)
-            timestamp = Timestamp.fromHashAttributes(
-                metadata.getAttributes("timestamp")
-            )
-            train_id = timestamp.getTrainId()
+        self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table)
+        timestamp = Timestamp.fromHashAttributes(
+            metadata.getAttributes("timestamp")
+        )
 
-            # check time server connection
-            with self.warning_context(
-                "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION
-            ) as warn:
-                my_timestamp = self.getActualTimestamp()
-                my_train_id = my_timestamp.getTrainId()
-                self._input_delay_tracker.update(
-                    (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000
-                )
-                self._buffered_status_update.set(
-                    "performance.inputDelay", self._input_delay_tracker.get()
-                )
-                if my_train_id == 0:
-                    my_train_id = self._last_train_id_processed + 1
-                    warn(
-                        "Failed to get current train ID, using previously seen train "
-                        "ID for future train thresholding - if this persists, check "
-                        "connection to timeserver."
-                    )
+        self._check_train_id_and_time(timestamp)
+
+        if self._warn_memory_cell_range:
             with self.warning_context(
-                "inputDataState", WarningLampType.TRAIN_ID
+                "processingState", WarningLampType.MEMORY_CELL_RANGE
             ) as warn:
-                if train_id > (
-                    my_train_id
-                    + self.unsafe_get("workarounds.trainFromFutureThreshold")
+                if (
+                    self.unsafe_get("constantParameters.memoryCells")
+                    <= cell_table.max()
                 ):
-                    warn(
-                        f"Suspecting train from the future: 'now' is {my_train_id}, "
-                        f"received train ID {train_id}, dropping data"
-                    )
-                    continue
-                try:
-                    self._train_ratio_tracker.update(train_id)
-                    self._buffered_status_update.set(
-                        "performance.ratioOfRecentTrainsReceived",
-                        self._train_ratio_tracker.get(
-                            current_train=my_train_id,
-                            expected_delay=math.ceil(
-                                self.unsafe_get("performance.inputDelay") / 100
-                            ),
-                        )
-                        if my_train_id != 0
-                        else self._train_ratio_tracker.get(),
-                    )
-                except trackers.NonMonotonicTrainIdWarning as ex:
-                    warn(
-                        f"Train ratio tracker noticed issue with train ID: {ex}\n"
-                        f"For the record, I think now is: {my_train_id}"
-                    )
-                    self._train_ratio_tracker.reset()
-                    self._train_ratio_tracker.update(train_id)
+                    warn("Input cell IDs out of range of constants")
 
-            if self._warn_memory_cell_range:
-                with self.warning_context(
-                    "processingState", WarningLampType.MEMORY_CELL_RANGE
-                ) as warn:
-                    if (
-                        self.unsafe_get("constantParameters.memoryCells")
-                        <= cell_table.max()
-                    ):
-                        warn("Input cell IDs out of range of constants")
-
-            # TODO: avoid branching multiple times on cell_table
-            if cell_table is not None and cell_table.size != self.unsafe_get(
-                "dataFormat.frames"
-            ):
-                self.log_status_info(
-                    f"Updating new input shape to account for {cell_table.size} cells"
-                )
-                self.set("dataFormat.frames", cell_table.size)
-                self._lock_and_update(self._update_frame_filter, background=False)
-
-            # DataAggregator typically tells us the wrong axis order
-            if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-                expected_shape = self.input_data_shape
-                if expected_shape != image_data.shape:
-                    image_data.shape = expected_shape
-
-            with self._buffer_lock:
-                self.process_data(
-                    data_hash,
-                    metadata,
-                    source,
-                    train_id,
-                    image_data,
-                    cell_table,
-                    pulse_table,
+        # TODO: change ShmemCircularBuffer API so we don't have to poke like this
+        if (
+            (output_shape := self.kernel_runner.expected_output_shape(num_frames))
+            != self._shmem_buffer._buffer_ary.shape[1:]
+            or self.kernel_runner._output_dtype != self._shmem_buffer._buffer_ary.dtype
+        ):
+            self.set("dataFormat.outputDataShape", list(output_shape))
+            self.log.DEBUG("Updating shmem buffer shape / dtype")
+            self._shmem_buffer.change_shape(
+                output_shape, self.kernel_runner._output_dtype
+            )
+        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
+        with self.warning_context(
+            "processingState", WarningLampType.CORRECTION_RUNNER
+        ):
+            corrections, processed_buffer, previews = self.kernel_runner.correct(
+                image_data, cell_table, *additional_data
+            )
+            data_hash["corrections"] = corrections
+            for addon in self._enabled_addons:
+                addon.post_correction(
+                    processed_buffer, cell_table, pulse_table, data_hash
                 )
-            self._last_train_id_processed = train_id
-            self._buffered_status_update.set("trainId", train_id)
-            self._processing_time_tracker.update(
-                default_timer() - self._last_processing_started
+            self.kernel_runner.reshape(processed_buffer, out=buffer_array)
+
+        with self.warning_context(
+            "processingState", WarningLampType.PREVIEW_SETTINGS
+        ) as warn:
+            self.preview_friend.write_outputs(
+                *previews,
+                timestamp=timestamp,
+                cell_table=cell_table,
+                pulse_table=pulse_table,
+                warn_fun=warn,
+            )
+
+        for addon in self._enabled_addons:
+            addon.post_reshape(
+                buffer_array, cell_table, pulse_table, data_hash
+            )
+
+        if self.unsafe_get("useShmemHandles"):
+            data_hash.set(self._image_data_path, buffer_handle)
+            data_hash.set("calngShmemPaths", [self._image_data_path])
+        else:
+            data_hash.set(self._image_data_path, buffer_array)
+            data_hash.set("calngShmemPaths", [])
+
+        self._write_output(data_hash, metadata)
+        self._processing_time_tracker.update(
+            default_timer() - self._last_processing_started
+        )
+        self._buffered_status_update.set(
+            "performance.processingTime", self._processing_time_tracker.get() * 1000
+        )
+        self._rate_tracker.update()
+
+    def _please_send_me_cached_constants(self, constants, callback):
+        for constant in constants:
+            if constant in self.calcat_friend.cached_constants:
+                callback(constant, self.calcat_friend.cached_constants[constant])
+
+    def _check_train_id_and_time(self, timestamp):
+        train_id = timestamp.getTrainId()
+        # check time server connection
+        with self.warning_context(
+            "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION
+        ) as warn:
+            my_timestamp = self.getActualTimestamp()
+            my_train_id = my_timestamp.getTrainId()
+            self._input_delay_tracker.update(
+                (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000
             )
             self._buffered_status_update.set(
-                "performance.processingTime", self._processing_time_tracker.get() * 1000
+                "performance.inputDelay", self._input_delay_tracker.get()
             )
-            self._rate_tracker.update()
+            if my_train_id == 0:
+                my_train_id = self._last_train_id_processed + 1
+                warn(
+                    "Failed to get current train ID, using previously seen train "
+                    "ID for future train thresholding - if this persists, check "
+                    "connection to timeserver."
+                )
+        with self.warning_context(
+            "inputDataState", WarningLampType.TRAIN_ID
+        ) as warn:
+            if train_id > (
+                my_train_id
+                + self.unsafe_get("workarounds.trainFromFutureThreshold")
+            ):
+                warn(
+                    f"Suspecting train from the future: 'now' is {my_train_id}, "
+                    f"received train ID {train_id}, dropping data"
+                )
+                raise TrainFromTheFutureException()
+            try:
+                self._train_ratio_tracker.update(train_id)
+                self._buffered_status_update.set(
+                    "performance.ratioOfRecentTrainsReceived",
+                    self._train_ratio_tracker.get(
+                        current_train=my_train_id,
+                        expected_delay=math.ceil(
+                            self.unsafe_get("performance.inputDelay") / 100
+                        ),
+                    )
+                    if my_train_id != 0
+                    else self._train_ratio_tracker.get(),
+                )
+            except trackers.NonMonotonicTrainIdWarning as ex:
+                warn(
+                    f"Train ratio tracker noticed issue with train ID: {ex}\n"
+                    f"For the record, I think now is: {my_train_id}"
+                )
+                self._train_ratio_tracker.reset()
+                self._train_ratio_tracker.update(train_id)
+
+        self._last_train_id_processed = train_id
+        self._buffered_status_update.set("trainId", train_id)
 
     def _update_rate_and_state(self):
         if self.get("state") is State.PROCESSING:
@@ -1329,119 +1035,6 @@ class BaseCorrection(PythonDevice):
         self.signalEndOfStream("dataOutput")
 
 
-def add_correction_step_schema(schema, field_flag_constants_mapping):
-    """Using the fields in the provided mapping, will add nodes to schema
-
-    field_flag_mapping is assumed to be iterable of pairs where first entry in each
-    pair is the name of a correction step as it will appear in device schema (second
-    entry - typically an enum field - is ignored). For correction step, a node and some
-    booleans are added to the schema (all tagged as managed). Subclass can customize /
-    add additional keys under node later.
-
-    This method should be called in expectedParameters of subclass after the same for
-    BaseCorrection has been called. Would be nice to include in BaseCorrection instead,
-    but that is tricky: static method of superclass will need _correction_steps
-    of subclass or device server gets mad. A nice solution with classmethods would be
-    welcome.
-    """
-
-    for field_name, _, used_constants in field_flag_constants_mapping:
-        node_name = f"corrections.{field_name}"
-        (
-            NODE_ELEMENT(schema).key(node_name).commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.available")
-            .displayedName("Available")
-            .description(
-                "This boolean indicates whether the necessary constants have been "
-                "loaded for this correction step to be applied. Enabling the "
-                "correction will have no effect unless this is True."
-            )
-            .readOnly()
-            # some corrections available without constants
-            .initialValue(not used_constants or None in used_constants)
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.enable")
-            .tags("managed")
-            .displayedName("Enable")
-            .description(
-                "Controls whether to apply this correction step for main data "
-                "output - subject to availability."
-            )
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(schema)
-            .key(f"{node_name}.preview")
-            .tags("managed")
-            .displayedName("Preview")
-            .description(
-                "Whether to apply this correction step for corrected preview "
-                "output - subject to availability."
-            )
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-
-
-def add_bad_pixel_config_node(schema, prefix="corrections.badPixels"):
-    (
-        STRING_ELEMENT(schema)
-        .key("corrections.badPixels.maskingValue")
-        .tags("managed")
-        .displayedName("Bad pixel masking value")
-        .description(
-            "Any pixels masked by the bad pixel mask will have their value replaced "
-            "with this. Note that this parameter is to be interpreted as a "
-            "numpy.float32; use 'nan' to get NaN value."
-        )
-        .assignmentOptional()
-        .defaultValue("nan")
-        .reconfigurable()
-        .commit(),
-
-        NODE_ELEMENT(schema)
-        .key("corrections.badPixels.subsetToUse")
-        .displayedName("Bad pixel flags to use")
-        .description(
-            "The booleans under this node allow for selecting a subset of bad pixel "
-            "types to take into account when doing bad pixel masking. Upon updating "
-            "these flags, the map used for bad pixel masking will be ANDed with this "
-            "selection. Turning disabled flags back on causes reloading of cached "
-            "constants."
-        )
-        .commit(),
-    )
-    for field in utils.BadPixelValues:
-        (
-            BOOL_ELEMENT(schema)
-            .key(f"corrections.badPixels.subsetToUse.{field.name}")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit()
-        )
-
-
-def add_preview_outputs(schema, channels):
-    for channel in channels:
-        (
-            OUTPUT_CHANNEL(schema)
-            .key(f"preview.{channel}")
-            .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True))
-            .commit(),
-        )
-    preview_utils.PreviewFriend.add_schema(schema, output_channels=channels)
-
-
 def add_addon_nodes(schema, device_class, prefix="addons"):
     det_name = device_class.__name__[:-len("Correction")].lower()
     device_class._available_addons = [
@@ -1467,28 +1060,3 @@ def add_addon_nodes(schema, device_class, prefix="addons"):
         addon_class.extend_device_schema(
             schema, f"{prefix}.{addon_class.__name__}"
         )
-
-
-def get_bad_pixel_field_selection(self):
-    selection = 0
-    for field in utils.BadPixelValues:
-        if self.get(f"corrections.badPixels.subsetToUse.{field.name}"):
-            selection |= field
-    return selection
-
-
-def _parse_frame_filter(config):
-    filter_type = FramefilterSpecType(config["frameFilter.type"])
-    filter_string = config["frameFilter.spec"]
-
-    if filter_type is FramefilterSpecType.NONE or filter_string.strip() == "":
-        return None
-    elif filter_type is FramefilterSpecType.RANGE:
-        # allow exceptions
-        numbers = tuple(int(part) for part in filter_string.split(","))
-        return np.arange(*numbers, dtype=np.uint16)
-    elif filter_type is FramefilterSpecType.COMMASEPARATED:
-        # np.fromstring is too lenient I think
-        return np.array([int(s) for s in filter_string.split(",")], dtype=np.uint16)
-    else:
-        raise TypeError(f"Unknown frame filter type {filter_type}")
diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py
index 5eaaf101..e04ad559 100644
--- a/src/calng/base_kernel_runner.py
+++ b/src/calng/base_kernel_runner.py
@@ -1,19 +1,163 @@
 import enum
 import functools
-import operator
+import itertools
 import pathlib
 
-from calngUtils import misc as utils
+from karabo.bound import (
+    BOOL_ELEMENT,
+    NODE_ELEMENT,
+    STRING_ELEMENT,
+    Hash,
+)
+from .utils import (
+    BadPixelValues,
+    maybe_get,
+    subset_of_hash,
+)
+from calngUtils import misc
 import jinja2
-import numpy as np
 
 
 class BaseKernelRunner:
-    _gpu_based = True
     _xp = None  # subclass sets numpy or cupy
+    num_pixels_ss = None  # subclass must set
+    num_pixels_fs = None  # subclass must set
+
+    _bad_pixel_constants = None  # should be set by subclasses using bad pixels
+    bad_pixel_subset = 0xFFFFFFFF  # will be set in reconfigure
+
+    @classmethod
+    def add_schema(cls, schema):
+        """Using the steps in the provided mapping, will add nodes to schema
+
+        correction_steps is assumed to be iterable of pairs where first entry in each
+        pair is the name of a correction step as it will appear in device schema (second
+        entry - typically an enum field - is ignored). For correction step, a node and
+        some booleans are added to the schema (all tagged as managed). Subclass can
+        customize / add additional keys under node later.
+        """
+        (
+            NODE_ELEMENT(schema)
+            .key("corrections")
+            .commit(),
+        )
+
+        for field_name, _, used_constants in cls._correction_steps:
+            node_name = f"corrections.{field_name}"
+            (
+                NODE_ELEMENT(schema).key(node_name).commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.available")
+                .displayedName("Available")
+                .description(
+                    "This boolean indicates whether the necessary constants have been "
+                    "loaded for this correction step to be applied. Enabling the "
+                    "correction will have no effect unless this is True."
+                )
+                .readOnly()
+                # some corrections available without constants
+                .initialValue(not used_constants or None in used_constants)
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.enable")
+                .tags("managed")
+                .displayedName("Enable")
+                .description(
+                    "Controls whether to apply this correction step for main data "
+                    "output - subject to availability."
+                )
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_name}.preview")
+                .tags("managed")
+                .displayedName("Preview")
+                .description(
+                    "Whether to apply this correction step for corrected preview "
+                    "output - subject to availability."
+                )
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit(),
+            )
+
+    @staticmethod
+    def add_bad_pixel_config(schema):
+        (
+            STRING_ELEMENT(schema)
+            .key("corrections.badPixels.maskingValue")
+            .tags("managed")
+            .displayedName("Bad pixel masking value")
+            .description(
+                "Any pixels masked by the bad pixel mask will have their value "
+                "replaced with this. Note that this parameter is to be interpreted as "
+                "a numpy.float32; use 'nan' to get NaN value."
+            )
+            .assignmentOptional()
+            .defaultValue("nan")
+            .reconfigurable()
+            .commit(),
+
+            NODE_ELEMENT(schema)
+            .key("corrections.badPixels.subsetToUse")
+            .displayedName("Bad pixel flags to use")
+            .description(
+                "The booleans under this node allow for selecting a subset of bad "
+                "pixel types to take into account when doing bad pixel masking. Upon "
+                "updating these flags, the map used for bad pixel masking will be "
+                "ANDed with this selection. Turning disabled flags back on causes "
+                "reloading of cached constants."
+            )
+            .commit(),
+        )
+        for field in BadPixelValues:
+            (
+                BOOL_ELEMENT(schema)
+                .key(f"corrections.badPixels.subsetToUse.{field.name}")
+                .tags("managed")
+                .assignmentOptional()
+                .defaultValue(True)
+                .reconfigurable()
+                .commit()
+            )
+
+    def __init__(self, device):
+        self._device = device
+        # for now, we depend on multiple keys from device hash, so get multiple nodes
+        config = subset_of_hash(
+            self._device._parameters,
+            "corrections",
+            "dataFormat",
+            "constantParameters",
+        )
+        self._constant_memory_cells = config["constantParameters.memoryCells"]
+        self._pre_init()
+        # note: does not handle one constant enabling multiple correction steps
+        # (do we need to generalize "decide if this correction step is available"?)
+        self._constant_to_correction_names = {}
+        for (name, _, constants) in self._correction_steps:
+            for constant in constants:
+                self._constant_to_correction_names.setdefault(constant, set()).add(name)
+        self._output_dtype = self._xp.float32
+        self._setup_constant_buffers()
+        self.reconfigure(config)
+        self._post_init()
 
     def _pre_init(self):
-        # can be used, for example, to import cupy and set as _xp at runtime
+        """Hook used to set up things which need to be in place before __init__ is
+        called. Note that __init__ will do reconfigure with full configuration, which
+        means creating constant buffers and such (__init__ already takes care to set
+        _constant_memory_cells so this can be used in _setup_constant_buffers).
+
+
+        See also _setup_constant_buffers
+        """
         pass
 
     def _post_init(self):
@@ -21,121 +165,230 @@ class BaseKernelRunner:
         # without overriding __init__
         pass
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
-        self._pre_init()
-        self.pixels_x = pixels_x
-        self.pixels_y = pixels_y
-        self.frames = frames
-        if constant_memory_cells == 0:
-            # if not set, guess same as input; may save one recompilation
-            self.constant_memory_cells = frames
-        else:
-            self.constant_memory_cells = constant_memory_cells
-        self.output_data_dtype = output_data_dtype
-        self._post_init()
+    def expected_input_shape(self, num_frames):
+        """Mostly used to inform the host device about what data shape should be.
+        Subclass should override as needed (ex. XTDF detectors with extra image/gain
+        axis)."""
+        return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
-
-    def correct(self, flags):
-        """Correct (already loaded) image data according to flags
-
-        Detector-specific subclass must define this method. It should assume that image
-        data, cell table, and other data (including constants) has already been loaded.
-        It should probably run some {CPU,GPU} kernel and output should go into
-        self.processed_data{,_gpu}.
+    def expected_output_shape(self, num_frames):
+        """Used to inform the host device about what to expect (for reshaping the shmem
+        buffer). Takes into account transpose setting; override _expected_output_shape
+        for subclasses which do funky shape things."""
+        base_shape = self._expected_output_shape(num_frames)
+        if self._output_transpose is None:
+            return base_shape
+        else:
+            return tuple(base_shape[i] for i in self._output_transpose)
 
-        Keep in mind that user only gets output from compute_preview or reshape
-        (either of these should come after correct).
+    def _expected_output_shape(self, num_frames):
+        return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
 
-        The submodules providing subclasses should have some IntFlag enums defining
-        which flags are available to pass along to the kernel. A zero flag should allow
-        the kernel to do no actual correction - but still copy the data between buffers
-        and cast it to desired output type.
+    def _correct(self, flags, image_data, cell_table, *rest):
+        """Subclass will implement the core correcton routine with this name. Observe
+        the parameter list: the "rest" part will always contain at least one output
+        buffer, but may have more entries such as additional raw input buffers and
+        multiple output buffers. See how it is called by correct."""
 
-        """
         raise NotImplementedError()
 
-    @property
-    def preview_data_views(self):
+    def _preview_data_views(self, *buffers):
         """Should provide tuple of views of data for different preview layers - as a
         minimum, raw and corrected.  View is expected to have axis order frame, slow
-        scan, fast scan with the latter two matching preview_shape (which is X and which
-        is Y may differ between detectors)."""
+        scan, fast scan with the latter being the final preview shape (which is X and
+        which is Y may differ between detectors). The buffers passed vary from subclass
+        to subclass; will be all inputs and outputs used in _correct. By default, that
+        is just what gets returned."""
+
+        return buffers
+
+    def _load_constant(self, constant_type, constant_data):
+        """Should update the appropriate internal buffers for constant type (maps to
+        "calibration" type in CalCat) with the data in constant_tdata. This may include
+        sanitization; data passed here has the axis order found in the stored constant
+        files."""
+        raise NotImplementedError()
 
+    def _setup_constant_buffers(self):
         raise NotImplementedError()
 
+    def _make_output_buffers(self, num_frames, flags):
+        """Should, based on number of frames passed (may in the future be extended to
+        include other information), create the necessary buffers to store corrected
+        output(s). By default, will create a single processed data buffer of float32
+        with the same shape as input image data, skipping any extra axes (like the
+        image/gain axis for XTDF). Subclass should override if it wants more outputs.
+        First buffer must be for corrected image data (passed to addons)"""
+        return [
+            self._xp.empty(
+                (num_frames, self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
+
+    @functools.cached_property
+    def _all_relevant_constants(self):
+        return set(
+            itertools.chain.from_iterable(
+                constants for (_, _, constants) in self._correction_steps
+            )
+        ) - {None}
+
+    def reconfigure(self, config):
+        if config.has("constantParameters.memoryCells"):
+            self._constant_memory_cells = config.get("constantParameters.memoryCells")
+            self._setup_constant_buffers()
+            self._device._please_send_me_cached_constants(
+                self._all_relevant_constants, self.load_constant
+            )
+
+        if config.has("dataFormat.outputImageDtype"):
+            self._output_dtype = self._xp.dtype(
+                config.get("dataFormat.outputImageDtype")
+            )
+
+        if config.has("dataFormat.outputAxisOrder"):
+            order = config["dataFormat.outputAxisOrder"]
+            if order == "f-ss-fs":
+                self._output_transpose = None
+            else:
+                self._output_transpose = misc.transpose_order(
+                    ("f", "ss", "fs"), order.split("-")
+                )
+
+        if config.has("corrections.badPixels.maskingValue"):
+            self.bad_pixel_mask_value = self._xp.float32(
+                config["corrections.badPixels.maskingValue"]
+            )
+
+        if config.has("corrections.badPixels.subsetToUse"):
+            # note: now just always reloading from cache for convenience
+            self._update_bad_pixel_subset()
+            self.flush_buffers(self._bad_pixel_constants)
+            self._device._please_send_me_cached_constants(
+                self._bad_pixel_constants, self.load_constant
+            )
+
+        # TODO: only trigger flag update if changed (cheap though)
+        self._update_correction_flags()
+
+    def reshape(self, processed_data, out=None):
+        """Move axes to desired output order"""
+        if self._output_transpose is None:
+            return maybe_get(processed_data, out)
+        else:
+            return maybe_get(processed_data.transpose(self._output_transpose), out)
+
     def flush_buffers(self, constants=None):
         """Optional reset internal buffers (implement in subclasses which need this)"""
         pass
 
-    def compute_previews(self, preview_index):
-        """Generate single slice or reduction previews for raw and corrected data and
-        any other layers, determined by self.preview_data_views
+    def _update_correction_flags(self):
+        enabled = self._correction_flag_class.NONE
+        output = Hash()  # for informing downstream receivers what was applied
+        preview = self._correction_flag_class.NONE
+        for field_name, flag, constants in self._correction_steps:
+            output[field_name] = False
+            if self._get(f"corrections.{field_name}.available"):
+                if self._get(f"corrections.{field_name}.enable"):
+                    enabled |= flag
+                    output[field_name] = True
+                if self._get(f"corrections.{field_name}.preview"):
+                    preview |= flag
+        self._correction_flag_enabled = enabled
+        self._correction_flag_preview = preview
+        self._correction_applied_hash = output
+        self._device.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
+        self._device.log.DEBUG(f"Corrections for preview: {str(preview)}")
+
+    def _update_bad_pixel_subset(self):
+        selection = 0
+        for field in BadPixelValues:
+            if self._get(f"corrections.badPixels.subsetToUse.{field.name}"):
+                selection |= field
+        self.bad_pixel_subset = selection
 
-        Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and
-        -4 for stdev (across cells).
+    def _get(self, key):
+        """Will look up key in host device's parameters.  Note that device calls
+        reconfigure during preReconfigure, but with new unmerged config overlaid
+        temporarily, making _get get the "new" state of things."""
+        return self._device.unsafe_get(key)
 
-        Note that preview_index is taken from data without checking cell table.
-        Caller has to figure out which index along memory cell dimension they
-        actually want to preview in case it needs to be a specific cell / pulse.
+    def _set(self, key, value):
+        """Helper function to host device's configuration. Note that this does *not*
+        call reconfigure, so either do so afterwards or make sure to update auxiliary
+        properties correctly yourself."""
+        self._device.set(key, value)
 
-        Will typically reuse data from corrected output buffer. Therefore,
-        correct(...) must have been called with the appropriate flags before
-        compute_preview(...).
+    @property
+    def _map_shape(self):
+        return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs)
+
+    def load_constant(self, constant, constant_data):
+        self._load_constant(constant, constant_data)
+        self._post_load_constant(constant)
+
+    def _post_load_constant(self, constant):
+        for field_name in self._constant_to_correction_names[constant]:
+            key = f"corrections.{field_name}.available"
+            if not self._get(key):
+                self._set(key, True)
+
+        self._update_correction_flags()
+        self._device.log_status_info(f"Done loading {constant.name} to runner")
+
+    def correct(self, image_data, cell_table, *additional_data):
+        """Correct image data, providing both full corrected data and previews
+
+        This method relies on the following auxiliary methods which may need to be
+        overridden / defined in subclass:
+
+        _make_output_buffers
+        _preview_data_views
+        _correct (must be implemented by subclass)
+
+        Look at signature for _correct; it is given image_data, cell_table,
+        *additional_data (anything the host device class' _get_data_from_hash gives us
+        in addition to image data, cell table, and pulse table), and *processed_buffers
+        which is the result of _make_output_buffers. So all these four methods must
+        agree on the relevant set of buffers for input and output.
+
+        The pulse_table is only used for preview index selection purposes as corrections
+        do not (yet, at any rate) take the pulse table into account.
         """
-        if preview_index < -4:
-            raise ValueError(f"No statistic with code {preview_index} defined")
-        elif preview_index >= self.frames:
-            raise ValueError(f"Memory cell index {preview_index} out of range")
-
-        if preview_index >= 0:
-            fun = operator.itemgetter(preview_index)
-        elif preview_index == -1:
-            # note: separate from next case because dtype not applicable here
-            fun = functools.partial(self._xp.nanmax, axis=0)
-        elif preview_index in (-2, -3, -4):
-            fun = functools.partial(
-                {
-                    -2: self._xp.nanmean,
-                    -3: self._xp.nansum,
-                    -4: self._xp.nanstd,
-                }[preview_index],
-                axis=0,
-                dtype=self._xp.float32,
+        # TODO: check if this is robust enough (also base_correction)
+        num_frames = image_data.shape[0]
+        processed_buffers = self._make_output_buffers(
+            num_frames, self._correction_flag_preview
+        )
+        image_data = self._xp.asarray(image_data)  # if cupy, will put on GPU
+        if cell_table is not None:
+            cell_table = self._xp.asarray(cell_table)
+        additional_data = [self._xp.asarray(data) for data in additional_data]
+        self._correct(
+            self._correction_flag_preview,
+            image_data,
+            cell_table,
+            *additional_data,
+            *processed_buffers,
+        )
+
+        preview_buffers = self._preview_data_views(
+            image_data, *additional_data, *processed_buffers
+        )
+
+        if self._correction_flag_preview != self._correction_flag_enabled:
+            processed_buffers = self._make_output_buffers(
+                num_frames, self._correction_flag_enabled
             )
-        # TODO: reuse output buffers
-        # TODO: integrate multithreading
-        res = (fun(in_buffer_view) for in_buffer_view in self.preview_data_views)
-        if self._gpu_based:
-            res = (buf.get() for buf in res)
-        return res
-
-    def reshape(self, output_order, out=None):
-        """Move axes to desired output order"""
-        # TODO: avoid copy
-        if output_order == self._corrected_axis_order:
-            self.reshaped_data = self.processed_data
-        else:
-            self.reshaped_data = self.processed_data.transpose(
-                utils.transpose_order(self._corrected_axis_order, output_order)
+            self._correct(
+                self._correction_flag_enabled,
+                image_data,
+                cell_table,
+                *additional_data,
+                *processed_buffers,
             )
-
-        if self._gpu_based:
-            return self.reshaped_data.get(out=out)
-        else:
-            if out is None:
-                return self.reshaped_data
-            else:
-                out[:] = self.reshaped_data
+        return self._correction_applied_hash, processed_buffers[0], preview_buffers
 
 
 kernel_dir = pathlib.Path(__file__).absolute().parent / "kernels"
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index d97f1fc4..6632b6ef 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -9,30 +9,36 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
+    UINT32_ELEMENT,
 )
 
 from .. import (
     base_calcat,
     base_correction,
     base_kernel_runner,
+
     schemas,
     utils,
 )
+from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
 class Constants(enum.Enum):
     ThresholdsDark = enum.auto()
     Offset = enum.auto()
+    SlopesCS = enum.auto()
     SlopesPC = enum.auto()
     SlopesFF = enum.auto()
     BadPixelsDark = enum.auto()
+    BadPixelsCS = enum.auto()
     BadPixelsPC = enum.auto()
     BadPixelsFF = enum.auto()
 
 
 bad_pixel_constants = {
     Constants.BadPixelsDark,
+    Constants.BadPixelsCS,
     Constants.BadPixelsPC,
     Constants.BadPixelsFF,
 }
@@ -51,146 +57,389 @@ class CorrectionFlags(enum.IntFlag):
     NONE = 0
     THRESHOLD = 1
     OFFSET = 2
-    BLSHIFT = 4
-    REL_GAIN_PC = 8
-    GAIN_XRAY = 16
-    BPMASK = 32
-    FORCE_MG_IF_BELOW = 64
-    FORCE_HG_IF_BELOW = 128
+    REL_GAIN_PC = 4
+    GAIN_XRAY = 8
+    BPMASK = 16
+    FORCE_MG_IF_BELOW = 32
+    FORCE_HG_IF_BELOW = 64
+    COMMON_MODE = 128
+
+
+correction_steps = (
+    # step name (used in schema), flag to enable for kernel, constants required
+    ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}),
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
+    (  # TODO: specify that /both/ constants are needed, not just one or the other
+        "forceMgIfBelow",
+        CorrectionFlags.FORCE_MG_IF_BELOW,
+        {Constants.ThresholdsDark, Constants.Offset},
+    ),
+    (
+        "forceHgIfBelow",
+        CorrectionFlags.FORCE_HG_IF_BELOW,
+        {Constants.ThresholdsDark, Constants.Offset},
+    ),
+    ("relGain", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesCS, Constants.SlopesPC}),
+    ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        bad_pixel_constants | {
+            None,  # means stay available even without constants loaded
+        },
+    ),
+    ("commonMode", CorrectionFlags.COMMON_MODE, {None}),
+)
 
 
 class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fxy"
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 512
+    num_pixels_fs = 128
 
-    @property
-    def input_shape(self):
-        return (self.frames, 2, self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        # will add node, enable / preview toggles
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
 
-    @property
-    def processed_shape(self):
-        return (self.frames, self.pixels_x, self.pixels_y)
+        # additional settings specific to AGIPD correction steps
+        # turn off some advanced corrections by default
+        for flag in ("enable", "preview"):
+            (
+                OVERWRITE_ELEMENT(schema)
+                .key(f"corrections.commonMode.{flag}")
+                .setNewDefaultValue(False)
+                .commit(),
+            )
+            for step in ("forceMgIfBelow", "forceHgIfBelow"):
+                (
+                    OVERWRITE_ELEMENT(schema)
+                    .key(f"corrections.{step}.{flag}")
+                    .setNewDefaultValue(False)
+                    .commit(),
+                )
+        (
+            FLOAT_ELEMENT(schema)
+            .key("corrections.forceMgIfBelow.hardThreshold")
+            .tags("managed")
+            .description(
+                "If enabled, any pixels assigned to low gain stage which would be "
+                "below this threshold after having medium gain offset subtracted will "
+                "be reassigned to medium gain."
+            )
+            .assignmentOptional()
+            .defaultValue(1000)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_x, self.pixels_y)
+            FLOAT_ELEMENT(schema)
+            .key("corrections.forceHgIfBelow.hardThreshold")
+            .tags("managed")
+            .description(
+                "Like forceMgIfBelow, but potentially reassigning from medium gain to "
+                "high gain based on threshold and pixel value minus low gain offset. "
+                "Applied after forceMgIfBelow, pixels can theoretically be reassigned "
+                "twice, from LG to MG and to HG."
+            )
+            .assignmentOptional()
+            .defaultValue(1000)
+            .reconfigurable()
+            .commit(),
+
+            STRING_ELEMENT(schema)
+            .key("corrections.relGain.sourceOfSlopes")
+            .tags("managed")
+            .displayedName("Kind of slopes")
+            .description(
+                "Slopes for relative gain correction can be derived from pulse "
+                "capacitor (PC) scans or current source scans (CS). Choose one!"
+            )
+            .assignmentOptional()
+            .defaultValue("PC")
+            .options("PC,CS")
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.relGain.adjustMgBaseline")
+            .tags("managed")
+            .displayedName("Adjust MG baseline")
+            .description(
+                "If set, an additional offset is applied to pixels in medium gain. "
+                "The value of the offset is either computed per pixel (and memory "
+                "cell) based on currently loaded relative gain slopes (PC or CS) or "
+                "it is set to a specific value when overrideMdAdditionalOffset is set."
+            )
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.relGain.overrideMdAdditionalOffset")
+            .tags("managed")
+            .displayedName("Override md_additional_offset")
+            .description(
+                "If set, the additional offset applied to medium gain pixels "
+                "(assuming adjustMgBaseLine is set) will use the value from "
+                "mdAdditionalOffset globally for. Otherwise, the additional offset "
+                "is derived from the relevant constants."
+            )
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.relGain.mdAdditionalOffset")
+            .tags("managed")
+            .displayedName("Value for md_additional_offset (if overriding)")
+            .description(
+                "Normally, md_additional_offset (part of relative gain correction) is "
+                "computed when loading SlopesPC. In case you want to use a different "
+                "value (global for all medium gain pixels), you can specify it here "
+                "and set corrections.overrideMdAdditionalOffset to True."
+            )
+            .assignmentOptional()
+            .defaultValue(0)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.gainXray.gGainValue")
+            .tags("managed")
+            .displayedName("G_gain_value")
+            .description(
+                "Newer X-ray gain correction constants are absolute. The default "
+                "G_gain_value of 1 means that output is expected to be in keV. If "
+                "this is not desired, one can here specify the mean X-ray gain value "
+                "over all modules to get ADU values out - operator must manually "
+                "find this mean value."
+            )
+            .assignmentOptional()
+            .defaultValue(1)
+            .reconfigurable()
+            .commit(),
+
+            UINT32_ELEMENT(schema)
+            .key("corrections.commonMode.iterations")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(4)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(0.15)
+            .reconfigurable()
+            .commit(),
+
+            FLOAT_ELEMENT(schema)
+            .key("corrections.commonMode.noisePeakRange")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(35)
+            .reconfigurable()
+            .commit(),
+        )
+
+    def reconfigure(self, config):
+        super().reconfigure(config)
+
+        if config.has("corrections.gainXray.gGainValue"):
+            self.g_gain_value = self._xp.float32(
+                config["corrections.gainXray.gGainValue"]
+            )
+
+        if config.has("corrections.forceHgIfBelow.hardThreshold"):
+            self.hg_hard_threshold = self._xp.float32(
+                config["corrections.forceHgIfBelow.hardThreshold"]
+            )
+
+        if config.has("corrections.forceMgIfBelow.hardThreshold"):
+            self.mg_hard_threshold = self._xp.float32(
+                config["corrections.forceMgIfBelow.hardThreshold"]
+            )
+
+        if (
+            config.has("corrections.relGain.overrideMdAdditionalOffset")
+            or config.has("corrections.relGain.mdAdditionalOffset")
+        ):
+            if (
+                self._get("corrections.relGain.overrideMdAdditionalOffset")
+                and self._get("corrections.relGain.adjustMgBaseLine")
+            ):
+                self.md_additional_offset.fill(
+                    self._xp.float32(
+                        self._get("corrections.relGain.mdAdditionalOffset")
+                    )
+                )
+            else:
+                self._device._please_send_me_cached_constants(
+                    {
+                        Constants.SlopesPC
+                        if self._get("corrections.relGain.sourceOfSlopes") == "PC"
+                        else Constants.SlopesCS
+                    },
+                    self.load_constant,
+                )
+
+        if config.has("corrections.relGain.sourceOfSlopes"):
+            self.flush_buffers(
+                {
+                    Constants.SlopesPC,
+                    Constants.SlopesCS,
+                    Constants.BadPixelsCS,
+                    Constants.BadPixelsPC,
+                }
+            )
+
+            # note: will try to reload all bad pixels, but
+            # _load_constant will reject PC/CS based on which we want
+            if self._get("corrections.relGain.sourceOfSlopes") == "PC":
+                to_get = {Constants.SlopesPC} | bad_pixel_constants
+            else:
+                to_get = {Constants.SlopesCS} | bad_pixel_constants
+
+            self._device._please_send_me_cached_constants(to_get, self.load_constant)
+
+        if config.has("corrections.commonMode"):
+            self.cm_min_frac = self._xp.float32(
+                self._get("corrections.commonMode.minFrac")
+            )
+            self.cm_noise_peak = self._xp.float32(
+                self._get("corrections.commonMode.noisePeakRange")
+            )
+            self.cm_iter = self._xp.uint16(
+                self._get("corrections.commonMode.iterations")
+            )
+
+        if config.has("constantParameters.gainMode"):
+            gain_mode = GainModes[config["constantParameters.gainMode"]]
+            if gain_mode is GainModes.ADAPTIVE_GAIN:
+                self.default_gain = self._xp.uint8(gain_mode)
+            else:
+                self.default_gain = self._xp.uint8(gain_mode - 1)
+
+    def expected_input_shape(self, num_frames):
+        return (
+            num_frames,
+            2,
+            self.num_pixels_ss,
+            self.num_pixels_fs,
+        )
 
     @property
-    def gm_map_shape(self):
-        return self.map_shape + (3,)  # for gain-mapped constants
+    def _gm_map_shape(self):
+        return self._map_shape + (3,)  # for gain-mapped constants
 
     @property
     def threshold_map_shape(self):
-        return self.map_shape + (2,)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        gain_mode=GainModes.ADAPTIVE_GAIN,
-        g_gain_value=1,
-        mg_hard_threshold=2000,
-        hg_hard_threshold=2000,
-        override_md_additional_offset=None,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.gain_mode = gain_mode
-        if self.gain_mode is GainModes.ADAPTIVE_GAIN:
-            self.default_gain = self._xp.uint8(gain_mode)
-        else:
-            self.default_gain = self._xp.uint8(gain_mode - 1)
-        self.gain_map = self._xp.empty(self.processed_shape, dtype=np.float32)
+        return self._map_shape + (2,)
 
+    def _setup_constant_buffers(self):
         # constants
         self.gain_thresholds = self._xp.empty(
             self.threshold_map_shape, dtype=np.float32
         )
-        self.offset_map = self._xp.empty(self.gm_map_shape, dtype=np.float32)
-        self.rel_gain_pc_map = self._xp.empty(self.gm_map_shape, dtype=np.float32)
-        # not gm_map_shape because it only applies to medium gain pixels
-        self.md_additional_offset = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_xray_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = self._xp.empty(self.gm_map_shape, dtype=np.uint32)
-        self.override_md_additional_offset(override_md_additional_offset)
-        self.set_bad_pixel_mask_value(bad_pixel_mask_value)
-        self.set_mg_hard_threshold(mg_hard_threshold)
-        self.set_hg_hard_threshold(hg_hard_threshold)
-        self.set_g_gain_value(g_gain_value)
+        self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.rel_gain_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        # not _gm_map_shape because it only applies to medium gain pixels
+        self.md_additional_offset = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.xray_gain_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32)
         self.flush_buffers(set(Constants))
-        self.processed_data = self._xp.empty(self.processed_shape, output_data_dtype)
 
-    @property
-    def preview_data_views(self):
+    def _make_output_buffers(self, num_frames, flags):
+        output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+        return [
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # image
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # gain
+        ]
+
+    def _preview_data_views(self, raw_data, processed_data, gain_map):
         return (
-            self.input_data[:, 0],  # raw
-            self.processed_data,  # corrected
-            self.input_data[:, 1],  # raw gain
-            self.gain_map,  # digitized gain
+            raw_data[:, 0],  # raw
+            processed_data,  # corrected
+            raw_data[:, 1],  # raw gain
+            gain_map,  # digitized gain
         )
 
-    def load_constant(self, constant, constant_data):
+    def _load_constant(self, constant, constant_data):
         if constant is Constants.ThresholdsDark:
             # shape: y, x, memory cell, thresholds and gain values
             # note: the gain values are something like means used to derive thresholds
             self.gain_thresholds[:] = self._xp.asarray(
                 constant_data[..., :2], dtype=np.float32
-            ).transpose((2, 1, 0, 3))[:self.constant_memory_cells]
+            ).transpose((2, 1, 0, 3))[:self._constant_memory_cells]
         elif constant is Constants.Offset:
             # shape: y, x, memory cell, gain stage
             self.offset_map[:] = self._xp.asarray(
                 constant_data, dtype=np.float32
-            ).transpose((2, 1, 0, 3))[:self.constant_memory_cells]
+            ).transpose((2, 1, 0, 3))[:self._constant_memory_cells]
+        elif constant is Constants.SlopesCS:
+            if self._get("corrections.relGain.sourceOfSlopes") == "PC":
+                return
+            self.rel_gain_map.fill(1)
+            self.rel_gain_map[..., 1] = self.rel_gain_map[..., 0] * self._xp.asarray(
+                constant_data[ ..., :self._constant_memory_cells, 6]
+            ).T
+            self.rel_gain_map[..., 2] = self.rel_gain_map[..., 1] * self._xp.asarray(
+                constant_data[ ..., :self._constant_memory_cells, 7]
+            ).T
+            if self._get("corrections.relGain.adjustMgBaseline"):
+                self.md_additional_offset.fill(
+                    self._get("corrections.relGain.mdAdditionalOffset")
+                )
+            else:
+                self.md_additional_offset.fill(0)
         elif constant is Constants.SlopesPC:
-            # pc has funny shape (11, 352, 128, 512) from file
-            # this is (fi, memory cell, y, x)
-            # the following may contain NaNs, though...
+            if self._get("corrections.relGain.sourceOfSlopes") == "CS":
+                return
+            # pc has funny shape (11, 352, 128, 512) from file (fi, memory cell, y, x)
+            # assume newer variant (TODO: check) which is pre-sanitized
             hg_slope = constant_data[0]
             hg_intercept = constant_data[1]
             mg_slope = constant_data[3]
             mg_intercept = constant_data[4]
-            # TODO: remove sanitization (should happen in constant preparation notebook)
-            # from agipdlib.py: replace NaN with median (per memory cell)
-            # note: suffixes in agipdlib are "_m" and "_l", should probably be "_I"
-            for naughty_array in (hg_slope, hg_intercept, mg_slope, mg_intercept):
-                medians = np.nanmedian(naughty_array, axis=(1, 2))
-                nan_bool = np.isnan(naughty_array)
-                nan_cell, _, _ = np.where(nan_bool)
-                naughty_array[nan_bool] = medians[nan_cell]
-
-                too_low_bool = naughty_array < 0.8 * medians[:, np.newaxis, np.newaxis]
-                too_low_cell, _, _ = np.where(too_low_bool)
-                naughty_array[too_low_bool] = medians[too_low_cell]
-
-                too_high_bool = naughty_array > 1.2 * medians[:, np.newaxis, np.newaxis]
-                too_high_cell, _, _ = np.where(too_high_bool)
-                naughty_array[too_high_bool] = medians[too_high_cell]
+
+            hg_slope_median = np.nanmedian(hg_slope, axis=(1, 2))
+            mg_slope_median = np.nanmedian(mg_slope, axis=(1, 2))
 
             frac_hg_mg = hg_slope / mg_slope
             rel_gain_map = np.ones(
-                (3, self.constant_memory_cells, self.pixels_y, self.pixels_x),
+                (
+                    3,
+                    self._constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
+                ),
                 dtype=np.float32,
             )
             rel_gain_map[1] = rel_gain_map[0] * frac_hg_mg
             rel_gain_map[2] = rel_gain_map[1] * 4.48
-            self.rel_gain_pc_map[:] = self._xp.asarray(
+            self.rel_gain_map[:] = self._xp.asarray(
                 rel_gain_map.transpose((1, 3, 2, 0)), dtype=np.float32
-            )[:self.constant_memory_cells]
-            if self._md_additional_offset_value is None:
-                md_additional_offset = (
-                    hg_intercept - mg_intercept * frac_hg_mg
-                ).astype(np.float32)
-                self.md_additional_offset[:] = self._xp.asarray(
-                    md_additional_offset.transpose((0, 2, 1)), dtype=np.float32
-                )[:self.constant_memory_cells]
+            )[:self._constant_memory_cells]
+            if self._get("corrections.relGain.adjustMgBaseline"):
+                if self._get("corrections.relGain.overrideMdAdditionalOffset"):
+                    self.md_additional_offset.fill(
+                        self._get("corrections.relGain.mdAdditionalOffset")
+                    )
+                else:
+                    self.md_additional_offset[:] = self._xp.asarray(
+                        (
+                            hg_intercept - mg_intercept * frac_hg_mg
+                        ).astype(np.float32).transpose((0, 2, 1)), dtype=np.float32
+                    )[:self._constant_memory_cells]
+            else:
+                self.md_additional_offset.fill(0)
         elif constant is Constants.SlopesFF:
             # constant shape: y, x, memory cell
             if constant_data.shape[2] == 2:
@@ -199,35 +448,46 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
                 # note: we should not support this in online
                 constant_data = self._xp.broadcast_to(
                     constant_data[..., 0][..., np.newaxis],
-                    (self.pixels_y, self.pixels_x, self.constant_memory_cells),
+                    (
+                        self.num_pixels_fs,
+                        self.num_pixels_ss,
+                        self._constant_memory_cells,
+                    ),
                 )
-            self.rel_gain_xray_map[:] = self._xp.asarray(
+            self.xray_gain_map[:] = self._xp.asarray(
                 constant_data.transpose(), dtype=np.float32
-            )[:self.constant_memory_cells]
+            )[:self._constant_memory_cells]
         else:
             assert constant in bad_pixel_constants
+            if (
+                constant is Constants.BadPixelsCS
+                and self._get("corrections.relGain.sourceOfSlopes") == "PC"
+                or constant is Constants.BadPixelsPC
+                and self._get("corrections.relGain.sourceOfSlopes") == "CS"
+            ):
+                return
             # will simply OR with already loaded, does not take into account which ones
             constant_data = self._xp.asarray(constant_data, dtype=np.uint32)
             if len(constant_data.shape) == 3:
                 if constant_data.shape == (
-                    self.pixels_y,
-                    self.pixels_x,
-                    self.constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
+                    self._constant_memory_cells,
                 ):
                     # BadPixelsFF is not per gain stage - broadcasting along gain
                     constant_data = self._xp.broadcast_to(
                         constant_data.transpose()[..., np.newaxis],
-                        self.gm_map_shape,
+                        self._gm_map_shape,
                     )
                 elif constant_data.shape == (
-                    self.constant_memory_cells,
-                    self.pixels_y,
-                    self.pixels_x,
+                    self._constant_memory_cells,
+                    self.num_pixels_fs,
+                    self.num_pixels_ss,
                 ):
                     # old BadPixelsPC have different axis order
                     constant_data = self._xp.broadcast_to(
                         constant_data.transpose((0, 2, 1))[..., np.newaxis],
-                        self.gm_map_shape,
+                        self._gm_map_shape,
                     )
                 else:
                     raise ValueError(
@@ -236,142 +496,126 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
             else:
                 # gain mapped constants seem consistent
                 constant_data = constant_data.transpose((2, 1, 0, 3))
-            self.bad_pixel_map |= constant_data[:self.constant_memory_cells]
-
-    def override_md_additional_offset(self, override_value):
-        self._md_additional_offset_value = override_value
-        if override_value is not None:
-            self.md_additional_offset.fill(override_value)
-
-    def set_g_gain_value(self, override_value):
-        self.g_gain_value = self._xp.float32(override_value)
-
-    def set_bad_pixel_mask_value(self, mask_value):
-        self.bad_pixel_mask_value = self._xp.float32(mask_value)
-
-    def set_mg_hard_threshold(self, value):
-        self.mg_hard_threshold = self._xp.float32(value)
-
-    def set_hg_hard_threshold(self, value):
-        self.hg_hard_threshold = self._xp.float32(value)
+            constant_data &= self._xp.uint32(self.bad_pixel_subset)
+            self.bad_pixel_map |= constant_data[:self._constant_memory_cells]
 
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
             self.offset_map.fill(0)
-        if Constants.SlopesPC in constants:
-            self.rel_gain_pc_map.fill(1)
+        if constants & {Constants.SlopesCS, Constants.SlopesPC}:
+            self.rel_gain_map.fill(1)
             self.md_additional_offset.fill(0)
         if Constants.SlopesFF:
-            self.rel_gain_xray_map.fill(1)
+            self.xray_gain_map.fill(1)
         if constants & bad_pixel_constants:
             self.bad_pixel_map.fill(0)
-            self.bad_pixel_map[
-                :, 64:512:64
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, 63:511:64
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+            if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE:
+                self._mask_asic_seams
+
+    def _mask_asic_seams(self):
+        self.bad_pixel_map[:, 64:512:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[:, 63:511:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
 
 
 class AgipdCpuRunner(AgipdBaseRunner):
     _xp = np
-    _gpu_based = False
 
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        if flags & CorrectionFlags.COMMON_MODE:
+            raise NotImplementedError("Common mode not available on CPU yet")
         self.correction_kernel(
-            self.input_data,
-            self.cell_table,
+            image_data,
+            cell_table,
             flags,
             self.default_gain,
             self.gain_thresholds,
             self.mg_hard_threshold,
             self.hg_hard_threshold,
             self.offset_map,
-            self.rel_gain_pc_map,
+            self.rel_gain_map,
             self.md_additional_offset,
-            self.rel_gain_xray_map,
+            self.xray_gain_map,
             self.g_gain_value,
             self.bad_pixel_map,
             self.bad_pixel_mask_value,
-            self.processed_data,
+            processed_data,
         )
 
-    def _post_init(self):
-        self.input_data = None
-        self.cell_table = None
-        # NOTE: CPU kernel does not yet support anything other than float32
-        self.processed_data = self._xp.empty(
-            self.processed_shape, dtype=self.output_data_dtype
-        )
+    def _pre_init(self):
         from ..kernels import agipd_cython
         self.correction_kernel = agipd_cython.correct
 
-    def load_data(self, image_data, cell_table):
-        self.input_data = image_data
-        self.cell_table = cell_table
-
 
 class AgipdGpuRunner(AgipdBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
-        import cupy as cp
-
-        self._xp = cp
+        import cupy
+        self._xp = cupy
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.processed_shape, self.block_shape
-        )
-        self.correction_kernel = self._xp.RawModule(
+        self.correction_kernel = self._xp.RawKernel(
             code=base_kernel_runner.get_kernel_template("agipd_gpu.cu").render(
-                {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
-                    "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
-                    ),
-                    "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                }
-            )
-        ).get_function("correct")
-
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
+                pixels_x=self.num_pixels_ss,
+                pixels_y=self.num_pixels_fs,
+                output_data_dtype=utils.np_dtype_to_c_type(self._output_dtype),
+                corr_enum=utils.enum_to_c_template(CorrectionFlags),
+            ),
+            name="correct",
+        )
+        self.cm_kernel = self._xp.RawKernel(
+            code=base_kernel_runner.get_kernel_template("common_gpu.cu").render(
+                ss_dim=self.num_pixels_ss,
+                fs_dim=self.num_pixels_fs,
+                asic_dim=64,
+            ),
+            name="common_mode_asic",
+        )
 
-    def correct(self, flags):
-        if flags & CorrectionFlags.BLSHIFT:
-            raise NotImplementedError("Baseline shift not implemented yet")
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            processed_data.shape, block_shape
+        )
 
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.cell_table,
-                self._xp.uint8(flags),
+                image_data,
+                cell_table,
+                flags,
+                num_frames,
+                self._constant_memory_cells,
                 self.default_gain,
                 self.gain_thresholds,
                 self.mg_hard_threshold,
                 self.hg_hard_threshold,
                 self.offset_map,
-                self.rel_gain_pc_map,
+                self.rel_gain_map,
                 self.md_additional_offset,
-                self.rel_gain_xray_map,
+                self.xray_gain_map,
                 self.g_gain_value,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self.gain_map,
-                self.processed_data,
+                gain_map,
+                processed_data,
             ),
         )
 
+        if flags & CorrectionFlags.COMMON_MODE:
+            # grid: one block per ss asic, fs asic, and frame
+            self.cm_kernel(
+                (2, 8, processed_data.shape[0]),
+                (64, 1, 1),
+                (
+                    processed_data,
+                    num_frames,
+                    self.cm_iter,
+                    self.cm_min_frac,
+                    self.cm_noise_peak,
+                ),
+            )
+
 
 class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
@@ -381,9 +625,11 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
         return {
             Constants.ThresholdsDark: self.dark_condition,
             Constants.Offset: self.dark_condition,
+            Constants.SlopesCS: self.dark_condition,
             Constants.SlopesPC: self.dark_condition,
             Constants.SlopesFF: self.illuminated_condition,
             Constants.BadPixelsDark: self.dark_condition,
+            Constants.BadPixelsCS: self.dark_condition,
             Constants.BadPixelsPC: self.dark_condition,
             Constants.BadPixelsFF: self.illuminated_condition,
         }
@@ -489,69 +735,34 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend):
 @KARABO_CLASSINFO("AgipdCorrection", deviceVersion)
 class AgipdCorrection(base_correction.BaseCorrection):
     # subclass *must* set these attributes
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        # step name (used in schema), flag to enable for kernel, constants required
-        ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}),
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
-        (  # TODO: specify that /both/ constants are needed, not just one or the other
-            "forceMgIfBelow",
-            CorrectionFlags.FORCE_MG_IF_BELOW,
-            {Constants.ThresholdsDark, Constants.Offset},
-        ),
-        (
-            "forceHgIfBelow",
-            CorrectionFlags.FORCE_HG_IF_BELOW,
-            {Constants.ThresholdsDark, Constants.Offset},
-        ),
-        ("relGainPc", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesPC}),
-        ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            bad_pixel_constants | {
-                None,  # means stay available even without constants loaded
-            },
-        ),
-    )
+    _base_output_schema = schemas.xtdf_output_schema
     _calcat_friend_class = AgipdCalcatFriend
     _constant_enum_class = Constants
+    _correction_steps = correction_steps
+
     _preview_outputs = [
-        "outputRaw",
-        "outputCorrected",
-        "outputRawGain",
-        "outputGainMap",
+        PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.CELL)
+        for name in ("raw", "corrected", "rawGain", "gainMap")
     ]
 
-    @staticmethod
-    def expectedParameters(expected):
+    @classmethod
+    def expectedParameters(cls, expected):
+        # this is not automatically done by superclass for complicated class reasons
+        cls._calcat_friend_class.add_schema(expected)
+        AgipdBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(AgipdCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(352)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("cell")
+            .dataSchema(
+                AgipdCorrection._base_output_schema(
+                    use_shmem_handles=cls._use_shmem_handles
+                )
+            )
             .commit(),
         )
 
-        # this is not automatically done by superclass for complicated class reasons
-        AgipdCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, AgipdCorrection)
-        base_correction.add_correction_step_schema(
-            expected,
-            AgipdCorrection._correction_steps,
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        base_correction.add_preview_outputs(expected, AgipdCorrection._preview_outputs)
         (
             # support both CPU and GPU kernels
             STRING_ELEMENT(expected)
@@ -569,130 +780,6 @@ class AgipdCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        # turn off the force MG / HG steps by default
-        for step in ("forceMgIfBelow", "forceHgIfBelow"):
-            for flag in ("enable", "preview"):
-                (
-                    OVERWRITE_ELEMENT(expected)
-                    .key(f"corrections.{step}.{flag}")
-                    .setNewDefaultValue(False)
-                    .commit(),
-                )
-        # additional settings specific to AGIPD correction steps
-        (
-            FLOAT_ELEMENT(expected)
-            .key("corrections.forceMgIfBelow.hardThreshold")
-            .tags("managed")
-            .description(
-                "If enabled, any pixels assigned to low gain stage which would be "
-                "below this threshold after having medium gain offset subtracted will "
-                "be reassigned to medium gain."
-            )
-            .assignmentOptional()
-            .defaultValue(1000)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.forceHgIfBelow.hardThreshold")
-            .tags("managed")
-            .description(
-                "Like forceMgIfBelow, but potentially reassigning from medium gain to "
-                "high gain based on threshold and pixel value minus low gain offset. "
-                "Applied after forceMgIfBelow, pixels can theoretically be reassigned "
-                "twice, from LG to MG and to HG."
-            )
-            .assignmentOptional()
-            .defaultValue(1000)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.relGainPc.overrideMdAdditionalOffset")
-            .tags("managed")
-            .displayedName("Override md_additional_offset")
-            .description(
-                "Toggling this on will use the value in the next field globally for "
-                "md_additional_offset. Note that the correction map on GPU gets "
-                "overwritten as long as this boolean is True, so reload constants "
-                "after turning off."
-            )
-            .assignmentOptional()
-            .defaultValue(False)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.relGainPc.mdAdditionalOffset")
-            .tags("managed")
-            .displayedName("Value for md_additional_offset (if overriding)")
-            .description(
-                "Normally, md_additional_offset (part of relative gain correction) is "
-                "computed when loading SlopesPC. In case you want to use a different "
-                "value (global for all medium gain pixels), you can specify it here "
-                "and set corrections.overrideMdAdditionalOffset to True."
-            )
-            .assignmentOptional()
-            .defaultValue(0)
-            .reconfigurable()
-            .commit(),
-
-            FLOAT_ELEMENT(expected)
-            .key("corrections.gainXray.gGainValue")
-            .tags("managed")
-            .displayedName("G_gain_value")
-            .description(
-                "Newer X-ray gain correction constants are absolute. The default "
-                "G_gain_value of 1 means that output is expected to be in keV. If "
-                "this is not desired, one can here specify the mean X-ray gain value "
-                "over all modules to get ADU values out - operator must manually "
-                "find this mean value."
-            )
-            .assignmentOptional()
-            .defaultValue(1)
-            .reconfigurable()
-            .commit(),
-        )
-
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            2,
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        # note: gain mode single sourced from constant retrieval node
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        self._has_updated_bad_pixel_selection = False
-
-    @property
-    def bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection)
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "gain_mode": self.gain_mode,
-            "bad_pixel_mask_value": self.bad_pixel_mask_value,
-            "g_gain_value": self.unsafe_get("corrections.gainXray.gGainValue"),
-            "mg_hard_threshold": self.unsafe_get(
-                "corrections.forceMgIfBelow.hardThreshold"
-            ),
-            "hg_hard_threshold": self.unsafe_get(
-                "corrections.forceHgIfBelow.hardThreshold"
-            ),
-        }
-
     @property
     def _kernel_runner_class(self):
         kernel_type = base_kernel_runner.KernelRunnerTypes[
@@ -702,159 +789,3 @@ class AgipdCorrection(base_correction.BaseCorrection):
             return AgipdCpuRunner
         else:
             return AgipdGpuRunner
-
-    @property
-    def gain_mode(self):
-        return GainModes[self.unsafe_get("constantParameters.gainMode")]
-
-    @property
-    def _override_md_additional_offset(self):
-        if self.unsafe_get("corrections.relGainPc.overrideMdAdditionalOffset"):
-            return self.unsafe_get("corrections.relGainPc.mdAdditionalOffset")
-        else:
-            return None
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        """Called by input_handler for each data hash. Should correct data, optionally
-        compute preview, write data output, and optionally write preview outputs."""
-        # original shape: frame, data/raw_gain, x, y
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            self.kernel_runner.load_data(image_data, cell_table.ravel())
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            self.log_status_warn(
-                f"For debugging - cell table: {type(cell_table)}, {cell_table}"
-            )
-            self.log_status_warn(
-                f"For debugging - data: {type(image_data)}, {image_data.shape}"
-            )
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        # first prepare previews (not affected by addons)
-        self.kernel_runner.correct(self._correction_flag_preview)
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_corrected,
-            preview_raw_gain,
-            preview_gain_map,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-        # TODO: start writing out previews asynchronously in the background
-
-        # then prepare full corrected data
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        for addon in self._enabled_addons:
-            addon.post_reshape(
-                buffer_array, cell_table, pulse_table, data_hash
-            )
-
-        # reusing input data hash for sending
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_raw_gain, preview_gain_map
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        if constant in bad_pixel_constants:
-            constant_data &= self._override_bad_pixel_flags
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def preReconfigure(self, config):
-        super().preReconfigure(config)
-        if config.has("corrections.badPixels.maskingValue"):
-            # only check if it is valid; postReconfigure will use it
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-
-    def postReconfigure(self):
-        super().postReconfigure()
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("constantParameters.gainMode"):
-            self.flush_constants()
-            self._update_buffers()
-
-        if update.has("corrections.forceMgIfBelow.hardThreshold"):
-            self.kernel_runner.set_mg_hard_threshold(
-                self.get("corrections.forceMgIfBelow.hardThreshold")
-            )
-
-        if update.has("corrections.forceHgIfBelow.hardThreshold"):
-            self.kernel_runner.set_hg_hard_threshold(
-                self.get("corrections.forceHgIfBelow.hardThreshold")
-            )
-
-        if update.has("corrections.gainXray.gGainValue"):
-            self.kernel_runner.set_g_gain_value(
-                self.get("corrections.gainXray.gGainValue")
-            )
-
-        if update.has("corrections.badPixels.maskingValue"):
-            self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value)
-
-        if update.has("corrections.badPixels.subsetToUse"):
-            self.log_status_info("Updating bad pixel maps based on subset specified")
-            # note: now just always reloading from cache for convenience
-            with self.calcat_friend.cached_constants_lock:
-                self.kernel_runner.flush_buffers(bad_pixel_constants)
-                for constant in bad_pixel_constants:
-                    if constant in self.calcat_friend.cached_constants:
-                        self._load_constant_to_runner(
-                            constant, self.calcat_friend.cached_constants[constant]
-                        )
diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py
index 9a6cc613..3611b267 100644
--- a/src/calng/corrections/DsscCorrection.py
+++ b/src/calng/corrections/DsscCorrection.py
@@ -8,7 +8,14 @@ from karabo.bound import (
     STRING_ELEMENT,
 )
 
-from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils
+from .. import (
+    base_calcat,
+    base_correction,
+    base_kernel_runner,
+    schemas,
+    utils,
+)
+from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
@@ -21,131 +28,98 @@ class Constants(enum.Enum):
     Offset = enum.auto()
 
 
-class DsscBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fyx"
-    _xp = None
+correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),)
 
-    @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_y, self.pixels_x)
 
-    @property
-    def input_shape(self):
-        return (self.frames, self.pixels_y, self.pixels_x)
+class DsscBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 128
+    num_pixels_fs = 512
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.processed_data = self._xp.empty(
-            self.processed_shape, dtype=output_data_dtype
+    def expected_input_shape(self, num_frames):
+        return (
+            num_frames,
+            1,
+            self.num_pixels_ss,
+            self.num_pixels_fs,
         )
 
+    def _setup_constant_buffers(self):
+        self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32)
+
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
             self.offset_map.fill(0)
 
-    @property
-    def preview_data_views(self):
-        return (
-            self.input_data,
-            self.processed_data,
-        )
-
-    def load_offset_map(self, offset_map):
+    def _load_constant(self, constant_type, data):
+        assert constant_type is Constants.Offset
         # can have an extra dimension for some reason
-        if len(offset_map.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
-            offset_map = offset_map[..., 0]
+        if len(data.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
+            data = data[..., 0]
         # shape (now): x, y, memory cell
-        offset_map = self._xp.asarray(offset_map, dtype=np.float32).transpose()
-        self.offset_map[:] = offset_map[:self.constant_memory_cells]
+        data = self._xp.asarray(data, dtype=np.float32).transpose()
+        self.offset_map[:] = data[:self._constant_memory_cells]
+
+    def _preview_data_views(self, raw_data, processed_data):
+        return (
+            raw_data[:, 0],
+            processed_data,
+        )
 
 
 class DsscCpuRunner(DsscBaseRunner):
-    _gpu_based = False
     _xp = np
 
     def _post_init(self):
-        self.input_data = None
-        self.cell_table = None
         from ..kernels import dssc_cython
-
         self.correction_kernel = dssc_cython.correct
 
-    def load_data(self, image_data, cell_table):
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.cell_table = cell_table.astype(np.uint16, copy=False)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, output):
         self.correction_kernel(
-            self.input_data,
-            self.cell_table,
+            image_data,
+            cell_table,
             flags,
             self.offset_map,
-            self.processed_data,
+            output,
         )
 
 
 class DsscGpuRunner(DsscBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
-        import cupy as cp
-
-        self._xp = cp
+        import cupy
+        self._xp = cupy
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint16)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.input_shape, self.block_shape
-        )
         self.correction_kernel = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("dssc_gpu.cu").render(
                 {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
+                    "pixels_x": self.num_pixels_fs,
+                    "pixels_y": self.num_pixels_ss,
                     "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
+                        self._output_dtype
                     ),
                     "corr_enum": utils.enum_to_c_template(CorrectionFlags),
                 }
             )
         ).get_function("correct")
 
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, output):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(output.shape, block_shape)
+        image_data = image_data[:, 0]
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.cell_table,
-                np.uint8(flags),
+                image_data,
+                cell_table,
+                num_frames,
+                self._constant_memory_cells,
+                flags,
                 self.offset_map,
-                self.processed_data,
+                output,
             ),
         )
 
@@ -185,44 +159,30 @@ class DsscCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("DsscCorrection", deviceVersion)
 class DsscCorrection(base_correction.BaseCorrection):
-    # subclass *must* set these attributes
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),)
+    _base_output_schema = schemas.xtdf_output_schema
+    _correction_steps = correction_steps
     _calcat_friend_class = DsscCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.PULSE)
+        for name in ("raw", "corrected")
+    ]
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        DsscBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(DsscCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(400)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewDefaultValue("fyx")
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("pulse")
+            .dataSchema(
+                DsscCorrection._base_output_schema(cls._use_shmem_handles)
+            )
             .commit(),
         )
-        DsscCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, DsscCorrection)
-        base_correction.add_correction_step_schema(
-            expected,
-            DsscCorrection._correction_steps,
-        )
-        base_correction.add_preview_outputs(expected, DsscCorrection._preview_outputs)
         (
             # support both CPU and GPU kernels
             STRING_ELEMENT(expected)
@@ -240,15 +200,6 @@ class DsscCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-    @property
-    def input_data_shape(self):
-        return (
-            self.get("dataFormat.frames"),
-            1,
-            self.get("dataFormat.pixelsY"),
-            self.get("dataFormat.pixelsX"),
-        )
-
     @property
     def _kernel_runner_class(self):
         kernel_type = base_kernel_runner.KernelRunnerTypes[
@@ -258,80 +209,3 @@ class DsscCorrection(base_correction.BaseCorrection):
             return DsscCpuRunner
         else:
             return DsscGpuRunner
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            # drop the singleton dimension for kernel runner
-            self.kernel_runner.load_data(image_data[:, 0], cell_table.ravel())
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-            preview_raw, preview_corrected = self.kernel_runner.compute_previews(
-                preview_slice_index,
-            )
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(metadata, preview_raw, preview_corrected)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_offset_map(constant_data)
diff --git a/src/calng/corrections/Epix100Correction.py b/src/calng/corrections/Epix100Correction.py
index 5bd981bb..b541a72a 100644
--- a/src/calng/corrections/Epix100Correction.py
+++ b/src/calng/corrections/Epix100Correction.py
@@ -1,6 +1,5 @@
 import concurrent.futures
 import enum
-import functools
 
 import numpy as np
 from karabo.bound import (
@@ -19,6 +18,7 @@ from .. import (
     utils,
 )
 from .._version import version as deviceVersion
+from ..preview_utils import PreviewFriend, PreviewSpec
 from ..kernels import common_cython
 
 
@@ -37,6 +37,14 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 2**4
 
 
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}),
+    ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}),
+    ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}),
+    ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}),
+)
+
+
 class Epix100CalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -138,84 +146,119 @@ class Epix100CalcatFriend(base_calcat.BaseCalcatFriend):
 
 
 class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "xy"
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _bad_pixel_constants = {Constants.BadPixelsDarkEPix100}
     _xp = np
-    _gpu_based = False
 
-    @property
-    def input_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    num_pixels_ss = 708
+    num_pixels_fs = 768
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.noiseSigma")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(0.25)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def map_shape(self):
-        return (self.pixels_x, self.pixels_y)
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableRow")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,  # will be 1, will be ignored
-        constant_memory_cells,
-        input_data_dtype=np.uint16,
-        output_data_dtype=np.float32,
-    ):
-        assert (
-            output_data_dtype == np.float32
-        ), "Alternative output types not supported yet"
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableCol")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableBlock")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
 
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
+    def reconfigure(self, config):
+        super().reconfigure(config)
+        if config.has("corrections.commonMode.noiseSigma"):
+            self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"])
+        if config.has("corrections.commonMode.minFrac"):
+            self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"])
+        if config.has("corrections.commonMode.enableRow"):
+            self._cm_row = config["corrections.commonMode.enableRow"]
+        if config.has("corrections.commonMode.enableCol"):
+            self._cm_col = config["corrections.commonMode.enableCol"]
+        if config.has("corrections.commonMode.enableBlock"):
+            self._cm_block = config["corrections.commonMode.enableBlock"]
+
+    def expected_input_data_shape(self, num_frames):
+        assert num_frames == 1
+        return (
+            self.num_pixels_ss,
+            self.num_pixels_fs,
         )
 
-        self.input_data = None
-        self.processed_data = np.empty(self.processed_shape, dtype=np.float32)
+    def _expected_output_shape(self, num_frames):
+        return (self.num_pixels_ss, self.num_pixels_fs)
 
-        self.offset_map = np.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.noise_map = np.empty(self.map_shape, dtype=np.float32)
+    @property
+    def _map_shape(self):
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        # ignore parameters
+        return [
+            self._xp.empty(
+                (self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
+
+    def _setup_constant_buffers(self):
+        assert (
+            self._output_dtype == np.float32
+        ), "Alternative output types not supported yet"
+
+        self.offset_map = np.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32)
+        self.noise_map = np.empty(self._map_shape, dtype=np.float32)
         # will do everything by quadrant
         self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
-        self._q_input_data = None
-        self._q_processed_data = utils.quadrant_views(self.processed_data)
         self._q_offset_map = utils.quadrant_views(self.offset_map)
         self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map)
         self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map)
         self._q_noise_map = utils.quadrant_views(self.noise_map)
 
-    def __del__(self):
-        self.thread_pool.shutdown()
-
-    @property
-    def preview_data_views(self):
-        # NOTE: apparently there are "calibration rows"
-        # geometry assumes these are cut out already
-        return (self.input_data[2:-2], self.processed_data[2:-2])
-
-    def load_data(self, image_data):
-        # should almost never squeeze, but ePix doesn't do burst mode, right?
-        self.input_data = image_data.astype(np.uint16, copy=False).squeeze()
-        self._q_input_data = utils.quadrant_views(self.input_data)
-
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.OffsetEPix100:
             self.offset_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.RelativeGainEPix100:
             self.rel_gain_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.BadPixelsDarkEPix100:
-            self.bad_pixel_map[:] = data.squeeze()
+            self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset)
         elif constant_type is Constants.NoiseEPix100:
             self.noise_map[:] = data.squeeze()
         else:
@@ -231,127 +274,115 @@ class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner):
         if Constants.NoiseEPix100 in constants:
             self.noise_map.fill(np.inf)
 
+    def _preview_data_views(self, raw_data, processed_data):
+        # NOTE: apparently there are "calibration rows"
+        # geometry assumes these are cut out already
+        return (raw_data[2:-2], processed_data[2:-2])
+
     def _correct_quadrant(
         self,
-        q,
         flags,
-        bad_pixel_mask_value,
-        cm_noise_sigma,
-        cm_min_frac,
-        cm_row,
-        cm_col,
-        cm_block,
+        input_data,
+        offset_map,
+        noise_map,
+        bad_pixel_map,
+        rel_gain_map,
+        output,
     ):
-        output = self._q_processed_data[q]
-        output[:] = self._q_input_data[q].astype(np.float32)
+        output[:] = input_data.astype(np.float32)
         if flags & CorrectionFlags.OFFSET:
-            output -= self._q_offset_map[q]
+            output -= offset_map
 
         if flags & CorrectionFlags.COMMONMODE:
             # per rectangular block that looks like something is going on
             cm_mask = (
-                (self._q_bad_pixel_map[q] != 0)
-                | (output > self._q_noise_map[q] * cm_noise_sigma)
+                (bad_pixel_map != 0)
+                | (output > noise_map * self._cm_sigma)
             ).astype(np.uint8, copy=False)
             masked = np.ma.masked_array(data=output, mask=cm_mask)
-            if cm_block:
+            if self._cm_block:
                 for block in np.hsplit(masked, 4):
-                    if block.count() < block.size * cm_min_frac:
+                    if block.count() < block.size * self._cm_minfrac:
                         continue
                     block.data[:] -= np.ma.median(block)
 
-            if cm_row:
-                common_cython.cm_fs(output, cm_mask, cm_noise_sigma, cm_min_frac)
+            if self._cm_row:
+                common_cython.cm_fs(output, cm_mask, self._cm_sigma, self._cm_minfrac)
 
-            if cm_col:
-                common_cython.cm_ss(output, cm_mask, cm_noise_sigma, cm_min_frac)
+            if self._cm_col:
+                common_cython.cm_ss(output, cm_mask, self._cm_sigma, self._cm_minfrac)
 
         if flags & CorrectionFlags.RELGAIN:
-            output *= self._q_rel_gain_map[q]
+            output *= rel_gain_map
 
         if flags & CorrectionFlags.BPMASK:
-            output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value
+            output[bad_pixel_map != 0] = self.bad_pixel_mask_value
+
+    def _correct(self, flags, image_data, cell_table, output):
+        # ignore cell table None
+        concurrent.futures.wait(
+            [
+                self.thread_pool.submit(
+                    self._correct_quadrant(flags, *parts)
+                )
+                for parts in zip(
+                        utils.quadrant_views(image_data),
+                        self._q_offset_map,
+                        self._q_noise_map,
+                        self._q_bad_pixel_map,
+                        self._q_rel_gain_map,
+                        utils.quadrant_views(output),
+                )
+            ]
+        )
 
-    def correct(
-        self,
-        flags,
-        bad_pixel_mask_value=np.nan,
-        cm_noise_sigma=5,
-        cm_min_frac=0.25,
-        cm_row=True,
-        cm_col=True,
-        cm_block=True,
-    ):
-        # NOTE: how to best clean up all these duplicated parameters?
-        for result in self.thread_pool.map(
-            functools.partial(
-                self._correct_quadrant,
-                flags=flags,
-                bad_pixel_mask_value=bad_pixel_mask_value,
-                cm_noise_sigma=cm_noise_sigma,
-                cm_min_frac=cm_min_frac,
-                cm_row=cm_row,
-                cm_col=cm_col,
-                cm_block=cm_block,
-            ),
-            range(4),
-        ):
-            pass  # just run through to await map
+    def __del__(self):
+        self.thread_pool.shutdown()
 
 
 @KARABO_CLASSINFO("Epix100Correction", deviceVersion)
 class Epix100Correction(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}),
-        ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}),
-        ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}),
-        ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}),
-    )
-    _image_data_path = "data.image.pixels"
+    _base_output_schema = schemas.jf_output_schema
     _kernel_runner_class = Epix100CpuRunner
-    _calcat_friend_class = Epix100CalcatFriend
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
+    _calcat_friend_class = Epix100CalcatFriend
+
+    _image_data_path = "data.image.pixels"
     _cell_table_path = None
     _pulse_table_path = None
-    _warn_memory_cell_range = False
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=False,
+        )
+        for name in ("raw", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(Epix100Correction._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(708)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(768)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewOptions("xy,yx")
-            .setNewDefaultValue("xy")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # TODO: disable preview selection mode
 
             OVERWRITE_ELEMENT(expected)
             .key("useShmemHandles")
-            .setNewDefaultValue(False)
+            .setNewDefaultValue(cls._use_shmem_handles)
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
@@ -360,112 +391,11 @@ class Epix100Correction(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, Epix100Correction._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            Epix100Correction._correction_steps,
-        )
-        (
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.noiseSigma")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(5)
-            .reconfigurable()
-            .commit(),
-
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.minFrac")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(0.25)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableRow")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableCol")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableBlock")
-            .tags("managed")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-        Epix100CalcatFriend.add_schema(expected)
-        # TODO: bad pixel node?
-
-    @property
-    def input_data_shape(self):
-        # TODO: check
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
         return (
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,  # will be None
-        pulse_table,  # ditto
-    ):
-        self.kernel_runner.load_data(image_data)
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        args_which_should_be_cached = dict(
-            cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"),
-            cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"),
-            cm_row=self.unsafe_get("corrections.commonMode.enableRow"),
-            cm_col=self.unsafe_get("corrections.commonMode.enableCol"),
-            cm_block=self.unsafe_get("corrections.commonMode.enableBlock"),
+            1,
+            image_data,
+            None,
+            None,
         )
-        self.kernel_runner.correct(
-            flags=self._correction_flag_enabled, **args_which_should_be_cached
-        )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(
-                flags=self._correction_flag_preview,
-                **args_which_should_be_cached,
-            )
-
-        if self._use_shmem_handles:
-            # TODO: consider better name for key...
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        # note: base class preview machinery assumes burst mode, shortcut it
-        self._preview_friend.write_outputs(
-            metadata, *self.kernel_runner.preview_data_views
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/corrections/Gotthard2Correction.py b/src/calng/corrections/Gotthard2Correction.py
index aca011cc..f8fa85bb 100644
--- a/src/calng/corrections/Gotthard2Correction.py
+++ b/src/calng/corrections/Gotthard2Correction.py
@@ -6,17 +6,18 @@ from karabo.bound import (
     KARABO_CLASSINFO,
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
-    Hash,
-    Timestamp,
 )
 
-from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils
+from .. import (
+    base_calcat,
+    base_correction,
+    base_kernel_runner,
+    schemas,
+)
+from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec
 from .._version import version as deviceVersion
 
 
-_pretend_pulse_table = np.arange(2720, dtype=np.uint8)
-
-
 class Constants(enum.Enum):
     LUTGotthard2 = enum.auto()
     OffsetGotthard2 = enum.auto()
@@ -25,7 +26,7 @@ class Constants(enum.Enum):
     BadPixelsFFGotthard2 = enum.auto()
 
 
-bp_constant_types = {
+bad_pixel_constants = {
     Constants.BadPixelsDarkGotthard2,
     Constants.BadPixelsFFGotthard2,
 }
@@ -39,106 +40,130 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 8
 
 
-class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
-    _gpu_based = False
-    _xp = np
+correction_steps = (
+    (
+        "lut",
+        CorrectionFlags.LUT,
+        {Constants.LUTGotthard2},
+    ),
+    (
+        "offset",
+        CorrectionFlags.OFFSET,
+        {Constants.OffsetGotthard2},
+    ),
+    (
+        "gain",
+        CorrectionFlags.GAIN,
+        {Constants.RelativeGainGotthard2},
+    ),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        {
+            Constants.BadPixelsDarkGotthard2,
+            Constants.BadPixelsFFGotthard2,
+        }
+    ),
+)
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-        bad_pixel_subset=None,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
 
-        from ..kernels import gotthard2_cython
+def framesums_reduction_fun(data, cell_table, pulse_table, warn_fun):
+    return np.nansum(data, axis=1)
 
-        self.correction_kernel = gotthard2_cython.correct
-        self.input_shape = (frames, pixels_x)
-        self.processed_shape = self.input_shape
+
+class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _xp = np
+
+    num_pixels_ss = None  # 1d detector, should never access this attribute
+    num_pixels_fs = 1280
+
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+
+    def expected_input_shape(self, num_frames):
+        return (num_frames, self.num_pixels_fs)
+
+    def _expected_output_shape(self, num_frames):
+        return (num_frames, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        return [
+            self._xp.empty(
+                (num_frames, self.num_pixels_fs), dtype=np.float32)
+           ]
+
+    def _preview_data_views(self, raw_data, raw_gain, processed_data):
+        return [
+            raw_data,  # raw 1d
+            raw_gain,  # gain 1d
+            processed_data,  # corrected 1d
+            processed_data,  # frame sums
+            raw_data,  # raw streak (2d)
+            raw_gain,  # gain streak
+            processed_data,  # corrected streak
+        ]
+
+    def _setup_constant_buffers(self):
         # model: 2 buffers (corresponding to actual memory cells), 2720 frames
         # lut maps from uint12 to uint10 values
-        self.lut_shape = (2, 4096, pixels_x)
-        self.map_shape = (3, self.constant_memory_cells, self.pixels_x)
+        # TODO: override superclass properties instead
+        self.lut_shape = (2, 4096, self.num_pixels_fs)
+        self.map_shape = (3, self._constant_memory_cells, self.num_pixels_fs)
         self.lut = np.empty(self.lut_shape, dtype=np.uint16)
         self.offset_map = np.empty(self.map_shape, dtype=np.float32)
         self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
         self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-        self.flush_buffers()
-        self._bad_pixel_subset = bad_pixel_subset
+        self.flush_buffers(set(Constants))
 
-        self.input_data = None  # will just point to data coming in
-        self.input_gain_stage = None  # will just point to data coming in
-        self.processed_data = None  # will just point to buffer we're given
-
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.input_gain_stage, self.processed_data)
+    def _post_init(self):
+        from ..kernels import gotthard2_cython
+        self.correction_kernel = gotthard2_cython.correct
 
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.LUTGotthard2:
             self.lut[:] = np.transpose(data.astype(np.uint16, copy=False), (1, 2, 0))
         elif constant_type is Constants.OffsetGotthard2:
             self.offset_map[:] = np.transpose(data.astype(np.float32, copy=False))
         elif constant_type is Constants.RelativeGainGotthard2:
             self.rel_gain_map[:] = np.transpose(data.astype(np.float32, copy=False))
-        elif constant_type in bp_constant_types:
+        elif constant_type in bad_pixel_constants:
             # TODO: add the regular bad pixel subset configuration
-            data = data.astype(np.uint32, copy=False)
-            if self._bad_pixel_subset is not None:
-                data = data & self._bad_pixel_subset
-            self.bad_pixel_map |= np.transpose(data)
+            self.bad_pixel_map |= (np.transpose(data) & self.bad_pixel_subset)
         else:
             raise ValueError(f"What is this constant '{constant_type}'?")
 
-    def set_bad_pixel_subset(self, subset, apply_now=True):
-        self._bad_pixel_subset = subset
-        if apply_now:
-            self.bad_pixel_map &= self._bad_pixel_subset
-
-    def load_data(self, image_data, input_gain_stage):
-        """Experiment: loading both in one function as they are tied"""
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False)
-
-    def flush_buffers(self, constants=None):
-        # for now, always flushing all here...
-        default_lut = (
-            np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12
-        ).astype(np.uint16)
-        self.lut[:] = np.stack([np.stack([default_lut] * 2)] * self.pixels_x, axis=2)
-        self.offset_map.fill(0)
-        self.rel_gain_map.fill(1)
-        self.bad_pixel_map.fill(0)
-
-    def correct(self, flags, out=None):
-        if out is None:
-            out = np.empty(self.processed_shape, dtype=np.float32)
-
+    def flush_buffers(self, constants):
+        if Constants.LUTGotthard2 in constants:
+            default_lut = (
+                np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12
+            ).astype(np.uint16)
+            self.lut[:] = np.stack(
+                [np.stack([default_lut] * 2)] * self.num_pixels_fs, axis=2
+            )
+        if Constants.OffsetGotthard2 in constants:
+            self.offset_map.fill(0)
+        if Constants.RelativeGainGotthard2 in constants:
+            self.rel_gain_map.fill(1)
+        if bad_pixel_constants & constants:
+            self.bad_pixel_map.fill(0)
+
+    def _correct(self, flags, raw_data, cell_table, gain_stages, processed_data):
         self.correction_kernel(
-            self.input_data,
-            self.input_gain_stage,
-            np.uint8(flags),
+            raw_data,
+            gain_stages,
+            flags,
             self.lut,
             self.offset_map,
             self.rel_gain_map,
             self.bad_pixel_map,
             self.bad_pixel_mask_value,
-            out,
+            processed_data,
         )
-        self.processed_data = out
-        return out
 
 
 class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend):
@@ -227,236 +252,109 @@ class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("Gotthard2Correction", deviceVersion)
 class Gotthard2Correction(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema(use_shmem_handle=False)
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        (
-            "lut",
-            CorrectionFlags.LUT,
-            {Constants.LUTGotthard2},
-        ),
-        (
-            "offset",
-            CorrectionFlags.OFFSET,
-            {Constants.OffsetGotthard2},
-        ),
-        (
-            "gain",
-            CorrectionFlags.GAIN,
-            {Constants.RelativeGainGotthard2},
-        ),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDarkGotthard2,
-                Constants.BadPixelsFFGotthard2,
-            }
-        ),
-    )
+    _base_output_schema = schemas.jf_output_schema
     _kernel_runner_class = Gotthard2CpuRunner
     _calcat_friend_class = Gotthard2CalcatFriend
+    _correction_steps = correction_steps
     _constant_enum_class = Constants
+
     _image_data_path = "data.adc"
-    _cell_table_path = "data.memoryCell"
+    _cell_table_path = None
     _pulse_table_path = None
+
     _warn_memory_cell_range = False  # for now, receiver always writes 255
-    _preview_outputs = ["outputStreak", "outputGainStreak"]
     _cuda_pin_buffers = False
+    _use_shmem_handles = False
+
+    _preview_outputs = [
+        PreviewSpec(
+            "raw",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "gain",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "corrected",
+            dimensions=1,
+            frame_reduction=True,
+            wrap_in_imagedata=False,
+            valid_frame_selection_modes=[FrameSelectionMode.FRAME],
+        ),
+        PreviewSpec(
+            "framesums",
+            dimensions=1,
+            frame_reduction=False,
+            frame_reduction_fun=framesums_reduction_fun,
+            wrap_in_imagedata=False,
+        ),
+        PreviewSpec(
+            "rawStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+        PreviewSpec(
+            "gainStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+        PreviewSpec(
+            "correctedStreak",
+            dimensions=2,
+            frame_reduction=False,
+            wrap_in_imagedata=True,
+        ),
+    ]
 
-    @staticmethod
-    def expectedParameters(expected):
-        super(Gotthard2Correction, Gotthard2Correction).expectedParameters(expected)
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(Gotthard2Correction._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1280)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(2720)  # note: actually just frames...
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
             .key("outputShmemBufferSize")
             .setNewDefaultValue(2)
             .commit(),
-        )
-
-        base_correction.add_preview_outputs(
-            expected, Gotthard2Correction._preview_outputs
-        )
-        for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
-            # add these "manually" as the automated bits wrap ImageData
-            (
-                OUTPUT_CHANNEL(expected)
-                .key(f"preview.{channel}")
-                .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=False))
-                .commit(),
-            )
-        base_correction.add_correction_step_schema(
-            expected,
-            Gotthard2Correction._correction_steps,
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        Gotthard2CalcatFriend.add_schema(expected)
 
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            OVERWRITE_ELEMENT(expected)
+            .key("useShmemHandles")
+            .setNewDefaultValue(cls._use_shmem_handles)
+            .commit(),
         )
 
-    @property
-    def output_data_shape(self):
+    def _get_data_from_hash(self, data_hash):
+        image_data = np.asarray(
+            data_hash.get(self._image_data_path)
+        ).astype(np.uint16, copy=False)
+        gain_data = np.asarray(
+            data_hash.get("data.gain")
+        ).astype(np.uint8, copy=False)
+        num_frames = image_data.shape[0]
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(num_frames)
+            image_data.shape = expected_shape
+            gain_data.shape = expected_shape
         return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            num_frames,
+            image_data,
+            None,  # no cell table
+            None,  # no pulse table
+            gain_data,
         )
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "bad_pixel_mask_value": self._bad_pixel_mask_value,
-            "bad_pixel_subset": self._bad_pixel_subset,
-        }
-
-    @property
-    def _bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _bad_pixel_subset = property(base_correction.get_bad_pixel_field_selection)
-
-    def postReconfigure(self):
-        super().postReconfigure()
-        update = self._prereconfigure_update_hash
-        if update.has("corrections.badPixels.subsetToUse"):
-            if any(
-                update.get(
-                    f"corrections.badPixels.subsetToUse.{field.name}", default=False
-                )
-                for field in utils.BadPixelValues
-            ):
-                self.log_status_info(
-                    "Some fields reenabled, reloading cached bad pixel constants"
-                )
-                self.kernel_runner.set_bad_pixel_subset(
-                    self._bad_pixel_subset, apply_now=False
-                )
-                with self.calcat_friend.cached_constants_lock:
-                    self.kernel_runner.flush_buffers(bp_constant_types)
-                    for constant_type in (
-                        bp_constant_types & self.calcat_friend.cached_constants.keys()
-                    ):
-                        self._load_constant_to_runner(
-                            constant_type,
-                            self.calcat_friend.cached_constants[constant_type],
-                        )
-            else:
-                # just narrowing the subset - no reload, just AND
-                self.kernel_runner.set_bad_pixel_subset(
-                    self._bad_pixel_subset, apply_now=True
-                )
-        if update.has("corrections.badPixels.maskingvalue"):
-            self.kernel_runner.bad_pixel_mask_value = self._bad_pixel_mask_value
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        # cell table currently not used for GOTTHARD2 (assume alternating)
-        gain_map = np.asarray(data_hash.get("data.gain"))
-        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-            gain_map.shape = self.input_data_shape
-        try:
-            self.kernel_runner.load_data(image_data, gain_map)
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled, out=buffer_array)
-
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                _pretend_pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_gain,
-            preview_corrected,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        frame_sums = np.nansum(buffer_array, axis=1)
-        timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
-        for channel, data in (
-            ("outputRaw", preview_raw),
-            ("outputGain", preview_gain),
-            ("outputCorrected", preview_corrected),
-            ("outputFrameSums", frame_sums),
-        ):
-            self.writeChannel(
-                f"preview.{channel}",
-                Hash(
-                    "image.data",
-                    data,
-                    "image.mask",
-                    (~np.isfinite(data)).astype(np.uint8),
-                ),
-                timestamp=timestamp,
-            )
-        self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py
index ab5af5f6..c51cbe08 100644
--- a/src/calng/corrections/JungfrauCorrection.py
+++ b/src/calng/corrections/JungfrauCorrection.py
@@ -9,7 +9,6 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -19,12 +18,10 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec
 from .._version import version as deviceVersion
 
 
-_pretend_pulse_table = np.arange(16, dtype=np.uint8)
-
-
 class Constants(enum.Enum):
     Offset10Hz = enum.auto()
     BadPixelsDark10Hz = enum.auto()
@@ -57,66 +54,95 @@ class CorrectionFlags(enum.IntFlag):
     STRIXEL = 8
 
 
-class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "fyx"
-    _xp = None  # subclass sets CuPy (import at runtime) or numpy
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}),
+    ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}),
+    (
+        "badPixels",
+        CorrectionFlags.BPMASK,
+        {
+            Constants.BadPixelsDark10Hz,
+            Constants.BadPixelsFF10Hz,
+            None,
+        },
+    ),
+    ("strixel", CorrectionFlags.STRIXEL, set()),
+)
 
-    @property
-    def input_shape(self):
-        return self.frames, self.pixels_y, self.pixels_x
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 512
+    num_pixels_fs = 1024
 
     @property
-    def map_shape(self):
-        return (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        config,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-    ):
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
+    def _map_shape(self):
+        return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs, 3)
+
+    def _setup_constant_buffers(self):
         # note: superclass creates cell table with wrong dtype
-        self.input_gain_stage = self._xp.empty(self.input_shape, dtype=np.uint8)
-        self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = self._xp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = self._xp.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-
-        self.output_dtype = output_data_dtype
-        self._processed_data_regular = self._xp.empty(
-            self.processed_shape, dtype=output_data_dtype
-        )
-        self._processed_data_strixel = None
+        self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = self._xp.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._map_shape, dtype=np.uint32)
         self.flush_buffers(set(Constants))
-        self.correction_kernel_strixel = None
-        self.reconfigure(config)
+
+    def _make_output_buffers(self, num_frames, flags):
+        if flags & CorrectionFlags.STRIXEL:
+            return [
+                self._xp.empty(
+                    (num_frames,) + self._strixel_out_shape, dtype=self._output_dtype
+                )
+            ]
+        else:
+            return [
+                self._xp.empty(
+                    (num_frames, self.num_pixels_ss, self.num_pixels_fs),
+                    dtype=self._output_dtype,
+                )
+            ]
+
+    def _expected_output_shape(self, num_frames):
+        # TODO: think of how to unify with _make_output_buffers
+        if self._correction_flag_enabled & CorrectionFlags.STRIXEL:
+            return (num_frames,) + self._strixel_out_shape
+        else:
+            return (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            OVERWRITE_ELEMENT(schema)
+            .key("corrections.strixel.enable")
+            .setNewDefaultValue(False)
+            .commit(),
+
+            OVERWRITE_ELEMENT(schema)
+            .key("corrections.strixel.preview")
+            .setNewDefaultValue(False)
+            .commit(),
+
+            STRING_ELEMENT(schema)
+            .key("corrections.strixel.type")
+            .description(
+                "Which kind of strixel layout is used for this module? cols_A0123 is "
+                "the first strixel layout deployed at HED and rows_A1256 is the one "
+                "later deployed at SCS."
+            )
+            .assignmentOptional()
+            .defaultValue("cols_A0123(HED-type)")
+            .options("cols_A0123(HED-type),rows_A1256(SCS-type)")
+            .reconfigurable()
+            .commit(),
+        )
 
     def reconfigure(self, config):
-        # note: regular bad pixel masking uses device property (TODO: unify)
-        if (mask_value := config.get("badPixels.maskingValue")) is not None:
-            self._bad_pixel_mask_value = self._xp.float32(mask_value)
-            # this is a functools.partial, can just update the captured parameter
-            if self.correction_kernel_strixel is not None:
-                self.correction_kernel_strixel.keywords[
-                    "missing"
-                ] = self._bad_pixel_mask_value
-
-        if (strixel_type := config.get("strixel.type")) is not None:
+        super().reconfigure(config)
+
+        if (strixel_type := config.get("corrections.strixel.type")) is not None:
             # drop the friendly parenthesized name
             strixel_type = strixel_type.partition("(")[0]
             strixel_package = np.load(
@@ -124,18 +150,17 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
                 / f"strixel_{strixel_type}-lut_mask.npz"
             )
             self._strixel_out_shape = tuple(strixel_package["frame_shape"])
-            self._processed_data_strixel = None
             # TODO: use bad pixel masking config here
             self.correction_kernel_strixel = functools.partial(
                 utils.apply_partial_lut,
                 lut=self._xp.asarray(strixel_package["lut"]),
                 mask=self._xp.asarray(strixel_package["mask"]),
-                missing=self._bad_pixel_mask_value,
+                missing=self.bad_pixel_mask_value,
             )
             # note: always masking unmapped pixels (not respecting NON_STANDARD_SIZE)
 
-    def load_constant(self, constant, constant_data):
-        if constant_data.shape[0] == self.pixels_x:
+    def _load_constant(self, constant, constant_data):
+        if constant_data.shape[0] == self.num_pixels_fs:
             constant_data = np.transpose(constant_data, (2, 1, 0, 3))
         else:
             constant_data = np.transpose(constant_data, (2, 0, 1, 3))
@@ -145,17 +170,14 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
         elif constant is Constants.RelativeGain10Hz:
             self.rel_gain_map[:] = self._xp.asarray(constant_data, dtype=np.float32)
         elif constant in bad_pixel_constants:
-            self.bad_pixel_map |= self._xp.asarray(constant_data, dtype=np.uint32)
+            self.bad_pixel_map |= self._xp.asarray(
+                constant_data, dtype=np.uint32
+            ) & self._xp.uint32(self.bad_pixel_subset)
         else:
             raise ValueError(f"Unexpected constant type {constant}")
 
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.processed_data, self.input_gain_stage)
-
-    @property
-    def burst_mode(self):
-        return self.frames > 1
+    def _preview_data_views(self, raw_data, gain_map, processed_data):
+        return (raw_data, gain_map, processed_data)
 
     def flush_buffers(self, constants):
         if Constants.Offset10Hz in constants:
@@ -167,106 +189,80 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner):
             Constants.BadPixelsFF10Hz,
         }:
             self.bad_pixel_map.fill(0)
-            self.bad_pixel_map[
-                :, :, 255:1023:256
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, :, 256:1024:256
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
-            self.bad_pixel_map[
-                :, [255, 256]
-            ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+            if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE:
+                self._mask_asic_seams()
 
-    def override_bad_pixel_flags_to_use(self, override_value):
-        self.bad_pixel_map &= self._xp.uint32(override_value)
+    def _mask_asic_seams(self):
+        self.bad_pixel_map[
+            :, :, 255:1023:256
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[
+            :, :, 256:1024:256
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
+        self.bad_pixel_map[
+            :, [255, 256]
+        ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value
 
 
 class JungfrauGpuRunner(JungfrauBaseRunner):
-    _gpu_based = True
-
     def _pre_init(self):
         import cupy as cp
 
         self._xp = cp
 
     def _post_init(self):
-        self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16)
-        self.cell_table = self._xp.empty(self.frames, dtype=np.uint8)
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.input_shape, self.block_shape
-        )
-
         source_module = self._xp.RawModule(
             code=base_kernel_runner.get_kernel_template("jungfrau_gpu.cu").render(
                 {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
                     "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
+                        self._output_dtype
                     ),
                     "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                    "burst_mode": self.burst_mode,
                 }
             )
         )
         self.correction_kernel = source_module.get_function("correct")
 
-    def load_data(self, image_data, input_gain_stage, cell_table):
-        self.input_data.set(image_data)
-        self.input_gain_stage.set(input_gain_stage)
-        if self.burst_mode:
-            self.cell_table.set(cell_table)
-
-    def correct(self, flags):
+    def _correct(self, flags, image_data, cell_table, gain_map, output):
+        num_frames = image_data.shape[0]
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            output.shape, block_shape
+        )
+        if flags & CorrectionFlags.STRIXEL:
+            # do "regular" correction into "regular" buffer, strixel transform to output
+            final_output = output
+            output = self._xp.empty_like(image_data, dtype=np.float32)
         self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
+            grid_shape,
+            block_shape,
             (
-                self.input_data,
-                self.input_gain_stage,
-                self.cell_table,
+                image_data,
+                gain_map,
+                cell_table,
                 self._xp.uint8(flags),
+                self._xp.uint16(num_frames),
+                self._xp.uint16(self._constant_memory_cells),
+                self._xp.uint8(num_frames > 1),
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             ),
         )
         if flags & CorrectionFlags.STRIXEL:
-            if (
-                self._processed_data_strixel is None
-                or self.frames != self._processed_data_strixel.shape[0]
-            ):
-                self._processed_data_strixel = self._xp.empty(
-                    (self.frames,) + self._strixel_out_shape,
-                    dtype=self.output_dtype,
-                )
-            for pixel_frame, strixel_frame in zip(
-                self._processed_data_regular, self._processed_data_strixel
-            ):
+            for pixel_frame, strixel_frame in zip(output, final_output):
                 self.correction_kernel_strixel(
                     data=pixel_frame,
                     out=strixel_frame,
                 )
-            self.processed_data = self._processed_data_strixel
-        else:
-            self.processed_data = self._processed_data_regular
 
 
 class JungfrauCpuRunner(JungfrauBaseRunner):
-    _gpu_based = False
     _xp = np
 
-    def _post_init(self):
-        # not actually allocating, will just point to incoming data
-        self.input_data = None
-        self.input_gain_stage = None
-        self.input_cell_table = None
-
+    def _pre_init(self):
         # for computing previews faster
         self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
 
@@ -278,45 +274,39 @@ class JungfrauCpuRunner(JungfrauBaseRunner):
     def __del__(self):
         self.thread_pool.shutdown()
 
-    def load_data(self, image_data, input_gain_stage, cell_table):
-        self.input_data = image_data.astype(np.uint16, copy=False)
-        self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False)
-        self.input_cell_table = cell_table.astype(np.uint8, copy=False)
+    def _correct(self, flags, image_data, cell_table, gain_map, output):
+        if flags & CorrectionFlags.STRIXEL:
+            # do "regular" correction into "regular" buffer, strixel transform to output
+            final_output = output
+            output = self._xp.empty_like(image_data, dtype=np.float32)
 
-    def correct(self, flags):
-        if self.burst_mode:
+        # burst mode
+        if image_data.shape[0] > 1:
             self.correction_kernel_burst(
-                self.input_data,
-                self.input_gain_stage,
-                self.input_cell_table,
+                image_data,
+                gain_map,
+                cell_table,
                 flags,
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             )
         else:
             self.correction_kernel_single(
-                self.input_data,
-                self.input_gain_stage,
+                image_data,
+                gain_map,
                 flags,
                 self.offset_map,
                 self.rel_gain_map,
                 self.bad_pixel_map,
                 self.bad_pixel_mask_value,
-                self._processed_data_regular,
+                output,
             )
 
         if flags & CorrectionFlags.STRIXEL:
-            if (
-                self._processed_data_strixel is None
-                or self.frames != self._processed_data_strixel.shape[0]
-            ):
-                self._processed_data_strixel = self._xp.empty(
-                    (self.frames,) + self._strixel_out_shape,
-                    dtype=self.output_dtype,
-                )
+            # now do strixel transformation into actual output buffer
             concurrent.futures.wait(
                 [
                     self.thread_pool.submit(
@@ -324,40 +314,9 @@ class JungfrauCpuRunner(JungfrauBaseRunner):
                         data=pixel_frame,
                         out=strixel_frame,
                     )
-                    for pixel_frame, strixel_frame in zip(
-                        self._processed_data_regular, self._processed_data_strixel
-                    )
+                    for pixel_frame, strixel_frame in zip(output, final_output)
                 ]
             )
-            self.processed_data = self._processed_data_strixel
-        else:
-            self.processed_data = self._processed_data_regular
-
-    def compute_previews(self, preview_index):
-        if preview_index < -4:
-            raise ValueError(f"No statistic with code {preview_index} defined")
-        elif preview_index >= self.frames:
-            raise ValueError(f"Memory cell index {preview_index} out of range")
-
-        if preview_index >= 0:
-
-            def fun(a):
-                return a[preview_index]
-
-        elif preview_index == -1:
-            # note: separate from next case because dtype not applicable here
-            fun = functools.partial(np.nanmax, axis=0)
-        elif preview_index in (-2, -3, -4):
-            fun = functools.partial(
-                {
-                    -2: np.nanmean,
-                    -3: np.nansum,
-                    -4: np.nanstd,
-                }[preview_index],
-                axis=0,
-                dtype=np.float32,
-            )
-        return self.thread_pool.map(fun, self.preview_data_views)
 
 
 class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend):
@@ -475,59 +434,40 @@ class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("JungfrauCorrection", deviceVersion)
 class JungfrauCorrection(base_correction.BaseCorrection):
-    _base_output_schema = schemas.jf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}),
-        ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDark10Hz,
-                Constants.BadPixelsFF10Hz,
-                None,
-            },
-        ),
-        ("strixel", CorrectionFlags.STRIXEL, set()),
-    )
+    _base_output_schema = schemas.jf_output_schema
+    _correction_steps = correction_steps
     _calcat_friend_class = JungfrauCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = [
-        "outputRaw",
-        "outputCorrected",
-        "outputGainMap",
-    ]
+
     _image_data_path = "data.adc"
     _cell_table_path = "data.memoryCell"
     _pulse_table_path = None
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=True,
+            valid_frame_selection_modes=(
+                FrameSelectionMode.FRAME,
+                FrameSelectionMode.CELL,
+            ),
+        ) for name in ("raw", "gain", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        JungfrauBaseRunner.add_schema(expected)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(JungfrauCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(512)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # JUNGFRAU data is small, can fit plenty of trains in here
@@ -535,6 +475,11 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             .key("outputShmemBufferSize")
             .setNewDefaultValue(2)
             .commit(),
+
+            OVERWRITE_ELEMENT(expected)
+            .key("useShmemHandles")
+            .setNewDefaultValue(cls._use_shmem_handles)
+            .commit(),
         )
         (
             # support both CPU and GPU kernels
@@ -553,53 +498,30 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, JungfrauCorrection._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            JungfrauCorrection._correction_steps,
-        )
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("corrections.strixel.enable")
-            .setNewDefaultValue(False)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("corrections.strixel.preview")
-            .setNewDefaultValue(False)
-            .commit(),
-
-            STRING_ELEMENT(expected)
-            .key("corrections.strixel.type")
-            .description(
-                "Which kind of strixel layout is used for this module? cols_A0123 is "
-                "the first strixel layout deployed at HED and rows_A1256 is the one "
-                "later deployed at SCS."
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
+        gain_data = np.asarray(data_hash.get("data.gain")).astype(np.uint8, copy=False)
+        # explicit np.asarray because types are special here
+        cell_table = np.asarray(
+            data_hash[self._cell_table_path]
+        ).astype(np.uint8, copy=False)
+        num_frames = cell_table.size
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(
+                num_frames
             )
-            .assignmentOptional()
-            .defaultValue("cols_A0123(HED-type)")
-            .options("cols_A0123(HED-type),rows_A1256(SCS-type)")
-            .reconfigurable()
-            .commit(),
-        )
-        base_correction.add_bad_pixel_config_node(expected)
-        JungfrauCalcatFriend.add_schema(expected)
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
+                gain_data.shape = expected_shape
 
-    @property
-    def input_data_shape(self):
         return (
-            self.unsafe_get("dataFormat.frames"),
-            self.unsafe_get("dataFormat.pixelsY"),
-            self.unsafe_get("dataFormat.pixelsX"),
+            num_frames,
+            image_data,  # not as ndarray as runner can do that
+            cell_table,
+            None,
+            gain_data,
         )
 
-    @property
-    def _warn_memory_cell_range(self):
-        # disable warning in normal operation as cell 15 is expected
-        return self.unsafe_get("dataFormat.frames") > 1
-
     @property
     def output_data_shape(self):
         if self._correction_flag_enabled & CorrectionFlags.STRIXEL:
@@ -623,134 +545,3 @@ class JungfrauCorrection(base_correction.BaseCorrection):
             return JungfrauCpuRunner
         else:
             return JungfrauGpuRunner
-
-    @property
-    def _kernel_runner_init_args(self):
-        return {
-            "bad_pixel_mask_value": self.bad_pixel_mask_value,
-            # temporary: will refactor base class to always pass config node
-            "config": self.get("corrections"),
-        }
-
-    @property
-    def bad_pixel_mask_value(self):
-        return np.float32(self.unsafe_get("corrections.badPixels.maskingValue"))
-
-    _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection)
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue", default="nan"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        if config.get("useShmemHandles", default=True):
-            schema_override = Schema()
-            (
-                OUTPUT_CHANNEL(schema_override)
-                .key("dataOutput")
-                .dataSchema(schemas.jf_output_schema(use_shmem_handle=False))
-                .commit(),
-            )
-            self.updateSchema(schema_override)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        if len(cell_table.shape) == 0:
-            cell_table = cell_table[np.newaxis]
-        try:
-            gain_map = data_hash["data.gain"]
-            if self.unsafe_get("workarounds.overrideInputAxisOrder"):
-                gain_map.shape = self.input_data_shape
-            self.kernel_runner.load_data(image_data, gain_map, cell_table)
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                _pretend_pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-        (
-            preview_raw,
-            preview_corrected,
-            preview_gain_map,
-        ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        # reusing input data hash for sending
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            # TODO: use shmem for data.gain, too
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_gain_map
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        if constant in bad_pixel_constants:
-            constant_data &= self._override_bad_pixel_flags
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def preReconfigure(self, config):
-        super().preReconfigure(config)
-        if config.has("corrections"):
-            self.kernel_runner.reconfigure(config["corrections"])
-
-    def postReconfigure(self):
-        super().postReconfigure()
-
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            return
-
-        update = self._prereconfigure_update_hash
-
-        if update.has("corrections.strixel.enable"):
-            self._lock_and_update(self._update_frame_filter)
-
-        if update.has("corrections.badPixels.subsetToUse"):
-            self.log_status_info("Updating bad pixel maps based on subset specified")
-            # note: now just always reloading from cache for convenience
-            with self.calcat_friend.cached_constants_lock:
-                self.kernel_runner.flush_buffers(bad_pixel_constants)
-                for constant in bad_pixel_constants:
-                    if constant in self.calcat_friend.cached_constants:
-                        self._load_constant_to_runner(
-                            constant, self.calcat_friend.cached_constants[constant]
-                        )
diff --git a/src/calng/corrections/LpdCorrection.py b/src/calng/corrections/LpdCorrection.py
index 98e81e1e..325f5b21 100644
--- a/src/calng/corrections/LpdCorrection.py
+++ b/src/calng/corrections/LpdCorrection.py
@@ -8,7 +8,6 @@ from karabo.bound import (
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
     STRING_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -18,6 +17,7 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 
 
@@ -30,6 +30,12 @@ class Constants(enum.Enum):
     BadPixelsFF = enum.auto()
 
 
+bad_pixel_constants = {
+    Constants.BadPixelsDark,
+    Constants.BadPixelsFF,
+}
+
+
 class CorrectionFlags(enum.IntFlag):
     NONE = 0
     OFFSET = 1
@@ -39,130 +45,85 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 16
 
 
-class LpdGpuRunner(base_kernel_runner.BaseKernelRunner):
-    _gpu_based = True
-    _corrected_axis_order = "fyx"
-
-    @property
-    def input_shape(self):
-        return (self.frames, 1, self.pixels_y, self.pixels_x)
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
+    ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}),
+    ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}),
+    ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}),
+    ("badPixels", CorrectionFlags.BPMASK, bad_pixel_constants),
+)
 
-    @property
-    def processed_shape(self):
-        return (self.frames, self.pixels_y, self.pixels_x)
-
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,
-        constant_memory_cells,
-        output_data_dtype=np.float32,
-        bad_pixel_mask_value=np.nan,
-    ):
-        import cupy as cp
-        self._xp = cp
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-        self.gain_map = cp.empty(self.processed_shape, dtype=np.float32)
-        self.input_data = cp.empty(self.input_shape, dtype=np.uint16)
-        self.processed_data = cp.empty(self.processed_shape, dtype=output_data_dtype)
-        self.cell_table = cp.empty(frames, dtype=np.uint16)
-
-        self.map_shape = (constant_memory_cells, pixels_y, pixels_x, 3)
-        self.offset_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.gain_amp_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_slopes_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.flatfield_map = cp.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = cp.empty(self.map_shape, dtype=np.uint32)
-        self.bad_pixel_mask_value = bad_pixel_mask_value
-        self.flush_buffers(set(Constants))
 
-        self.correction_kernel = cp.RawModule(
-            code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render(
-                {
-                    "pixels_x": self.pixels_x,
-                    "pixels_y": self.pixels_y,
-                    "frames": self.frames,
-                    "constant_memory_cells": self.constant_memory_cells,
-                    "output_data_dtype": utils.np_dtype_to_c_type(
-                        self.output_data_dtype
-                    ),
-                    "corr_enum": utils.enum_to_c_template(CorrectionFlags),
-                }
-            )
-        ).get_function("correct")
+class LpdBaseRunner(base_kernel_runner.BaseKernelRunner):
+    _bad_pixel_constants = bad_pixel_constants
+    _correction_flag_class = CorrectionFlags
+    _correction_steps = correction_steps
+    num_pixels_ss = 256
+    num_pixels_fs = 256
 
-        self.block_shape = (1, 1, 64)
-        self.grid_shape = utils.grid_to_cover_shape_with_blocks(
-            self.processed_shape, self.block_shape
-        )
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
 
-    def load_data(self, image_data, cell_table):
-        self.input_data.set(image_data)
-        self.cell_table.set(cell_table)
+    def expected_input_shape(self, num_frames):
+        return (num_frames, 1, self.num_pixels_ss, self.num_pixels_fs)
 
     @property
-    def preview_data_views(self):
+    def _gm_map_shape(self):
+        return self._map_shape + (3,)  # for gain-mapped constants
+
+    def _setup_constant_buffers(self):
+        self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.gain_amp_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.rel_gain_slopes_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.flatfield_map = self._xp.empty(self._gm_map_shape, dtype=np.float32)
+        self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32)
+        self.flush_buffers(set(Constants))
+
+    def _make_output_buffers(self, num_frames, flags):
+        output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs)
+        return [
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # image
+            self._xp.empty(output_shape, dtype=self._xp.float32),  # gain
+        ]
+
+    def _preview_data_views(self, raw_data, processed_data, gain_map):
         # TODO: always split off gain from raw to avoid messing up preview?
         return (
-            self.input_data[:, 0],  # raw
-            self.processed_data,  # corrected
-            self.gain_map,  # gain (split from raw)
+            raw_data[:, 0],  # raw
+            processed_data,  # corrected
+            gain_map,  # gain (split from raw)
         )
 
-    def correct(self, flags):
-        self.correction_kernel(
-            self.grid_shape,
-            self.block_shape,
-            (
-                self.input_data,
-                self.cell_table,
-                self._xp.uint8(flags),
-                self.offset_map,
-                self.gain_amp_map,
-                self.rel_gain_slopes_map,
-                self.flatfield_map,
-                self.bad_pixel_map,
-                self.bad_pixel_mask_value,
-                self.gain_map,
-                self.processed_data,
-            ),
-        )
-
-    def load_constant(self, constant_type, constant_data):
-        # constant type → transpose order
-        bad_pixel_loading = {
-            Constants.BadPixelsDark: (2, 1, 0, 3),
-            Constants.BadPixelsFF: (2, 0, 1, 3),
-        }
+    def _load_constant(self, constant_type, constant_data):
         # constant type → transpose order, GPU buffer
-        other_constant_loading = {
+        transpose_order, my_buffer = {
+            Constants.BadPixelsDark: ((2, 1, 0, 3), self.bad_pixel_map),
+            Constants.BadPixelsFF: ((2, 0, 1, 3), self.bad_pixel_map),
             Constants.Offset: ((2, 1, 0, 3), self.offset_map),
             Constants.GainAmpMap: ((2, 0, 1, 3), self.gain_amp_map),
             Constants.FFMap: ((2, 0, 1, 3), self.flatfield_map),
             Constants.RelativeGain: ((2, 0, 1, 3), self.rel_gain_slopes_map),
-        }
-        if constant_type in bad_pixel_loading:
-            self.bad_pixel_map |= self._xp.asarray(
-                constant_data.transpose(bad_pixel_loading[constant_type]),
-                dtype=np.uint32,
-            )[:self.constant_memory_cells]
-        elif constant_type in other_constant_loading:
-            transpose_order, gpu_buffer = other_constant_loading[constant_type]
-            gpu_buffer.set(
-                np.transpose(
-                    constant_data.astype(np.float32),
-                    transpose_order,
-                )[:self.constant_memory_cells]
+        }[constant_type]
+
+        if constant_type in bad_pixel_constants:
+            my_buffer |= (
+                self._xp.asarray(
+                    constant_data.transpose(transpose_order),
+                    dtype=np.uint32,
+                )[:self._constant_memory_cells]
+                & self._xp.uint32(self.bad_pixel_subset)
             )
         else:
-            raise ValueError(f"Unhandled constant type {constant_type}")
+            my_buffer[:] = (
+                self._xp.asarray(
+                    self._xp.transpose(
+                        constant_data.astype(np.float32),
+                        transpose_order,
+                    )[:self._constant_memory_cells]
+                )
+            )
 
     def flush_buffers(self, constants):
         if Constants.Offset in constants:
@@ -173,10 +134,79 @@ class LpdGpuRunner(base_kernel_runner.BaseKernelRunner):
             self.rel_gain_slopes_map.fill(1)
         if Constants.FFMap in constants:
             self.flatfield_map.fill(1)
-        if constants & {Constants.BadPixelsDark, Constants.BadPixelsFF}:
+        if constants & bad_pixel_constants:
             self.bad_pixel_map.fill(0)
 
 
+class LpdGpuRunner(LpdBaseRunner):
+    def _pre_init(self):
+        import cupy
+        self._xp = cupy
+
+    def _post_init(self):
+        self.correction_kernel = self._xp.RawModule(
+            code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render(
+                {
+                    "ss_dim": self.num_pixels_ss,
+                    "fs_dim": self.num_pixels_fs,
+                    "output_data_dtype": utils.np_dtype_to_c_type(
+                        self._output_dtype
+                    ),
+                    "corr_enum": utils.enum_to_c_template(self._correction_flag_class),
+                }
+            )
+        ).get_function("correct")
+
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        num_frames = self._xp.uint16(image_data.shape[0])
+        block_shape = (1, 1, 64)
+        grid_shape = utils.grid_to_cover_shape_with_blocks(
+            processed_data.shape, block_shape
+        )
+        self.correction_kernel(
+            grid_shape,
+            block_shape,
+            (
+                image_data,
+                cell_table,
+                flags,
+                num_frames,
+                self._constant_memory_cells,
+                self.offset_map,
+                self.gain_amp_map,
+                self.rel_gain_slopes_map,
+                self.flatfield_map,
+                self.bad_pixel_map,
+                self.bad_pixel_mask_value,
+                gain_map,
+                processed_data,
+            ),
+        )
+
+
+class LpdCpuRunner(LpdBaseRunner):
+    _xp = np
+
+    def _post_init(self):
+        from ..kernels import lpd_cython
+        self.correction_kernel = lpd_cython.correct
+
+    def _correct(self, flags, image_data, cell_table, processed_data, gain_map):
+        self.correction_kernel(
+            image_data,
+            cell_table,
+            flags,
+            self.offset_map,
+            self.gain_amp_map,
+            self.rel_gain_slopes_map,
+            self.flatfield_map,
+            self.bad_pixel_map,
+            self.bad_pixel_mask_value,
+            gain_map,
+            processed_data,
+        )
+
+
 class LpdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -294,216 +324,25 @@ class LpdCalcatFriend(base_calcat.BaseCalcatFriend):
 
 @KARABO_CLASSINFO("LpdCorrection", deviceVersion)
 class LpdCorrection(base_correction.BaseCorrection):
-    _base_output_schema = schemas.xtdf_output_schema()
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.Offset}),
-        ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}),
-        ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}),
-        ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}),
-        (
-            "badPixels",
-            CorrectionFlags.BPMASK,
-            {
-                Constants.BadPixelsDark,
-                Constants.BadPixelsFF,
-            }
-        ),
-    )
+    _base_output_schema = schemas.xtdf_output_schema
+    _correction_steps = correction_steps
     _kernel_runner_class = LpdGpuRunner
     _calcat_friend_class = LpdCalcatFriend
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected", "outputGainMap"]
 
-    @staticmethod
-    def expectedParameters(expected):
-        (
-            OUTPUT_CHANNEL(expected)
-            .key("dataOutput")
-            .dataSchema(LpdCorrection._base_output_schema)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(256)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(256)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(512)
-            .commit(),
-
-            # TODO: determine ideal default
-            OVERWRITE_ELEMENT(expected)
-            .key("preview.selectionMode")
-            .setNewDefaultValue("frame")
-            .commit(),
-        )
+    _preview_outputs = [PreviewSpec(name) for name in ("raw", "corrected", "gainMap")]
 
-        # TODO: add bad pixel config node
+    @classmethod
+    def expectedParameters(cls, expected):
         LpdCalcatFriend.add_schema(expected)
-        base_correction.add_addon_nodes(expected, LpdCorrection)
-        base_correction.add_correction_step_schema(
-            expected, LpdCorrection._correction_steps
-        )
-        base_correction.add_preview_outputs(expected, LpdCorrection._preview_outputs)
-
-        # additional settings for correction steps
+        LpdBaseRunner.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
-            STRING_ELEMENT(expected)
-            .key("corrections.badPixels.maskingValue")
-            .tags("managed")
-            .displayedName("Bad pixel masking value")
-            .description(
-                "Any pixels masked by the bad pixel mask will have their value "
-                "replaced with this. Note that this parameter is to be interpreted as "
-                "a numpy.float32; use 'nan' to get NaN value."
+            OUTPUT_CHANNEL(expected)
+            .key("dataOutput")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
             )
-            .assignmentOptional()
-            .defaultValue("nan")
-            .reconfigurable()
             .commit(),
         )
-
-    @property
-    def input_data_shape(self):
-        return (
-            self.unsafe_get("dataFormat.frames"),
-            1,
-            self.unsafe_get("dataFormat.pixelsY"),
-            self.unsafe_get("dataFormat.pixelsX"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        try:
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-        except ValueError:
-            config["corrections.badPixels.maskingValue"] = "nan"
-
-        if config.get("useShmemHandles", default=False):
-            def aux():
-                schema_override = Schema()
-                (
-                    OUTPUT_CHANNEL(schema_override)
-                    .key("dataOutput")
-                    .dataSchema(schemas.xtdf_output_schema(use_shmem_handle=False))
-                    .commit(),
-                )
-
-            self.registerInitialFunction(aux)
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,
-        pulse_table,
-    ):
-        with self.warning_context(
-            "processingState",
-            base_correction.WarningLampType.CONSTANT_OPERATING_PARAMETERS,
-        ) as warn:
-            if (
-                cell_table_string := utils.cell_table_to_string(cell_table)
-            ) != self.unsafe_get(
-                "constantParameters.memoryCellOrder"
-            ) and self.unsafe_get(
-                "constantParameters.useMemoryCellOrder"
-            ):
-                warn(f"Cell order does not match input; input: {cell_table_string}")
-        if self._frame_filter is not None:
-            try:
-                cell_table = cell_table[self._frame_filter]
-                pulse_table = pulse_table[self._frame_filter]
-                image_data = image_data[self._frame_filter]
-            except IndexError:
-                self.log_status_warn(
-                    "Failed to apply frame filter, please check that it is valid!"
-                )
-                return
-
-        try:
-            # drop the singleton dimension for kernel runner
-            self.kernel_runner.load_data(image_data, cell_table)
-        except ValueError as e:
-            self.log_status_warn(f"Failed to load data: {e}")
-            return
-        except Exception as e:
-            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        self.kernel_runner.correct(self._correction_flag_enabled)
-        for addon in self._enabled_addons:
-            addon.post_correction(
-                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
-            )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        with self.warning_context(
-            "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS
-        ) as warn:
-            if self._correction_flag_enabled != self._correction_flag_preview:
-                self.kernel_runner.correct(self._correction_flag_preview)
-            (
-                preview_slice_index,
-                preview_cell,
-                preview_pulse,
-            ), preview_warning = utils.pick_frame_index(
-                self.unsafe_get("preview.selectionMode"),
-                self.unsafe_get("preview.index"),
-                cell_table,
-                pulse_table,
-            )
-            if preview_warning is not None:
-                warn(preview_warning)
-            (
-                preview_raw,
-                preview_corrected,
-                preview_gain_map,
-            ) = self.kernel_runner.compute_previews(preview_slice_index)
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        data_hash.set(self._cell_table_path, cell_table[:, np.newaxis])
-        data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis])
-
-        self._write_output(data_hash, metadata)
-        self._preview_friend.write_outputs(
-            metadata, preview_raw, preview_corrected, preview_gain_map
-        )
-
-    def preReconfigure(self, config):
-        # TODO: DRY (taken from AGIPD device)
-        super().preReconfigure(config)
-        if config.has("corrections.badPixels.maskingValue"):
-            # only check if it is valid (let raise exception)
-            # if valid, postReconfigure will use it
-            np.float32(config.get("corrections.badPixels.maskingValue"))
-
-    def postReconfigure(self):
-        super().postReconfigure()
-        if not hasattr(self, "_prereconfigure_update_hash"):
-            return
-
-        update = self._prereconfigure_update_hash
-        if update.has("corrections.badPixels.maskingValue"):
-            self.kernel_runner.bad_pixel_mask_value = self.bad_pixel_mask_value
diff --git a/src/calng/corrections/LpdminiCorrection.py b/src/calng/corrections/LpdminiCorrection.py
index f43650ad..4a5d5c24 100644
--- a/src/calng/corrections/LpdminiCorrection.py
+++ b/src/calng/corrections/LpdminiCorrection.py
@@ -6,13 +6,13 @@ from karabo.bound import (
 )
 
 from .._version import version as deviceVersion
-from ..base_correction import add_correction_step_schema
 from . import LpdCorrection
 
 
 class LpdminiGpuRunner(LpdCorrection.LpdGpuRunner):
+    num_pixels_ss = 32
+
     def load_constant(self, constant_type, constant_data):
-        print(f"Given: {constant_type} with shape {constant_data.shape}")
         # constant type → transpose order
         constant_buffer_map = {
             LpdCorrection.Constants.Offset: self.offset_map,
@@ -74,20 +74,3 @@ class LpdminiCalcatFriend(LpdCorrection.LpdCalcatFriend):
 class LpdminiCorrection(LpdCorrection.LpdCorrection):
     _calcat_friend_class = LpdminiCalcatFriend
     _kernel_runner_class = LpdminiGpuRunner
-
-    @classmethod
-    def expectedParameters(cls, expected):
-        (
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(32)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(512)
-            .commit(),
-        )
-        cls._calcat_friend_class.add_schema(expected)
-        # warning: this is redundant, but needed for now to get managed keys working
-        add_correction_step_schema(expected, cls._correction_steps)
diff --git a/src/calng/corrections/PnccdCorrection.py b/src/calng/corrections/PnccdCorrection.py
index 0b941e2e..0156e200 100644
--- a/src/calng/corrections/PnccdCorrection.py
+++ b/src/calng/corrections/PnccdCorrection.py
@@ -1,6 +1,5 @@
 import concurrent.futures
 import enum
-import functools
 
 import numpy as np
 from karabo.bound import (
@@ -9,7 +8,6 @@ from karabo.bound import (
     KARABO_CLASSINFO,
     OUTPUT_CHANNEL,
     OVERWRITE_ELEMENT,
-    Schema,
 )
 
 from .. import (
@@ -19,6 +17,7 @@ from .. import (
     schemas,
     utils,
 )
+from ..preview_utils import PreviewFriend, PreviewSpec
 from .._version import version as deviceVersion
 from ..kernels import common_cython
 
@@ -39,6 +38,14 @@ class CorrectionFlags(enum.IntFlag):
     BPMASK = 2**4
 
 
+correction_steps = (
+    ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}),
+    ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}),
+    ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}),
+    ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}),
+)
+
+
 class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
     _constant_enum_class = Constants
 
@@ -51,30 +58,30 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
             Constants.RelativeGainCCD: self.illuminated_condition,
         }
 
-    @staticmethod
-    def add_schema(schema):
-        super(PnccdCalcatFriend, PnccdCalcatFriend).add_schema(schema, "pnCCD-Type")
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema, "pnCCD-Type")
 
         # set some defaults for common parameters
         (
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.memoryCells")
-            .setNewDefaultValue(1)
+            .key("constantParameters.pixelsX")
+            .setNewDefaultValue(1024)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.biasVoltage")
-            .setNewDefaultValue(300)
+            .key("constantParameters.pixelsY")
+            .setNewDefaultValue(1024)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.pixelsX")
-            .setNewDefaultValue(1024)
+            .key("constantParameters.memoryCells")
+            .setNewDefaultValue(1)
             .commit(),
 
             OVERWRITE_ELEMENT(schema)
-            .key("constantParameters.pixelsY")
-            .setNewDefaultValue(1024)
+            .key("constantParameters.biasVoltage")
+            .setNewDefaultValue(300)
             .commit(),
         )
 
@@ -148,82 +155,104 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend):
 
 
 class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner):
-    _corrected_axis_order = "xy"
+    _correction_steps = correction_steps
+    _correction_flag_class = CorrectionFlags
+    _bad_pixel_constants = {Constants.BadPixelsDarkCCD}
     _xp = np
-    _gpu_based = False
 
-    @property
-    def input_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    num_pixels_ss = 1024
+    num_pixels_fs = 1024
 
-    @property
-    def preview_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    @classmethod
+    def add_schema(cls, schema):
+        super(cls, cls).add_schema(schema)
+        super(cls, cls).add_bad_pixel_config(schema)
+        (
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.noiseSigma")
+            .assignmentOptional()
+            .defaultValue(5)
+            .reconfigurable()
+            .commit(),
 
-    @property
-    def processed_shape(self):
-        return self.input_shape
+            DOUBLE_ELEMENT(schema)
+            .key("corrections.commonMode.minFrac")
+            .assignmentOptional()
+            .defaultValue(0.25)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableRow")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+
+            BOOL_ELEMENT(schema)
+            .key("corrections.commonMode.enableCol")
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
+        )
+
+    def reconfigure(self, config):
+        super().reconfigure(config)
+        if config.has("corrections.commonMode.noiseSigma"):
+            self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"])
+        if config.has("corrections.commonMode.minFrac"):
+            self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"])
+        if config.has("corrections.commonMode.enableRow"):
+            self._cm_row = config["corrections.commonMode.enableRow"]
+        if config.has("corrections.commonMode.enableCol"):
+            self._cm_col = config["corrections.commonMode.enableCol"]
+
+    def expected_input_shape(self, num_frames):
+        assert num_frames == 1, "pnCCD not expected to have multiple frames"
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _expected_output_shape(self, num_frames):
+        return (self.num_pixels_ss, self.num_pixels_fs)
 
     @property
-    def map_shape(self):
-        return (self.pixels_x, self.pixels_y)
+    def _map_shape(self):
+        return (self.num_pixels_ss, self.num_pixels_fs)
+
+    def _make_output_buffers(self, num_frames, flags):
+        # ignore parameters
+        return [
+            self._xp.empty(
+                (self.num_pixels_ss, self.num_pixels_fs),
+                dtype=self._xp.float32,
+            )
+        ]
 
-    def __init__(
-        self,
-        pixels_x,
-        pixels_y,
-        frames,  # will be 1, will be ignored
-        constant_memory_cells,
-        input_data_dtype=np.uint16,
-        output_data_dtype=np.float32,
-    ):
+    def _post_init(self):
+        self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
+
+    def _setup_constant_buffers(self):
         assert (
-            output_data_dtype == np.float32
+            self._output_dtype == np.float32
         ), "Alternative output types not supported yet"
 
-        super().__init__(
-            pixels_x,
-            pixels_y,
-            frames,
-            constant_memory_cells,
-            output_data_dtype,
-        )
-
-        self.input_data = None
-        self.processed_data = np.empty(self.processed_shape, dtype=np.float32)
-
-        self.offset_map = np.empty(self.map_shape, dtype=np.float32)
-        self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
-        self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32)
-        self.noise_map = np.empty(self.map_shape, dtype=np.float32)
+        self.offset_map = np.empty(self._map_shape, dtype=np.float32)
+        self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32)
+        self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32)
+        self.noise_map = np.empty(self._map_shape, dtype=np.float32)
         # will do everything by quadrant
-        self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
-        self._q_input_data = None
-        self._q_processed_data = utils.quadrant_views(self.processed_data)
         self._q_offset_map = utils.quadrant_views(self.offset_map)
         self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map)
         self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map)
         self._q_noise_map = utils.quadrant_views(self.noise_map)
 
-    def __del__(self):
-        self.thread_pool.shutdown()
-
-    @property
-    def preview_data_views(self):
-        return (self.input_data, self.processed_data)
-
-    def load_data(self, image_data):
-        # should almost never squeeze, but ePix doesn't do burst mode, right?
-        self.input_data = image_data.astype(np.uint16, copy=False).squeeze()
-        self._q_input_data = utils.quadrant_views(self.input_data)
-
-    def load_constant(self, constant_type, data):
+    def _load_constant(self, constant_type, data):
         if constant_type is Constants.OffsetCCD:
             self.offset_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.RelativeGainCCD:
             self.rel_gain_map[:] = data.squeeze().astype(np.float32)
         elif constant_type is Constants.BadPixelsDarkCCD:
-            self.bad_pixel_map[:] = data.squeeze()
+            self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset)
         elif constant_type is Constants.NoiseCCD:
             self.noise_map[:] = data.squeeze()
         else:
@@ -241,123 +270,116 @@ class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner):
 
     def _correct_quadrant(
         self,
-        q,
         flags,
-        bad_pixel_mask_value,
-        cm_noise_sigma,
-        cm_min_frac,
-        cm_row,
-        cm_col,
+        input_data,
+        offset_map,
+        noise_map,
+        bad_pixel_map,
+        rel_gain_map,
+        output,
     ):
-        output = self._q_processed_data[q]
-        output[:] = self._q_input_data[q].astype(np.float32)
+        output[:] = input_data.astype(np.float32)
         if flags & CorrectionFlags.OFFSET:
-            output -= self._q_offset_map[q]
+            output -= offset_map
 
         if flags & CorrectionFlags.COMMONMODE:
             cm_mask = (
-                (self._q_bad_pixel_map[q] != 0)
-                | (output > self._q_noise_map[q] * cm_noise_sigma)
+                (bad_pixel_map != 0)
+                | (output > noise_map * self._cm_sigma)
             ).astype(np.uint8, copy=False)
-            if cm_row:
+            if self._cm_row:
                 common_cython.cm_fs(
                     output,
                     cm_mask,
-                    cm_noise_sigma,
-                    cm_min_frac,
+                    self._cm_sigma,
+                    self._cm_minfrac,
                 )
 
-            if cm_col:
+            if self._cm_col:
                 common_cython.cm_ss(
                     output,
                     cm_mask,
-                    cm_noise_sigma,
-                    cm_min_frac,
+                    self._cm_sigma,
+                    self._cm_minfrac,
                 )
 
         if flags & CorrectionFlags.RELGAIN:
-            output *= self._q_rel_gain_map[q]
+            output *= rel_gain_map
 
         if flags & CorrectionFlags.BPMASK:
-            output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value
+            output[bad_pixel_map != 0] = self.bad_pixel_mask_value
+
+    def _correct(self, flags, image_data, cell_table, output):
+        # cell_table is going to be None for now, just ignore
+        concurrent.futures.wait(
+            [
+                self.thread_pool.submit(
+                    self._correct_quadrant(flags, *parts)
+                )
+                for parts in zip(
+                        utils.quadrant_views(image_data),
+                        self._q_offset_map,
+                        self._q_noise_map,
+                        self._q_bad_pixel_map,
+                        self._q_rel_gain_map,
+                        utils.quadrant_views(output),
+                )
+            ]
+        )
 
-    def correct(
-        self,
-        flags,
-        bad_pixel_mask_value=np.nan,
-        cm_noise_sigma=5,
-        cm_min_frac=0.25,
-        cm_row=True,
-        cm_col=True,
-    ):
-        # NOTE: how to best clean up all these duplicated parameters?
-        for result in self.thread_pool.map(
-            functools.partial(
-                self._correct_quadrant,
-                flags=flags,
-                bad_pixel_mask_value=bad_pixel_mask_value,
-                cm_noise_sigma=cm_noise_sigma,
-                cm_min_frac=cm_min_frac,
-                cm_row=cm_row,
-                cm_col=cm_col,
-            ),
-            range(4),
-        ):
-            pass  # just run through to await map
+    def __del__(self):
+        self.thread_pool.shutdown()
+
+    def _preview_data_views(self, raw_data, processed_data):
+        return [
+            raw_data[np.newaxis],
+            processed_data[np.newaxis],
+        ]
 
 
 @KARABO_CLASSINFO("PnccdCorrection", deviceVersion)
 class PnccdCorrection(base_correction.BaseCorrection):
-    _correction_flag_class = CorrectionFlags
-    _correction_steps = (
-        ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}),
-        ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}),
-        ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}),
-        ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}),
-    )
-    _image_data_path = "data.image"
+    _base_output_schema = schemas.pnccd_output_schema
     _kernel_runner_class = PnccdCpuRunner
     _calcat_friend_class = PnccdCalcatFriend
+    _correction_steps = correction_steps
     _constant_enum_class = Constants
-    _preview_outputs = ["outputRaw", "outputCorrected"]
+
+    _image_data_path = "data.image"
     _cell_table_path = None
     _pulse_table_path = None
-    _warn_memory_cell_range = False
 
-    @staticmethod
-    def expectedParameters(expected):
+    _preview_outputs = [
+        PreviewSpec(
+            name,
+            dimensions=2,
+            frame_reduction=False,
+        )
+        for name in ("raw", "corrected")
+    ]
+    _warn_memory_cell_range = False
+    _cuda_pin_buffers = False
+    _use_shmem_handles = False
+
+    @classmethod
+    def expectedParameters(cls, expected):
+        cls._calcat_friend_class.add_schema(expected)
+        cls._kernel_runner_class.add_schema(expected)
+        base_correction.add_addon_nodes(expected, cls)
+        PreviewFriend.add_schema(expected, cls._preview_outputs)
         (
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
-            .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False))
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsX")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.pixelsY")
-            .setNewDefaultValue(1024)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.frames")
-            .setNewDefaultValue(1)
-            .commit(),
-
-            OVERWRITE_ELEMENT(expected)
-            .key("dataFormat.outputAxisOrder")
-            .setNewOptions("xy,yx")
-            .setNewDefaultValue("xy")
+            .dataSchema(
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
+            )
             .commit(),
 
             # TODO: disable preview selection mode
 
             OVERWRITE_ELEMENT(expected)
             .key("useShmemHandles")
-            .setNewDefaultValue(False)
+            .setNewDefaultValue(cls._use_shmem_handles)
             .commit(),
 
             OVERWRITE_ELEMENT(expected)
@@ -366,113 +388,11 @@ class PnccdCorrection(base_correction.BaseCorrection):
             .commit(),
         )
 
-        base_correction.add_preview_outputs(
-            expected, PnccdCorrection._preview_outputs
-        )
-        base_correction.add_correction_step_schema(
-            expected,
-            PnccdCorrection._correction_steps,
-        )
-        (
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.noiseSigma")
-            .assignmentOptional()
-            .defaultValue(5)
-            .reconfigurable()
-            .commit(),
-
-            DOUBLE_ELEMENT(expected)
-            .key("corrections.commonMode.minFrac")
-            .assignmentOptional()
-            .defaultValue(0.25)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableRow")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-
-            BOOL_ELEMENT(expected)
-            .key("corrections.commonMode.enableCol")
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
-        )
-        PnccdCalcatFriend.add_schema(expected)
-        # TODO: bad pixel node?
-
-    @property
-    def input_data_shape(self):
-        # TODO: check
+    def _get_data_from_hash(self, data_hash):
+        image_data = data_hash.get(self._image_data_path)
         return (
-            self.unsafe_get("dataFormat.pixelsX"),
-            self.unsafe_get("dataFormat.pixelsY"),
-        )
-
-    def __init__(self, config):
-        super().__init__(config)
-        if config.get("useShmemHandles", default=False):
-            def aux():
-                schema_override = Schema()
-                (
-                    OUTPUT_CHANNEL(schema_override)
-                    .key("dataOutput")
-                    .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False))
-                    .commit(),
-                )
-                self.updateSchema(schema_override)
-
-            self.registerInitialFunction(aux)
-
-    def process_data(
-        self,
-        data_hash,
-        metadata,
-        source,
-        train_id,
-        image_data,
-        cell_table,  # will be None
-        pulse_table,  # ditto
-    ):
-        self.kernel_runner.load_data(image_data)
-
-        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-        args_which_should_be_cached = dict(
-            cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"),
-            cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"),
-            cm_row=self.unsafe_get("corrections.commonMode.enableRow"),
-            cm_col=self.unsafe_get("corrections.commonMode.enableCol"),
+            1,
+            image_data,
+            None,
+            None,
         )
-        self.kernel_runner.correct(
-            flags=self._correction_flag_enabled, **args_which_should_be_cached
-        )
-        self.kernel_runner.reshape(
-            output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
-            out=buffer_array,
-        )
-        if self._correction_flag_enabled != self._correction_flag_preview:
-            self.kernel_runner.correct(
-                flags=self._correction_flag_preview,
-                **args_which_should_be_cached,
-            )
-
-        if self._use_shmem_handles:
-            data_hash.set(self._image_data_path, buffer_handle)
-            data_hash.set("calngShmemPaths", [self._image_data_path])
-        else:
-            data_hash.set(self._image_data_path, buffer_array)
-            data_hash.set("calngShmemPaths", [])
-
-        self._write_output(data_hash, metadata)
-
-        # note: base class preview machinery assumes burst mode, shortcut it
-        self._preview_friend.write_outputs(
-            metadata, *self.kernel_runner.preview_data_views
-        )
-
-    def _load_constant_to_runner(self, constant, constant_data):
-        self.kernel_runner.load_constant(constant, constant_data)
diff --git a/src/calng/kernels/agipd_gpu.cu b/src/calng/kernels/agipd_gpu.cu
index ac5fc743..f638af17 100644
--- a/src/calng/kernels/agipd_gpu.cu
+++ b/src/calng/kernels/agipd_gpu.cu
@@ -13,6 +13,8 @@ extern "C" {
 	__global__ void correct(const unsigned short* data,
 	                        const unsigned short* cell_table,
 	                        const unsigned char corr_flags,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_cells,
 	                        // default_gain can be 0, 1, or 2, and is relevant for fixed gain mode (no THRESHOLD)
 	                        const unsigned char default_gain,
 	                        const float* threshold_map,
@@ -27,71 +29,69 @@ extern "C" {
 	                        const float bad_pixel_mask_value,
 	                        float* gain_map, // TODO: more compact yet plottable representation
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 512;
+		const size_t fs_dim = 128;
 
-		const size_t cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t x = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t y = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t frame = blockIdx.x * blockDim.x + threadIdx.x;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (cell >= input_frames || y >= Y || x >= X) {
+		if (frame >= input_frames || fs >= fs_dim || ss >= ss_dim) {
 			return;
 		}
 
-		// data shape: memory cell, data/raw_gain (dim size 2), x, y
-		const size_t data_stride_y = 1;
-		const size_t data_stride_x = Y * data_stride_y;
-		const size_t data_stride_raw_gain = X * data_stride_x;
-		const size_t data_stride_cell = 2 * data_stride_raw_gain;
-		const size_t data_index = cell * data_stride_cell +
+		// data shape: frame, data/raw_gain (dim size 2), ss, fs
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_raw_gain = ss_dim * data_stride_ss;
+		const size_t data_stride_frame = 2 * data_stride_raw_gain;
+		const size_t data_index = frame * data_stride_frame +
 			0 * data_stride_raw_gain +
-			y * data_stride_y +
-			x * data_stride_x;
-		const size_t raw_gain_index = cell * data_stride_cell +
+			fs * data_stride_fs +
+			ss * data_stride_ss;
+		const size_t raw_gain_index = frame * data_stride_frame +
 			1 * data_stride_raw_gain +
-			y * data_stride_y +
-			x * data_stride_x;
+			fs * data_stride_fs +
+			ss * data_stride_ss;
 		float res = (float)data[data_index];
 		const float raw_gain_val = (float)data[raw_gain_index];
 
-		const size_t output_stride_y = 1;
-		const size_t output_stride_x = output_stride_y * Y;
-		const size_t output_stride_cell = output_stride_x * X;
-		const size_t output_index = cell * output_stride_cell + x * output_stride_x + y * output_stride_y;
+		const size_t output_stride_fs = 1;
+		const size_t output_stride_ss = output_stride_fs * fs_dim;
+		const size_t output_stride_frame = output_stride_ss * ss_dim;
+		const size_t output_index = frame * output_stride_frame + ss * output_stride_ss + fs * output_stride_fs;
 
-		// per-pixel only constant: cell, x, y
-		const size_t map_stride_y = 1;
-		const size_t map_stride_x = Y * map_stride_y;
-		const size_t map_stride_cell = X * map_stride_x;
+		// per-pixel only constant: cell, ss, fs
+		const size_t map_stride_fs = 1;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 
-		// threshold constant shape: cell, x, y, threshold (dim size 2)
+		// threshold constant shape: cell, ss, fs, threshold (dim size 2)
 		const size_t threshold_map_stride_threshold = 1;
-		const size_t threshold_map_stride_y = 2 * threshold_map_stride_threshold;
-		const size_t threshold_map_stride_x = Y * threshold_map_stride_y;
-		const size_t threshold_map_stride_cell = X * threshold_map_stride_x;
+		const size_t threshold_map_stride_fs = 2 * threshold_map_stride_threshold;
+		const size_t threshold_map_stride_ss = fs_dim * threshold_map_stride_fs;
+		const size_t threshold_map_stride_cell = ss_dim * threshold_map_stride_ss;
 
-		// gain mapped constant shape: cell, x, y, gain_level (dim size 3)
+		// gain mapped constant shape: cell, ss, fs, gain_level (dim size 3)
 		const size_t gm_map_stride_gain = 1;
-		const size_t gm_map_stride_y = 3 * gm_map_stride_gain;
-		const size_t gm_map_stride_x = Y * gm_map_stride_y;
-		const size_t gm_map_stride_cell = X * gm_map_stride_x;
-		// note: assuming all maps have same shape (in terms of cells / x / y)
+		const size_t gm_map_stride_fs = 3 * gm_map_stride_gain;
+		const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs;
+		const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss;
+		// note: assuming all maps have same shape (in terms of cells / ss / fs)
 
-		const size_t map_cell = cell_table[cell];
+		const size_t map_cell = cell_table[frame];
 
 		if (map_cell < map_cells) {
 			unsigned char gain = default_gain;
 			if (corr_flags & THRESHOLD) {
 				const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold +
 				                                        map_cell * threshold_map_stride_cell +
-				                                        y * threshold_map_stride_y +
-				                                        x * threshold_map_stride_x];
+				                                        fs * threshold_map_stride_fs +
+				                                        ss * threshold_map_stride_ss];
 				const float threshold_1 = threshold_map[1 * threshold_map_stride_threshold +
 				                                        map_cell * threshold_map_stride_cell +
-				                                        y * threshold_map_stride_y +
-				                                        x * threshold_map_stride_x];
+				                                        fs * threshold_map_stride_fs +
+				                                        ss * threshold_map_stride_ss];
 				// could consider making this const using ternaries / tiny function
 				if (raw_gain_val <= threshold_0) {
 					gain = 0;
@@ -103,8 +103,8 @@ extern "C" {
 			}
 
 			const size_t gm_map_index_without_gain = map_cell * gm_map_stride_cell +
-				y * gm_map_stride_y +
-				x * gm_map_stride_x;
+				fs * gm_map_stride_fs +
+				ss * gm_map_stride_ss;
 			if ((corr_flags & FORCE_MG_IF_BELOW) && (gain == 2) &&
 			    (res - offset_map[gm_map_index_without_gain + 1 * gm_map_stride_gain] < mg_hard_threshold)) {
 				gain = 1;
@@ -116,8 +116,8 @@ extern "C" {
 			gain_map[output_index] = (float)gain;
 
 			const size_t map_index = map_cell * map_stride_cell +
-				y * map_stride_y +
-				x * map_stride_x;
+				fs * map_stride_fs +
+				ss * map_stride_ss;
 
 			const size_t gm_map_index = gm_map_index_without_gain + gain * gm_map_stride_gain;
 
diff --git a/src/calng/kernels/common_gpu.cu b/src/calng/kernels/common_gpu.cu
new file mode 100644
index 00000000..631ab7dd
--- /dev/null
+++ b/src/calng/kernels/common_gpu.cu
@@ -0,0 +1,60 @@
+// https://excalidraw.com/#json=mgBWynet5WUbpgfSd0pzC,TMhzt6L-x5yShNCHRGKLvg
+extern "C" __global__ void common_mode_asic(
+    float* const data,
+    const unsigned short n_frames,
+    const unsigned short num_iter,
+    const float min_dark_fraction,
+    const float noise_peak_range
+) {
+    const int frame = blockIdx.z;
+    const int ss_asic = blockIdx.y;
+    const int fs_asic = blockIdx.x;
+    const int thread_in_asic = threadIdx.x;
+    assert(blockDim.x == {{asic_dim}});
+
+    __shared__ float shared_sum[{{asic_dim}}];
+    __shared__ unsigned int shared_count[{{asic_dim}}];
+
+    const unsigned int min_dark_pixels = static_cast<int>(min_dark_fraction * static_cast<float>({{asic_dim}} * {{asic_dim}}));
+    float* const asic_start = data + frame * {{ss_dim}} * {{fs_dim}} + ss_asic * {{asic_dim}} * {{fs_dim}} + fs_asic * {{asic_dim}}; 
+
+    for (int iter=0; iter<num_iter; ++iter) {
+        // accumulate (over rows) to shared
+        unsigned int my_count = 0;
+        float my_sum = 0;
+        for (int row=0; row<{{asic_dim}}; ++row) {
+            const float pixel = asic_start[row * {{fs_dim}} + thread_in_asic];
+            if (!isfinite(pixel) || fabs(pixel) > noise_peak_range) {
+                continue;
+            }
+            my_sum += pixel;
+            ++my_count;
+        }
+        shared_sum[thread_in_asic] = my_sum;
+        shared_count[thread_in_asic] = my_count;
+        __syncthreads();
+
+        // reduce (per ASIC) in shared
+        int active_threads = {{asic_dim}};
+        while (active_threads > 1) {
+            active_threads /= 2;
+            const int other = thread_in_asic + active_threads;
+            if (thread_in_asic < active_threads) {
+                shared_sum[thread_in_asic] += shared_sum[other];
+                shared_count[thread_in_asic] += shared_count[other];
+            }
+            __syncthreads();
+        }
+
+        // now, each ASIC result should be in first index of shared by ASIC
+        const unsigned int dark_count = shared_count[0];
+        if (dark_count < min_dark_pixels) {
+            return;
+        }
+        const float dark_sum = shared_sum[0];
+        const float dark_mean = dark_sum / static_cast<float>(dark_count);
+        for (int row=0; row<{{asic_dim}}; ++row) {
+	        asic_start[row * {{fs_dim}} + thread_in_asic] -= dark_mean;
+        }
+    }
+}
diff --git a/src/calng/kernels/dssc_cpu.pyx b/src/calng/kernels/dssc_cpu.pyx
index 04803d3f..613ef307 100644
--- a/src/calng/kernels/dssc_cpu.pyx
+++ b/src/calng/kernels/dssc_cpu.pyx
@@ -8,7 +8,7 @@ cdef unsigned char OFFSET = 1
 from cython.parallel import prange
 
 def correct(
-    unsigned short[:, :, :] image_data,
+    unsigned short[:, :, :, :] image_data,
     unsigned short[:] cell_table,
     unsigned char flags,
     float[:, :, :] offset_map,
@@ -19,13 +19,13 @@ def correct(
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
         if map_cell >= offset_map.shape[0]:
-            for x in range(image_data.shape[1]):
-                for y in range(image_data.shape[2]):
-                    output[frame, x, y] = <float>image_data[frame, x, y]
+            for x in range(image_data.shape[2]):
+                for y in range(image_data.shape[3]):
+                    output[frame, x, y] = <float>image_data[frame, 0, x, y]
             continue
-        for x in range(image_data.shape[1]):
-            for y in range(image_data.shape[2]):
-                res = image_data[frame, x, y]
+        for x in range(image_data.shape[2]):
+            for y in range(image_data.shape[3]):
+                res = image_data[frame, 0, x, y]
                 if flags & OFFSET:
                     res = res - offset_map[map_cell, x, y]
                 output[frame, x, y] = res
diff --git a/src/calng/kernels/dssc_gpu.cu b/src/calng/kernels/dssc_gpu.cu
index 4e21b02b..6ebe7b9a 100644
--- a/src/calng/kernels/dssc_gpu.cu
+++ b/src/calng/kernels/dssc_gpu.cu
@@ -8,41 +8,41 @@ extern "C" {
 	  Take cell_table into account when getting correction values
 	  Converting to float while correcting
 	  Converting to output dtype at the end
-	  Shape of input data: memory cell, 1, y, x
-	  Shape of offset constant: x, y, memory cell
+	  Shape of input data: memory cell, 1, ss, fs
+	  Shape of offset constant: fs, ss, memory cell
 	*/
-	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x
+	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs
 	                        const unsigned short* cell_table,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
 	                        const unsigned char corr_flags,
 	                        const float* offset_map,
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 128;
+		const size_t fs_dim = 512;
 
 		const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (memory_cell >= input_frames || y >= Y || x >= X) {
+		if (memory_cell >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
 		// note: strides differ from numpy strides because unit here is sizeof(...), not byte
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_cell = Y * data_stride_y;
-		const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_cell = ss_dim * data_stride_ss;
+		const size_t data_index = memory_cell * data_stride_cell + ss * data_stride_ss + fs * data_stride_fs;
 		float res = (float)data[data_index];
 
-		const size_t map_stride_x = 1;
-		const size_t map_stride_y = X * map_stride_x;
-		const size_t map_stride_cell = Y * map_stride_y;
+		const size_t map_stride_fs = 1;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 		const size_t map_cell = cell_table[memory_cell];
 
 		if (map_cell < map_memory_cells) {
-			const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x;
+			const size_t map_index = map_cell * map_stride_cell + ss * map_stride_ss + fs * map_stride_fs;
 			if (corr_flags & OFFSET) {
 				res -= offset_map[map_index];
 			}
diff --git a/src/calng/kernels/jungfrau_cpu.pyx b/src/calng/kernels/jungfrau_cpu.pyx
index 2a200063..5ff3148c 100644
--- a/src/calng/kernels/jungfrau_cpu.pyx
+++ b/src/calng/kernels/jungfrau_cpu.pyx
@@ -24,20 +24,20 @@ def correct_burst(
     float badpixel_fill_value,
     float[:, :, :] output,
 ):
-    cdef int frame, map_cell, x, y
+    cdef int frame, map_cell, fs, ss
     cdef unsigned char gain
     cdef float corrected
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
         if map_cell >= offset_map.shape[0]:
-            for y in range(image_data.shape[1]):
-                for x in range(image_data.shape[2]):
-                    output[frame, y, x] = <float>image_data[frame, y, x]
+            for ss in range(image_data.shape[1]):
+                for fs in range(image_data.shape[2]):
+                    output[frame, ss, fs] = <float>image_data[frame, ss, fs]
             continue
-        for y in range(image_data.shape[1]):
-            for x in range(image_data.shape[2]):
-                corrected = image_data[frame, y, x]
-                gain = gain_stage[frame, y, x]
+        for ss in range(image_data.shape[1]):
+            for fs in range(image_data.shape[2]):
+                corrected = image_data[frame, ss, fs]
+                gain = gain_stage[frame, ss, fs]
                 # legal values: 0, 1, or 3
                 if gain == 2:
                     corrected = badpixel_fill_value
@@ -45,14 +45,14 @@ def correct_burst(
                     if gain == 3:
                         gain = 2
 
-                    if (flags & BPMASK) and badpixel_mask[map_cell, y, x, gain] != 0:
+                    if (flags & BPMASK) and badpixel_mask[map_cell, ss, fs, gain] != 0:
                         corrected = badpixel_fill_value
                     else:
                         if (flags & OFFSET):
-                            corrected = corrected - offset_map[map_cell, y, x, gain]
+                            corrected = corrected - offset_map[map_cell, ss, fs, gain]
                         if (flags & REL_GAIN):
-                            corrected = corrected / relgain_map[map_cell, y, x, gain]
-                output[frame, y, x] = corrected
+                            corrected = corrected / relgain_map[map_cell, ss, fs, gain]
+                output[frame, ss, fs] = corrected
 
 
 def correct_single(
@@ -66,13 +66,13 @@ def correct_single(
     float[:, :, :] output,
 ):
     # ignore "cell table", constant pretends cell 0
-    cdef int x, y
+    cdef int fs, ss
     cdef unsigned char gain
     cdef float corrected
-    for y in range(image_data.shape[1]):
-        for x in range(image_data.shape[2]):
-            corrected = image_data[0, y, x]
-            gain = gain_stage[0, y, x]
+    for ss in range(image_data.shape[1]):
+        for fs in range(image_data.shape[2]):
+            corrected = image_data[0, ss, fs]
+            gain = gain_stage[0, ss, fs]
             # legal values: 0, 1, or 3
             if gain == 2:
                 corrected = badpixel_fill_value
@@ -80,11 +80,11 @@ def correct_single(
                 if gain == 3:
                     gain = 2
 
-                if (flags & BPMASK) and badpixel_mask[0, y, x, gain] != 0:
+                if (flags & BPMASK) and badpixel_mask[0, ss, fs, gain] != 0:
                     corrected = badpixel_fill_value
                 else:
                     if (flags & OFFSET):
-                        corrected = corrected - offset_map[0, y, x, gain]
+                        corrected = corrected - offset_map[0, ss, fs, gain]
                     if (flags & REL_GAIN):
-                        corrected = corrected / relgain_map[0, y, x, gain]
-            output[0, y, x] = corrected
+                        corrected = corrected / relgain_map[0, ss, fs, gain]
+            output[0, ss, fs] = corrected
diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu
index a4f148f1..57a2d2a2 100644
--- a/src/calng/kernels/jungfrau_gpu.cu
+++ b/src/calng/kernels/jungfrau_gpu.cu
@@ -7,50 +7,45 @@ extern "C" {
 	                        const unsigned char* gain_stage, // same shape
 	                        const unsigned char* cell_table,
 	                        const unsigned char corr_flags,
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
+	                        const unsigned char burst_mode,
 	                        const float* offset_map,
 	                        const float* rel_gain_map,
 	                        const unsigned int* bad_pixel_map,
 	                        const float bad_pixel_mask_value,
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t ss_dim = 512;
+		const size_t fs_dim = 1024;
 
 		const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (current_frame >= input_frames || y >= Y || x >= X) {
+		if (current_frame >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_frame = Y * data_stride_y;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_frame = ss_dim * data_stride_ss;
 		const size_t data_index = current_frame * data_stride_frame +
-			y * data_stride_y +
-			x * data_stride_x;
+			ss * data_stride_ss +
+			fs * data_stride_fs;
 		float res = (float)data[data_index];
 
 		// gain mapped constant shape: cell, y, x, gain_level (dim size 3)
 		// note: in fixed gain mode, constant still provides data for three stages
 		const size_t map_stride_gain = 1;
-		const size_t map_stride_x = 3 * map_stride_gain;
-		const size_t map_stride_y = X * map_stride_x;
-		const size_t map_stride_cell = Y * map_stride_y;
+		const size_t map_stride_fs = 3 * map_stride_gain;
+		const size_t map_stride_ss = fs_dim * map_stride_fs;
+		const size_t map_stride_cell = ss_dim * map_stride_ss;
 
 		// TODO: warn user about cell_table value of 255 in either mode
 		// note: cell table may contain 255 if data didn't arrive
-		{% if burst_mode %}
-		// burst mode: "cell 255" will get copied
-		// TODO: consider masking "cell 255"
-		const size_t map_cell = cell_table[current_frame];
-		{% else %}
+		// burst mode: "cell 255" will get copied (not corrected)
 		// single cell: "cell 255" will get "corrected"
-		const size_t map_cell = 0;
-		{% endif %}
-
+		const size_t map_cell = burst_mode ? cell_table[current_frame] : 0;
 		if (map_cell < map_memory_cells) {
 			unsigned char gain = gain_stage[data_index];
 			if (gain == 2) {
@@ -58,20 +53,24 @@ extern "C" {
 				res = bad_pixel_mask_value;
 			} else {
 				if (gain == 3) {
+					// but 3 means 2 in the constants
 					gain = 2;
 				}
-				const size_t map_index = map_cell * map_stride_cell +
-					y * map_stride_y +
-					x * map_stride_x +
-					gain * map_stride_gain;
-				if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) {
-					res = bad_pixel_mask_value;
-				} else {
-					if (corr_flags & OFFSET) {
-						res -= offset_map[map_index];
-					}
-					if (corr_flags & REL_GAIN) {
-						res /= rel_gain_map[map_index];
+				if (map_cell < map_memory_cells) {
+					// only correct if in range of constant data
+					const size_t map_index = map_cell * map_stride_cell +
+						ss * map_stride_ss +
+						fs * map_stride_fs +
+						gain * map_stride_gain;
+					if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) {
+						res = bad_pixel_mask_value;
+					} else {
+						if (corr_flags & OFFSET) {
+							res -= offset_map[map_index];
+						}
+						if (corr_flags & REL_GAIN) {
+							res /= rel_gain_map[map_index];
+						}
 					}
 				}
 			}
diff --git a/src/calng/kernels/lpd_cpu.pyx b/src/calng/kernels/lpd_cpu.pyx
index 0c3eb82d..3cd878fb 100644
--- a/src/calng/kernels/lpd_cpu.pyx
+++ b/src/calng/kernels/lpd_cpu.pyx
@@ -12,6 +12,7 @@ cdef unsigned char BPMASK = 16
 from cython.parallel import prange
 from libc.math cimport isinf, isnan
 
+
 def correct(
     unsigned short[:, :, :, :] image_data,
     unsigned short[:] cell_table,
@@ -22,7 +23,7 @@ def correct(
     float[:, :, :, :] flatfield_map,
     unsigned[:, :, :, :] bad_pixel_map,
     float bad_pixel_mask_value,
-    # TODO: support spitting out gain map for preview purposes
+    float[:, :, :] gain_map,
     float[:, :, :] output
 ):
     cdef int frame, map_cell, ss, fs
@@ -31,25 +32,41 @@ def correct(
     cdef unsigned short raw_data_value
     for frame in prange(image_data.shape[0], nogil=True):
         map_cell = cell_table[frame]
-        if map_cell >= offset_map.shape[0]:
-            for ss in range(image_data.shape[1]):
-                for fs in range(image_data.shape[2]):
-                    output[frame, ss, fs] = <float>image_data[frame, 0, ss, fs]
-            continue
-        for ss in range(image_data.shape[1]):
-            for fs in range(image_data.shape[3]):
-                raw_data_value = image_data[frame, 0, ss, fs]
-        gain_stage = (raw_data_value >> 12) & 0x0003
-        res = <float>(raw_data_value & 0x0fff)
-        if gain_stage > 2 or (flags & BPMASK and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0):
-            res = bad_pixel_mask_value
+        if map_cell < offset_map.shape[0]:
+            for ss in range(image_data.shape[2]):
+                for fs in range(image_data.shape[3]):
+                    raw_data_value = image_data[frame, 0, ss, fs]
+                    gain_stage = (raw_data_value >> 12) & 0x0003
+                    res = <float>(raw_data_value & 0x0fff)
+                    if (
+                        gain_stage > 2
+                        or (
+                            flags & BPMASK
+                            and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0
+                        )
+                    ):
+                        res = bad_pixel_mask_value
+                    else:
+                        if flags & OFFSET:
+                            res = res - offset_map[map_cell, ss, fs, gain_stage]
+                        if flags & GAIN_AMP:
+                            res = res * gain_amp_map[map_cell, ss, fs, gain_stage]
+                        if flags & FF_CORR:
+                            res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage]
+                        if res < -1e7 or res > 1e7 or isnan(res) or isinf(res):
+                            res = bad_pixel_mask_value
+                    output[frame, ss, fs] = res
+                    gain_map[frame, ss, fs] = <float>gain_stage
         else:
-            if flags & OFFSET:
-                res = res - offset_map[map_cell, ss, fs, gain_stage]
-            if flags & GAIN_AMP:
-                res = res * gain_amp_map[map_cell, ss, fs, gain_stage]
-            if flags & FF_CORR:
-                res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage]
-            if res < 1e-7 or res > 1e7 or isnan(res) or isinf(res):
-                res = bad_pixel_mask_value
-        output[frame, ss, fs] = res
+            for ss in range(image_data.shape[2]):
+                for fs in range(image_data.shape[3]):
+                    raw_data_value = image_data[frame, 0, ss, fs]
+                    gain_stage = (raw_data_value >> 12) & 0x0003
+                    res = <float>(raw_data_value & 0x0fff)
+                    if (
+                        gain_stage > 2
+                        or res < -1e7 or res > 1e7
+                    ):
+                        res = bad_pixel_mask_value
+                    output[frame, ss, fs] = res
+                    gain_map[frame, ss, fs] = <float>gain_stage
diff --git a/src/calng/kernels/lpd_gpu.cu b/src/calng/kernels/lpd_gpu.cu
index 4e98cc85..7e935a38 100644
--- a/src/calng/kernels/lpd_gpu.cu
+++ b/src/calng/kernels/lpd_gpu.cu
@@ -3,10 +3,12 @@
 {{corr_enum}}
 
 extern "C" {
-	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x (I think)
+	__global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs
 	                        const unsigned short* cell_table,
 	                        const unsigned char corr_flags,
-	                        const float* offset_map, // shape: cell, y, x, gain
+	                        const unsigned short input_frames,
+	                        const unsigned short map_memory_cells,
+	                        const float* offset_map, // shape: cell, ss, fs, gain
 	                        const float* gain_amp_map,
 	                        const float* rel_gain_slopes_map,
 	                        const float* flatfield_map,
@@ -14,39 +16,37 @@ extern "C" {
 	                        const float bad_pixel_mask_value,
 	                        float* gain_map, // similar to preview for AGIPD
 	                        {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t input_frames = {{frames}};
-		const size_t map_memory_cells = {{constant_memory_cells}};
+		const size_t fs_dim = {{fs_dim}};
+		const size_t ss_dim = {{ss_dim}};
 
-		const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
+		const size_t frame = blockIdx.x * blockDim.x + threadIdx.x;
+		const size_t ss = blockIdx.y * blockDim.y + threadIdx.y;
+		const size_t fs = blockIdx.z * blockDim.z + threadIdx.z;
 
-		if (memory_cell >= input_frames || y >= Y || x >= X) {
+		if (frame >= input_frames || ss >= ss_dim || fs >= fs_dim) {
 			return;
 		}
 
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_cell = Y * data_stride_y;
-		const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
+		const size_t data_stride_fs = 1;
+		const size_t data_stride_ss = fs_dim * data_stride_fs;
+		const size_t data_stride_frame = ss_dim * data_stride_ss;
+		const size_t data_index = frame * data_stride_frame + ss * data_stride_ss + fs * data_stride_fs;
 		const unsigned short raw_data_value = data[data_index];
 		const unsigned char gain = (raw_data_value >> 12) & 0x0003;
 		float corrected = (float)(raw_data_value & 0x0fff);
 		float gain_for_preview = (float)gain;
 
 		const size_t gm_map_stride_gain = 1;
-		const size_t gm_map_stride_x = 3 * gm_map_stride_gain;
-		const size_t gm_map_stride_y = X * gm_map_stride_x;
-		const size_t gm_map_stride_cell = Y * gm_map_stride_y;
+		const size_t gm_map_stride_fs = 3 * gm_map_stride_gain;
+		const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs;
+		const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss;
 
-		const size_t map_cell = cell_table[memory_cell];
+		const size_t map_cell = cell_table[frame];
 		if (map_cell < map_memory_cells) {
 			const size_t gm_map_index = gain * gm_map_stride_gain +
 				map_cell * gm_map_stride_cell +
-				y * gm_map_stride_y +
-				x * gm_map_stride_x;
+				ss * gm_map_stride_ss +
+				fs * gm_map_stride_fs;
 
 			if (gain > 2 || ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index])) {
 				// now also checking for illegal gain value
@@ -70,7 +70,13 @@ extern "C" {
 					corrected = bad_pixel_mask_value;
 				}
 			}
+		} else {
+			if (gain > 2 || corrected < -1e7 || corrected > 1e7) {
+				corrected = bad_pixel_mask_value;
+				gain_for_preview = bad_pixel_mask_value;
+			}
 		}
+
 		gain_map[data_index] = gain_for_preview;
 		{% if output_data_dtype == "half" %}
 		output[data_index] = __float2half(corrected);
diff --git a/src/calng/preview_utils.py b/src/calng/preview_utils.py
index 55e90742..12cb7e18 100644
--- a/src/calng/preview_utils.py
+++ b/src/calng/preview_utils.py
@@ -1,81 +1,148 @@
+import enum
+import functools
+from dataclasses import dataclass
+from typing import Callable, List, Optional
+
 from karabo.bound import (
     BOOL_ELEMENT,
     FLOAT_ELEMENT,
+    INT32_ELEMENT,
     NODE_ELEMENT,
     OUTPUT_CHANNEL,
     STRING_ELEMENT,
     UINT32_ELEMENT,
-    ChannelMetaData,
     Dims,
     Encoding,
     Hash,
     ImageData,
+    Schema,
     Timestamp,
 )
 
 import numpy as np
 
-from . import schemas, utils
+from . import schemas
+from .utils import downsample_1d, downsample_2d, maybe_get
+
+
+class FrameSelectionMode(enum.Enum):
+    FRAME = "frame"
+    CELL = "cell"
+    PULSE = "pulse"
+
+
+@dataclass
+class PreviewSpec:
+    # used for output schema
+    name: str
+    channel_name: Optional[str] = None  # set after init
+    dimensions: int = 2
+    wrap_in_imagedata: bool = True
+
+    # reconfigurables in node schema
+    swap_axes: bool = False  # only appars in schema if dimensions > 1
+    flip_ss: bool = False  # only appars in schema if dimensions > 1
+    flip_fs: bool = False
+    nan_replacement: float = 0
+    downsampling_factor: int = 1
+    downsampling_function: Callable = np.nanmax
+    # if frame_reduction, frame_reduction_fun is made automatically from index and mode
+    # otherwise, leave None or provide custom
+    frame_reduction: bool = True
+    frame_reduction_selection_mode: FrameSelectionMode = FrameSelectionMode.FRAME
+    # not all detectors provide cell and / or pulse IDs
+    valid_frame_selection_modes: List[FrameSelectionMode] = tuple(FrameSelectionMode)
+    frame_reduction_index: int = 0
+    # either set explicitly or automatically with frame selection
+    frame_reduction_fun: Optional[Callable] = None
 
 
 class PreviewFriend:
     @staticmethod
     def add_schema(
-        schema, node_path="preview", output_channels=None, create_node=False
+        schema: Schema,
+        outputs: List[PreviewSpec],
+        node_name: str = "preview"
     ):
-        if output_channels is None:
-            output_channels = ["output"]
+        """Add the prerequisite schema configurables for all the outputs."""
+        (
+            NODE_ELEMENT(schema)
+            .key(node_name)
+            .displayedName("Preview")
+            .description(
+                "Output specifically intended for preview in Karabo GUI. Includes "
+                "some options for throttling and adjustments of the output data."
+            )
+            .commit(),
+        )
+
+        for spec in outputs:
+            PreviewFriend._add_subschema(schema, f"{node_name}.{spec.name}", spec)
+
+    @staticmethod
+    def _add_subschema(schema, node_path, spec):
+        # each preview gets its own node with config and channel
+        (
+            NODE_ELEMENT(schema)
+            .key(node_path)
+            .commit(),
+        )
 
-        if create_node:
+        if spec.dimensions > 1:
             (
-                NODE_ELEMENT(schema)
-                .key(node_path)
-                .displayedName("Preview")
+                BOOL_ELEMENT(schema)
+                .key(f"{node_path}.swapAxes")
+                .displayedName("Swap slow / fast scan")
                 .description(
-                    "Output specifically intended for preview in Karabo GUI. Includes "
-                    "some options for throttling and adjustments of the output data."
+                    "Swaps the two pixel axes around. Can be combined with "
+                    "flipping to rotate the image"
                 )
+                .assignmentOptional()
+                .defaultValue(spec.swap_axes)
+                .reconfigurable()
+                .commit(),
+
+                BOOL_ELEMENT(schema)
+                .key(f"{node_path}.flipSS")
+                .displayedName("Flip SS")
+                .description("Flip image data along slow scan axis.")
+                .assignmentOptional()
+                .defaultValue(spec.flip_ss)
+                .reconfigurable()
                 .commit(),
             )
-        (
-            BOOL_ELEMENT(schema)
-            .key(f"{node_path}.flipSS")
-            .displayedName("Flip SS")
-            .description("Flip image data along slow scan axis.")
-            .assignmentOptional()
-            .defaultValue(False)
-            .reconfigurable()
-            .commit(),
 
+        # configurable bits; these correspond to spec parts
+        (
             BOOL_ELEMENT(schema)
             .key(f"{node_path}.flipFS")
             .displayedName("Flip FS")
             .description("Flip image data along fast scan axis.")
             .assignmentOptional()
-            .defaultValue(False)
+            .defaultValue(spec.flip_fs)
             .reconfigurable()
             .commit(),
 
             UINT32_ELEMENT(schema)
             .key(f"{node_path}.downsamplingFactor")
-            .displayedName("Factor")
+            .displayedName("Downsampling factor")
             .description(
                 "If greater than 1, the preview image will be downsampled by this "
                 "factor before sending. This is mostly to save bandwidth in case GUI "
                 "updates start lagging."
             )
             .assignmentOptional()
-            .defaultValue(1)
+            .defaultValue(spec.downsampling_factor)
             .options("1,2,4,8")
             .reconfigurable()
             .commit(),
 
             STRING_ELEMENT(schema)
             .key(f"{node_path}.downsamplingFunction")
-            .displayedName("Function")
+            .displayedName("Downsampling function")
             .description("Reduction function used during downsampling.")
             .assignmentOptional()
-            .defaultValue("nanmax")
+            .defaultValue(spec.downsampling_function.__name__)
             .options("nanmax,nanmean,nanmin,nanmedian")
             .reconfigurable()
             .commit(),
@@ -91,97 +158,255 @@ class PreviewFriend:
                 "from the image data you want to see."
             )
             .assignmentOptional()
-            .defaultValue(0)
+            .defaultValue(spec.nan_replacement)
             .reconfigurable()
             .commit(),
         )
-        for channel in output_channels:
+
+        if spec.frame_reduction:
             (
-                OUTPUT_CHANNEL(schema)
-                .key(f"{node_path}.{channel}")
-                .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True))
-                .description("See description of parent node, 'preview'.")
+                STRING_ELEMENT(schema)
+                .key(f"{node_path}.selectionMode")
+                .tags("managed")
+                .displayedName("Index selection mode")
+                .description(
+                    "The value of preview.index can be used in multiple ways, "
+                    "controlled by this value. If this is set to 'frame', "
+                    "preview.index is sliced directly from data. If 'cell' "
+                    "(or 'pulse') is selected, I will look at cell (or pulse) table "
+                    "for the requested cell (or pulse ID). Special (stat) index values "
+                    "<0 are not affected by this."
+                )
+                .options(
+                    ",".join(
+                        selectionmode.value
+                        for selectionmode in spec.valid_frame_selection_modes
+                    )
+                )
+                .assignmentOptional()
+                .defaultValue("frame")
+                .reconfigurable()
+                .commit(),
+                INT32_ELEMENT(schema)
+                .key(f"{node_path}.index")
+                .tags("managed")
+                .displayedName("Index (or stat) for preview")
+                .description(
+                    "If this value is ≥ 0, the corresponding index (frame, cell, or "
+                    "pulse) will be sliced for the preview output. If this value is "
+                    "<0, preview will be one of the following stats: -1: max, "
+                    "-2: mean, -3: sum, -4: stdev. These stats are computed across "
+                    "frames, ignoring NaN values."
+                )
+                .assignmentOptional()
+                .defaultValue(spec.frame_reduction_index)
+                .minInc(-4)
+                .reconfigurable()
                 .commit(),
             )
 
-    def __init__(self, device, node_name="preview", output_channels=None):
-        if output_channels is None:
-            output_channels = ["output"]
-        self.output_channels = output_channels
-        self.device = device
-        self.dev_id = self.device.getInstanceId()
-        self.node_name = node_name
-        self.outputs = [
-            self.device.signalSlotable.getOutputChannel(f"{self.node_name}.{channel}")
-            for channel in self.output_channels
-        ]
-        self.reconfigure(device._parameters)
-
-    def write_outputs(self, timestamp, *datas, inplace=True):
+        (
+            OUTPUT_CHANNEL(schema)
+            .key(f"{node_path}.output")
+            .dataSchema(
+                schemas.preview_schema(
+                    wrap_image_in_imagedata=spec.wrap_in_imagedata
+                )
+            )
+            .description("See description of grandparent node, 'preview'.")
+            .commit(),
+        )
+
+    def __init__(self, device, outputs, node_name="preview"):
+        self._device = device
+        self.output_specs = outputs
+        for spec in self.output_specs:
+            spec.channel_name = f"{node_name}.{spec.name}.output"
+        self.reconfigure(self._device.get(node_name))
+
+    def write_outputs(
+        self,
+        *datas,
+        timestamp=None,
+        inplace=True,
+        cell_table=None,
+        pulse_table=None,
+        warn_fun=None,
+    ):
         """Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
         ImageData) and writes to output channels. Make sure datas length matches number
-        of channels!"""
+        of channels! If inplace, masking and such may write to provided buffers (will
+        otherwise copy)."""
         if isinstance(timestamp, Hash):
             timestamp = Timestamp.fromHashAttributes(
                 timestamp.getAttributes("timestamp")
             )
-        for data, output, channel_name in zip(
-            datas, self.outputs, self.output_channels
-        ):
-            if self.downsampling_factor > 1:
-                data = utils.downsample_2d(
-                    data,
-                    self.downsampling_factor,
-                    reduction_fun=self.downsampling_function,
-                )
+        if warn_fun is None:
+            warn_fun = self._device.log.WARN
+        for data, spec in zip(datas, self.output_specs):
+            if spec.frame_reduction_fun is not None:
+                data = spec.frame_reduction_fun(data, cell_table, pulse_table, warn_fun)
+            if spec.downsampling_factor > 1:
+                if spec.dimensions == 1:
+                    data = downsample_1d(
+                        data,
+                        spec.downsampling_factor,
+                        reduction_fun=spec.downsampling_function,
+                    )
+                else:
+                    data = downsample_2d(
+                        data,
+                        spec.downsampling_factor,
+                        reduction_fun=spec.downsampling_function,
+                    )
             elif not inplace:
                 data = data.copy()
-            if self.flip_ss:
-                data = np.flip(data, 0)
-            if self.flip_fs:
-                data = np.flip(data, 1)
+            if spec.flip_ss:
+                data = np.flip(data, -2)
+            if spec.flip_fs:
+                data = np.flip(data, -1)
             if isinstance(data, np.ma.MaskedArray):
                 data, mask = data.data, data.mask | ~np.isfinite(data)
             else:
                 mask = ~np.isfinite(data)
-            data[mask] = self.nan_replacement
+            data[mask] = spec.nan_replacement
             mask = mask.astype(np.uint8)
-            output_hash = Hash(
-                "image.data",
-                ImageData(
-                    data,
-                    Dims(*data.shape),
-                    Encoding.GRAY,
-                    bitsPerPixel=32,
-                ),
-                "image.mask",
-                ImageData(
-                    mask,
-                    Dims(*mask.shape),
-                    Encoding.GRAY,
-                    bitsPerPixel=8,
-                ),
-            )
-            output.write(
-                output_hash,
-                ChannelMetaData(
-                    f"{self.dev_id}:{self.node_name}.{channel_name}",
-                    timestamp,
-                ),
-                copyAllData=False,
+            if spec.wrap_in_imagedata:
+                output_hash = Hash(
+                    "image.data",
+                    ImageData(
+                        maybe_get(data),
+                        Dims(*data.shape),
+                        Encoding.GRAY,
+                        bitsPerPixel=32,
+                    ),
+                    "image.mask",
+                    ImageData(
+                        maybe_get(mask),
+                        Dims(*mask.shape),
+                        Encoding.GRAY,
+                        bitsPerPixel=8,
+                    ),
+                )
+            else:
+                output_hash = Hash(
+                    "image.data", maybe_get(data), "image.mask", maybe_get(mask)
+                )
+            self._device.writeChannel(
+                spec.channel_name, output_hash, timestamp=timestamp, safeNDArray=True
             )
-            output.update(safeNDArray=True)
 
     def reconfigure(self, conf):
-        if conf.has(f"{self.node_name}.downsamplingFunction"):
-            self.downsampling_function = getattr(
-                np, conf[f"{self.node_name}.downsamplingFunction"]
+        for spec in self.output_specs:
+            if not conf.has(spec.name):
+                continue
+            subconf = conf.get(spec.name)
+            if spec.frame_reduction and (
+                subconf.has("selectionMode") or subconf.has("index")
+            ):
+                if subconf.has("selectionMode"):
+                    spec.frame_reduction_selection_mode = FrameSelectionMode(
+                        subconf["selectionMode"]
+                    )
+                if subconf.has("index"):
+                    spec.frame_reduction_index = subconf["index"]
+                spec.frame_reduction_fun = _make_frame_reduction_fun(spec)
+            if subconf.has("downsamplingFunction"):
+                spec.downsampling_function = getattr(
+                    np, subconf["downsamplingFunction"]
+                )
+            if subconf.has("downsamplingFactor"):
+                spec.downsampling_factor = subconf["downsamplingFactor"]
+            if subconf.has("flipSS"):
+                spec.flip_ss = subconf["flipSS"]
+            if subconf.has("flipFS"):
+                spec.flip_fs = subconf["flipFS"]
+            if subconf.has("swapAxes"):
+                spec.flip_fs = subconf["flipFS"]
+            if subconf.has("replaceNanWith"):
+                spec.nan_replacement = subconf["replaceNanWith"]
+
+
+def _make_frame_reduction_fun(spec) -> Optional[Callable]:
+    """Will be called during configuration in case frame_reduction is True. Will
+    use frame_reduction_selection_mode and frame_reduction_index to create a
+    function to either slice or reduce input data."""
+    # note: in case of frame picking, signature includes cell and pulse table
+    if spec.frame_reduction_index < 0:
+        # so while those tables are irrelevant here, still include in wrapper
+        if spec.frame_reduction_index == -1:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanmax, axis=0)
+        elif spec.frame_reduction_index == -2:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanmean, axis=0, dtype=np.float32)
+        elif spec.frame_reduction_index == -3:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nansum, axis=0, dtype=np.float32)
+        elif spec.frame_reduction_index == -4:
+            # note: separate from next case because dtype not applicable here
+            reduction_fun = functools.partial(np.nanstd, axis=0, dtype=np.float32)
+
+        def fun(data, cell_table, pulse_table, warn_fun):
+            try:
+                return reduction_fun(data)
+            except Exception as ex:
+                warn_fun(f"Frame reduction error: {ex}")
+
+        return fun
+    else:
+        # here, we pick out a frame
+        if spec.frame_reduction_selection_mode is FrameSelectionMode.FRAME:
+            return _pick_preview_index_frame(spec.frame_reduction_index)
+        elif spec.frame_reduction_selection_mode is FrameSelectionMode.CELL:
+            return _pick_preview_index_cell(spec.frame_reduction_index)
+        else:
+            return _pick_preview_index_pulse(spec.frame_reduction_index)
+
+
+def _pick_preview_index_frame(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        if index >= data.shape[0]:
+            warn_fun(
+                f"Frame index {index} out of bounds for data with "
+                f"{data.shape[0]} frames, previewing frame 0"
+            )
+            return data[0]
+        return data[index]
+    return aux
+
+
+def _pick_preview_index_cell(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        found = np.nonzero(cell_table == index)[0]
+        if found.size == 0:
+            warn_fun(
+                f"Cell ID {index} not found in cell mapping, "
+                "previewing frame 0 instead"
+            )
+            return data[0]
+        if found.size > 1:
+            warn_fun(
+                f"Cell ID {index} not unique in cell mapping, "
+                f"previewing first occurrence (frame {found[0]}"
+            )
+        return data[found[0]]
+    return aux
+
+
+def _pick_preview_index_pulse(index):
+    def aux(data, cell_table, pulse_table, warn_fun):
+        found = np.nonzero(pulse_table == index)[0]
+        if found.size == 0:
+            warn_fun(
+                f"Pulse ID {index} not found in pulse table, "
+                "previewing frame 0 instead"
+            )
+            return data[0]
+        if found.size > 1:
+            warn_fun(
+                f"Pulse ID {index} not unique in pulse table, "
+                f"previewing first occurrence (frame {found[0]})"
             )
-        if conf.has(f"{self.node_name}.downsamplingFactor"):
-            self.downsampling_factor = conf[f"{self.node_name}.downsamplingFactor"]
-        if conf.has(f"{self.node_name}.flipSS"):
-            self.flip_ss = conf[f"{self.node_name}.flipSS"]
-        if conf.has(f"{self.node_name}.flipFS"):
-            self.flip_fs = conf[f"{self.node_name}.flipFS"]
-        if conf.has(f"{self.node_name}.replaceNanWith"):
-            self.nan_replacement = conf[f"{self.node_name}.replaceNanWith"]
+        return data[found[0]]
+    return aux
diff --git a/src/calng/scenes.py b/src/calng/scenes.py
index 56ae3bfd..3eb66b0f 100644
--- a/src/calng/scenes.py
+++ b/src/calng/scenes.py
@@ -34,6 +34,8 @@ from karabo.common.scenemodel.api import (
     WebLinkModel,
 )
 
+docs_url = "https://rtd.xfel.eu/docs/calng/en/latest"
+
 
 @titled("Found constants", width=6 * NARROW_INC)
 @boxed
@@ -140,7 +142,7 @@ class ManagerDeviceStatus(VerticalLayout):
             text="Documentation",
             width=7 * BASE_INC,
             height=BASE_INC,
-            target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#calibration-manager",
+            target=f"{docs_url}/devices/#calibration-manager",
         )
         self.children.extend(
             [
@@ -446,63 +448,77 @@ class RoiBox(VerticalLayout):
 @titled("Preview settings")
 @boxed
 class PreviewSettings(HorizontalLayout):
-    def __init__(self, device_id, schema_hash, node_name="preview", extras=None):
+    def __init__(self, device_id, schema_hash, node_name):
         super().__init__()
-        self.children.extend(
-            [
-                VerticalLayout(
-                    EditableRow(
-                        device_id,
-                        schema_hash,
-                        f"{node_name}.replaceNanWith",
-                        6,
-                        4,
-                    ),
-                    HorizontalLayout(
-                        EditableRow(
-                            device_id,
-                            schema_hash,
-                            f"{node_name}.flipSS",
-                            3,
-                            2,
-                        ),
-                        EditableRow(
-                            device_id,
-                            schema_hash,
-                            f"{node_name}.flipFS",
-                            3,
-                            2,
-                        ),
-                        padding=0,
-                    ),
-                ),
-                Vline(height=3 * BASE_INC),
-                VerticalLayout(
-                    LabelModel(
-                        text="Image downsampling",
-                        width=8 * BASE_INC,
-                        height=BASE_INC,
-                    ),
+        tweaks = VerticalLayout(
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.replaceNanWith",
+                6,
+                4,
+            ),
+        )
+        if schema_hash.has(f"{node_name}.flipSS"):
+            tweaks.children.append(
+                HorizontalLayout(
                     EditableRow(
                         device_id,
                         schema_hash,
-                        f"{node_name}.downsamplingFactor",
-                        5,
-                        5,
+                        f"{node_name}.flipSS",
+                        3,
+                        2,
                     ),
                     EditableRow(
                         device_id,
                         schema_hash,
-                        f"{node_name}.downsamplingFunction",
-                        5,
-                        5,
+                        f"{node_name}.flipFS",
+                        3,
+                        2,
                     ),
-                ),
+                    padding=0,
+                )
+            )
+        downsampling = VerticalLayout(
+            LabelModel(
+                text="Downsampling",
+                width=8 * BASE_INC,
+                height=BASE_INC,
+            ),
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.downsamplingFactor",
+                5,
+                5,
+            ),
+            EditableRow(
+                device_id,
+                schema_hash,
+                f"{node_name}.downsamplingFunction",
+                5,
+                5,
+            ),
+        )
+        self.children.extend(
+            [
+                tweaks,
+                Vline(height=3 * BASE_INC),
+                downsampling,
             ]
         )
-        if extras is not None:
-            self.children.append(Vline(height=3 * BASE_INC))
-            self.children.extend(extras)
+        if schema_hash.has(f"{node_name}.index"):
+            self.children.extend(
+                [
+                    Vline(height=3 * BASE_INC),
+                    VerticalLayout(
+                        EditableRow(device_id, schema_hash, f"{node_name}.index", 8, 4),
+                        EditableRow(
+                            device_id, schema_hash, f"{node_name}.selectionMode", 8, 4
+                        ),
+                    ),
+                ]
+            )
 
 
 @titled("Histogram settings")
@@ -633,26 +649,23 @@ def correction_device_overview(device_id, schema):
     )
     return VerticalLayout(
         main_overview,
-        LabelModel(
-            text="Preview (corrected):",
-            width=20 * BASE_INC,
-            height=BASE_INC,
-        ),
-        PreviewDisplayArea(device_id, schema_hash, "preview.outputCorrected"),
         *(
             DeviceSceneLinkModel(
-                text=f"Preview: {channel}",
+                text=f"Preview: {preview_name}",
                 keys=[f"{device_id}.availableScenes"],
-                target=f"preview:preview.{channel}",
+                target=f"preview:{preview_name}",
                 target_window=SceneTargetWindow.Dialog,
                 width=16 * BASE_INC,
                 height=BASE_INC,
             )
-            for channel in schema_hash.get("preview").getKeys()
-            if schema_hash.hasAttribute(f"preview.{channel}", "classId")
-            and schema_hash.getAttribute(f"preview.{channel}", "classId")
-            == "OutputChannel"
+            for preview_name in schema_hash.get("preview").getKeys()
+        ),
+        LabelModel(
+            text="Preview (corrected):",
+            width=20 * BASE_INC,
+            height=BASE_INC,
         ),
+        PreviewDisplayArea(device_id, schema_hash, "preview.corrected.output"),
     )
 
 
@@ -666,25 +679,20 @@ def lpdmini_splitter_overview(device_id, schema):
 
 
 @scene_generator
-def correction_device_preview(device_id, schema, preview_channel):
+def correction_device_preview(device_id, schema, name):
     schema_hash = schema_to_hash(schema)
     return VerticalLayout(
         LabelModel(
-            text=f"Preview: {preview_channel}",
+            text=f"Preview: {name}",
             width=10 * BASE_INC,
             height=BASE_INC,
         ),
         PreviewSettings(
             device_id,
             schema_hash,
-            extras=[
-                VerticalLayout(
-                    EditableRow(device_id, schema_hash, "preview.index", 8, 4),
-                    EditableRow(device_id, schema_hash, "preview.selectionMode", 8, 4),
-                )
-            ],
+            f"preview.{name}",
         ),
-        PreviewDisplayArea(device_id, schema_hash, preview_channel),
+        PreviewDisplayArea(device_id, schema_hash, f"preview.{name}.output"),
     )
 
 
@@ -758,7 +766,8 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan
                     ),
                     DisplayCommandModel(
                         keys=[
-                            f"{device_id}.{prefix}.{constant}.overrideConstantFromVersion"
+                            f"{device_id}.{prefix}.{constant}"
+                            ".overrideConstantFromVersion"
                         ],
                         width=8 * BASE_INC,
                         height=BASE_INC,
@@ -829,67 +838,47 @@ def manager_device_overview(
     mds_hash = schema_to_hash(manager_device_schema)
     cds_hash = schema_to_hash(correction_device_schema)
 
-    data_throttling_children = [
-        LabelModel(
-            text="Frame filter",
-            width=11 * BASE_INC,
-            height=BASE_INC,
-        ),
-        EditableRow(
+    config_column = [
+        recursive_editable(
             manager_device_id,
             mds_hash,
-            "managedKeys.frameFilter.type",
-            7,
-            4,
+            "managedKeys.constantParameters",
         ),
-        EditableRow(
-            manager_device_id,
-            mds_hash,
-            "managedKeys.frameFilter.spec",
-            7,
-            4,
+        DisplayCommandModel(
+            keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"],
+            width=10 * BASE_INC,
+            height=BASE_INC,
         ),
     ]
 
-    if "managedKeys.daqTrainStride" in mds_hash:
-        # Only add DAQ train stride if present on the schema, may be
-        # disabled on the manager.
-        data_throttling_children.insert(
-            0,
-            EditableRow(
+    if "managedKeys.preview" in mds_hash:
+        config_column.append(
+            recursive_editable(
                 manager_device_id,
                 mds_hash,
-                "managedKeys.daqTrainStride",
-                7,
-                4,
+                "managedKeys.preview",
+                max_depth=2,
             ),
         )
 
-    return VerticalLayout(
-        HorizontalLayout(
-            ManagerDeviceStatus(manager_device_id),
-            VerticalLayout(
-                recursive_editable(
-                    manager_device_id,
-                    mds_hash,
-                    "managedKeys.constantParameters",
-                ),
-                DisplayCommandModel(
-                    keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"],
-                    width=10 * BASE_INC,
-                    height=BASE_INC,
-                ),
-                recursive_editable(
+    if "managedKeys.daqTrainStride" in mds_hash:
+        config_column.append(
+            titled("Data throttling")(boxed(VerticalLayout))(
+                EditableRow(
                     manager_device_id,
                     mds_hash,
-                    "managedKeys.preview",
-                    max_depth=2,
-                ),
-                titled("Data throttling")(boxed(VerticalLayout))(
-                    children=data_throttling_children,
-                    padding=0,
+                    "managedKeys.daqTrainStride",
+                    7,
+                    4,
                 ),
+                padding=0,
             ),
+        )
+
+    return VerticalLayout(
+        HorizontalLayout(
+            ManagerDeviceStatus(manager_device_id),
+            VerticalLayout(*config_column),
             recursive_editable(
                 manager_device_id,
                 mds_hash,
@@ -932,7 +921,7 @@ def manager_device_overview(
                             text="Documentation",
                             width=6 * BASE_INC,
                             height=BASE_INC,
-                            target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#correction-devices",
+                            target=f"{docs_url}/devices/#correction-devices",
                         ),
                         padding=0,
                     )
@@ -1043,12 +1032,12 @@ def detector_assembler_overview(device_id, schema):
     return VerticalLayout(
         HorizontalLayout(
             AssemblerDeviceStatus(device_id),
-            PreviewSettings(device_id, schema_hash),
+            PreviewSettings(device_id, schema_hash, "preview.assembled"),
         ),
         PreviewDisplayArea(
             device_id,
             schema_hash,
-            "preview.output",
+            "preview.assembled.output",
             data_width=40 * BASE_INC,
             data_height=40 * BASE_INC,
             mask_width=40 * BASE_INC,
diff --git a/src/calng/schemas.py b/src/calng/schemas.py
index 79524daa..3b4820b0 100644
--- a/src/calng/schemas.py
+++ b/src/calng/schemas.py
@@ -44,7 +44,7 @@ def preview_schema(wrap_image_in_imagedata=False):
     return res
 
 
-def xtdf_output_schema(use_shmem_handle=True):
+def xtdf_output_schema(use_shmem_handles):
     # TODO: trim output schema / adapt to specific detectors
     # currently: based on snapshot of actual output reusing AGIPD hash
     res = Schema()
@@ -178,7 +178,7 @@ def xtdf_output_schema(use_shmem_handle=True):
         .commit(),
     )
 
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("image.data")
@@ -194,7 +194,7 @@ def xtdf_output_schema(use_shmem_handle=True):
     return res
 
 
-def jf_output_schema(use_shmem_handle=True):
+def jf_output_schema(use_shmem_handles):
     res = Schema()
     (
         NODE_ELEMENT(res)
@@ -227,7 +227,7 @@ def jf_output_schema(use_shmem_handle=True):
         .readOnly()
         .commit(),
     )
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("data.adc")
@@ -243,7 +243,7 @@ def jf_output_schema(use_shmem_handle=True):
     return res
 
 
-def pnccd_output_schema(use_shmem_handle=True):
+def pnccd_output_schema(use_shmem_handles):
     res = Schema()
     (
         NODE_ELEMENT(res)
@@ -255,7 +255,7 @@ def pnccd_output_schema(use_shmem_handle=True):
         .readOnly()
         .commit(),
     )
-    if use_shmem_handle:
+    if use_shmem_handles:
         (
             STRING_ELEMENT(res)
             .key("data.image")
diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py
index 4d6e7008..8c81d1fe 100644
--- a/src/calng/stacking_utils.py
+++ b/src/calng/stacking_utils.py
@@ -166,8 +166,6 @@ class StackingFriend:
         self.reconfigure(source_config, merge_config)
 
     def reconfigure(self, source_config, merge_config):
-        print("merge_config", type(merge_config))
-        print("source_config", type(source_config))
         if source_config is not None:
             self._source_config = source_config
         if merge_config is not None:
diff --git a/src/calng/utils.py b/src/calng/utils.py
index fd7c6c28..c955d80b 100644
--- a/src/calng/utils.py
+++ b/src/calng/utils.py
@@ -8,6 +8,21 @@ import numpy as np
 from calngUtils import misc
 
 
+class WarningLampType(enum.Enum):
+    FRAME_FILTER = enum.auto()
+    MEMORY_CELL_RANGE = enum.auto()
+    CONSTANT_OPERATING_PARAMETERS = enum.auto()
+    PREVIEW_SETTINGS = enum.auto()
+    CORRECTION_RUNNER = enum.auto()
+    OUTPUT_BUFFER = enum.auto()
+    GPU_MEMORY = enum.auto()
+    CALCAT_CONNECTION = enum.auto()
+    EMPTY_HASH = enum.auto()
+    MISC_INPUT_DATA = enum.auto()
+    TRAIN_ID = enum.auto()
+    TIMESERVER_CONNECTION = enum.auto()
+
+
 class WarningContextSystem:
     """Helper object to trigger warning lamps based on different warning types in
     contexts. Intended use: something that is checked multiple times and is good or bad
@@ -90,63 +105,6 @@ class WarningContextSystem:
         self.device.set("status", message)
 
 
-class PreviewIndexSelectionMode(enum.Enum):
-    FRAME = "frame"
-    CELL = "cell"
-    PULSE = "pulse"
-
-
-def pick_frame_index(selection_mode, index, cell_table, pulse_table):
-    """When selecting a single frame to preview, an obvious question is whether the
-    number the operator provides is a frame index, a cell ID, or a pulse ID. This
-    function allows any of the three, translating into frame index.
-
-    Indices below zero are special values and thus returned directly.
-
-    Returns: (frame index, cell ID, pulse ID), optional warning"""
-
-    if index < 0:
-        return (index, index, index), None
-
-    warning = None
-    selection_mode = PreviewIndexSelectionMode(selection_mode)
-
-    if selection_mode is PreviewIndexSelectionMode.FRAME:
-        if index < cell_table.size:
-            frame_index = index
-        else:
-            warning = (
-                f"Frame index {index} out of range for cell table of length "
-                f"{len(cell_table)}. Will use frame index 0 instead."
-            )
-            frame_index = 0
-    else:
-        if selection_mode is PreviewIndexSelectionMode.CELL:
-            index_table = cell_table
-        else:
-            index_table = pulse_table
-        found = np.where(index_table == index)[0]
-        if len(found) == 1:
-            frame_index = found[0]
-        elif len(found) == 0:
-            frame_index = 0
-            warning = (
-                f"{selection_mode.value.capitalize()} ID {index} not found in "
-                f"{selection_mode.name} table. Will use frame {frame_index}, which "
-                f"corresponds to {selection_mode.name} {index_table[0]} instead."
-            )
-        elif len(found) > 1:
-            warning = (
-                f"{selection_mode.value.capitalize()} ID {index} was not unique in "
-                f"{selection_mode} table. Will use first occurrence out of "
-                f"{len(found)} occurrences (frame {frame_index})."
-            )
-
-    cell_id = cell_table[frame_index]
-    pulse_id = pulse_table[frame_index]
-    return (frame_index, cell_id, pulse_id), warning
-
-
 _np_typechar_to_c_typestring = {
     "?": "bool",
     "B": "unsigned char",
@@ -234,22 +192,37 @@ class BadPixelValues(enum.IntFlag):
 def downsample_2d(arr, factor, reduction_fun=np.nanmax):
     """Generalization of downsampling from FemDataAssembler
 
-    Expects first two dimensions of arr to be multiple of 2 ** factor
-    Useful if you're sitting at home and ssh connection is slow to get full-resolution
-    previews."""
+    Expects last two dimensions of arr to be multiple of 2 ** factor.  Useful for
+    looking at detector previews over a slow connection (say, ssh)."""
 
     for i in range(factor // 2):
+        # downsample slow scan
         arr = reduction_fun(
             (
-                arr[:-1:2],
-                arr[1::2],
+                arr[..., :-1:2, :],
+                arr[..., 1::2, :],
             ),
             axis=0,
         )
+        # downsample fast scan
         arr = reduction_fun(
             (
-                arr[:, :-1:2],
-                arr[:, 1::2],
+                arr[..., :-1:2],
+                arr[..., 1::2],
+            ),
+            axis=0,
+        )
+    return arr
+
+
+def downsample_1d(arr, factor, reduction_fun=np.nanmax):
+    """Same as downsample_2d, but only 1d (applied to last axis of input)"""
+
+    for i in range(factor // 2):
+        arr = reduction_fun(
+            (
+                arr[..., :-1:2],
+                arr[..., 1::2],
             ),
             axis=0,
         )
@@ -289,3 +262,25 @@ def apply_partial_lut(data, lut, mask, out, missing=np.nan):
     tmp = out.ravel()
     tmp[~mask] = data.ravel()[lut]
     tmp[mask] = missing
+
+
+def subset_of_hash(input_hash, *keys):
+    # don't bother importing here, just construct empty
+    res = input_hash.__class__()
+    for key in keys:
+        if input_hash.has(key):
+            res.set(key, input_hash.get(key))
+    return res
+
+
+def maybe_get(a, out=None):
+    """For getting CuPy ndarray data to system memory - tries to behave like
+    cupy.ndarray.get with respect to the 'out' parameter"""
+    if hasattr(a, "get"):
+        return a.get(out=out)
+    else:
+        if out is None:
+            return a
+        else:
+            np.copyto(out, a)
+            return out
diff --git a/tests/common_setup.py b/tests/common_setup.py
new file mode 100644
index 00000000..fafe2cb7
--- /dev/null
+++ b/tests/common_setup.py
@@ -0,0 +1,100 @@
+import contextlib
+import pathlib
+import threading
+
+import h5py
+import numpy as np
+from calngUtils import device as device_utils
+from calngUtils.misc import ChainHash
+from karabo.bound import Hash, Schema
+
+
+def maybe_get(b):
+    if hasattr(b, "get"):
+        return b.get()
+    else:
+        return b
+
+
+class DummyLogger:
+    DEBUG = print
+    INFO = print
+    WARN = print
+
+
+@device_utils.with_config_overlay
+@device_utils.with_unsafe_get
+class DummyCorrectionDevice:
+    log = DummyLogger()
+
+    def log_status_info(self, msg):
+        self.log.INFO(msg)
+
+    def log_status_warn(self, msg):
+        self.log.WARN(msg)
+
+    def get(self, key):
+        return self._parameters.get(key)
+
+    def set(self, key, value):
+        self._parameters[key] = value
+
+    def reconfigure(self, config):
+        with self.new_config_context(config):
+            self.runner.reconfigure(config)
+        self._parameters.merge(config)
+
+    def __init__(self, *friends):
+        self._parameters = Hash()
+        s = Schema()
+        self._temporary_config = []
+        self._temporary_config_lock = threading.RLock()
+        for friend in friends:
+            if isinstance(friend, dict):
+                friend_class = friend.pop("friend_class")
+                extra_args = friend
+            else:
+                friend_class = friend
+                extra_args = {}
+            friend_class.add_schema(s, **extra_args)
+        h = s.getParameterHash()
+        for path in h.getPaths():
+            if h.hasAttribute(path, "defaultValue"):
+                if "output" in path.split("."):
+                    continue
+                self._parameters[path] = h.getAttribute(path, "defaultValue")
+
+        self.warnings = {}
+
+    def _please_send_me_cached_constants(self, constants, callback):
+        pass
+
+    @contextlib.contextmanager
+    def warning_context(self, key, warn_type):
+        def aux(message):
+            print(message)
+            self.warnings[(key, warn_type)] = message
+
+        yield aux
+
+    @contextlib.contextmanager
+    def new_config_context(self, config):
+        with self._temporary_config_lock:
+            old_config = self._parameters
+            if isinstance(self._parameters, ChainHash):
+                # in case we need to support recursive case / configuration stacking
+                self._parameters = ChainHash(config, *old_config._hashes)
+            else:
+                self._parameters = ChainHash(config, old_config)
+            yield
+            self._parameters = old_config
+
+
+caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store")
+
+
+def get_constant_from_file(path_to_file, file_name, data_set_name):
+    """Intended to work with copy-pasting paths from CalCat. Specifially, the "Path to
+    File", "File Name", and "Data Set Name" fields displayed for a given CCV."""
+    with h5py.File(caldb_store / path_to_file / file_name, "r") as fd:
+        return np.array(fd[f"{data_set_name}/data"])
diff --git a/tests/test_agipd_kernels.py b/tests/test_agipd_kernels.py
index afbb2ca0..25021ae9 100644
--- a/tests/test_agipd_kernels.py
+++ b/tests/test_agipd_kernels.py
@@ -1,19 +1,23 @@
-import h5py
 import numpy as np
-import pathlib
-
+import pytest
 from calng.corrections.AgipdCorrection import (
+    AgipdBaseRunner,
+    AgipdCalcatFriend,
     AgipdCpuRunner,
     AgipdGpuRunner,
     Constants,
     CorrectionFlags,
 )
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get
 
 output_dtype = np.float32
 corr_dtype = np.float32
 pixels_x = 512
 pixels_y = 128
 memory_cells = 352
+num_frames = memory_cells
 
 raw_data = np.random.randint(
     low=0, high=2000, size=(memory_cells, 2, pixels_x, pixels_y), dtype=np.uint16
@@ -23,29 +27,26 @@ raw_gain = raw_data[:, 1]
 cell_table = np.arange(memory_cells, dtype=np.uint16)
 np.random.shuffle(cell_table)
 
-caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store/xfel/cal")
-caldb_prefix = caldb_store / "agipd-type/agipd_siv1_agipdv11_m305"
-
-with h5py.File(caldb_prefix / "cal.1619543695.4679213.h5", "r") as fd:
-    thresholds = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0/data"])
-with h5py.File(caldb_prefix / "cal.1619543664.1545036.h5", "r") as fd:
-    offset_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/Offset/0/data"])
-with h5py.File(caldb_prefix / "cal.1615377705.8904035.h5", "r") as fd:
-    slopes_pc_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0/data"])
-
-kernel_runners = [
-    runner_class(
-        pixels_x,
-        pixels_y,
-        memory_cells,
-        constant_memory_cells=memory_cells,
-        output_data_dtype=output_dtype,
-    )
-    for runner_class in (AgipdCpuRunner, AgipdGpuRunner)
-]
+thresholds = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1619543695.4679213.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0",
+)
+offset_map = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1619543664.1545036.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/Offset/0",
+)
+slopes_pc_map = get_constant_from_file(
+    "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305",
+    "cal.1615377705.8904035.h5",
+    "/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0",
+)
 
+# first compute things naively
 
-def thresholding_cpu(data, cell_table, thresholds):
+
+def thresholding_naive(data, cell_table, thresholds):
     # get to memory_cell, x, y
     raw_gain = data[:, 1, ...].astype(corr_dtype, copy=False)
     # get to threshold, memory_cell, x, y
@@ -56,16 +57,13 @@ def thresholding_cpu(data, cell_table, thresholds):
     return res
 
 
-gain_map_cpu = thresholding_cpu(raw_data, cell_table, thresholds)
-
-
-def corr_offset_cpu(data, cell_table, gain_map, offset):
+def corr_offset_naive(data, cell_table, gain_map, offset):
     image_data = data[:, 0].astype(corr_dtype, copy=False)
     offset = np.transpose(offset)[:, cell_table]
     return (image_data - np.choose(gain_map, offset)).astype(output_dtype)
 
 
-def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
+def corr_rel_gain_pc_naive(data, cell_table, gain_map, slopes_pc):
     slopes_pc = slopes_pc.astype(np.float32, copy=False)
     pc_high_m = slopes_pc[0]
     pc_high_I = slopes_pc[1]
@@ -77,7 +75,7 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
     rel_gain_map[0] = 1  # rel xray gain can come after
     rel_gain_map[1] = rel_gain_map[0] * np.transpose(frac_high_med)
     rel_gain_map[2] = rel_gain_map[1] * 4.48
-    res = data[:, 0].astype(corr_dtype, copy=True)
+    res = data.astype(corr_dtype, copy=True)
     res *= np.choose(gain_map, np.transpose(rel_gain_map, (0, 3, 1, 2)))
     pixels_in_medium_gain = gain_map == 1
     res[pixels_in_medium_gain] += np.transpose(md_additional_offset, (0, 2, 1))[
@@ -86,35 +84,90 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc):
     return res
 
 
-def test_thresholding():
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.correct(np.uint8(CorrectionFlags.THRESHOLD))
-    assert np.allclose(kernel_runners[0].gain_map, gain_map_cpu)
-    assert np.allclose(kernel_runners[1].gain_map.get(), gain_map_cpu)
-
-
-def test_offset():
-    reference = corr_offset_cpu(raw_data, cell_table, gain_map_cpu, offset_map)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.load_constant(Constants.Offset, offset_map)
-        # have to do thresholding, otherwise all is treated as high gain
-        runner.correct(np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET))
-        res = runner.reshape(runner._corrected_axis_order)
-        assert np.allclose(res, reference), f"{runner.__class__}"
-
-
-def test_rel_gain_pc():
-    reference = corr_rel_gain_pc_cpu(raw_data, cell_table, gain_map_cpu, slopes_pc_map)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.load_constant(Constants.ThresholdsDark, thresholds)
-        runner.load_constant(Constants.SlopesPC, slopes_pc_map)
-        runner.correct(
-            np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.REL_GAIN_PC)
+gain_map_naive = thresholding_naive(raw_data, cell_table, thresholds)
+offset_corrected_naive = corr_offset_naive(
+    raw_data, cell_table, gain_map_naive, offset_map
+)
+gain_corrected_naive = corr_rel_gain_pc_naive(
+    offset_corrected_naive, cell_table, gain_map_naive, slopes_pc_map
+)
+
+
+@pytest.fixture(params=[AgipdCpuRunner, AgipdGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(AgipdBaseRunner, AgipdCalcatFriend)
+    runner = request.param(device)
+    # TODO: also test ASIC seam masking
+    runner.reconfigure(
+        Hash(
+            "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE",
+            False,
         )
-        res = runner.reshape(runner._corrected_axis_order)
-        assert np.allclose(res, reference)
+    )
+    return runner
+
+
+# then start testing individually
+
+
+def test_thresholding(kernel_runner):
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    # AgipdBaseRunner's buffers don't depend on flags
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output_gain), gain_map_naive.astype(np.float32))
+
+
+def test_offset(kernel_runner):
+    # have to also do thresholding, otherwise all is treated as high gain
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output), offset_corrected_naive)
+
+
+def test_rel_gain_pc(kernel_runner):
+    # similarly, do previous steps first
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    kernel_runner.load_constant(Constants.SlopesPC, slopes_pc_map)
+    output, output_gain = kernel_runner._make_output_buffers(num_frames, None)
+    kernel_runner._correct(
+        kernel_runner._xp.uint8(
+            CorrectionFlags.THRESHOLD
+            | CorrectionFlags.OFFSET
+            | CorrectionFlags.REL_GAIN_PC
+        ),
+        kernel_runner._xp.asarray(raw_data),
+        kernel_runner._xp.asarray(cell_table),
+        output,
+        output_gain,
+    )
+    assert np.allclose(maybe_get(output), gain_corrected_naive)
+
+
+def test_with_preview(kernel_runner):
+    kernel_runner.load_constant(Constants.ThresholdsDark, thresholds)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    (
+        _,
+        processed_data,
+        (preview_raw, preview_corrected, preview_raw_gain, preview_gain_stage),
+    ) = kernel_runner.correct(raw_data, cell_table)
+    assert np.allclose(maybe_get(processed_data), offset_corrected_naive)
+    assert np.allclose(raw_data[:, 0], maybe_get(preview_raw))
+    assert np.allclose(offset_corrected_naive, maybe_get(preview_corrected))
+    assert np.allclose(raw_data[:, 1], maybe_get(preview_raw_gain))
diff --git a/tests/test_dssc_kernels.py b/tests/test_dssc_kernels.py
index 6ab37a8c..f7df8ef9 100644
--- a/tests/test_dssc_kernels.py
+++ b/tests/test_dssc_kernels.py
@@ -1,29 +1,34 @@
 import numpy as np
 import pytest
-
 from calng.corrections.DsscCorrection import (
-    CorrectionFlags,
+    Constants,
+    DsscBaseRunner,
+    DsscCalcatFriend,
     DsscCpuRunner,
     DsscGpuRunner,
 )
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, maybe_get
 
 output_dtype = np.float32
 corr_dtype = np.float32
 pixels_x = 512
 pixels_y = 128
 memory_cells = 400
+num_frames = memory_cells
 offset_map = (
     np.random.random(size=(pixels_x, pixels_y, memory_cells)).astype(corr_dtype) * 20
 )
 cell_table = np.arange(memory_cells, dtype=np.uint16)
 np.random.shuffle(cell_table)
 raw_data = np.random.randint(
-    low=0, high=2000, size=(memory_cells, pixels_y, pixels_x), dtype=np.uint16
+    low=0, high=2000, size=(memory_cells, 1, pixels_y, pixels_x), dtype=np.uint16
 )
 
 
 # TODO: gather CPU implementations elsewhere
-def correct_cpu(data, cell_table, offset_map):
+def correct_naive(data, cell_table, offset_map):
     corr = np.squeeze(data).astype(corr_dtype, copy=True)
     safe_cell_bool = cell_table < offset_map.shape[-1]
     safe_cell_index = cell_table[safe_cell_bool]
@@ -31,127 +36,76 @@ def correct_cpu(data, cell_table, offset_map):
     return corr.astype(output_dtype, copy=False)
 
 
-corrected_data = correct_cpu(raw_data, cell_table, offset_map)
+corrected_data = correct_naive(raw_data, cell_table, offset_map)
 only_cast_data = np.squeeze(raw_data).astype(output_dtype)
 
 
-kernel_runners = [
-    runner_class(
-        pixels_x,
-        pixels_y,
-        memory_cells,
-        constant_memory_cells=memory_cells,
-        output_data_dtype=output_dtype,
-    )
-    for runner_class in (DsscCpuRunner, DsscGpuRunner)
-]
-
-
-def test_only_cast():
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        runner.correct(CorrectionFlags.NONE)
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), only_cast_data)
-
-
-def test_correct():
-    for runner in kernel_runners:
-        runner.load_offset_map(offset_map)
-        runner.load_data(raw_data, cell_table)
-        runner.correct(CorrectionFlags.OFFSET)
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), corrected_data)
-
-
-def test_correct_oob_cells():
-    wild_cell_table = cell_table * 2
-    reference = correct_cpu(raw_data, wild_cell_table, offset_map)
-    for runner in kernel_runners:
-        runner.load_offset_map(offset_map)
-        # here, half the cell IDs will be out of bounds
-        runner.load_data(raw_data, wild_cell_table)
-        # should not crash
-        runner.correct(CorrectionFlags.OFFSET)
-        # should correct as much as possible
-        assert np.allclose(runner.reshape(runner._corrected_axis_order), reference)
-
-
-def test_reshape():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        assert np.allclose(
-            runner.reshape(output_order="xyf"), corrected_data.transpose()
-        )
-
-
-def test_preview_slice():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(42)
-        assert np.allclose(
-            preview_raw,
-            raw_data[42].astype(np.float32),
-        )
-        assert np.allclose(
-            preview_corrected,
-            corrected_data[42].astype(np.float32),
-        )
+@pytest.fixture(params=[DsscCpuRunner, DsscGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(DsscBaseRunner, DsscCalcatFriend)
+    runner = request.param(device)
+    device.runner = runner
+    return runner
 
 
-def test_preview_max():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        # note: in case correction failed, still test this separately
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-1)
-        assert np.allclose(preview_raw, np.max(raw_data, axis=0).astype(np.float32))
-        assert np.allclose(
-            preview_corrected, np.max(corrected_data, axis=0).astype(np.float32)
-        )
-
-
-def test_preview_mean():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-2)
-        assert np.allclose(preview_raw, np.nanmean(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nanmean(corrected_data, axis=0, dtype=np.float32)
-        )
-
-
-def test_preview_sum():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-3)
-        assert np.allclose(preview_raw, np.nansum(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nansum(corrected_data, axis=0, dtype=np.float32)
+def test_only_cast(kernel_runner):
+    assert raw_data.shape == kernel_runner.expected_input_shape(num_frames)
+    # note: previews are just buffers, not yet reduced
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), only_cast_data)
+    assert np.allclose(maybe_get(preview_raw), only_cast_data)
+    assert np.allclose(maybe_get(preview_corrected), only_cast_data)
+
+
+def test_correct(kernel_runner):
+    for preview in ("raw", "corrected"):
+        kernel_runner._device.reconfigure(
+            Hash(
+                f"preview.{preview}.index",
+                0,
+                f"preview.{preview}.selectionMode",
+                "frame",
+            )
         )
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), corrected_data)
+    assert np.allclose(maybe_get(preview_raw), raw_data[:, 0])
+    assert np.allclose(maybe_get(preview_corrected), corrected_data)
 
-
-def test_preview_std():
-    kernel_runners[0].processed_data = corrected_data
-    kernel_runners[1].processed_data.set(corrected_data)
-    for runner in kernel_runners:
-        runner.load_data(raw_data, cell_table)
-        preview_raw, preview_corrected = runner.compute_previews(-4)
-        assert np.allclose(preview_raw, np.nanstd(raw_data, axis=0, dtype=np.float32))
-        assert np.allclose(
-            preview_corrected, np.nanstd(corrected_data, axis=0, dtype=np.float32)
+    kernel_runner._device.reconfigure(Hash("corrections.offset.enable", False))
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_corrected), corrected_data)
+
+    kernel_runner._device.reconfigure(
+        Hash(
+            "corrections.offset.enable",
+            True,
+            "corrections.offset.preview",
+            False,
         )
+    )
+    _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct(
+        raw_data, cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), corrected_data)
+    assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data))
+    assert np.allclose(maybe_get(preview_corrected), np.squeeze(raw_data))
 
 
-def test_preview_valid_index():
-    for runner in kernel_runners:
-        with pytest.raises(ValueError):
-            runner.compute_previews(-5)
-        with pytest.raises(ValueError):
-            runner.compute_previews(memory_cells)
+def test_correct_oob_cells(kernel_runner):
+    wild_cell_table = cell_table * 2
+    reference = correct_naive(raw_data, wild_cell_table, offset_map)
+    kernel_runner.load_constant(Constants.Offset, offset_map)
+    _, processed_data, previews = kernel_runner.correct(
+        raw_data, wild_cell_table
+    )
+    assert np.allclose(maybe_get(processed_data), reference)
diff --git a/tests/test_jungfrau_kernels.py b/tests/test_jungfrau_kernels.py
new file mode 100644
index 00000000..94f18794
--- /dev/null
+++ b/tests/test_jungfrau_kernels.py
@@ -0,0 +1,108 @@
+import numpy as np
+import pytest
+from calng.corrections.JungfrauCorrection import (
+    Constants,
+    JungfrauBaseRunner,
+    JungfrauCalcatFriend,
+    JungfrauCpuRunner,
+    JungfrauGpuRunner,
+)
+from calng.utils import BadPixelValues
+from karabo.bound import Hash
+
+from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get
+
+
+raw_data = np.rint(np.random.random((16, 512, 1024)) * 5000).astype(np.uint16)
+gain = np.random.choice(
+    np.uint8([0, 1, 3]), size=raw_data.shape
+)  # not testing gain value 2 (wrong) for now
+gain_as_index = gain.copy()
+gain_as_index[gain_as_index == 3] = 2
+constants = {
+    Constants.Offset10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1714364366.1478188.h5",
+        "/Jungfrau_M572/Offset10Hz/0",
+    ),
+    Constants.BadPixelsDark10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1714364369.7054155.h5",
+        "/Jungfrau_M572/BadPixelsDark10Hz/0",
+    ),
+    Constants.BadPixelsFF10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1711037632.1430287.h5",
+        "/Jungfrau_M572/BadPixelsFF10Hz/0",
+    ),
+    Constants.RelativeGain10Hz: get_constant_from_file(
+        "xfel/cal/jungfrau-type/jungfrau_m572/",
+        "cal.1711036758.5906427.h5",
+        "/Jungfrau_M572/RelativeGain10Hz/0",
+    ),
+}
+cell_table = np.arange(16, dtype=np.uint8)
+np.random.shuffle(cell_table)
+cast_data = raw_data.astype(np.float32)
+offset_corrected = cast_data - np.choose(
+    gain_as_index,
+    np.transpose(constants[Constants.Offset10Hz][:, :, cell_table], (3, 2, 0, 1)),
+)
+gain_corrected = offset_corrected / np.choose(
+    gain_as_index, constants[Constants.RelativeGain10Hz][:, :, cell_table].T
+)
+masked_dark = gain_corrected.copy()
+masked_dark[
+    np.choose(
+        gain_as_index,
+        np.transpose(constants[Constants.BadPixelsDark10Hz], (3, 2, 0, 1))[
+            :, cell_table
+        ],
+    )
+    != 0
+] = np.nan
+masked_ff = masked_dark.copy()
+masked_ff[
+    np.choose(
+        gain_as_index,
+        np.transpose(constants[Constants.BadPixelsFF10Hz], (3, 2, 1, 0))[:, cell_table],
+    )
+    != 0
+] = np.nan
+
+
+@pytest.fixture(params=[JungfrauCpuRunner, JungfrauGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(JungfrauBaseRunner, JungfrauCalcatFriend)
+    runner = request.param(device)
+    device.runner = runner
+    # TODO: test masking aroudn ASIC seams
+    # TODO: test single-frame mode
+    device.reconfigure(
+        Hash(
+            "constantParameters.memoryCells",
+            16,
+        )
+    )
+    device.reconfigure(
+        Hash(
+            "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE",
+            False,
+        )
+    )
+    return runner
+
+
+def test_cast(kernel_runner):
+    _, result, previews = kernel_runner.correct(raw_data, cell_table, gain)
+    result = maybe_get(result)
+    assert np.allclose(result, cast_data, equal_nan=True)
+
+
+def test_correct(kernel_runner):
+    for constant, constant_data in constants.items():
+        kernel_runner.load_constant(constant, constant_data)
+
+    _, result, previews = kernel_runner.correct(raw_data, cell_table, gain)
+    result = maybe_get(result)
+    assert np.allclose(result, masked_ff, equal_nan=True)
diff --git a/tests/test_lpd_kernels.py b/tests/test_lpd_kernels.py
new file mode 100644
index 00000000..cf51fb20
--- /dev/null
+++ b/tests/test_lpd_kernels.py
@@ -0,0 +1,98 @@
+import numpy as np
+import pytest
+from calng.corrections.LpdCorrection import (
+    Constants,
+    LpdBaseRunner,
+    LpdCalcatFriend,
+    LpdCpuRunner,
+    LpdGpuRunner,
+)
+
+from common_setup import DummyCorrectionDevice, maybe_get
+
+
+@pytest.fixture(params=[LpdCpuRunner, LpdGpuRunner])
+def kernel_runner(request):
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return request.param(device)
+
+
+@pytest.fixture
+def cpu_runner():
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return LpdCpuRunner(device)
+
+
+@pytest.fixture
+def gpu_runner():
+    device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend)
+    return LpdGpuRunner(device)
+
+
+# generate some "raw data"
+raw_data_shape = (100, 1, 256, 256)
+# the image data part
+raw_image = (np.arange(np.product(raw_data_shape)).reshape(raw_data_shape)).astype(
+    np.uint16
+)
+# keeping in mind 12-bit ADC
+raw_image = np.bitwise_and(raw_image, 0x0FFF)
+raw_image_32 = raw_image.astype(np.float32)
+# the gain values (TODO: test with some being illegal)
+gain_data = (np.random.random(raw_data_shape) * 3).astype(np.uint16)
+assert np.array_equal(gain_data, np.bitwise_and(gain_data, 0x0003))
+gain_data_32 = gain_data.astype(np.float32)
+raw_data = np.bitwise_or(raw_image, np.left_shift(gain_data, 12))
+# just checking that we constructed something reasonable
+assert np.array_equal(gain_data, np.bitwise_and(np.right_shift(raw_data, 12), 0x0003))
+assert np.array_equal(raw_image, np.bitwise_and(raw_data, 0x0FFF))
+
+# generate some defects
+wrong_gain_value_indices = (np.random.random(raw_data_shape) < 0.005).astype(np.bool_)
+raw_data_with_some_wrong_gain = raw_data.copy()
+raw_data_with_some_wrong_gain[wrong_gain_value_indices] = np.bitwise_or(
+    raw_data_with_some_wrong_gain[wrong_gain_value_indices], 0x3000
+)
+
+# generate some constants
+dark_constant_shape = (256, 256, 512, 3)
+# okay, as slow and fast scan have same size, can't tell the difference here...
+funky_constant_shape = (256, 256, 512, 3)
+offset_constant = np.random.random(dark_constant_shape).astype(np.float32)
+gain_amp_constant = np.random.random(funky_constant_shape).astype(np.float32)
+bp_dark_constant = (np.random.random(dark_constant_shape) < 0.01).astype(np.uint32)
+
+cell_table = np.linspace(0, 512, 100).astype(np.uint16)
+np.random.shuffle(cell_table)
+
+
+def test_only_cast(kernel_runner):
+    _, processed, previews = kernel_runner.correct(raw_data, cell_table)
+    assert np.allclose(maybe_get(processed), raw_image_32[:, 0])
+    # TODO: make raw preview show image data (remove gain bits)
+    assert np.allclose(previews[0], raw_data[:, 0])  # raw: raw
+    assert np.allclose(previews[1], raw_image_32[:, 0])  # processed: raw
+    assert np.allclose(previews[2], gain_data_32[:, 0])  # gain: unchanged
+
+
+def test_only_cast_with_some_wrong_gain(kernel_runner):
+    (
+        _,
+        processed,
+        _,
+    ) = kernel_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    assert np.all(np.isnan(processed[wrong_gain_value_indices[:, 0]]))
+
+
+def test_correct(cpu_runner, gpu_runner):
+    # naive numpy version way too slow, just compare the two runners
+    # the CPU one uses kernel almost identical to pycalibration
+    cpu_runner.load_constant(Constants.Offset, offset_constant)
+    cpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant)
+    cpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant)
+    gpu_runner.load_constant(Constants.Offset, offset_constant)
+    gpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant)
+    gpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant)
+    _, processed_cpu, _ = cpu_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    _, processed_gpu, _ = gpu_runner.correct(raw_data_with_some_wrong_gain, cell_table)
+    assert np.allclose(processed_cpu, processed_gpu.get(), equal_nan=True)
diff --git a/tests/test_pnccd_kernels.py b/tests/test_pnccd_kernels.py
index 18f438fa..35da5f92 100644
--- a/tests/test_pnccd_kernels.py
+++ b/tests/test_pnccd_kernels.py
@@ -1,50 +1,78 @@
+import itertools
+
 import numpy as np
+import pytest
 from calng import utils
-from calng.corrections.PnccdCorrection import Constants, CorrectionFlags, PnccdCpuRunner
+from calng.corrections.PnccdCorrection import (
+    Constants,
+    PnccdCalcatFriend,
+    PnccdCpuRunner,
+)
+from karabo.bound import Hash
 
+from common_setup import DummyCorrectionDevice
+
+
+@pytest.fixture
+def kernel_runner():
+    device = DummyCorrectionDevice(PnccdCpuRunner, PnccdCalcatFriend)
+    runner = PnccdCpuRunner(device)
+    device.runner = runner
+    return runner
 
-def test_common_mode():
-    def slow_common_mode(
-        data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs
-    ):
-        masked = np.ma.masked_array(
-            data=data,
-            mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma),
-        )
 
-        if cm_fs:
-            subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac
-            data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True)
+def slow_common_mode(
+    data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs
+):
+    masked = np.ma.masked_array(
+        data=data,
+        mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma),
+    )
 
-        if cm_ss:
-            subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac
-            data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0)
+    if cm_fs:
+        subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac
+        data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True)
 
+    if cm_ss:
+        subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac
+        data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0)
+
+
+def test_common_mode(kernel_runner):
     data_shape = (1024, 1024)
     data = (np.random.random(data_shape) * 1000).astype(np.uint16)
-    runner = PnccdCpuRunner(1024, 1024, 1, 1)
-    for noise_level in (0, 5, 100, 1000):
-        noise = (np.random.random(data_shape) * noise_level).astype(np.float32)
-        runner.load_constant(Constants.NoiseCCD, noise)
-        for bad_pixel_percentage in (0, 1, 20, 80, 100):
-            bad_pixels = (
-                np.random.random(data_shape) < (bad_pixel_percentage / 100)
-            ).astype(np.uint32)
-            runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixels)
-            runner.load_data(data.copy())
-            runner.correct(
-                CorrectionFlags.COMMONMODE,
-                cm_min_frac=0.25,
-                cm_noise_sigma=10,
-                cm_row=True,
-                cm_col=True,
-            )
-
-            slow_version = data.astype(np.float32, copy=True)
-            for q_data, q_noise, q_bad_pixels in zip(
-                utils.quadrant_views(slow_version),
-                utils.quadrant_views(noise),
-                utils.quadrant_views(bad_pixels),
-            ):
-                slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True)
-            assert np.allclose(runner.processed_data, slow_version, equal_nan=True)
+    noise_maps = [
+        (np.random.random(data_shape) * noise_level).astype(np.float32)
+        for noise_level in (0, 5, 100, 1000)
+    ]
+    bad_pixel_maps = [
+        (np.random.random(data_shape) < (bad_pixel_percentage / 100)).astype(np.uint32)
+        for bad_pixel_percentage in (0, 1, 20, 80, 100)
+    ]
+    kernel_runner._device.reconfigure(
+        Hash(
+            "corrections.commonMode.noiseSigma",
+            10,
+            "corrections.commonMode.minFrac",
+            0.25,
+            "corrections.commonMode.enableRow",
+            True,
+            "corrections.commonMode.enableCol",
+            True,
+            "corrections.badPixels.enable",
+            False,
+        )
+    )
+    for noise_map, bad_pixel_map in itertools.product(noise_maps, bad_pixel_maps):
+        kernel_runner.load_constant(Constants.NoiseCCD, noise_map)
+        kernel_runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixel_map)
+        _, processed_data, _ = kernel_runner.correct(data, None)
+
+        slow_version = data.astype(np.float32, copy=True)
+        for q_data, q_noise, q_bad_pixels in zip(
+            utils.quadrant_views(slow_version),
+            utils.quadrant_views(noise_map),
+            utils.quadrant_views(bad_pixel_map),
+        ):
+            slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True)
+        assert np.allclose(processed_data, slow_version, equal_nan=True)
diff --git a/tests/test_preview_utils.py b/tests/test_preview_utils.py
new file mode 100644
index 00000000..043d7ac6
--- /dev/null
+++ b/tests/test_preview_utils.py
@@ -0,0 +1,160 @@
+import numpy as np
+import pytest
+from karabo.bound import Hash
+from calng.preview_utils import PreviewFriend, PreviewSpec
+
+from common_setup import DummyCorrectionDevice
+
+
+cell_table = np.array([0, 2, 1], dtype=np.uint16)
+pulse_table = np.array([1, 2, 0], dtype=np.uint16)
+raw_data = np.array(
+    [
+        [1.0, 2.0, 3.0],
+        [4.0, np.nan, 5.0],
+        [6.0, 7.0, 8.0],
+    ],
+    dtype=np.float32,
+)
+raw_without_nan = raw_data.copy()
+raw_without_nan[~np.isfinite(raw_data)] = 0
+
+
+@pytest.fixture
+def device():
+    output_specs = [
+        PreviewSpec(
+            "dummyPreview",
+            dimensions=1,
+            wrap_in_imagedata=False,
+            frame_reduction=True,
+        )
+    ]
+    device = DummyCorrectionDevice(
+        {"friend_class": PreviewFriend, "outputs": output_specs}
+    )
+
+    def writeChannel(name, hsh, **kwargs):
+        if not hasattr(device, "_written"):
+            device._written = []
+        device._written.append(("name", hsh))
+
+    device.writeChannel = writeChannel
+    device.preview_friend = PreviewFriend(device, output_specs)
+
+    return device
+
+
+def test_preview_frame(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            1,
+            "dummyPreview.selectionMode",
+            "frame",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    assert hasattr(device, "_written")
+    assert len(device._written) == 1
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[1])
+
+
+def test_preview_cell(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            1,
+            "dummyPreview.selectionMode",
+            "cell",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[2])
+
+
+def test_preview_pulse(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            2,
+            "dummyPreview.selectionMode",
+            "pulse",
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], raw_without_nan[1])
+
+
+def test_preview_max(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -1,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(preview["image.data"], np.nanmax(raw_data, axis=0))
+
+
+def test_preview_mean(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -2,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nanmean(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+def test_preview_sum(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -3,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nansum(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+def test_preview_std(device):
+    device.preview_friend.reconfigure(
+        Hash(
+            "dummyPreview.index",
+            -4,
+        )
+    )
+    device.preview_friend.write_outputs(
+        raw_data, cell_table=cell_table, pulse_table=pulse_table
+    )
+    preview = device._written[0][1]
+    assert np.allclose(
+        preview["image.data"], np.nanstd(raw_data, axis=0, dtype=np.float32)
+    )
+
+
+# TODO: also test alternative preview index selection modes
-- 
GitLab