From ae339b47fb088715b05f8e688af9f3becdee8481 Mon Sep 17 00:00:00 2001
From: David Hammer <david.hammer@xfel.eu>
Date: Fri, 2 Aug 2024 06:52:28 +0200
Subject: [PATCH] AGIPD: handle bad pixel constants with three dimensions with
 more cells than requested

---
 src/calng/base_correction.py             |  2 +-
 src/calng/corrections/AgipdCorrection.py | 18 +++++++++---------
 src/calng/corrections/DsscCorrection.py  |  2 +-
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index a11f9b96..caa4ff6d 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -569,7 +569,7 @@ class BaseCorrection(PythonDevice):
             or self._enabled_addons
         ):
             schema_override = Schema()
-            output_schema_override = self._base_output_schema(
+            output_schema_override = self.__class__._base_output_schema(
                 use_shmem_handles=self.get("useShmemHandles")
             )
             for addon in self._enabled_addons:
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index 6632b6ef..b6e60d9c 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -469,24 +469,26 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner):
             # will simply OR with already loaded, does not take into account which ones
             constant_data = self._xp.asarray(constant_data, dtype=np.uint32)
             if len(constant_data.shape) == 3:
-                if constant_data.shape == (
+                if constant_data.shape[:2] == (
                     self.num_pixels_fs,
                     self.num_pixels_ss,
-                    self._constant_memory_cells,
                 ):
                     # BadPixelsFF is not per gain stage - broadcasting along gain
                     constant_data = self._xp.broadcast_to(
-                        constant_data.transpose()[..., np.newaxis],
+                        constant_data.transpose()[
+                            :self._constant_memory_cells, ..., np.newaxis
+                        ],
                         self._gm_map_shape,
                     )
-                elif constant_data.shape == (
-                    self._constant_memory_cells,
+                elif constant_data.shape[1:] == (
                     self.num_pixels_fs,
                     self.num_pixels_ss,
                 ):
                     # old BadPixelsPC have different axis order
                     constant_data = self._xp.broadcast_to(
-                        constant_data.transpose((0, 2, 1))[..., np.newaxis],
+                        constant_data.transpose((0, 2, 1))[
+                            :self._constant_memory_cells, ..., np.newaxis
+                        ],
                         self._gm_map_shape,
                     )
                 else:
@@ -756,9 +758,7 @@ class AgipdCorrection(base_correction.BaseCorrection):
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
             .dataSchema(
-                AgipdCorrection._base_output_schema(
-                    use_shmem_handles=cls._use_shmem_handles
-                )
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
             )
             .commit(),
         )
diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py
index 3611b267..b4a65c76 100644
--- a/src/calng/corrections/DsscCorrection.py
+++ b/src/calng/corrections/DsscCorrection.py
@@ -179,7 +179,7 @@ class DsscCorrection(base_correction.BaseCorrection):
             OUTPUT_CHANNEL(expected)
             .key("dataOutput")
             .dataSchema(
-                DsscCorrection._base_output_schema(cls._use_shmem_handles)
+                cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles)
             )
             .commit(),
         )
-- 
GitLab