diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index adf8866e7d8e28198ad99711d400d23b61ce70b5..74d8c05880b81c155c9e13430460a49caf23cd81 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -35,16 +35,26 @@ class AgipdCorrection(BaseCorrection):
     _gpu_runner_class = AgipdGpuRunner
     _calcat_friend_class = AgipdCalcatFriend
     _constant_enum_class = AgipdConstants
-    _managed_keys = BaseCorrection._managed_keys[:]
+    _managed_keys = BaseCorrection._managed_keys[:] + ["overrideInputAxisOrder"]
 
     # this is just extending (not mandatory)
-    _schema_cache_fields = BaseCorrection._schema_cache_fields | {"sendGainMap"}
+    _schema_cache_fields = BaseCorrection._schema_cache_fields | {
+        "sendGainMap",
+        "overrideInputAxisOrder",
+    }
 
     @staticmethod
     def expectedParameters(expected):
         super(AgipdCorrection, AgipdCorrection).expectedParameters(expected)
         expected.setDefaultValue("dataFormat.memoryCells", 352)
         (
+            BOOL_ELEMENT(expected)
+            .key("overrideInputAxisOrder")
+            .displayedName("Override input axis order")
+            .assignmentOptional()
+            .defaultValue(False)
+            .reconfigurable()
+            .commit(),
             STRING_ELEMENT(expected)
             .key("gainMode")
             .displayedName("Gain mode")
@@ -188,10 +198,10 @@ class AgipdCorrection(BaseCorrection):
     @property
     def input_data_shape(self):
         return (
-            self.get("dataFormat.memoryCells"),
+            self._schema_cache["dataFormat.memoryCells"],
             2,
-            self.get("dataFormat.pixelsX"),
-            self.get("dataFormat.pixelsY"),
+            self._schema_cache["dataFormat.pixelsX"],
+            self._schema_cache["dataFormat.pixelsY"],
         )
 
     def __init__(self, config):
@@ -269,6 +279,14 @@ class AgipdCorrection(BaseCorrection):
             and self._schema_cache["preview.enable"]
         )
 
+        if self._schema_cache["overrideInputAxisOrder"]:
+            expected_shape = self.input_data_shape
+            if expected_shape != image_data.shape:
+                self.log.INFO(
+                    f"Overriding input data order from {image_data.shape} to {expected_shape}"
+                )
+                image_data.shape = expected_shape
+
         with self._buffer_lock:
             # cell_table = cell_table[self.pulse_filter]
             pulse_table = np.squeeze(data.get("image.pulseId"))  # [self.pulse_filter]
@@ -280,7 +298,12 @@ class AgipdCorrection(BaseCorrection):
                     "corrected."
                 )
 
-            self.gpu_runner.load_data(image_data)
+            try:
+                self.gpu_runner.load_data(image_data)
+            except ValueError as e:
+                self.log_status_warn(f"Failed to load data: {e}")
+                return
+
             buffer_handle, buffer_array = self._shmem_buffer.next_slot()
             self.gpu_runner.load_cell_table(cell_table)
             self.gpu_runner.correct(self._correction_flag_enabled)
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 3f3d46860d4212c96d607ae37c3733b59a1d2503..11f68b429a7904dcc1e4797e6db1ed03cc853a54 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -57,8 +57,8 @@ class BaseCorrection(PythonDevice):
     ]  # subclass must extend this and put it in schema
     _schema_cache_fields = {
         "doAnything",
-        "dataFormat.memoryCells",
         "constantParameters.memoryCells",
+        "dataFormat.memoryCells",
         "dataFormat.pixelsX",
         "dataFormat.pixelsY",
         "dataFormat.outputAxisOrder",