From 347dac7c627265dcf26f2d38335bb24b1ea7280f Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 15 Sep 2021 21:31:42 +0200
Subject: [PATCH] Improve correction config, constant loading, abstractions,
 start DRY

Quite WIP.  Have marked some TODOs.  Will add some provisional changes to
manager.
---
 src/calng/AgipdCorrection.py    | 255 +++++++++++++-----------------
 src/calng/CalibrationManager.py |  11 ++
 src/calng/DsscCorrection.py     | 264 +++++++++-----------------------
 src/calng/agipd_gpu.py          |  27 ++--
 src/calng/base_correction.py    | 185 +++++++++++++++++++---
 src/calng/base_gpu.py           |  84 +++++-----
 src/calng/dssc_gpu.py           |  43 +++---
 src/calng/dssc_gpu_kernels.cpp  |  44 ++----
 8 files changed, 436 insertions(+), 477 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index a2a07797..50ed129d 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -1,11 +1,10 @@
 import timeit
 
-import calibrationBase
 import numpy as np
-from karabo.bound import KARABO_CLASSINFO
+from karabo.bound import BOOL_ELEMENT, KARABO_CLASSINFO
 from karabo.common.states import State
 
-from . import shmem_utils, utils
+from . import utils
 from ._version import version as deviceVersion
 from .base_correction import BaseCorrection
 from .agipd_gpu import AgipdGpuRunner, CorrectionFlags
@@ -13,6 +12,17 @@ from .agipd_gpu import AgipdGpuRunner, CorrectionFlags
 
 @KARABO_CLASSINFO("AgipdCorrection", deviceVersion)
 class AgipdCorrection(BaseCorrection):
+    # subclass *must* set these attributes
+    _correction_flag_class = CorrectionFlags
+    _correction_slot_names = (
+        ("thresholding", CorrectionFlags.THRESHOLD),
+        ("offset", CorrectionFlags.OFFSET),
+        ("relGainPc", CorrectionFlags.REL_GAIN_PC),
+        ("relGainXray", CorrectionFlags.REL_GAIN_XRAY),
+        ("badpixels", CorrectionFlags.BPMASK),
+    )
+    _gpu_runner_class = AgipdGpuRunner
+
     @staticmethod
     def expectedParameters(expected):
         AgipdCorrection.addConstant(
@@ -49,6 +59,46 @@ class AgipdCorrection(BaseCorrection):
             mandatoryForIteration=True,
         )
         super(AgipdCorrection, AgipdCorrection).expectedParameters(expected)
+        for slot_name, _ in AgipdCorrection._correction_slot_names:
+            (
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.available.{slot_name}")
+                .readOnly()
+                .initialValue(False)
+                .commit(),
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.enabled.{slot_name}")
+                .assignmentOptional()
+                .defaultValue(False)
+                .reconfigurable()
+                .commit(),
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.preview.{slot_name}")
+                .assignmentOptional()
+                .defaultValue(False)
+                .reconfigurable()
+                .commit(),
+            )
+
+    @property
+    def input_data_shape(self):
+        return (
+            self.get("dataFormat.memoryCells"),
+            2,
+            self.get("dataFormat.pixelsX"),
+            self.get("dataFormat.pixelsY"),
+        )
+
+    @property
+    def output_data_shape(self):
+        return utils.shape_after_transpose(
+            (
+                self.get("dataFormat.memoryCells"),
+                self.get("dataFormat.pixelsX"),
+                self.get("dataFormat.pixelsY"),
+            ),
+            self._output_transpose,
+        )
 
     def __init__(self, config):
         super().__init__(config)
@@ -59,17 +109,7 @@ class AgipdCorrection(BaseCorrection):
             self._output_transpose = (2, 1, 0)
         else:
             self._output_transpose = None
-        self._offset_map = None
-        self._update_pulse_filter(config.get("pulseFilter"))
-        self._update_shapes(
-            config.get("dataFormat.pixelsX"),
-            config.get("dataFormat.pixelsY"),
-            config.get("dataFormat.memoryCells"),
-            self.pulse_filter,
-            self._output_transpose,
-        )
-        self._cached_constants = {}
-
+        self._update_shapes()
         self.updateState(State.ON)
 
     def process_input(self, data, metadata):
@@ -102,7 +142,6 @@ class AgipdCorrection(BaseCorrection):
             self.log.WARN(msg)
             return
         # original shape: memory_cell, data/raw_gain, x, y
-        # TODO: consider making paths configurable
         image_data = data.get("image.data")
         if image_data.shape[0] != self.get("dataFormat.memoryCells"):
             self.set(
@@ -111,23 +150,8 @@ class AgipdCorrection(BaseCorrection):
             # TODO: truncate if > 800
             self.set("dataFormat.memoryCells", image_data.shape[0])
             with self._buffer_lock:
-                self._update_pulse_filter(self.get("pulseFilter"))
-                self._update_shapes(
-                    self.get("dataFormat.pixelsX"),
-                    self.get("dataFormat.pixelsY"),
-                    self.get("dataFormat.memoryCells"),
-                    self.pulse_filter,
-                    self._output_transpose,
-                )
-        # TODO: check shape (DAQ fake data and RunToPipe don't agree)
-        # TODO: consider just updating shapes based on whatever comes in
-
-        correction_cell_num = self.get("dataFormat.memoryCellsCorrection")
-        do_generate_preview = train_id % self.get(
-            "preview.trainIdModulo"
-        ) == 0 and self.get("preview.enable")
-        can_apply_correction = True
-        do_apply_correction = self.get("applyCorrection")
+                # TODO: pulse filter update after reimplementation
+                self._update_shapes()
 
         if not self.get("state") is State.PROCESSING:
             self.updateState(State.PROCESSING)
@@ -140,43 +164,38 @@ class AgipdCorrection(BaseCorrection):
         else:
             self._state_reset_timer.set_timeout(self.get("processingStateTimeout"))
 
-        with self._buffer_lock:
-            cell_table = cell_table[self.pulse_filter]
-            pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter]
+        correction_cell_num = self.get("dataFormat.constantMemoryCells")
+        do_generate_preview = train_id % self.get(
+            "preview.trainIdModulo"
+        ) == 0 and self.get("preview.enable")
 
+        with self._buffer_lock:
+            # cell_table = cell_table[self.pulse_filter]
+            pulse_table = np.squeeze(data.get("image.pulseId"))  # [self.pulse_filter]
             cell_table_max = np.max(cell_table)
