From ffec3d043dcd4c8c0677a96d944dafe7e98c2aa7 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 1 Sep 2021 11:14:28 +0200
Subject: [PATCH] Add configurability to order of output axes

---
 src/calng/CalibrationManager.py | 11 +++++++++
 src/calng/DsscCorrection.py     | 33 +++++++++++++++++++++++---
 src/calng/dssc_gpu.py           | 41 ++++++++++++++++++++++-----------
 src/utils.py                    |  8 +++++++
 4 files changed, 76 insertions(+), 17 deletions(-)

diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py
index 84996e46..22332af9 100644
--- a/src/calng/CalibrationManager.py
+++ b/src/calng/CalibrationManager.py
@@ -424,6 +424,16 @@ class CalibrationManager(DeviceClientBase, Device):
         self.state = State.CHANGING
         background(self._instantiate_pipeline())
 
+    outputAxisOrder = String(
+        displayedName='Output axis order',
+        options=('pixels-fast','memorycells-fast','no-reshape'),
+        defaultValue='pixels-fast',
+        description='Axes of main data output can be reordered after correction. '
+                    'Choose between "pixels-fast" (memory_cell, x, y), '
+                    '"memorycells-fast" (x, y, memory_cell), and "no-reshape" '
+                    '(memory_cell, y, x)'
+    )
+
     # TODO: Inject at runtime by scanning correction device schema.
     runtimeParameters = Node(
         RuntimeParametersNode,
@@ -881,6 +891,7 @@ class CalibrationManager(DeviceClientBase, Device):
             config['dataInput.connectedOutputChannels'] = [input_channel]
             config['fastSources'] = [input_source]
             config['dataFormat.outputImageDtype'] = 'float16'
+            config['dataFormat.outputAxisOrder'] = self.outputAxisOrder
             config['dataFormat.pixelsX'] = 512
             config['dataFormat.pixelsY'] = 128
             config['dataFormat.memoryCells'] = 400
diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py
index 995b8689..6468a769 100644
--- a/src/calng/DsscCorrection.py
+++ b/src/calng/DsscCorrection.py
@@ -179,6 +179,18 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             .description("Full number of memory cells in incoming data")
             .assignmentMandatory()
             .commit(),
+            STRING_ELEMENT(expected)
+            .key("dataFormat.outputAxisOrder")
+            .displayedName("Output axis order")
+            .description(
+                "Axes of main data output can be reordered after correction. Choose "
+                "between 'pixels-fast' (memory_cell, x, y), 'memorycells-fast' "
+                "(x, y, memory_cell), and 'no-reshape' (memory_cell, y, x)"
+            )
+            .options("pixels-fast,memorycells-fast,no-reshape")
+            .assignmentOptional()
+            .defaultValue("pixels-fast")
+            .commit(),
             UINT32_ELEMENT(expected)
             .key("dataFormat.memoryCellsCorrection")
             .displayedName("(Debug) Memory cells in correction map")
@@ -374,6 +386,13 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
 
         self.input_data_dtype = getattr(np, config.get("dataFormat.inputImageDtype"))
         self.output_data_dtype = getattr(np, config.get("dataFormat.outputImageDtype"))
+        output_axis_order = config.get("dataFormat.outputAxisOrder")
+        if output_axis_order == "pixels-fast":
+            self._output_transpose = (0, 2, 1)
+        elif output_axis_order == "memorycells-fast":
+            self._output_transpose = (2, 1, 0)
+        else:
+            self._output_transpose = None
         self._offset_map = None
         self._update_pulse_filter(config.get("pulseFilter"))
         self._shmem_buffer = None
@@ -382,6 +401,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             config.get("dataFormat.pixelsY"),
             config.get("dataFormat.memoryCells"),
             self.pulse_filter,
+            self._output_transpose,
         )
         self._has_set_output_schema = False
         self._has_set_preview_output_schema = False
@@ -501,6 +521,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
                     self.get("dataFormat.pixelsY"),
                     self.get("dataFormat.memoryCells"),
                     self.pulse_filter,
+                    self._output_transpose,
                 )
         # TODO: check shape (DAQ fake data and RunToPipe don't agree)
         # TODO: consider just updating shapes based on whatever comes in
@@ -768,16 +789,21 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
         assert np.max(new_filter) < self.get("dataFormat.memoryCells")
         self.pulse_filter = new_filter
 
-    def _update_shapes(self, pixels_x, pixels_y, memory_cells, pulse_filter):
+    def _update_shapes(
+        self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose
+    ):
         """(Re)initialize (GPU) buffers according to expected data shapes"""
 
         input_data_shape = (memory_cells, 1, pixels_y, pixels_x)
