diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index a2a0779799adc84a861113fc8ede9db9dc92ae50..50ed129d5e80fbdd86160fce2598250721b34635 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -1,11 +1,10 @@ import timeit -import calibrationBase import numpy as np -from karabo.bound import KARABO_CLASSINFO +from karabo.bound import BOOL_ELEMENT, KARABO_CLASSINFO from karabo.common.states import State -from . import shmem_utils, utils +from . import utils from ._version import version as deviceVersion from .base_correction import BaseCorrection from .agipd_gpu import AgipdGpuRunner, CorrectionFlags @@ -13,6 +12,17 @@ from .agipd_gpu import AgipdGpuRunner, CorrectionFlags @KARABO_CLASSINFO("AgipdCorrection", deviceVersion) class AgipdCorrection(BaseCorrection): + # subclass *must* set these attributes + _correction_flag_class = CorrectionFlags + _correction_slot_names = ( + ("thresholding", CorrectionFlags.THRESHOLD), + ("offset", CorrectionFlags.OFFSET), + ("relGainPc", CorrectionFlags.REL_GAIN_PC), + ("relGainXray", CorrectionFlags.REL_GAIN_XRAY), + ("badpixels", CorrectionFlags.BPMASK), + ) + _gpu_runner_class = AgipdGpuRunner + @staticmethod def expectedParameters(expected): AgipdCorrection.addConstant( @@ -49,6 +59,46 @@ class AgipdCorrection(BaseCorrection): mandatoryForIteration=True, ) super(AgipdCorrection, AgipdCorrection).expectedParameters(expected) + for slot_name, _ in AgipdCorrection._correction_slot_names: + ( + BOOL_ELEMENT(expected) + .key(f"corrections.available.{slot_name}") + .readOnly() + .initialValue(False) + .commit(), + BOOL_ELEMENT(expected) + .key(f"corrections.enabled.{slot_name}") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + BOOL_ELEMENT(expected) + .key(f"corrections.preview.{slot_name}") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + ) + + @property + def input_data_shape(self): + return ( + self.get("dataFormat.memoryCells"), + 2, + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + ) + + @property + def output_data_shape(self): + return utils.shape_after_transpose( + ( + self.get("dataFormat.memoryCells"), + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + ), + self._output_transpose, + ) def __init__(self, config): super().__init__(config) @@ -59,17 +109,7 @@ class AgipdCorrection(BaseCorrection): self._output_transpose = (2, 1, 0) else: self._output_transpose = None - self._offset_map = None - self._update_pulse_filter(config.get("pulseFilter")) - self._update_shapes( - config.get("dataFormat.pixelsX"), - config.get("dataFormat.pixelsY"), - config.get("dataFormat.memoryCells"), - self.pulse_filter, - self._output_transpose, - ) - self._cached_constants = {} - + self._update_shapes() self.updateState(State.ON) def process_input(self, data, metadata): @@ -102,7 +142,6 @@ class AgipdCorrection(BaseCorrection): self.log.WARN(msg) return # original shape: memory_cell, data/raw_gain, x, y - # TODO: consider making paths configurable image_data = data.get("image.data") if image_data.shape[0] != self.get("dataFormat.memoryCells"): self.set( @@ -111,23 +150,8 @@ class AgipdCorrection(BaseCorrection): # TODO: truncate if > 800 self.set("dataFormat.memoryCells", image_data.shape[0]) with self._buffer_lock: - self._update_pulse_filter(self.get("pulseFilter")) - self._update_shapes( - self.get("dataFormat.pixelsX"), - self.get("dataFormat.pixelsY"), - self.get("dataFormat.memoryCells"), - self.pulse_filter, - self._output_transpose, - ) - # TODO: check shape (DAQ fake data and RunToPipe don't agree) - # TODO: consider just updating shapes based on whatever comes in - - correction_cell_num = self.get("dataFormat.memoryCellsCorrection") - do_generate_preview = train_id % self.get( - "preview.trainIdModulo" - ) == 0 and self.get("preview.enable") - can_apply_correction = True - do_apply_correction = self.get("applyCorrection") + # TODO: pulse filter update after reimplementation + self._update_shapes() if not self.get("state") is State.PROCESSING: self.updateState(State.PROCESSING) @@ -140,43 +164,38 @@ class AgipdCorrection(BaseCorrection): else: self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) - with self._buffer_lock: - cell_table = cell_table[self.pulse_filter] - pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter] + correction_cell_num = self.get("dataFormat.constantMemoryCells") + do_generate_preview = train_id % self.get( + "preview.trainIdModulo" + ) == 0 and self.get("preview.enable") + with self._buffer_lock: + # cell_table = cell_table[self.pulse_filter] + pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] cell_table_max = np.max(cell_table) - if do_apply_correction: - if not can_apply_correction: - msg = "No constant loaded, correction will not be applied." - self.log.WARN(msg) - self.set("status", msg) - do_apply_correction = False - elif cell_table_max >= correction_cell_num: - msg = ( - f"Max cell ID ({cell_table_max}) exceeds range for loaded " - f"constant (has {correction_cell_num} cells). Some frames " - "will not be corrected." - ) - self.log.WARN(msg) - self.set("status", msg) + # TODO: all this checking and warning can go into GPU runner + if cell_table_max >= correction_cell_num: + msg = ( + f"Max cell ID ({cell_table_max}) exceeds range for loaded " + f"constants ({correction_cell_num} cells). Some frames will not be " + "corrected." + ) + self.log.WARN(msg) + self.set("status", msg) self.gpu_runner.load_data(image_data) buffer_handle, buffer_array = self._shmem_buffer.next_slot() - if do_apply_correction: - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct( - CorrectionFlags.THRESHOLD - | CorrectionFlags.OFFSET - | CorrectionFlags.REL_GAIN_PC - | CorrectionFlags.REL_GAIN_XRAY - ) - else: - self.gpu_runner.only_cast() + self.gpu_runner.load_cell_table(cell_table) + self.gpu_runner.correct(self._correction_flag_enabled) self.gpu_runner.reshape(out=buffer_array) + # after reshape, data for dataOutput is now safe in its own buffer if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.gpu_runner.correct(self._correction_flag_preview) preview_slice_index = self.get("preview.pulse") if preview_slice_index >= 0: # look at pulse_table to find which index this pulse ID is in + # TODO: move this to GPU pulse_id_found = np.where(pulse_table == preview_slice_index)[0] if len(pulse_id_found) == 0: pulse_found_instead = pulse_table[0] @@ -185,24 +204,13 @@ class AgipdCorrection(BaseCorrection): f"image.pulseId, arbitrary pulse " f"{pulse_found_instead} will be shown." ) - preview_slice_index = 0 self.log.WARN(msg) self.set("status", msg) + preview_slice_index = 0 else: preview_slice_index = pulse_id_found[0] - if not do_apply_correction: - if can_apply_correction: - # in this case, cell table has not been loaded, but needs to be now - self.gpu_runner.load_cell_table(cell_table) - else: - # in this case, there will be no corrected preview - self.log.WARN( - "Corrected preview will not actually be corrected." - ) preview_raw, preview_corrected = self.gpu_runner.compute_preview( - preview_slice_index, - have_corrected=do_apply_correction, - can_correct=can_apply_correction, + preview_slice_index ) data.set("image.data", buffer_handle) @@ -225,21 +233,36 @@ class AgipdCorrection(BaseCorrection): if self.get("performance.rateUpdateOnEachInput"): self._update_actual_rate() - def requestConstant(self, name, mostRecent=False, tryRemote=True): - """constantLoaded hook would have gotten called without naming which constant, - so here we go. Ugly hooking it.""" - # TODO: clear from device, too - # TODO: update correction capability flag - if name in self._cached_constants: - del self._cached_constants[name] - super().requestConstant(name, mostRecent, tryRemote) - constant = self.getConstant(name) - if constant is not None: - self._cached_constants[name] = constant - if name == "ThresholdsDark": - self.gpu_runner.load_thresholds(constant) - elif name == "Offset": - self.gpu_runner.load_offset_map(constant) + def _load_constant_to_gpu(self, constant_name, constant_data): + # TODO: also hook flushConstants or whatever it is called + if constant_name == "ThresholdsDark": + self.gpu_runner.load_thresholds(constant_data) + # TODO: encode correction / constant dependencies in a clever way + if not self.get("corrections.available.thresholding"): + self.set("corrections.available.thresholding", True) + self.set("corrections.enabled.thresholding", True) + self.set("corrections.preview.thresholding", True) + elif constant_name == "Offset": + self.gpu_runner.load_offset_map(constant_data) + if not self.get("corrections.available.offset"): + self.set("corrections.available.offset", True) + self.set("corrections.enabled.offset", True) + self.set("corrections.preview.offset", True) + elif constant_name == "SlopesPC": + self.gpu_runner.load_rel_gain_pc_map(constant_data) + if not self.get("corrections.available.relGainPc"): + self.set("corrections.available.relGainPc", True) + self.set("corrections.enabled.relGainPc", True) + self.set("corrections.preview.relGainPc", True) + elif constant_name == "SlopesFF": + self.gpu_runner.load_rel_gain_ff_map(constant_data) + if not self.get("corrections.available.relGainXray"): + self.set("corrections.available.relGainXray", True) + self.set("corrections.enabled.relGainXray", True) + self.set("corrections.preview.relGainXray", True) + elif "BadPixels" in constant_name: + # TODO: implement loading bad pixels + ... def _update_pulse_filter(self, filter_string): """Called whenever the pulse filter changes, typically followed by @@ -251,61 +274,3 @@ class AgipdCorrection(BaseCorrection): new_filter = np.array(eval(filter_string), dtype=np.uint16) assert np.max(new_filter) < self.get("dataFormat.memoryCells") self.pulse_filter = new_filter - - def _update_shapes( - self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose - ): - """(Re)initialize (GPU) buffers according to expected data shapes""" - - # TODO: report "actual" input shape (incl. raw gain) - input_data_shape = (memory_cells, pixels_x, pixels_y) - # reflect the axis reordering in the expected output shape - output_data_shape = utils.shape_after_transpose( - input_data_shape, output_transpose - ) - self.set("dataFormat.inputDataShape", list(input_data_shape)) - self.set("dataFormat.outputDataShape", list(output_data_shape)) - - if self._shmem_buffer is None: - shmem_buffer_name = self.getInstanceId() + ":dataOutput" - memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 - self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") - self._shmem_buffer = shmem_utils.ShmemCircularBuffer( - memory_budget, - output_data_shape, - self.output_data_dtype, - shmem_buffer_name, - ) - else: - self._shmem_buffer.change_shape(output_data_shape) - - self.gpu_runner = AgipdGpuRunner( - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells=250, - output_transpose=output_transpose, - input_data_dtype=self.input_data_dtype, - output_data_dtype=self.output_data_dtype, - ) - - self._update_maps_on_gpu() - - def _update_maps_on_gpu(self): - """Updates the correction maps stored on GPU based on constants known - - This only does something useful if constants have been retrieved from - CalCat. Should be called automatically upon retrieval and after - changing the data shape. - - """ - - self.set("status", "Updating constants on GPU using known constants") - self.updateState(State.CHANGING) - if self._offset_map is not None: - self.gpu_runner.load_constants(self._offset_map) - msg = "Done transferring known constant(s) to GPU" - self.log.INFO(msg) - self.set("status", msg) - - self.updateState(State.ON) diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py index 39ae148256b152f7115b20dca11dbade6c34f982..c5ab67cde704f3fe798ca85419a22ebfa152c96a 100644 --- a/src/calng/CalibrationManager.py +++ b/src/calng/CalibrationManager.py @@ -292,6 +292,7 @@ class ManagedKeysCloneFactory(ProxyFactory): class CalibrationManager(DeviceClientBase, Device): __version__ = deviceVersion + _conditions_to_set = {} interfaces = VectorString( displayedName='Device interfaces', @@ -1082,6 +1083,16 @@ class CalibrationManager(DeviceClientBase, Device): ): return + # get device schema sneakily + self.logger.debug(f'Trying to figure out which constant parameters to set later') + # note: this should "obviously" happen earlier, but also more generally + correction_device_schema, _, _ = await call(server_by_group[group], "slotGetClassSchema", class_ids["correction"]) + self._conditions_to_set = { + constant_name: constant_schema.getAttribute('detector_condition', 'defaultValue') + for constant_name, constant_schema in correction_device_schema.hash['constants'].items() + } + self.logger.debug(str(self._conditions_to_set)) + # Instantiate group matchers and bridges. for row in self.moduleGroups.value: group, server, with_matcher, with_bridge, bridge_port, \ diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index fff86e9dd708c50c8523e1a469fd6fd21288685b..7a2cc5ec8a71e0cb74659030b64edec8ae1708b5 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -1,24 +1,61 @@ import timeit -import calibrationBase import numpy as np -from karabo.bound import KARABO_CLASSINFO +from karabo.bound import BOOL_ELEMENT, KARABO_CLASSINFO from karabo.common.states import State -from . import shmem_utils, utils +from . import utils from ._version import version as deviceVersion from .base_correction import BaseCorrection -from .dssc_gpu import DsscGpuRunner +from .dssc_gpu import DsscGpuRunner, CorrectionFlags @KARABO_CLASSINFO("DsscCorrection", deviceVersion) class DsscCorrection(BaseCorrection): + # subclass *must* set these attributes + _correction_flag_class = CorrectionFlags + _correction_slot_names = (("offset", CorrectionFlags.OFFSET),) + _gpu_runner_class = DsscGpuRunner + @staticmethod def expectedParameters(expected): DsscCorrection.addConstant( "Offset", "Dark", expected, optional=True, mandatoryForIteration=True ) super(DsscCorrection, DsscCorrection).expectedParameters(expected) + for slot_name, _ in DsscCorrection._correction_slot_names: + ( + BOOL_ELEMENT(expected) + .key(f"corrections.available.{slot_name}") + .readOnly() + .initialValue(False) + .commit(), + BOOL_ELEMENT(expected) + .key(f"corrections.enabled.{slot_name}") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + BOOL_ELEMENT(expected) + .key(f"corrections.preview.{slot_name}") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + ) + + @property + def input_data_shape(self): + return ( + self.get("dataFormat.memoryCells"), + 1, + self.get("dataFormat.pixelsY"), + self.get("dataFormat.pixelsX"), + ) + + @property + def output_data_shape(self): + return utils.shape_after_transpose(self.input_data_shape, self._output_transpose) def __init__(self, config): super().__init__(config) @@ -29,39 +66,11 @@ class DsscCorrection(BaseCorrection): self._output_transpose = (2, 1, 0) else: self._output_transpose = None - self._offset_map = None - self._update_pulse_filter(config.get("pulseFilter")) - self._update_shapes( - config.get("dataFormat.pixelsX"), - config.get("dataFormat.pixelsY"), - config.get("dataFormat.memoryCells"), - self.pulse_filter, - self._output_transpose, - ) - + self._update_shapes() self.updateState(State.ON) - def preReconfigure(self, config): - super().preReconfigure(config) - if config.has("pulseFilter"): - with self._buffer_lock: - # apply new pulse filter - self._update_pulse_filter(config.get("pulseFilter")) - # but existing shapes (not reconfigurable) - # TODO: avoid double compilation here if constants are loaded - self._update_shapes( - self.get("dataFormat.pixelsX"), - self.get("dataFormat.pixelsY"), - self.get("dataFormat.memoryCells"), - self.pulse_filter, - ) - def process_input(self, data, metadata): - """Registered for dataInput, handles all processing and sending - - Comparable to StreamBase.onInput but hopefully faster - - """ + """Registered for dataInput, handles all processing and sending""" if not self.get("doAnything"): if self.get("state") is State.PROCESSING: @@ -75,7 +84,6 @@ class DsscCorrection(BaseCorrection): self.log.INFO(f"Ignoring unknown source {source}") return - # TODO: what are these empty things for? if not data.has("image"): self.log.INFO("Ignoring hash without image node") return @@ -91,7 +99,6 @@ class DsscCorrection(BaseCorrection): self.log.WARN(msg) return # original shape: 400, 1, 128, 512 (memory cells, something, y, x) - # TODO: consider making paths configurable image_data = data.get("image.data") if image_data.shape[0] != self.get("dataFormat.memoryCells"): self.set( @@ -100,23 +107,8 @@ class DsscCorrection(BaseCorrection): # TODO: truncate if > 800 self.set("dataFormat.memoryCells", image_data.shape[0]) with self._buffer_lock: - self._update_pulse_filter(self.get("pulseFilter")) - self._update_shapes( - self.get("dataFormat.pixelsX"), - self.get("dataFormat.pixelsY"), - self.get("dataFormat.memoryCells"), - self.pulse_filter, - self._output_transpose, - ) - # TODO: check shape (DAQ fake data and RunToPipe don't agree) - # TODO: consider just updating shapes based on whatever comes in - - correction_cell_num = self.get("dataFormat.memoryCellsCorrection") - do_generate_preview = train_id % self.get( - "preview.trainIdModulo" - ) == 0 and self.get("preview.enable") - can_apply_correction = correction_cell_num > 0 - do_apply_correction = self.get("applyCorrection") + # TODO: pulse filter update after reimplementation + self._update_shapes() if not self.get("state") is State.PROCESSING: self.updateState(State.PROCESSING) @@ -129,35 +121,33 @@ class DsscCorrection(BaseCorrection): else: self._state_reset_timer.set_timeout(self.get("processingStateTimeout")) + correction_cell_num = self.get("dataFormat.constantMemoryCells") + do_generate_preview = train_id % self.get( + "preview.trainIdModulo" + ) == 0 and self.get("preview.enable") + with self._buffer_lock: - cell_table = cell_table[self.pulse_filter] - pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter] + # cell_table = cell_table[self.pulse_filter] + pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] cell_table_max = np.max(cell_table) - if do_apply_correction: - if not can_apply_correction: - msg = "No constant loaded, correction will not be applied." - self.log.WARN(msg) - self.set("status", msg) - do_apply_correction = False - elif cell_table_max >= correction_cell_num: - msg = ( - f"Max cell ID ({cell_table_max}) exceeds range for loaded " - f"constant (has {correction_cell_num} cells). Some frames " - "will not be corrected." - ) - self.log.WARN(msg) - self.set("status", msg) + if cell_table_max >= correction_cell_num: + msg = ( + f"Max cell ID ({cell_table_max}) exceeds range for loaded " + f"constant ({correction_cell_num} cells). Some frames will not be " + "corrected." + ) + self.log.WARN(msg) + self.set("status", msg) self.gpu_runner.load_data(image_data) buffer_handle, buffer_array = self._shmem_buffer.next_slot() - if do_apply_correction: - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct() - else: - self.gpu_runner.only_cast() + self.gpu_runner.load_cell_table(cell_table) + self.gpu_runner.correct(self._correction_flag_enabled) self.gpu_runner.reshape(out=buffer_array) if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.gpu_runner.correct(self._correction_flag_preview) preview_slice_index = self.get("preview.pulse") if preview_slice_index >= 0: # look at pulse_table to find which index this pulse ID is in @@ -169,24 +159,13 @@ class DsscCorrection(BaseCorrection): f"image.pulseId, arbitrary pulse " f"{pulse_found_instead} will be shown." ) - preview_slice_index = 0 self.log.WARN(msg) self.set("status", msg) + preview_slice_index = 0 else: preview_slice_index = pulse_id_found[0] - if not do_apply_correction: - if can_apply_correction: - # in this case, cell table has not been loaded, but needs to be now - self.gpu_runner.load_cell_table(cell_table) - else: - # in this case, there will be no corrected preview - self.log.WARN( - "Corrected preview will not actually be corrected." - ) preview_raw, preview_corrected = self.gpu_runner.compute_preview( preview_slice_index, - have_corrected=do_apply_correction, - can_correct=can_apply_correction, ) data.set("image.data", buffer_handle) @@ -209,57 +188,6 @@ class DsscCorrection(BaseCorrection): if self.get("performance.rateUpdateOnEachInput"): self._update_actual_rate() - def constantLoaded(self): - """Hook from CalibrationReceiverBaseDevice called after each getConstant - - Here, used to load the received constants (or correction maps derived - fromt them) onto GPU. - - TODO: call after receiving *all* constants instead of calling once per - new constant (will cause some overhead for bigger devices) - - """ - - offset_map = self.getConstant("Offset") - input_memory_cells = self.get("dataFormat.memoryCells") - if offset_map is None: - msg = ( - "Warning: Did not find offset constant, offset correction " - "will not be applied" - ) - self.set("status", msg) - self.log.WARN(msg) - self._offset_map = None - elif len(offset_map.shape) not in (3, 4): - msg = ( - f"Offset map had unexpected shape {offset_map.shape}, " - "offset correction will not be applied" - ) - self.set("status", msg) - self.log.WARN(msg) - else: - self.log.INFO(f"Offset map loaded has shape {offset_map.shape}") - if len(offset_map.shape) == 4: # old format (see offsetcorrection_dssc.py)? - offset_map = offset_map[..., 0] - constant_memory_cells = offset_map.shape[-1] - if input_memory_cells > constant_memory_cells: - msg = ( - f"Warning: Memory cells in input {input_memory_cells} > " - f"memory cells in constant {constant_memory_cells}, some " - "frames may not get correction applied." - ) - self.set("status", msg) - self.log.WARN(msg) - self._offset_map = offset_map.astype(np.float32) - msg = f"Offset map with shape {self._offset_map.shape} ready to load to GPU" - self.set("status", msg) - self.log.INFO(msg) - if constant_memory_cells != self.get("dataFormat.memoryCellsCorrection"): - self.log.INFO("Will first have to update buffers on GPU") - self.set("dataFormat.memoryCellsCorrection", constant_memory_cells) - - self._update_maps_on_gpu() - def _update_pulse_filter(self, filter_string): """Called whenever the pulse filter changes, typically followed by _update_shapes""" @@ -271,58 +199,10 @@ class DsscCorrection(BaseCorrection): assert np.max(new_filter) < self.get("dataFormat.memoryCells") self.pulse_filter = new_filter - def _update_shapes( - self, pixels_x, pixels_y, memory_cells, pulse_filter, output_transpose - ): - """(Re)initialize (GPU) buffers according to expected data shapes""" - - input_data_shape = (memory_cells, 1, pixels_y, pixels_x) - # reflect the axis reordering in the expected output shape - output_data_shape = utils.shape_after_transpose( - input_data_shape, output_transpose - ) - self.set("dataFormat.inputDataShape", list(input_data_shape)) - self.set("dataFormat.outputDataShape", list(output_data_shape)) - - if self._shmem_buffer is None: - shmem_buffer_name = self.getInstanceId() + ":dataOutput" - memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 - self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") - self._shmem_buffer = shmem_utils.ShmemCircularBuffer( - memory_budget, - output_data_shape, - self.output_data_dtype, - shmem_buffer_name, - ) - else: - self._shmem_buffer.change_shape(output_data_shape) - - self.gpu_runner = DsscGpuRunner( - pixels_x, - pixels_y, - memory_cells, - output_transpose=output_transpose, - input_data_dtype=self.input_data_dtype, - output_data_dtype=self.output_data_dtype, - ) - - self._update_maps_on_gpu() - - def _update_maps_on_gpu(self): - """Updates the correction maps stored on GPU based on constants known - - This only does something useful if constants have been retrieved from - CalCat. Should be called automatically upon retrieval and after - changing the data shape. - - """ - - self.set("status", "Updating constants on GPU using known constants") - self.updateState(State.CHANGING) - if self._offset_map is not None: - self.gpu_runner.load_constants(self._offset_map) - msg = "Done transferring known constant(s) to GPU" - self.log.INFO(msg) - self.set("status", msg) - - self.updateState(State.ON) + def _load_constant_to_gpu(self, constant_name, constant_data): + assert constant_name == "Offset" + self.gpu_runner.load_offset_map(constant_data) + if not self.get("corrections.available.offset"): + self.set("corrections.available.offset", True) + self.set("corrections.enabled.offset", True) + self.set("corrections.preview.offset", True) diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py index c445ae6a2358cb4812495a67bd2703a8dbb00201..56a36ac1560e938e11a7205cf33dc8a27e840c65 100644 --- a/src/calng/agipd_gpu.py +++ b/src/calng/agipd_gpu.py @@ -7,6 +7,7 @@ from . import base_gpu, utils class CorrectionFlags(enum.IntFlag): + NONE = 0 THRESHOLD = 1 OFFSET = 2 BLSHIFT = 4 @@ -23,8 +24,8 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): pixels_x, pixels_y, memory_cells, + constant_memory_cells, output_transpose=(1, 2, 0), # default: memorycells-fast - constant_memory_cells=None, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): @@ -34,8 +35,8 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): pixels_x, pixels_y, memory_cells, - output_transpose, constant_memory_cells, + output_transpose, input_data_dtype, output_data_dtype, ) @@ -84,10 +85,14 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): pc_med_m = slopes_pc_map[3] pc_med_I = slopes_pc_map[4] frac_high_med = pc_high_m / pc_med_m + # TODO: handle NaN somehow? + if np.isnan(frac_high_med).any(): + ... # TODO: verify formula md_additional_offset = (pc_high_I - pc_med_I * frac_high_med).astype(np.float32) rel_gain_map = np.ones( - (3, self.constant_memory_cells, self.pixels_y, self.pixels_x), dtype=np.float32 + (3, self.constant_memory_cells, self.pixels_y, self.pixels_x), + dtype=np.float32, ) rel_gain_map[1] = rel_gain_map[0] * frac_high_med rel_gain_map[2] = rel_gain_map[1] * 4.48 @@ -100,22 +105,9 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): if slopes_ff_map.shape[2] == 2: # old format, is per pixel only raise NotImplementedError("Old slopes FF map format") - self.rel_gain_xray_map_gpu.set( - np.transpose(slopes_ff_map).astype(np.float32) - ) + self.rel_gain_xray_map_gpu.set(np.transpose(slopes_ff_map).astype(np.float32)) def correct(self, flags): - """Apply corrections to data (must load constant, data, and cell_table first) - - Applies corrections to input data and casts to desired output dtype. - Parameter cell_table allows out of order or non-contiguous memory cells - in input data. Both input ndarrays are assumed to be on GPU already, - preferably wrapped in GPU arrays (cupy array). - - Will return string encoded handle to shared memory output buffer and - (view of) said buffer as an ndarray. Keep in mind that the output - buffers will get overwritten eventually (circular buffer). - """ if flags & CorrectionFlags.BLSHIFT: raise NotImplementedError("Baseline shift not implemented yet") self.correction_kernel( @@ -150,4 +142,3 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): ) self.source_module = cupy.RawModule(code=kernel_source) self.correction_kernel = self.source_module.get_function("correct") - self.casting_kernel = self.source_module.get_function("only_cast") diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index dfdc83e65e7a7e77ddecd04256d7056d5a75a47b..6bdc4d377d265b4b302bd550876f0f81f0c6c5b0 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -31,11 +31,11 @@ from . import shmem_utils, utils class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): + _correction_flag_class = None # subclass must override this with some enum class _dict_cache_slots = { - "applyCorrection", "doAnything", "dataFormat.memoryCells", - "dataFormat.memoryCellsCorrection", + "dataFormat.constantMemoryCells", "dataFormat.pixelsX", "dataFormat.pixelsY", "preview.enable", @@ -46,6 +46,17 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): "state", } + def _load_constant_to_gpu(constant_name, constant_data): + raise NotImplementedError() + + @property + def input_data_shape(self): + raise NotImplementedError() + + @property + def output_data_shape(self): + raise NotImplementedError() + @staticmethod def expectedParameters(expected): ( @@ -61,20 +72,6 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): .defaultValue(True) .reconfigurable() .commit(), - BOOL_ELEMENT(expected) - .key("applyCorrection") - .displayedName("Enable correction(s)") - .description( - "Toggle whether not correction(s) are applied to image data. If " - "false, this device still reshapes data to output shape, applies the " - "pulse filter, and casts to output dtype. Useful if constants are " - "missing / bad, or if data is sent to application doing its own " - "correction." - ) - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), INPUT_CHANNEL(expected).key("dataInput").commit(), # note: output schema not set, will be updated to match data later OUTPUT_CHANNEL(expected).key("dataOutput").commit(), @@ -169,15 +166,18 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): .defaultValue("pixels-fast") .commit(), UINT32_ELEMENT(expected) - .key("dataFormat.memoryCellsCorrection") - .displayedName("(Debug) Memory cells in correction map") + .key("dataFormat.constantMemoryCells") + .displayedName("Memory cells in correction map") .description( - "Full number of memory cells in currently loaded correction map. " + "Number of memory cells in loaded or expected constants. " "May exceed memory cell number in input if veto is on. " - "This value just displayed for debugging." + "This value should be updated (will be done by manager) before " + "requesting constants with different number of cells. " + "Will in future versions get better integrated into constant loading." ) - .readOnly() - .initialValue(0) + .assignmentOptional() + .defaultValue(0) + .reconfigurable() .commit(), VECTOR_UINT32_ELEMENT(expected) .key("dataFormat.inputDataShape") @@ -349,6 +349,48 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): .commit(), ) + ( + NODE_ELEMENT(expected) + .key("corrections") + .displayedName("Corrections to apply") + .commit(), + NODE_ELEMENT(expected) + .key("corrections.available") + .displayedName("Available given loaded constants") + .description( + "Corrections typically require some constants to be loaded. These " + "flags indicate which corrections can be done, given the constants " + "found and loaded so far. Corrections - for dataOutput or preview - " + "only happen if selected AND available." + ) + .commit(), + NODE_ELEMENT(expected) + .key("corrections.enabled") + .displayedName("Enabled (if available)") + .description("Corrections applied to data for dataOutput (main output)") + .commit(), + NODE_ELEMENT(expected) + .key("corrections.preview") + .displayedName("Preview (if available)") + .description("Corrections applied for corrected preview output") + .commit(), + BOOL_ELEMENT(expected) + .key("corrections.disableAll") + .displayedName("Disable corrections for dataOutput") + .description( + "Toggle whether not correction(s) are applied to image data. If " + "false, this device still reshapes data to output shape, applies the " + "pulse filter, and casts to output dtype. Useful if constants are " + "missing / bad, or if data is sent to application doing its own " + "correction. Preview is still corrected based on selection of " + "corrections independently of this." + ) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + ) + def __init__(self, config): self._dict_cache = {k: config.get(k) for k in self._dict_cache_slots} super().__init__(config) @@ -361,8 +403,13 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype")) self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) + self._correction_flag_enabled = self._correction_flag_class.NONE + self._correction_flag_preview = self._correction_flag_class.NONE + self._cached_constants = {} + self._shmem_buffer = None self._has_set_output_schema = False + self._has_updated_shapes = False self._rate_tracker = calibrationBase.utils.UpdateRate( interval=config.get("performance.rateBufferSpan") ) @@ -401,6 +448,26 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): if path in self._dict_cache_slots: self._dict_cache[path] = config.get(path) + # TODO: pulse filter (after reimplementing) + if any( + config.has(shape_param) + for shape_param in ( + "dataFormat.pixelsX", + "dataFormat.pixelsY", + "dataFormat.memoryCells", + "dataFormat.constantMemoryCells", + ) + ): + # will make postReconfigure handle shape update after merging schema + self._has_updated_shapes = False + + def postReconfigure(self): + self.log.INFO("postReconfigure") + if not self._has_updated_shapes: + self._update_shapes() + # TODO: only call this if they are changed (is cheap, though) + self._update_correction_flags() + def get(self, key): if key in self._dict_cache_slots: return self._dict_cache.get(key) @@ -414,6 +481,23 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): self._dict_cache[key] = value super().set(*args) + def requestConstant(self, name, mostRecent=False, tryRemote=True): + """constantLoaded hook would have gotten called without naming constant, so here + we go. Ugly hooking it.""" + # TODO: clear from device, too + # TODO: update correction capability flag + if name in self._cached_constants: + del self._cached_constants[name] + super().requestConstant(name, mostRecent, tryRemote) + constant = self.getConstant(name) + if constant is None: + return + + # TODO: remaining constants, DRY + self._cached_constants[name] = constant + self._load_constant_to_gpu(name, constant) + self._update_correction_flags() + def _write_output(self, data, old_metadata): metadata = ChannelMetaData( old_metadata.get("source"), @@ -452,6 +536,25 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): channel.write(preview_hash, metadata, False) channel.update() + def _update_correction_flags(self): + available = self._correction_flag_class.NONE + enabled = self._correction_flag_class.NONE + preview = self._correction_flag_class.NONE + for slot_name, flag in self._correction_slot_names: + if self.get(f"corrections.available.{slot_name}"): + available |= flag + if self.get(f"corrections.enabled.{slot_name}"): + enabled |= flag + if self.get(f"corrections.preview.{slot_name}"): + preview |= flag + enabled &= available + preview &= available + if self.get("corrections.disableAll"): + enabled &= self._correction_flag_class.NONE + self._correction_flag_enabled = enabled + self._correction_flag_preview = preview + self.log.INFO(str(enabled)) + def _update_output_schema(self, data): """Updates the schema of dataOutput based on parameter data (a Hash) @@ -469,9 +572,45 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): .dataSchema(data_schema) .commit() ) - self.updateSchema(my_schema_update) + self.appendSchema(my_schema_update) self._has_set_output_schema = True + def _update_shapes(self): + """(Re)initialize buffers according to expected data shapes""" + self.log.INFO("Updating shapes") + + # reflect the axis reordering in the expected output shape + self.set("dataFormat.inputDataShape", list(self.input_data_shape)) + self.set("dataFormat.outputDataShape", list(self.output_data_shape)) + + if self._shmem_buffer is None: + shmem_buffer_name = self.getInstanceId() + ":dataOutput" + memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 + self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") + self._shmem_buffer = shmem_utils.ShmemCircularBuffer( + memory_budget, + self.output_data_shape, + self.output_data_dtype, + shmem_buffer_name, + ) + else: + self._shmem_buffer.change_shape(self.output_data_shape) + + self.gpu_runner = self._gpu_runner_class( + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + self.get("dataFormat.memoryCells"), + self.get("dataFormat.constantMemoryCells"), + output_transpose=self._output_transpose, + input_data_dtype=self.input_data_dtype, + output_data_dtype=self.output_data_dtype, + ) + + for constant_name, constant_data in self._cached_constants.items(): + self._load_constant_to_gpu(constant_name, constant_data) + + self._has_updated_shapes = True + def _reset_state_from_processing(self): if self.get("state") is State.PROCESSING: self.updateState(State.ON) diff --git a/src/calng/base_gpu.py b/src/calng/base_gpu.py index 486043499683a6e7b5a352bb3f44286d759d3aa2..5347fc2a1b730c10cd3fd5663e9c16a73342a603 100644 --- a/src/calng/base_gpu.py +++ b/src/calng/base_gpu.py @@ -9,13 +9,14 @@ from . import utils class BaseGpuRunner: - """Class to handle instantiation and execution of CUDA kernels on trains + """Class to handle GPU buffers and execution of CUDA kernels on image data - All GPU buffers are kept within this class. This generally means that you will - want to load data into it and then do something. Typical usage in correct order: + All GPU buffers are kept within this class and it is intentionally very stateful. + This generally means that you will want to load data into it and then do something. + Typical usage in correct order: 1. instantiate - 2. load_constants + 2. load constants 3. load_data 4. load_cell_table 5. correct @@ -24,9 +25,9 @@ class BaseGpuRunner: repeat from 2. or 3. - In case no constants are available / correction is not desired, can skip 3 and 4 - and use only_cast in step 5 instead of correct (taking care to call - compute_preview with parameters set accordingly). + In case no constants are available / correction is not desired, can skip 3 and 4 and + pass CorrectionFlags.NONE to correct(...). Generally, user must handle which + correction steps are appropriate given the constants loaded so far. """ def __init__( @@ -34,8 +35,8 @@ class BaseGpuRunner: pixels_x, pixels_y, memory_cells, + constant_memory_cells, output_transpose=(2, 1, 0), # default: memorycells-fast - constant_memory_cells=None, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): @@ -48,7 +49,8 @@ class BaseGpuRunner: self.pixels_y = pixels_y self.memory_cells = memory_cells self.output_transpose = output_transpose - if constant_memory_cells is None: + if constant_memory_cells == 0: + # if not set, guess same as input; may save one recompilation self.constant_memory_cells = memory_cells else: self.constant_memory_cells = constant_memory_cells @@ -62,10 +64,12 @@ class BaseGpuRunner: self._init_kernels() - # reuse output arrays + # reuse buffers for input / output self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=np.uint16) self.input_data_gpu = cupy.empty(self.input_shape, dtype=input_data_dtype) - self.processed_data_gpu = cupy.empty(self.processed_shape, dtype=output_data_dtype) + self.processed_data_gpu = cupy.empty( + self.processed_shape, dtype=output_data_dtype + ) self.reshaped_data_gpu = cupy.empty(self.output_shape, dtype=output_data_dtype) self.preview_raw = cupyx.empty_pinned(self.preview_shape, dtype=np.float32) self.preview_corrected = cupyx.empty_pinned( @@ -73,25 +77,32 @@ class BaseGpuRunner: ) # functions to get data from respective buffers to cell, x, y shape for preview computation + # TODO: handle shape juggling programmatically, removing need for these two helpers def _preview_preprocess_raw(): + """Should return view of self.input_data_gpu with shape (cell, x, y)""" raise NotImplementedError() def _preview_preprocess_corr(): + """Should return view of self.processed_data_gpu with shape (cell, x, y)""" raise NotImplementedError() - def only_cast(self): - """Like correct without the correction + def correct(self, flags): + """Correct (already loaded) image data according to flags + + Subclass must define this method. It should assume that image data, cell table, + and other data (including constants) has already been loaded. It should + probably run some GPU kernel and output should go into self.processed_data_gpu. + + Keep in mind that user only gets output from compute_preview or reshape + (either of these should come after correct). + + The submodules providing subclasses should have some IntFlag enums defining + which flags are available to pass along to the kernel. A zero flag should allow + the kernel to do no actual correction - but still copy the data between buffers + and cast it to desired output type. - This currently means just casting to output dtype. """ - self.casting_kernel( - self.full_grid, - self.full_block, - ( - self.input_data_gpu, - self.processed_data_gpu, - ), - ) + raise NotImplementedError() def reshape(self, out=None): """Move axes to desired output order @@ -119,19 +130,18 @@ class BaseGpuRunner: def load_cell_table(self, cell_table): self.cell_table_gpu.set(cell_table) - def compute_preview(self, preview_index, have_corrected=True, can_correct=True): + def compute_preview(self, preview_index): """Generate single slice or reduction preview of raw and corrected data - Special values of preview_index are -1 for max, -2 for mean, -3 for - sum, and -4 for stdev (across cells). + Special values of preview_index are -1 for max (select max integrated intensity + frame), -2 for mean, -3 for sum, and -4 for stdev (across cells). Note that preview_index is taken from data without checking cell table. Caller has to figure out which index along memory cell dimension they - actually want to preview. + actually want to preview in case it needs to be a specific pulse. - Can reuse data from corrected output buffer with have_corrected parameter. - Note that preview requires relevant data to be loaded (raw data for raw - preview, correction map and cell table in addition for corrected preview). + Will reuse data from corrected output buffer. Therefore, correct(...) must have + been called with the appropriate flags before compute_preview(...). """ if preview_index < -4: @@ -139,10 +149,6 @@ class BaseGpuRunner: elif preview_index >= self.memory_cells: raise ValueError(f"Memory cell index {preview_index} out of range") - if (not have_corrected) and can_correct: - self.correct() - # if not have_corrected and not can_correct, assume only_cast already done - # TODO: enum around reduction type for (preprocces, output_buffer) in ( (self._preview_preprocess_raw, self.preview_raw), @@ -151,23 +157,17 @@ class BaseGpuRunner: image_data = preprocces() if preview_index >= 0: # TODO: change axis order when moving reshape to after correction - image_data[preview_index].astype(np.float32).get( - out=output_buffer - ) + image_data[preview_index].astype(np.float32).get(out=output_buffer) elif preview_index == -1: # TODO: select argmax independently for raw and corrected? # TODO: send frame sums somewhere to compute global max frame max_index = cupy.argmax( cupy.sum(image_data, axis=(1, 2), dtype=cupy.float32) ) - image_data[max_index].astype(np.float32).get( - out=output_buffer - ) + image_data[max_index].astype(np.float32).get(out=output_buffer) elif preview_index in (-2, -3, -4): stat_fun = {-2: cupy.mean, -3: cupy.sum, -4: cupy.std}[preview_index] - stat_fun(image_data, axis=0, dtype=cupy.float32).get( - out=output_buffer - ) + stat_fun(image_data, axis=0, dtype=cupy.float32).get(out=output_buffer) return self.preview_raw, self.preview_corrected def update_block_size(self, full_block, target_shape=None): diff --git a/src/calng/dssc_gpu.py b/src/calng/dssc_gpu.py index cd7e2d80595fe6c252cc497b1b4c86d6f8aaa007..8c1995e39733f462c36f7a663da3b1ec2dbe0d16 100644 --- a/src/calng/dssc_gpu.py +++ b/src/calng/dssc_gpu.py @@ -1,9 +1,16 @@ +import enum + import cupy import numpy as np from . import base_gpu, utils +class CorrectionFlags(enum.IntFlag): + NONE = 0 + OFFSET = 1 + + class DsscGpuRunner(base_gpu.BaseGpuRunner): _kernel_source_filename = "dssc_gpu_kernels.cpp" @@ -12,8 +19,8 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner): pixels_x, pixels_y, memory_cells, + constant_memory_cells, output_transpose=(2, 1, 0), # default: memorycells-fast - constant_memory_cells=None, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): @@ -23,12 +30,13 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner): pixels_x, pixels_y, memory_cells, - output_transpose, constant_memory_cells, + output_transpose, input_data_dtype, output_data_dtype, ) - self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells) + + self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x) self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) self._init_kernels() @@ -43,33 +51,22 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner): def _preview_preprocess_corr(self): return cupy.transpose(self.processed_data_gpu, (0, 2, 1)) - def load_constants(self, offset_map): - constant_memory_cells = offset_map.shape[-1] - if constant_memory_cells != self.constant_memory_cells: - self.constant_memory_cells = constant_memory_cells - self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells) - self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) - self._init_kernels() + def load_offset_map(self, offset_map): + # can have an extra dimension for some reason + if len(offset_map.shape) == 4: # old format (see offsetcorrection_dssc.py)? + offset_map = offset_map[..., 0] + # shape (now): x, y, memory cell + offset_map = np.transpose(offset_map).astype(np.float32) self.offset_map_gpu.set(offset_map) - def correct(self): - """Apply corrections to data (must load constant, data, and cell_table first) - - Applies corrections to input data and casts to desired output dtype. - Parameter cell_table allows out of order or non-contiguous memory cells - in input data. Both input ndarrays are assumed to be on GPU already, - preferably wrapped in GPU arrays (cupy array). - - Will return string encoded handle to shared memory output buffer and - (view of) said buffer as an ndarray. Keep in mind that the output - buffers will get overwritten eventually (circular buffer). - """ + def correct(self, flags): self.correction_kernel( self.full_grid, self.full_block, ( self.input_data_gpu, self.cell_table_gpu, + np.uint8(flags), self.offset_map_gpu, self.processed_data_gpu, ), @@ -84,8 +81,8 @@ class DsscGpuRunner(base_gpu.BaseGpuRunner): "constant_memory_cells": self.constant_memory_cells, "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), + "corr_enum": utils.enum_to_c_template(CorrectionFlags), } ) self.source_module = cupy.RawModule(code=kernel_source) self.correction_kernel = self.source_module.get_function("correct") - self.casting_kernel = self.source_module.get_function("only_cast") diff --git a/src/calng/dssc_gpu_kernels.cpp b/src/calng/dssc_gpu_kernels.cpp index 2412a86ac89d3eb07335f59d2630793e2f04b1d7..b82159c0885ca6026fddc7e813f96b0c8eb9c1b4 100644 --- a/src/calng/dssc_gpu_kernels.cpp +++ b/src/calng/dssc_gpu_kernels.cpp @@ -1,5 +1,7 @@ #include <cuda_fp16.h> +{{corr_enum}} + extern "C" { /* Perform correction: offset @@ -9,6 +11,7 @@ extern "C" { */ __global__ void correct(const {{input_data_dtype}}* data, const unsigned short* cell_table, + const unsigned char corr_flags, const float* offset_map, {{output_data_dtype}}* output) { const size_t X = {{pixels_x}}; @@ -31,13 +34,16 @@ extern "C" { const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x; const float raw = (float)data[data_index]; - const size_t map_stride_cell = 1; - const size_t map_stride_y = map_memory_cells * map_stride_cell; - const size_t map_stride_x = Y * map_stride_y; + const size_t map_stride_x = 1; + const size_t map_stride_y = X * map_stride_x; + const size_t map_stride_cell = Y * map_stride_y; const size_t map_cell = cell_table[memory_cell]; if (map_cell < map_memory_cells) { const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x; - const float corrected = raw - offset_map[map_index]; + float corrected = raw; + if (corr_flags & OFFSET) { + corrected -= offset_map[map_index]; + } {% if output_data_dtype == "half" %} output[data_index] = __float2half(corrected); {% else %} @@ -51,34 +57,4 @@ extern "C" { {% endif %} } } - - /* - Same as correction, except don't do any correction - */ - __global__ void only_cast(const {{input_data_dtype}}* data, - {{output_data_dtype}}* output) { - const size_t X = {{pixels_x}}; - const size_t Y = {{pixels_y}}; - const size_t memory_cells = {{data_memory_cells}}; - - const size_t data_stride_x = 1; - const size_t data_stride_y = X * data_stride_x; - const size_t data_stride_cell = Y * data_stride_y; - - const size_t cell = blockIdx.x * blockDim.x + threadIdx.x; - const size_t y = blockIdx.y * blockDim.y + threadIdx.y; - const size_t x = blockIdx.z * blockDim.z + threadIdx.z; - - if (cell >= memory_cells || y >= Y || x >= X) { - return; - } - - const size_t data_index = cell * data_stride_cell + y * data_stride_y + x * data_stride_x; - const float raw = (float)data[data_index]; - {% if output_data_dtype == "half" %} - output[data_index] = __float2half(raw); - {% else %} - output[data_index] = ({{output_data_dtype}})raw; - {% endif %} - } }