-            if do_apply_correction:
-                if not can_apply_correction:
-                    msg = "No constant loaded, correction will not be applied."
-                    self.log.WARN(msg)
-                    self.set("status", msg)
-                    do_apply_correction = False
-                elif cell_table_max >= correction_cell_num:
-                    msg = (
-                        f"Max cell ID ({cell_table_max}) exceeds range for loaded "
-                        f"constant (has {correction_cell_num} cells). Some frames "
-                        "will not be corrected."
-                    )
-                    self.log.WARN(msg)
-                    self.set("status", msg)
+            # TODO: all this checking and warning can go into GPU runner
+            if cell_table_max >= correction_cell_num:
+                msg = (
+                    f"Max cell ID ({cell_table_max}) exceeds range for loaded "
+                    f"constants ({correction_cell_num} cells). Some frames will not be "
+                    "corrected."
+                )
+                self.log.WARN(msg)
+                self.set("status", msg)
 
             self.gpu_runner.load_data(image_data)
             buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-            if do_apply_correction:
-                self.gpu_runner.load_cell_table(cell_table)
-                self.gpu_runner.correct(
-                    CorrectionFlags.THRESHOLD
-                    | CorrectionFlags.OFFSET
-                    | CorrectionFlags.REL_GAIN_PC
-                    | CorrectionFlags.REL_GAIN_XRAY
-                )
-            else:
-                self.gpu_runner.only_cast()
+            self.gpu_runner.load_cell_table(cell_table)
+            self.gpu_runner.correct(self._correction_flag_enabled)
             self.gpu_runner.reshape(out=buffer_array)
+            # after reshape, data for dataOutput is now safe in its own buffer
             if do_generate_preview:
+                if self._correction_flag_enabled != self._correction_flag_preview:
+                    self.gpu_runner.correct(self._correction_flag_preview)
                 preview_slice_index = self.get("preview.pulse")
                 if preview_slice_index >= 0:
                     # look at pulse_table to find which index this pulse ID is in
+                    # TODO: move this to GPU
                     pulse_id_found = np.where(pulse_table == preview_slice_index)[0]
                     if len(pulse_id_found) == 0:
                         pulse_found_instead = pulse_table[0]
@@ -185,24 +204,13 @@ class AgipdCorrection(BaseCorrection):
                             f"image.pulseId, arbitrary pulse "
                             f"{pulse_found_instead} will be shown."
                         )
-                        preview_slice_index = 0
                         self.log.WARN(msg)
                         self.set("status", msg)
+                        preview_slice_index = 0
                     else:
                         preview_slice_index = pulse_id_found[0]
-                if not do_apply_correction:
-                    if can_apply_correction:
-                        # in this case, cell table has not been loaded, but needs to be now
-                        self.gpu_runner.load_cell_table(cell_table)
-                    else:
-                        # in this case, there will be no corrected preview
-                        self.log.WARN(
-                            "Corrected preview will not actually be corrected."
-                        )
                 preview_raw, preview_corrected = self.gpu_runner.compute_preview(
-                    preview_slice_index,
-                    have_corrected=do_apply_correction,
-                    can_correct=can_apply_correction,
+                    preview_slice_index
                 )
 
         data.set("image.data", buffer_handle)
@@ -225,21 +233,36 @@ class AgipdCorrection(BaseCorrection):
         if self.get("performance.rateUpdateOnEachInput"):
             self._update_actual_rate()
 