-        output_data_shape = (pixels_x, pixels_y, pulse_filter.size)
+        # reflect the axis reordering in the expected output shape
+        output_data_shape = utils.shape_after_transpose(
+            input_data_shape, output_transpose
+        )
         self.set("dataFormat.inputDataShape", list(input_data_shape))
         self.set("dataFormat.outputDataShape", list(output_data_shape))
 
         if self._shmem_buffer is None:
-            shmem_buffer_name = self.getInstanceId() + f":dataOutput"
+            shmem_buffer_name = self.getInstanceId() + ":dataOutput"
             memory_budget = self.get("outputShmemBufferSize") * 2 ** 30
             self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}")
             self._shmem_buffer = shmem_utils.ShmemCircularBuffer(
@@ -793,6 +819,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice):
             pixels_x,
             pixels_y,
             memory_cells,
+            output_transpose=output_transpose,
             input_data_dtype=self.input_data_dtype,
             output_data_dtype=self.output_data_dtype,
         )
diff --git a/src/calng/dssc_gpu.py b/src/calng/dssc_gpu.py
index 7dacb363..53f82f87 100644
--- a/src/calng/dssc_gpu.py
+++ b/src/calng/dssc_gpu.py
@@ -38,18 +38,26 @@ class DsscGpuRunner:
         pixels_x,
         pixels_y,
         memory_cells,
+        output_transpose=(2, 1, 0),  # default: memorycells-fast
+        constant_memory_cells=None,
         input_data_dtype=np.uint16,
         output_data_dtype=np.float32,
     ):
         self.pixels_x = pixels_x
         self.pixels_y = pixels_y
         self.memory_cells = memory_cells
-        self.constant_memory_cells = 0
+        self.output_transpose = output_transpose
+        if constant_memory_cells is None:
+            self.constant_memory_cells = memory_cells
+        else:
+            self.constant_memory_cells = constant_memory_cells
         self.input_shape = (self.memory_cells, self.pixels_y, self.pixels_x)
-        self.output_shape = (self.pixels_x, self.pixels_y, self.memory_cells)
+        self.output_shape = utils.shape_after_transpose(
+            self.input_shape, self.output_transpose
+        )
         self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
         # preview will only be single memory cell
-        self.preview_shape = self.output_shape[:-1]
+        self.preview_shape = (self.pixels_x, self.pixels_y)
         self.input_data_dtype = input_data_dtype
         self.output_data_dtype = output_data_dtype
 
@@ -140,18 +148,23 @@ class DsscGpuRunner:
         )
 
     def reshape(self, out=None):
-        # TODO: make order configurable
         """Move axes to desired output order
 
-        equivalent to:
-        output_data[:] = np.moveaxis(
-            np.squeeze(input_data), (0, 1, 2), (2, 1, 0)
-        )
+        The out parameter is passed directly to the get function of GPU array: if
+        None, then a new ndarray (in host memory) is returned. If not None, then data
+        will be loaded into the provided array, which must match shape / dtype.
         """
         # TODO: avoid copy
-        self.reshaped_data_gpu = cupy.ascontiguousarray(
-            cupy.transpose(cupy.squeeze(self.processed_data_gpu))
-        )
+        if self.output_transpose is None:
+            self.reshaped_data_gpu = cupy.ascontiguousarray(
+                cupy.squeeze(self.processed_data_gpu)
+            )
+        else:
+            self.reshaped_data_gpu = cupy.ascontiguousarray(
+                cupy.transpose(
+                    cupy.squeeze(self.processed_data_gpu), self.output_transpose
+                )
+            )
         return self.reshaped_data_gpu.get(out=out)
 
     def compute_preview(self, preview_index, have_corrected=True, can_correct=True):
@@ -202,9 +215,9 @@ class DsscGpuRunner:
                 )
             elif preview_index in (-2, -3, -4):
                 stat_fun = {-2: cupy.mean, -3: cupy.sum, -4: cupy.std}[preview_index]
-                stat_fun(
-                    image_data, axis=0, dtype=cupy.float32
-                ).transpose().get(out=output_buffer)
+                stat_fun(image_data, axis=0, dtype=cupy.float32).transpose().get(
+                    out=output_buffer
+                )
         return self.preview_raw, self.preview_corrected
 
     def _init_kernels(self):
diff --git a/src/utils.py b/src/utils.py
index f294f1ec..ddd8d715 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -17,6 +17,14 @@ def ceil_div(num, denom):
     return (num + denom - 1) // denom
 
 
+def shape_after_transpose(input_shape, transpose_pattern, squeeze=True):
+    if squeeze:
+        input_shape = tuple(dim for dim in input_shape if dim>1)
+    if transpose_pattern is None:
+        return input_shape
+    return tuple(np.array(input_shape)[list(transpose_pattern)].tolist())
+
+
 class DelayableTimer:
     """Start a timer which can be extended
 
-- 
GitLab