diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py index 84996e46ba96557750defc14342ef0ea5bde4959..22332af9ebab3eac62a53937cf0fe307669d7c83 100644 --- a/src/calng/CalibrationManager.py +++ b/src/calng/CalibrationManager.py @@ -424,6 +424,16 @@ class CalibrationManager(DeviceClientBase, Device): self.state = State.CHANGING background(self._instantiate_pipeline()) + outputAxisOrder = String( + displayedName='Output axis order', + options=('pixels-fast','memorycells-fast','no-reshape'), + defaultValue='pixels-fast', + description='Axes of main data output can be reordered after correction. ' + 'Choose between "pixels-fast" (memory_cell, x, y), ' + '"memorycells-fast" (x, y, memory_cell), and "no-reshape" ' + '(memory_cell, y, x)' + ) + # TODO: Inject at runtime by scanning correction device schema. runtimeParameters = Node( RuntimeParametersNode, @@ -881,6 +891,7 @@ class CalibrationManager(DeviceClientBase, Device): config['dataInput.connectedOutputChannels'] = [input_channel] config['fastSources'] = [input_source] config['dataFormat.outputImageDtype'] = 'float16' + config['dataFormat.outputAxisOrder'] = self.outputAxisOrder config['dataFormat.pixelsX'] = 512 config['dataFormat.pixelsY'] = 128 config['dataFormat.memoryCells'] = 400 diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index 995b8689e4079cbb78647ffe43fc4a6701e69b5e..6468a769dbb522d78a8c41a720709e4d7fe060e1 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -179,6 +179,18 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): .description("Full number of memory cells in incoming data") .assignmentMandatory() .commit(), + STRING_ELEMENT(expected) + .key("dataFormat.outputAxisOrder") + .displayedName("Output axis order") + .description( + "Axes of main data output can be reordered after correction. Choose " + "between 'pixels-fast' (memory_cell, x, y), 'memorycells-fast' " + "(x, y, memory_cell), and 'no-reshape' (memory_cell, y, x)" + ) + .options("pixels-fast,memorycells-fast,no-reshape") + .assignmentOptional() + .defaultValue("pixels-fast") + .commit(), UINT32_ELEMENT(expected) .key("dataFormat.memoryCellsCorrection") .displayedName("(Debug) Memory cells in correction map") @@ -374,6 +386,13 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): self.input_data_dtype = getattr(np, config.get("dataFormat.inputImageDtype")) self.output_data_dtype = getattr(np, config.get("dataFormat.outputImageDtype")) + output_axis_order = config.get("dataFormat.outputAxisOrder") + if output_axis_order == "pixels-fast": + self._output_transpose = (0, 2, 1) + elif output_axis_order == "memorycells-fast": + self._output_transpose = (2, 1, 0) + else: + self._output_transpose = None self._offset_map = None self._update_pulse_filter(config.get("pulseFilter")) self._shmem_buffer = None @@ -382,6 +401,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): config.get("dataFormat.pixelsY"), config.get("dataFormat.memoryCells"), self.pulse_filter, + self._output_transpose, ) self._has_set_output_schema = False self._has_set_preview_output_schema = False @@ -501,6 +521,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): self.get("dataFormat.pixelsY"), self.get("dataFormat.memoryCells"), self.pulse_filter, + self._output_transpose, ) # TODO: check shape (DAQ fake data and RunToPipe don't agree) # TODO: consider just updating shapes based on whatever comes in @@ -768,16 +789,21 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): assert np.max(new_filter) < self.get("dataFormat.memoryCells") self.pulse_filter = new_filter - def _update_shapes(self, pixels_x, pixels_y, memory_cells, pulse_filter): + def _update_shapes( + self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose + ): """(Re)initialize (GPU) buffers according to expected data shapes""" input_data_shape = (memory_cells, 1, pixels_y, pixels_x) - output_data_shape = (pixels_x, pixels_y, pulse_filter.size) + # reflect the axis reordering in the expected output shape + output_data_shape = utils.shape_after_transpose( + input_data_shape, output_transpose + ) self.set("dataFormat.inputDataShape", list(input_data_shape)) self.set("dataFormat.outputDataShape", list(output_data_shape)) if self._shmem_buffer is None: - shmem_buffer_name = self.getInstanceId() + f":dataOutput" + shmem_buffer_name = self.getInstanceId() + ":dataOutput" memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") self._shmem_buffer = shmem_utils.ShmemCircularBuffer( @@ -793,6 +819,7 @@ class DsscCorrection(calibrationBase.CalibrationReceiverBaseDevice): pixels_x, pixels_y, memory_cells, + output_transpose=output_transpose, input_data_dtype=self.input_data_dtype, output_data_dtype=self.output_data_dtype, ) diff --git a/src/calng/dssc_gpu.py b/src/calng/dssc_gpu.py index 7dacb363835c93928232ab3b2a5ee37161d51c38..53f82f87b9cb80254dfab64dfde7411fa5e8517d 100644 --- a/src/calng/dssc_gpu.py +++ b/src/calng/dssc_gpu.py @@ -38,18 +38,26 @@ class DsscGpuRunner: pixels_x, pixels_y, memory_cells, + output_transpose=(2, 1, 0), # default: memorycells-fast + constant_memory_cells=None, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): self.pixels_x = pixels_x self.pixels_y = pixels_y self.memory_cells = memory_cells - self.constant_memory_cells = 0 + self.output_transpose = output_transpose + if constant_memory_cells is None: + self.constant_memory_cells = memory_cells + else: + self.constant_memory_cells = constant_memory_cells self.input_shape = (self.memory_cells, self.pixels_y, self.pixels_x) - self.output_shape = (self.pixels_x, self.pixels_y, self.memory_cells) + self.output_shape = utils.shape_after_transpose( + self.input_shape, self.output_transpose + ) self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells) # preview will only be single memory cell - self.preview_shape = self.output_shape[:-1] + self.preview_shape = (self.pixels_x, self.pixels_y) self.input_data_dtype = input_data_dtype self.output_data_dtype = output_data_dtype @@ -140,18 +148,23 @@ class DsscGpuRunner: ) def reshape(self, out=None): - # TODO: make order configurable """Move axes to desired output order - equivalent to: - output_data[:] = np.moveaxis( - np.squeeze(input_data), (0, 1, 2), (2, 1, 0) - ) + The out parameter is passed directly to the get function of GPU array: if + None, then a new ndarray (in host memory) is returned. If not None, then data + will be loaded into the provided array, which must match shape / dtype. """ # TODO: avoid copy - self.reshaped_data_gpu = cupy.ascontiguousarray( - cupy.transpose(cupy.squeeze(self.processed_data_gpu)) - ) + if self.output_transpose is None: + self.reshaped_data_gpu = cupy.ascontiguousarray( + cupy.squeeze(self.processed_data_gpu) + ) + else: + self.reshaped_data_gpu = cupy.ascontiguousarray( + cupy.transpose( + cupy.squeeze(self.processed_data_gpu), self.output_transpose + ) + ) return self.reshaped_data_gpu.get(out=out) def compute_preview(self, preview_index, have_corrected=True, can_correct=True): @@ -202,9 +215,9 @@ class DsscGpuRunner: ) elif preview_index in (-2, -3, -4): stat_fun = {-2: cupy.mean, -3: cupy.sum, -4: cupy.std}[preview_index] - stat_fun( - image_data, axis=0, dtype=cupy.float32 - ).transpose().get(out=output_buffer) + stat_fun(image_data, axis=0, dtype=cupy.float32).transpose().get( + out=output_buffer + ) return self.preview_raw, self.preview_corrected def _init_kernels(self): diff --git a/src/utils.py b/src/utils.py index f294f1ecbed450157defb934b49b2b05637dcf8c..ddd8d7158f64293c356e3903cdefa5c90ba10da3 100644 --- a/src/utils.py +++ b/src/utils.py @@ -17,6 +17,14 @@ def ceil_div(num, denom): return (num + denom - 1) // denom +def shape_after_transpose(input_shape, transpose_pattern, squeeze=True): + if squeeze: + input_shape = tuple(dim for dim in input_shape if dim>1) + if transpose_pattern is None: + return input_shape + return tuple(np.array(input_shape)[list(transpose_pattern)].tolist()) + + class DelayableTimer: """Start a timer which can be extended