-    def requestConstant(self, name, mostRecent=False, tryRemote=True):
-        """constantLoaded hook would have gotten called without naming which constant,
-        so here we go. Ugly hooking it."""
-        # TODO: clear from device, too
-        # TODO: update correction capability flag
-        if name in self._cached_constants:
-            del self._cached_constants[name]
-        super().requestConstant(name, mostRecent, tryRemote)
-        constant = self.getConstant(name)
-        if constant is not None:
-            self._cached_constants[name] = constant
-            if name == "ThresholdsDark":
-                self.gpu_runner.load_thresholds(constant)
-            elif name == "Offset":
-                self.gpu_runner.load_offset_map(constant)
+    def _load_constant_to_gpu(self, constant_name, constant_data):
+        # TODO: also hook flushConstants or whatever it is called
+        if constant_name == "ThresholdsDark":
+            self.gpu_runner.load_thresholds(constant_data)
+            # TODO: encode correction / constant dependencies in a clever way
+            if not self.get("corrections.available.thresholding"):
+                self.set("corrections.available.thresholding", True)
+                self.set("corrections.enabled.thresholding", True)
+                self.set("corrections.preview.thresholding", True)
+        elif constant_name == "Offset":
+            self.gpu_runner.load_offset_map(constant_data)
+            if not self.get("corrections.available.offset"):
+                self.set("corrections.available.offset", True)
+                self.set("corrections.enabled.offset", True)
+                self.set("corrections.preview.offset", True)
+        elif constant_name == "SlopesPC":
+            self.gpu_runner.load_rel_gain_pc_map(constant_data)
+            if not self.get("corrections.available.relGainPc"):
+                self.set("corrections.available.relGainPc", True)
+                self.set("corrections.enabled.relGainPc", True)
+                self.set("corrections.preview.relGainPc", True)
+        elif constant_name == "SlopesFF":
+            self.gpu_runner.load_rel_gain_ff_map(constant_data)
+            if not self.get("corrections.available.relGainXray"):
+                self.set("corrections.available.relGainXray", True)
+                self.set("corrections.enabled.relGainXray", True)
+                self.set("corrections.preview.relGainXray", True)
+        elif "BadPixels" in constant_name:
+            # TODO: implement loading bad pixels
+            ...
 
     def _update_pulse_filter(self, filter_string):
         """Called whenever the pulse filter changes, typically followed by
@@ -251,61 +274,3 @@ class AgipdCorrection(BaseCorrection):
             new_filter = np.array(eval(filter_string), dtype=np.uint16)
         assert np.max(new_filter) < self.get("dataFormat.memoryCells")
         self.pulse_filter = new_filter
-
-    def _update_shapes(
-        self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose
-    ):
-        """(Re)initialize (GPU) buffers according to expected data shapes"""
-
-        # TODO: report "actual" input shape (incl. raw gain)
-        input_data_shape = (memory_cells, pixels_x, pixels_y)
-        # reflect the axis reordering in the expected output shape
-        output_data_shape = utils.shape_after_transpose(
-            input_data_shape, output_transpose
-        )
-        self.set("dataFormat.inputDataShape", list(input_data_shape))
-        self.set("dataFormat.outputDataShape", list(output_data_shape))
-
-        if self._shmem_buffer is None:
-            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
-            memory_budget = self.get("outputShmemBufferSize") * 2 ** 30
-            self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
-            self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
-                memory_budget,
-                output_data_shape,
-                self.output_data_dtype,
-                shmem_buffer_name,
-            )
-        else:
-            self._shmem_buffer.change_shape(output_data_shape)
-
-        self.gpu_runner = AgipdGpuRunner(
-            pixels_x,
-            pixels_y,
-            memory_cells,
-            constant_memory_cells=250,
-            output_transpose=output_transpose,
-            input_data_dtype=self.input_data_dtype,
-            output_data_dtype=self.output_data_dtype,
-        )
-
-        self._update_maps_on_gpu()
-
-    def _update_maps_on_gpu(self):
-        """Updates the correction maps stored on GPU based on constants known
-
-        This only does something useful if constants have been retrieved from
-        CalCat.  Should be called automatically upon retrieval and after
-        changing the data shape.
-
-        """
-
-        self.set("status", "Updating constants on GPU using known constants")
-        self.updateState(State.CHANGING)
-        if self._offset_map is not None:
-            self.gpu_runner.load_constants(self._offset_map)
-            msg = "Done transferring known constant(s) to GPU"
-            self.log.INFO(msg)
-            self.set("status", msg)
-
-        self.updateState(State.ON)
diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py
index 39ae1482..c5ab67cd 100644
--- a/src/calng/CalibrationManager.py
+++ b/src/calng/CalibrationManager.py
@@ -292,6 +292,7 @@ class ManagedKeysCloneFactory(ProxyFactory):
 
 class CalibrationManager(DeviceClientBase, Device):
     __version__ = deviceVersion
+    _conditions_to_set = {}
 
     interfaces = VectorString(
         displayedName='Device interfaces',
@@ -1082,6 +1083,16 @@ class CalibrationManager(DeviceClientBase, Device):
             ):
                 return
 
+        # get device schema sneakily
+        self.logger.debug(f'Trying to figure out which constant parameters to set later')
+        # note: this should "obviously" happen earlier, but also more generally
+        correction_device_schema, _, _ = await call(server_by_group[group], "slotGetClassSchema", class_ids["correction"])
+        self._conditions_to_set = {
+            constant_name: constant_schema.getAttribute('detector_condition', 'defaultValue')
+            for constant_name, constant_schema in correction_device_schema.hash['constants'].items()
+        }
+        self.logger.debug(str(self._conditions_to_set))
+
         # Instantiate group matchers and bridges.
         for row in self.moduleGroups.value:
             group, server, with_matcher, with_bridge, bridge_port, \
diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py
index fff86e9d..7a2cc5ec 100644
--- a/src/calng/DsscCorrection.py
+++ b/src/calng/DsscCorrection.py
@@ -1,24 +1,61 @@
 import timeit
 
-import calibrationBase
 import numpy as np
-from karabo.bound import KARABO_CLASSINFO
+from karabo.bound import BOOL_ELEMENT, KARABO_CLASSINFO
 from karabo.common.states import State
 
-from . import shmem_utils, utils
+from . import utils
 from ._version import version as deviceVersion
 from .base_correction import BaseCorrection
-from .dssc_gpu import DsscGpuRunner
+from .dssc_gpu import DsscGpuRunner, CorrectionFlags
 
 
 @KARABO_CLASSINFO("DsscCorrection", deviceVersion)
 class DsscCorrection(BaseCorrection):
+    # subclass *must* set these attributes
+    _correction_flag_class = CorrectionFlags
+    _correction_slot_names = (("offset", CorrectionFlags.OFFSET),)
+    _gpu_runner_class = DsscGpuRunner
+
     @staticmethod
     def expectedParameters(expected):
         DsscCorrection.addConstant(
             "Offset", "Dark", expected, optional=True, mandatoryForIteration=True
         )
         super(DsscCorrection, DsscCorrection).expectedParameters(expected)
+        for slot_name, _ in DsscCorrection._correction_slot_names:
+            (
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.available.{slot_name}")
+                .readOnly()
+                .initialValue(False)
+                .commit(),
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.enabled.{slot_name}")
+                .assignmentOptional()
+                .defaultValue(False)
+                .reconfigurable()
+                .commit(),
+                BOOL_ELEMENT(expected)
+                .key(f"corrections.preview.{slot_name}")
+                .assignmentOptional()
+                .defaultValue(False)
+                .reconfigurable()
+                .commit(),
+            )
+
+    @property
+    def input_data_shape(self):
+        return (
+            self.get("dataFormat.memoryCells"),
+            1,
+            self.get("dataFormat.pixelsY"),
+            self.get("dataFormat.pixelsX"),
+        )
+
+    @property
+    def output_data_shape(self):
+        return utils.shape_after_transpose(self.input_data_shape, self._output_transpose)
 
     def __init__(self, config):
         super().__init__(config)
@@ -29,39 +66,11 @@ class DsscCorrection(BaseCorrection):
             self._output_transpose = (2, 1, 0)
         else:
             self._output_transpose = None
-        self._offset_map = None
-        self._update_pulse_filter(config.get("pulseFilter"))
-        self._update_shapes(
-            config.get("dataFormat.pixelsX"),
-            config.get("dataFormat.pixelsY"),
-            config.get("dataFormat.memoryCells"),
-            self.pulse_filter,
-            self._output_transpose,
-        )
-
+        self._update_shapes()
         self.updateState(State.ON)
 
-    def preReconfigure(self, config):
-        super().preReconfigure(config)
-        if config.has("pulseFilter"):
-            with self._buffer_lock:
-                # apply new pulse filter
-                self._update_pulse_filter(config.get("pulseFilter"))
-                # but existing shapes (not reconfigurable)
-                # TODO: avoid double compilation here if constants are loaded
-                self._update_shapes(
-                    self.get("dataFormat.pixelsX"),
-                    self.get("dataFormat.pixelsY"),
-                    self.get("dataFormat.memoryCells"),
-                    self.pulse_filter,
-                )
-
     def process_input(self, data, metadata):
-        """Registered for dataInput, handles all processing and sending
-
-        Comparable to StreamBase.onInput but hopefully faster
-
-        """
+        """Registered for dataInput, handles all processing and sending"""
 
         if not self.get("doAnything"):
             if self.get("state") is State.PROCESSING:
@@ -75,7 +84,6 @@ class DsscCorrection(BaseCorrection):
             self.log.INFO(f"Ignoring unknown source {source}")
             return
 
-        # TODO: what are these empty things for?
         if not data.has("image"):
             self.log.INFO("Ignoring hash without image node")
             return
@@ -91,7 +99,6 @@ class DsscCorrection(BaseCorrection):
             self.log.WARN(msg)
             return
         # original shape: 400, 1, 128, 512 (memory cells, something, y, x)
-        # TODO: consider making paths configurable
         image_data = data.get("image.data")
         if image_data.shape[0] != self.get("dataFormat.memoryCells"):
             self.set(
@@ -100,23 +107,8 @@ class DsscCorrection(BaseCorrection):
             # TODO: truncate if > 800
             self.set("dataFormat.memoryCells", image_data.shape[0])
             with self._buffer_lock:
-                self._update_pulse_filter(self.get("pulseFilter"))
-                self._update_shapes(
-                    self.get("dataFormat.pixelsX"),
-                    self.get("dataFormat.pixelsY"),
-                    self.get("dataFormat.memoryCells"),
-                    self.pulse_filter,
-                    self._output_transpose,
-                )
-        # TODO: check shape (DAQ fake data and RunToPipe don't agree)
-        # TODO: consider just updating shapes based on whatever comes in
-
-        correction_cell_num = self.get("dataFormat.memoryCellsCorrection")
-        do_generate_preview = train_id % self.get(
-            "preview.trainIdModulo"
-        ) == 0 and self.get("preview.enable")
-        can_apply_correction = correction_cell_num > 0
-        do_apply_correction = self.get("applyCorrection")
+                # TODO: pulse filter update after reimplementation
+                self._update_shapes()
 
         if not self.get("state") is State.PROCESSING:
             self.updateState(State.PROCESSING)
@@ -129,35 +121,33 @@ class DsscCorrection(BaseCorrection):
         else:
             self._state_reset_timer.set_timeout(self.get("processingStateTimeout"))
 
+        correction_cell_num = self.get("dataFormat.constantMemoryCells")
+        do_generate_preview = train_id % self.get(
+            "preview.trainIdModulo"
+        ) == 0 and self.get("preview.enable")
+
         with self._buffer_lock:
-            cell_table = cell_table[self.pulse_filter]
-            pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter]
+            # cell_table = cell_table[self.pulse_filter]
+            pulse_table = np.squeeze(data.get("image.pulseId"))  # [self.pulse_filter]
 
             cell_table_max = np.max(cell_table)
-            if do_apply_correction:
-                if not can_apply_correction:
-                    msg = "No constant loaded, correction will not be applied."
-                    self.log.WARN(msg)
-                    self.set("status", msg)
-                    do_apply_correction = False
-                elif cell_table_max >= correction_cell_num:
-                    msg = (
-                        f"Max cell ID ({cell_table_max}) exceeds range for loaded "
-                        f"constant (has {correction_cell_num} cells). Some frames "
-                        "will not be corrected."
-                    )
-                    self.log.WARN(msg)
-                    self.set("status", msg)
+            if cell_table_max >= correction_cell_num:
+                msg = (
+                    f"Max cell ID ({cell_table_max}) exceeds range for loaded "
+                    f"constant ({correction_cell_num} cells). Some frames will not be "
+                    "corrected."
+                )
+                self.log.WARN(msg)
+                self.set("status", msg)
 
             self.gpu_runner.load_data(image_data)
             buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-            if do_apply_correction:
-                self.gpu_runner.load_cell_table(cell_table)
-                self.gpu_runner.correct()
-            else:
-                self.gpu_runner.only_cast()
+            self.gpu_runner.load_cell_table(cell_table)
+            self.gpu_runner.correct(self._correction_flag_enabled)
             self.gpu_runner.reshape(out=buffer_array)
             if do_generate_preview:
+                if self._correction_flag_enabled != self._correction_flag_preview:
+                    self.gpu_runner.correct(self._correction_flag_preview)
                 preview_slice_index = self.get("preview.pulse")
                 if preview_slice_index >= 0:
                     # look at pulse_table to find which index this pulse ID is in
@@ -169,24 +159,13 @@ class DsscCorrection(BaseCorrection):
                             f"image.pulseId, arbitrary pulse "
                             f"{pulse_found_instead} will be shown."
                         )
-                        preview_slice_index = 0
                         self.log.WARN(msg)
                         self.set("status", msg)
+                        preview_slice_index = 0
                     else:
                         preview_slice_index = pulse_id_found[0]
-                if not do_apply_correction:
-                    if can_apply_correction:
-                        # in this case, cell table has not been loaded, but needs to be now
-                        self.gpu_runner.load_cell_table(cell_table)
-                    else:
-                        # in this case, there will be no corrected preview
-                        self.log.WARN(
-                            "Corrected preview will not actually be corrected."
-                        )
                 preview_raw, preview_corrected = self.gpu_runner.compute_preview(
                     preview_slice_index,
-                    have_corrected=do_apply_correction,
-                    can_correct=can_apply_correction,
                 )
 
         data.set("image.data", buffer_handle)
@@ -209,57 +188,6 @@ class DsscCorrection(BaseCorrection):
         if self.get("performance.rateUpdateOnEachInput"):
             self._update_actual_rate()
 
-    def constantLoaded(self):
-        """Hook from CalibrationReceiverBaseDevice called after each getConstant
-
-        Here, used to load the received constants (or correction maps derived
-        fromt them) onto GPU.
-
-        TODO: call after receiving *all* constants instead of calling once per
-        new constant (will cause some overhead for bigger devices)
-
-        """
-
-        offset_map = self.getConstant("Offset")
-        input_memory_cells = self.get("dataFormat.memoryCells")
-        if offset_map is None:
-            msg = (
-                "Warning: Did not find offset constant, offset correction "
-                "will not be applied"
-            )
-            self.set("status", msg)
-            self.log.WARN(msg)
-            self._offset_map = None
-        elif len(offset_map.shape) not in (3, 4):
-            msg = (
-                f"Offset map had unexpected shape {offset_map.shape}, "
-                "offset correction will not be applied"
-            )
-            self.set("status", msg)
-            self.log.WARN(msg)
-        else:
-            self.log.INFO(f"Offset map loaded has shape {offset_map.shape}")
-            if len(offset_map.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
-                offset_map = offset_map[..., 0]
-            constant_memory_cells = offset_map.shape[-1]
-            if input_memory_cells > constant_memory_cells:
-                msg = (
-                    f"Warning: Memory cells in input {input_memory_cells} > "
-                    f"memory cells in constant {constant_memory_cells}, some "
-                    "frames may not get correction applied."
-                )
-                self.set("status", msg)
-                self.log.WARN(msg)
-            self._offset_map = offset_map.astype(np.float32)
-            msg = f"Offset map with shape {self._offset_map.shape} ready to load to GPU"
-            self.set("status", msg)
-            self.log.INFO(msg)
-            if constant_memory_cells != self.get("dataFormat.memoryCellsCorrection"):
-                self.log.INFO("Will first have to update buffers on GPU")
-                self.set("dataFormat.memoryCellsCorrection", constant_memory_cells)
-
-        self._update_maps_on_gpu()
-
     def _update_pulse_filter(self, filter_string):
         """Called whenever the pulse filter changes, typically followed by
         _update_shapes"""
@@ -271,58 +199,10 @@ class DsscCorrection(BaseCorrection):
         assert np.max(new_filter) < self.get("dataFormat.memoryCells")
         self.pulse_filter = new_filter
 
-    def _update_shapes(
-        self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose
-    ):
-        """(Re)initialize (GPU) buffers according to expected data shapes"""
-
-        input_data_shape = (memory_cells, 1, pixels_y, pixels_x)
-        # reflect the axis reordering in the expected output shape
-        output_data_shape = utils.shape_after_transpose(
-            input_data_shape, output_transpose
-        )
-        self.set("dataFormat.inputDataShape", list(input_data_shape))
-        self.set("dataFormat.outputDataShape", list(output_data_shape))
-
-        if self._shmem_buffer is None:
-            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
-            memory_budget = self.get("outputShmemBufferSize") * 2 ** 30
-            self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
-            self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
-                memory_budget,
-                output_data_shape,
-                self.output_data_dtype,
-                shmem_buffer_name,
-            )
-        else:
-            self._shmem_buffer.change_shape(output_data_shape)
-
-        self.gpu_runner = DsscGpuRunner(
-            pixels_x,
-            pixels_y,
-            memory_cells,
-            output_transpose=output_transpose,
-            input_data_dtype=self.input_data_dtype,
-            output_data_dtype=self.output_data_dtype,
-        )
-
-        self._update_maps_on_gpu()
-
-    def _update_maps_on_gpu(self):
-        """Updates the correction maps stored on GPU based on constants known
-
-        This only does something useful if constants have been retrieved from
-        CalCat.  Should be called automatically upon retrieval and after
-        changing the data shape.
-
-        """
-
-        self.set("status", "Updating constants on GPU using known constants")
-        self.updateState(State.CHANGING)
-        if self._offset_map is not None:
-            self.gpu_runner.load_constants(self._offset_map)
-            msg = "Done transferring known constant(s) to GPU"
-            self.log.INFO(msg)
-            self.set("status", msg)
-
-        self.updateState(State.ON)
+    def _load_constant_to_gpu(self, constant_name, constant_data):
+        assert constant_name == "Offset"
+        self.gpu_runner.load_offset_map(constant_data)
+        if not self.get("corrections.available.offset"):
+            self.set("corrections.available.offset", True)
+            self.set("corrections.enabled.offset", True)
+            self.set("corrections.preview.offset", True)
diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py
index c445ae6a..56a36ac1 100644
--- a/src/calng/agipd_gpu.py
+++ b/src/calng/agipd_gpu.py
@@ -7,6 +7,7 @@ from . import base_gpu, utils
 
 
 class CorrectionFlags(enum.IntFlag):
+    NONE = 0
     THRESHOLD = 1
     OFFSET = 2
     BLSHIFT = 4
@@ -23,8 +24,8 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         pixels_x,
         pixels_y,
         memory_cells,
+        constant_memory_cells,
         output_transpose=(1, 2, 0),  # default: memorycells-fast
-        constant_memory_cells=None,
         input_data_dtype=np.uint16,
         output_data_dtype=np.float32,
     ):
@@ -34,8 +35,8 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
             pixels_x,
             pixels_y,
             memory_cells,
-            output_transpose,
             constant_memory_cells,
+            output_transpose,
             input_data_dtype,
             output_data_dtype,
         )
@@ -84,10 +85,14 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         pc_med_m = slopes_pc_map[3]
         pc_med_I = slopes_pc_map[4]
         frac_high_med = pc_high_m / pc_med_m
+        # TODO: handle NaN somehow?
+        if np.isnan(frac_high_med).any():
+            ...
         # TODO: verify formula
         md_additional_offset = (pc_high_I - pc_med_I * frac_high_med).astype(np.float32)
         rel_gain_map = np.ones(
-            (3, self.constant_memory_cells, self.pixels_y, self.pixels_x), dtype=np.float32
+            (3, self.constant_memory_cells, self.pixels_y, self.pixels_x),
+            dtype=np.float32,
         )
         rel_gain_map[1] = rel_gain_map[0] * frac_high_med
         rel_gain_map[2] = rel_gain_map[1] * 4.48
@@ -100,22 +105,9 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         if slopes_ff_map.shape[2] == 2:
             # old format, is per pixel only
             raise NotImplementedError("Old slopes FF map format")
-        self.rel_gain_xray_map_gpu.set(
-            np.transpose(slopes_ff_map).astype(np.float32)
-        )
+        self.rel_gain_xray_map_gpu.set(np.transpose(slopes_ff_map).astype(np.float32))
 
     def correct(self, flags):
-        """Apply corrections to data (must load constant, data, and cell_table first)
-
-        Applies corrections to input data and casts to desired output dtype.
-        Parameter cell_table allows out of order or non-contiguous memory cells
-        in input data.  Both input ndarrays are assumed to be on GPU already,
-        preferably wrapped in GPU arrays (cupy array).
-
-        Will return string encoded handle to shared memory output buffer and
-        (view of) said buffer as an ndarray.  Keep in mind that the output
-        buffers will get overwritten eventually (circular buffer).
-        """
         if flags & CorrectionFlags.BLSHIFT:
             raise NotImplementedError("Baseline shift not implemented yet")
         self.correction_kernel(
@@ -150,4 +142,3 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
         )
         self.source_module = cupy.RawModule(code=kernel_source)
         self.correction_kernel = self.source_module.get_function("correct")
-        self.casting_kernel = self.source_module.get_function("only_cast")
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index dfdc83e6..6bdc4d37 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -31,11 +31,11 @@ from . import shmem_utils, utils
 
 
 class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
+    _correction_flag_class = None  # subclass must override this with some enum class
     _dict_cache_slots = {
-        "applyCorrection",
         "doAnything",
         "dataFormat.memoryCells",
-        "dataFormat.memoryCellsCorrection",
+        "dataFormat.constantMemoryCells",
         "dataFormat.pixelsX",
         "dataFormat.pixelsY",
         "preview.enable",
@@ -46,6 +46,17 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         "state",
     }
 
+    def _load_constant_to_gpu(constant_name, constant_data):
+        raise NotImplementedError()
+
+    @property
+    def input_data_shape(self):
+        raise NotImplementedError()
+
+    @property
+    def output_data_shape(self):
+        raise NotImplementedError()
+
     @staticmethod
     def expectedParameters(expected):
         (
@@ -61,20 +72,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .defaultValue(True)
             .reconfigurable()
             .commit(),
-            BOOL_ELEMENT(expected)
-            .key("applyCorrection")
-            .displayedName("Enable correction(s)")
-            .description(
-                "Toggle whether not correction(s) are applied to image data. If "
-                "false, this device still reshapes data to output shape, applies the "
-                "pulse filter, and casts to output dtype. Useful if constants are "
-                "missing / bad, or if data is sent to application doing its own "
-                "correction."
-            )
-            .assignmentOptional()
-            .defaultValue(True)
-            .reconfigurable()
-            .commit(),
             INPUT_CHANNEL(expected).key("dataInput").commit(),
             # note: output schema not set, will be updated to match data later
             OUTPUT_CHANNEL(expected).key("dataOutput").commit(),
@@ -169,15 +166,18 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .defaultValue("pixels-fast")
             .commit(),
             UINT32_ELEMENT(expected)
-            .key("dataFormat.memoryCellsCorrection")
-            .displayedName("(Debug) Memory cells in correction map")
+            .key("dataFormat.constantMemoryCells")
+            .displayedName("Memory cells in correction map")
             .description(
-                "Full number of memory cells in currently loaded correction map. "
+                "Number of memory cells in loaded or expected constants. "
                 "May exceed memory cell number in input if veto is on. "
-                "This value just displayed for debugging."
+                "This value should be updated (will be done by manager) before "
+                "requesting constants with different number of cells. "
+                "Will in future versions get better integrated into constant loading."
             )
-            .readOnly()
-            .initialValue(0)
+            .assignmentOptional()
+            .defaultValue(0)
+            .reconfigurable()
             .commit(),
             VECTOR_UINT32_ELEMENT(expected)
             .key("dataFormat.inputDataShape")
@@ -349,6 +349,48 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .commit(),
         )
 
+        (
+            NODE_ELEMENT(expected)
+            .key("corrections")
+            .displayedName("Corrections to apply")
+            .commit(),
+            NODE_ELEMENT(expected)
+            .key("corrections.available")
+            .displayedName("Available given loaded constants")
+            .description(
+                "Corrections typically require some constants to be loaded. These "
+                "flags indicate which corrections can be done, given the constants "
+                "found and loaded so far. Corrections - for dataOutput or preview - "
+                "only happen if selected AND available."
+            )
+            .commit(),
+            NODE_ELEMENT(expected)
+            .key("corrections.enabled")
+            .displayedName("Enabled (if available)")
+            .description("Corrections applied to data for dataOutput (main output)")
+            .commit(),
+            NODE_ELEMENT(expected)
+            .key("corrections.preview")
+            .displayedName("Preview (if available)")
+            .description("Corrections applied for corrected preview output")
+            .commit(),
+            BOOL_ELEMENT(expected)
+            .key("corrections.disableAll")
+            .displayedName("Disable corrections for dataOutput")
+            .description(
+                "Toggle whether not correction(s) are applied to image data. If "
+                "false, this device still reshapes data to output shape, applies the "
+                "pulse filter, and casts to output dtype. Useful if constants are "
+                "missing / bad, or if data is sent to application doing its own "
+                "correction.  Preview is still corrected based on selection of "
+                "corrections independently of this."
+            )
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
+        )
+
     def __init__(self, config):
         self._dict_cache = {k: config.get(k) for k in self._dict_cache_slots}
         super().__init__(config)
@@ -361,8 +403,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype"))
         self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype"))
 
+        self._correction_flag_enabled = self._correction_flag_class.NONE
+        self._correction_flag_preview = self._correction_flag_class.NONE
+        self._cached_constants = {}
+
         self._shmem_buffer = None
         self._has_set_output_schema = False
+        self._has_updated_shapes = False
         self._rate_tracker = calibrationBase.utils.UpdateRate(
             interval=config.get("performance.rateBufferSpan")
         )
@@ -401,6 +448,26 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             if path in self._dict_cache_slots:
                 self._dict_cache[path] = config.get(path)
 
+        # TODO: pulse filter (after reimplementing)
+        if any(
+            config.has(shape_param)
+            for shape_param in (
+                "dataFormat.pixelsX",
+                "dataFormat.pixelsY",
+                "dataFormat.memoryCells",
+                "dataFormat.constantMemoryCells",
+            )
+        ):
+            # will make postReconfigure handle shape update after merging schema
+            self._has_updated_shapes = False
+
+    def postReconfigure(self):
+        self.log.INFO("postReconfigure")
+        if not self._has_updated_shapes:
+            self._update_shapes()
+        # TODO: only call this if they are changed (is cheap, though)
+        self._update_correction_flags()
+
     def get(self, key):
         if key in self._dict_cache_slots:
             return self._dict_cache.get(key)
@@ -414,6 +481,23 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
                 self._dict_cache[key] = value
         super().set(*args)
 
+    def requestConstant(self, name, mostRecent=False, tryRemote=True):
+        """constantLoaded hook would have gotten called without naming constant, so here
+        we go. Ugly hooking it."""
+        # TODO: clear from device, too
+        # TODO: update correction capability flag
+        if name in self._cached_constants:
+            del self._cached_constants[name]
+        super().requestConstant(name, mostRecent, tryRemote)
+        constant = self.getConstant(name)
+        if constant is None:
+            return
+
+        # TODO: remaining constants, DRY
+        self._cached_constants[name] = constant
+        self._load_constant_to_gpu(name, constant)
+        self._update_correction_flags()
+
     def _write_output(self, data, old_metadata):
         metadata = ChannelMetaData(
             old_metadata.get("source"),
@@ -452,6 +536,25 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             channel.write(preview_hash, metadata, False)
             channel.update()
 
+    def _update_correction_flags(self):
+        available = self._correction_flag_class.NONE
+        enabled = self._correction_flag_class.NONE
+        preview = self._correction_flag_class.NONE
+        for slot_name, flag in self._correction_slot_names:
+            if self.get(f"corrections.available.{slot_name}"):
+                available |= flag
+            if self.get(f"corrections.enabled.{slot_name}"):
+                enabled |= flag
+            if self.get(f"corrections.preview.{slot_name}"):
+                preview |= flag
+        enabled &= available
+        preview &= available
+        if self.get("corrections.disableAll"):
+            enabled &= self._correction_flag_class.NONE
+        self._correction_flag_enabled = enabled
+        self._correction_flag_preview = preview
+        self.log.INFO(str(enabled))
+
     def _update_output_schema(self, data):
         """Updates the schema of dataOutput based on parameter data (a Hash)
 
