From 3d9eebd1acc8cc7f1ebdf13a886a4eafa6d12eb4 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Thu, 1 Aug 2024 14:03:25 +0200
Subject: [PATCH] Use workarounds.overrideInputAxisOrder in slow detectors

---
 src/calng/corrections/Epix100Correction.py | 13 +++++++------
 src/calng/corrections/PnccdCorrection.py   |  4 ++++
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/src/calng/corrections/Epix100Correction.py b/src/calng/corrections/Epix100Correction.py
index b541a72a..a469c0ee 100644
--- a/src/calng/corrections/Epix100Correction.py
+++ b/src/calng/corrections/Epix100Correction.py
@@ -213,12 +213,9 @@ class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner):
         if config.has("corrections.commonMode.enableBlock"):
             self._cm_block = config["corrections.commonMode.enableBlock"]
 
-    def expected_input_data_shape(self, num_frames):
-        assert num_frames == 1
-        return (
-            self.num_pixels_ss,
-            self.num_pixels_fs,
-        )
+    def expected_input_shape(self, num_frames):
+        assert num_frames == 1, "ePix not expected to have multiple frames"
+        return (self.num_pixels_ss, self.num_pixels_fs)
 
     def _expected_output_shape(self, num_frames):
         return (self.num_pixels_ss, self.num_pixels_fs)
@@ -393,6 +390,10 @@ class Epix100Correction(base_correction.BaseCorrection):
 
     def _get_data_from_hash(self, data_hash):
         image_data = data_hash.get(self._image_data_path)
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(1)
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
         return (
             1,
             image_data,
diff --git a/src/calng/corrections/PnccdCorrection.py b/src/calng/corrections/PnccdCorrection.py
index 0156e200..12353342 100644
--- a/src/calng/corrections/PnccdCorrection.py
+++ b/src/calng/corrections/PnccdCorrection.py
@@ -390,6 +390,10 @@ class PnccdCorrection(base_correction.BaseCorrection):
 
     def _get_data_from_hash(self, data_hash):
         image_data = data_hash.get(self._image_data_path)
+        if self.unsafe_get("workarounds.overrideInputAxisOrder"):
+            expected_shape = self.kernel_runner.expected_input_shape(1)
+            if expected_shape != image_data.shape:
+                image_data.shape = expected_shape
         return (
             1,
             image_data,
-- 
GitLab