@@ -469,9 +572,45 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .dataSchema(data_schema)
             .commit()
         )
-        self.updateSchema(my_schema_update)
+        self.appendSchema(my_schema_update)
         self._has_set_output_schema = True
 
+    def _update_shapes(self):
+        """(Re)initialize buffers according to expected data shapes"""
+        self.log.INFO("Updating shapes")
+
+        # reflect the axis reordering in the expected output shape
+        self.set("dataFormat.inputDataShape", list(self.input_data_shape))
+        self.set("dataFormat.outputDataShape", list(self.output_data_shape))
+
+        if self._shmem_buffer is None:
+            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
+            memory_budget = self.get("outputShmemBufferSize") * 2 ** 30
+            self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
+            self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
+                memory_budget,
+                self.output_data_shape,
+                self.output_data_dtype,
+                shmem_buffer_name,
+            )
+        else:
+            self._shmem_buffer.change_shape(self.output_data_shape)
+
+        self.gpu_runner = self._gpu_runner_class(
+            self.get("dataFormat.pixelsX"),
+            self.get("dataFormat.pixelsY"),
+            self.get("dataFormat.memoryCells"),
+            self.get("dataFormat.constantMemoryCells"),
+            output_transpose=self._output_transpose,
+            input_data_dtype=self.input_data_dtype,
+            output_data_dtype=self.output_data_dtype,
+        )
+
+        for constant_name, constant_data in self._cached_constants.items():
+            self._load_constant_to_gpu(constant_name, constant_data)
+
+        self._has_updated_shapes = True
+
     def _reset_state_from_processing(self):
         if self.get("state") is State.PROCESSING:
             self.updateState(State.ON)
diff --git a/src/calng/base_gpu.py b/src/calng/base_gpu.py
index 48604349..5347fc2a 100644
--- a/src/calng/base_gpu.py
+++ b/src/calng/base_gpu.py
@@ -9,13 +9,14 @@ from . import utils
 
 
 class BaseGpuRunner:
-    """Class to handle instantiation and execution of CUDA kernels on trains
+    """Class to handle GPU buffers and execution of CUDA kernels on image data
 
-    All GPU buffers are kept within this class. This generally means that you will
-    want to load data into it and then do something. Typical usage in correct order:
+    All GPU buffers are kept within this class and it is intentionally very stateful.
+    This generally means that you will want to load data into it and then do something.
+    Typical usage in correct order:
 
     1. instantiate
-    2. load_constants
+    2. load constants
     3. load_data
     4. load_cell_table
     5. correct
@@ -24,9 +25,9 @@ class BaseGpuRunner:
 
     repeat from 2. or 3.
 
-    In case no constants are available / correction is not desired, can skip 3 and 4
-    and use only_cast in step 5 instead of correct (taking care to call
-    compute_preview with parameters set accordingly).
+    In case no constants are available / correction is not desired, can skip 3 and 4 and
+    pass CorrectionFlags.NONE to correct(...). Generally, user must handle which
+    correction steps are appropriate given the constants loaded so far.
     """
 
     def __init__(
@@ -34,8 +35,8 @@ class BaseGpuRunner:
         pixels_x,
         pixels_y,
         memory_cells,
+        constant_memory_cells,
         output_transpose=(2, 1, 0),  # default: memorycells-fast
-        constant_memory_cells=None,
         input_data_dtype=np.uint16,
         output_data_dtype=np.float32,
     ):
@@ -48,7 +49,8 @@ class BaseGpuRunner:
         self.pixels_y = pixels_y
         self.memory_cells = memory_cells
         self.output_transpose = output_transpose
-        if constant_memory_cells is None:
+        if constant_memory_cells == 0:
+            # if not set, guess same as input; may save one recompilation
             self.constant_memory_cells = memory_cells
         else:
             self.constant_memory_cells = constant_memory_cells
@@ -62,10 +64,12 @@ class BaseGpuRunner:
 
         self._init_kernels()
 
-        # reuse output arrays
+        # reuse buffers for input / output
         self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=np.uint16)
         self.input_data_gpu = cupy.empty(self.input_shape, dtype=input_data_dtype)
-        self.processed_data_gpu = cupy.empty(self.processed_shape, dtype=output_data_dtype)
+        self.processed_data_gpu = cupy.empty(
+            self.processed_shape, dtype=output_data_dtype
+        )
         self.reshaped_data_gpu = cupy.empty(self.output_shape, dtype=output_data_dtype)
         self.preview_raw = cupyx.empty_pinned(self.preview_shape, dtype=np.float32)
         self.preview_corrected = cupyx.empty_pinned(
@@ -73,25 +77,32 @@ class BaseGpuRunner:
         )
 
     # functions to get data from respective buffers to cell, x, y shape for preview computation
+    # TODO: handle shape juggling programmatically, removing need for these two helpers
     def _preview_preprocess_raw():
+        """Should return view of self.input_data_gpu with shape (cell, x, y)"""
         raise NotImplementedError()
 
     def _preview_preprocess_corr():
+        """Should return view of self.processed_data_gpu with shape (cell, x, y)"""
         raise NotImplementedError()
 
-    def only_cast(self):
-        """Like correct without the correction
+    def correct(self, flags):
+        """Correct (already loaded) image data according to flags
+
+        Subclass must define this method. It should assume that image data, cell table,
+        and other data (including constants) has already been loaded. It should
+        probably run some GPU kernel and output should go into self.processed_data_gpu.
+
+        Keep in mind that user only gets output from compute_preview or reshape
+        (either of these should come after correct).
+
+        The submodules providing subclasses should have some IntFlag enums defining
+        which flags are available to pass along to the kernel. A zero flag should allow
+        the kernel to do no actual correction - but still copy the data between buffers
+        and cast it to desired output type.
 
-        This currently means just casting to output dtype.
         """
-        self.casting_kernel(
-            self.full_grid,
-            self.full_block,
-            (
-                self.input_data_gpu,
-                self.processed_data_gpu,
-            ),
-        )
+        raise NotImplementedError()
 
     def reshape(self, out=None):
         """Move axes to desired output order
@@ -119,19 +130,18 @@ class BaseGpuRunner:
     def load_cell_table(self, cell_table):
         self.cell_table_gpu.set(cell_table)
 
-    def compute_preview(self, preview_index, have_corrected=True, can_correct=True):
+    def compute_preview(self, preview_index):
         """Generate single slice or reduction preview of raw and corrected data
 
-        Special values of preview_index are -1 for max, -2 for mean, -3 for
-        sum, and -4 for stdev (across cells).
+        Special values of preview_index are -1 for max (select max integrated intensity
+        frame), -2 for mean, -3 for sum, and -4 for stdev (across cells).
 
         Note that preview_index is taken from data without checking cell table.
         Caller has to figure out which index along memory cell dimension they
-        actually want to preview.
+        actually want to preview in case it needs to be a specific pulse.
 
-        Can reuse data from corrected output buffer with have_corrected parameter.
-        Note that preview requires relevant data to be loaded (raw data for raw
-        preview, correction map and cell table in addition for corrected preview).
+        Will reuse data from corrected output buffer. Therefore, correct(...) must have
+        been called with the appropriate flags before compute_preview(...).
         """
 
         if preview_index < -4:
@@ -139,10 +149,6 @@ class BaseGpuRunner:
         elif preview_index >= self.memory_cells:
             raise ValueError(f"Memory cell index {preview_index} out of range")
 
-        if (not have_corrected) and can_correct:
-            self.correct()
-            # if not have_corrected and not can_correct, assume only_cast already done
-
         # TODO: enum around reduction type
         for (preprocces, output_buffer) in (
             (self._preview_preprocess_raw, self.preview_raw),
@@ -151,23 +157,17 @@ class BaseGpuRunner:
             image_data = preprocces()
             if preview_index >= 0:
                 # TODO: change axis order when moving reshape to after correction
-                image_data[preview_index].astype(np.float32).get(
-                    out=output_buffer
-                )
+                image_data[preview_index].astype(np.float32).get(out=output_buffer)
             elif preview_index == -1:
                 # TODO: select argmax independently for raw and corrected?
                 # TODO: send frame sums somewhere to compute global max frame
                 max_index = cupy.argmax(
                     cupy.sum(image_data, axis=(1, 2), dtype=cupy.float32)
                 )
-                image_data[max_index].astype(np.float32).get(
-                    out=output_buffer
-                )
+                image_data[max_index].astype(np.float32).get(out=output_buffer)
             elif preview_index in (-2, -3, -4):
                 stat_fun = {-2: cupy.mean, -3: cupy.sum, -4: cupy.std}[preview_index]
-                stat_fun(image_data, axis=0, dtype=cupy.float32).get(
-                    out=output_buffer
-                )
+                stat_fun(image_data, axis=0, dtype=cupy.float32).get(out=output_buffer)
         return self.preview_raw, self.preview_corrected
 
     def update_block_size(self, full_block, target_shape=None):
diff --git a/src/calng/dssc_gpu.py b/src/calng/dssc_gpu.py
index cd7e2d80..8c1995e3 100644
--- a/src/calng/dssc_gpu.py
+++ b/src/calng/dssc_gpu.py
@@ -1,9 +1,16 @@
+import enum
+
 import cupy
 import numpy as np
 
 from . import base_gpu, utils
 
 
+class CorrectionFlags(enum.IntFlag):
+    NONE = 0
+    OFFSET = 1
+
+
 class DsscGpuRunner(base_gpu.BaseGpuRunner):
     _kernel_source_filename = "dssc_gpu_kernels.cpp"
 
@@ -12,8 +19,8 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner):
         pixels_x,
         pixels_y,
         memory_cells,
+        constant_memory_cells,
         output_transpose=(2, 1, 0),  # default: memorycells-fast
-        constant_memory_cells=None,
         input_data_dtype=np.uint16,
         output_data_dtype=np.float32,
     ):
@@ -23,12 +30,13 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner):
             pixels_x,
             pixels_y,
             memory_cells,
-            output_transpose,
             constant_memory_cells,
+            output_transpose,
             input_data_dtype,
             output_data_dtype,
         )
-        self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
+
+        self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x)
         self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32)
 
         self._init_kernels()
@@ -43,33 +51,22 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner):
     def _preview_preprocess_corr(self):
         return cupy.transpose(self.processed_data_gpu, (0, 2, 1))
 
-    def load_constants(self, offset_map):
-        constant_memory_cells = offset_map.shape[-1]
-        if constant_memory_cells != self.constant_memory_cells:
-            self.constant_memory_cells = constant_memory_cells
-            self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
-            self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32)
-            self._init_kernels()
+    def load_offset_map(self, offset_map):
+        # can have an extra dimension for some reason
+        if len(offset_map.shape) == 4:  # old format (see offsetcorrection_dssc.py)?
+            offset_map = offset_map[..., 0]
+        # shape (now): x, y, memory cell
+        offset_map = np.transpose(offset_map).astype(np.float32)
         self.offset_map_gpu.set(offset_map)
 
-    def correct(self):
-        """Apply corrections to data (must load constant, data, and cell_table first)
-
-        Applies corrections to input data and casts to desired output dtype.
-        Parameter cell_table allows out of order or non-contiguous memory cells
-        in input data.  Both input ndarrays are assumed to be on GPU already,
-        preferably wrapped in GPU arrays (cupy array).
-
-        Will return string encoded handle to shared memory output buffer and
-        (view of) said buffer as an ndarray.  Keep in mind that the output
-        buffers will get overwritten eventually (circular buffer).
-        """
+    def correct(self, flags):
         self.correction_kernel(
             self.full_grid,
             self.full_block,
             (
                 self.input_data_gpu,
                 self.cell_table_gpu,
+                np.uint8(flags),
                 self.offset_map_gpu,
                 self.processed_data_gpu,
             ),
@@ -84,8 +81,8 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner):
                 "constant_memory_cells": self.constant_memory_cells,
                 "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype),
                 "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype),
+                "corr_enum": utils.enum_to_c_template(CorrectionFlags),
             }
         )
         self.source_module = cupy.RawModule(code=kernel_source)
         self.correction_kernel = self.source_module.get_function("correct")
-        self.casting_kernel = self.source_module.get_function("only_cast")
diff --git a/src/calng/dssc_gpu_kernels.cpp b/src/calng/dssc_gpu_kernels.cpp
index 2412a86a..b82159c0 100644
--- a/src/calng/dssc_gpu_kernels.cpp
+++ b/src/calng/dssc_gpu_kernels.cpp
@@ -1,5 +1,7 @@
 #include <cuda_fp16.h>
 
+{{corr_enum}}
+
 extern "C" {
 	/*
 	  Perform correction: offset
@@ -9,6 +11,7 @@ extern "C" {
 	*/
 	__global__ void correct(const {{input_data_dtype}}* data,
 							const unsigned short* cell_table,
+	                        const unsigned char corr_flags,
 							const float* offset_map,
 							{{output_data_dtype}}* output) {
 		const size_t X = {{pixels_x}};
@@ -31,13 +34,16 @@ extern "C" {
 		const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
 		const float raw = (float)data[data_index];
 
-		const size_t map_stride_cell = 1;
-		const size_t map_stride_y = map_memory_cells * map_stride_cell;
-		const size_t map_stride_x = Y * map_stride_y;
+		const size_t map_stride_x = 1;
+		const size_t map_stride_y = X * map_stride_x;
+		const size_t map_stride_cell = Y * map_stride_y;
 		const size_t map_cell = cell_table[memory_cell];
 		if (map_cell < map_memory_cells) {
 			const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x;
-			const float corrected = raw - offset_map[map_index];
+			float corrected = raw;
+			if (corr_flags & OFFSET) {
+				corrected -= offset_map[map_index];
+			}
 			{% if output_data_dtype == "half" %}
 			output[data_index] = __float2half(corrected);
 			{% else %}
@@ -51,34 +57,4 @@ extern "C" {
 			{% endif %}
 		}
 	}
-
-	/*
-	  Same as correction, except don't do any correction
-	*/
-	__global__ void only_cast(const {{input_data_dtype}}* data,
-							  {{output_data_dtype}}* output) {
-		const size_t X = {{pixels_x}};
-		const size_t Y = {{pixels_y}};
-		const size_t memory_cells = {{data_memory_cells}};
-
-		const size_t data_stride_x = 1;
-		const size_t data_stride_y = X * data_stride_x;
-		const size_t data_stride_cell = Y * data_stride_y;
-
-		const size_t cell = blockIdx.x * blockDim.x + threadIdx.x;
-		const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
-		const size_t x = blockIdx.z * blockDim.z + threadIdx.z;
-
-		if (cell >= memory_cells || y >= Y || x >= X) {
-			return;
-		}
-
-		const size_t data_index = cell * data_stride_cell + y * data_stride_y + x * data_stride_x;
-		const float raw = (float)data[data_index];
-		{% if output_data_dtype == "half" %}
-		output[data_index] = __float2half(raw);
-		{% else %}
-		output[data_index] = ({{output_data_dtype}})raw;
-		{% endif %}
-	}
 }
-- 
GitLab