Skip to content
Snippets Groups Projects
Commit 53cca8ca authored by David Hammer's avatar David Hammer
Browse files

Hotfixes applied during beamtime at SQS

parent d70a481b
No related branches found
No related tags found
3 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices,!1WIP: Add DSSC device
...@@ -24,6 +24,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -24,6 +24,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
"applyCorrection", "applyCorrection",
"doAnything", "doAnything",
"dataFormat.memoryCells", "dataFormat.memoryCells",
"dataFormat.memoryCellsCorrection",
"dataFormat.pixelsX", "dataFormat.pixelsX",
"dataFormat.pixelsY", "dataFormat.pixelsY",
"preview.enable", "preview.enable",
...@@ -131,6 +132,15 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -131,6 +132,15 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
).description( ).description(
"Full number of memory cells in incoming data" "Full number of memory cells in incoming data"
).assignmentMandatory().commit() ).assignmentMandatory().commit()
bound.UINT32_ELEMENT(expected).key(
"dataFormat.memoryCellsCorrection"
).displayedName("(Debug) Memory cells in correction map").description(
"Full number of memory cells in currently loaded correction map. "
"May exceed memory cell number in input if veto is on. "
"This value just displayed for debugging."
).readOnly().initialValue(
0
).commit()
bound.VECTOR_UINT32_ELEMENT(expected).key( bound.VECTOR_UINT32_ELEMENT(expected).key(
"dataFormat.inputDataShape" "dataFormat.inputDataShape"
).displayedName("Input data shape").description( ).displayedName("Input data shape").description(
...@@ -186,9 +196,6 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -186,9 +196,6 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
"-4: stdev\n" "-4: stdev\n"
"Max means selecting the pulse with the maximum integrated value. " "Max means selecting the pulse with the maximum integrated value. "
"The others are computed across all filtered pulses in the train." "The others are computed across all filtered pulses in the train."
"Note that index slicing (≥ 0 case) currently does not take "
"image.pulseId into account, so certain pulse filters may yield "
"unexpected preview pulse shown."
).assignmentOptional().defaultValue( ).assignmentOptional().defaultValue(
0 0
).reconfigurable().commit() ).reconfigurable().commit()
...@@ -318,6 +325,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -318,6 +325,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.input_data_dtype = getattr(np, config.get("dataFormat.inputImageDtype")) self.input_data_dtype = getattr(np, config.get("dataFormat.inputImageDtype"))
self.output_data_dtype = getattr(np, config.get("dataFormat.outputImageDtype")) self.output_data_dtype = getattr(np, config.get("dataFormat.outputImageDtype"))
self._offset_map = None
self._update_pulse_filter(config.get("pulseFilter")) self._update_pulse_filter(config.get("pulseFilter"))
self._update_shapes( self._update_shapes(
config.get("dataFormat.pixelsX"), config.get("dataFormat.pixelsX"),
...@@ -373,6 +381,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -373,6 +381,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
# apply new pulse filter # apply new pulse filter
self._update_pulse_filter(config.get("pulseFilter")) self._update_pulse_filter(config.get("pulseFilter"))
# but existing shapes (not reconfigurable) # but existing shapes (not reconfigurable)
# TODO: avoid double compilation here if constants are loaded
self._update_shapes( self._update_shapes(
self.get("dataFormat.pixelsX"), self.get("dataFormat.pixelsX"),
self.get("dataFormat.pixelsY"), self.get("dataFormat.pixelsY"),
...@@ -425,10 +434,10 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -425,10 +434,10 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
train_id = metadata.getAttribute("timestamp", "tid") train_id = metadata.getAttribute("timestamp", "tid")
cell_table = np.squeeze(data.get("image.cellId")) cell_table = np.squeeze(data.get("image.cellId"))
assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray" assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray"
if not len(cell_table.shape) == 1: if len(cell_table.shape) == 0:
self.set( msg = "cellId had 0 dimensions. DAQ may not be sending data."
"status", f"Failed to process, cell table had shape {cell_table.shape}" self.set("status", msg)
) self.log.WARN(msg)
return return
# original shape: 400, 1, 128, 512 (memory cells, something, y, x) # original shape: 400, 1, 128, 512 (memory cells, something, y, x)
# TODO: consider making paths configurable # TODO: consider making paths configurable
...@@ -440,6 +449,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -440,6 +449,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
# TODO: truncate if > 800 # TODO: truncate if > 800
self.set("dataFormat.memoryCells", image_data.shape[0]) self.set("dataFormat.memoryCells", image_data.shape[0])
with self._buffer_lock: with self._buffer_lock:
self._update_pulse_filter(self.get("pulseFilter"))
self._update_shapes( self._update_shapes(
self.get("dataFormat.pixelsX"), self.get("dataFormat.pixelsX"),
self.get("dataFormat.pixelsY"), self.get("dataFormat.pixelsY"),
...@@ -452,6 +462,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -452,6 +462,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
do_generate_preview = train_id % self.get( do_generate_preview = train_id % self.get(
"preview.trainIdModulo" "preview.trainIdModulo"
) == 0 and self.get("preview.enable") ) == 0 and self.get("preview.enable")
do_apply_correction = self.get("applyCorrection")
if not self.get("state") is State.PROCESSING: if not self.get("state") is State.PROCESSING:
self.updateState(State.PROCESSING) self.updateState(State.PROCESSING)
...@@ -468,6 +479,23 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -468,6 +479,23 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
cell_table = cell_table[self.pulse_filter] cell_table = cell_table[self.pulse_filter]
pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter] pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter]
cell_table_max = np.max(cell_table)
correction_cell_num = self.get("dataFormat.memoryCellsCorrection")
if do_apply_correction:
if correction_cell_num == 0:
msg = "No constant loaded, correction will not be applied."
self.log.WARN(msg)
self.set("status", msg)
do_apply_correction = False
elif cell_table_max >= correction_cell_num:
msg = (
f"Max cell ID ({cell_table_max}) exceeds range for loaded "
f"constant (has {correction_cell_num} cells). Some frames "
"will not be corrected."
)
self.log.WARN(msg)
self.set("status", msg)
with gpu_utils.GPUContextContext(self.gpu_context): with gpu_utils.GPUContextContext(self.gpu_context):
self.gpu_buffer_cell_table.set(cell_table) self.gpu_buffer_cell_table.set(cell_table)
self.gpu_buffer_input_image_data.set(image_data) self.gpu_buffer_input_image_data.set(image_data)
...@@ -475,7 +503,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -475,7 +503,7 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.gpu_buffer_input_image_data, self.gpu_buffer_input_image_data,
self.gpu_buffer_reshaped_image_data, self.gpu_buffer_reshaped_image_data,
) )
if self.get("applyCorrection"): if do_apply_correction:
buffer_handle, result = self.pipeline.correct( buffer_handle, result = self.pipeline.correct(
self.gpu_buffer_reshaped_image_data, self.gpu_buffer_reshaped_image_data,
self.gpu_buffer_cell_table, self.gpu_buffer_cell_table,
...@@ -485,9 +513,27 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -485,9 +513,27 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.gpu_buffer_reshaped_image_data self.gpu_buffer_reshaped_image_data
) )
if do_generate_preview: if do_generate_preview:
preview_slice_index = self.get("preview.pulse")
if preview_slice_index >= 0:
# look at pulse_table to find which index this pulse ID is in
pulse_id_found = np.where(pulse_table == preview_slice_index)[0]
if len(pulse_id_found) == 0:
pulse_found_instead = pulse_table[0]
msg = (
f"Pulse {preview_slice_index} not found in "
f"image.pulseId, arbitrary pulse "
f"{pulse_found_instead} will be shown."
)
preview_slice_index = 0
self.log.WARN(msg)
self.set("status", msg)
else:
preview_slice_index = pulse_id_found[0]
preview_raw, preview_corrected = self.pipeline.compute_preview( preview_raw, preview_corrected = self.pipeline.compute_preview(
self.gpu_buffer_reshaped_image_data, self.gpu_buffer_reshaped_image_data,
self.get("preview.pulse"), preview_slice_index,
reuse_corrected=do_apply_correction,
cell_table=self.gpu_buffer_cell_table,
) )
data.set("image.data", buffer_handle) data.set("image.data", buffer_handle)
...@@ -579,6 +625,44 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -579,6 +625,44 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
""" """
offset_map = self.getConstant("Offset")
input_memory_cells = self.get("dataFormat.memoryCells")
if offset_map is None:
msg = (
"Warning: Did not find offset constant, offset correction "
"will not be applied"
)
self.set("status", msg)
self.log.WARN(msg)
self._offset_map = None
elif len(offset_map.shape) not in (3, 4):
msg = (
f"Offset map had unexpected shape {offset_map.shape}, "
"offset correction will not be applied"
)
self.set("status", msg)
self.log.WARN(msg)
else:
self.log.INFO(f"Offset map loaded has shape {offset_map.shape}")
if len(offset_map.shape) == 4: # old format (see offsetcorrection_dssc.py)?
offset_map = offset_map[..., 0]
constant_memory_cells = offset_map.shape[-1]
if input_memory_cells > constant_memory_cells:
msg = (
f"Warning: Memory cells in input {input_memory_cells} > "
f"memory cells in constant {constant_memory_cells}, some "
"frames may not get correction applied."
)
self.set("status", msg)
self.log.WARN(msg)
self._offset_map = offset_map.astype(np.float32)
msg = f"Offset map with shape {self._offset_map.shape} ready to load to GPU"
self.set("status", msg)
self.log.INFO(msg)
if constant_memory_cells != self.get("dataFormat.memoryCellsCorrection"):
self.log.INFO("Will first have to update buffers on GPU")
self.set("dataFormat.memoryCellsCorrection", constant_memory_cells)
self._update_maps_on_gpu() self._update_maps_on_gpu()
def registerManager(self, instance_id): def registerManager(self, instance_id):
...@@ -692,37 +776,12 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -692,37 +776,12 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.set("status", "Updating constants on GPU using known constants") self.set("status", "Updating constants on GPU using known constants")
self.updateState(State.CHANGING) self.updateState(State.CHANGING)
if self._offset_map is not None:
offset_map = self.getConstant("Offset") with gpu_utils.GPUContextContext(self.gpu_context):
memory_cells = self.get("dataFormat.memoryCells") self.pipeline.load_constants(self._offset_map)
if offset_map is None: msg = "Done transferring known constant(s) to GPU"
msg = f"Warning: Did not find offset constant, offset correction will not be applied" self.log.INFO(msg)
self.set("status", msg) self.set("status", msg)
self.log.WARN(msg)
else:
if len(offset_map.shape) in (3, 4):
self.log.INFO(f"Offset map known has shape {offset_map.shape}")
# this is from offsetcorrection_dssc.py
if len(offset_map.shape) == 4: # old format?
offset_map = np.squeeze(offset_map[..., 0])
constant_memory_cells = offset_map.shape[-1]
if memory_cells > constant_memory_cells:
msg = (
f"Warning: Memory cells in input ({memory_cells}) exceeded memory cells in constant ({constant_memory_cells}), offset correction will not be applied",
)
self.set("status", msg)
self.log.WARN(msg)
else:
offset_map = offset_map[..., :memory_cells, :].astype(np.float32)
with gpu_utils.GPUContextContext(self.gpu_context):
self.pipeline.offset_map.set(offset_map)
msg = "Done transferring known constant(s) to GPU"
self.set("status", msg)
self.log.INFO(msg)
else:
msg = f"Offset map had unexpected shape {offset_map.shape}, offset correction will not be applied"
self.set("status", msg)
self.log.WARN(msg)
self.updateState(State.ON) self.updateState(State.ON)
...@@ -736,11 +795,13 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -736,11 +795,13 @@ class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self._rate_update_timer.delay() self._rate_update_timer.delay()
return return
self._buffered_status_update.set("performance.rate", self._rate_tracker.rate()) self._buffered_status_update.set("performance.rate", self._rate_tracker.rate())
theoretical_rate = 1000 / self._buffered_status_update.get( last_processing = self._buffered_status_update.get(
"performance.lastProcessingDuration" "performance.lastProcessingDuration"
) )
self._buffered_status_update.set( if last_processing > 0:
"performance.theoreticalRate", theoretical_rate theoretical_rate = 1000 / last_processing
) self._buffered_status_update.set(
"performance.theoreticalRate", theoretical_rate
)
self.set(self._buffered_status_update) self.set(self._buffered_status_update)
self._rate_update_timer.delay() self._rate_update_timer.delay()
...@@ -34,43 +34,18 @@ class PyCudaPipeline: ...@@ -34,43 +34,18 @@ class PyCudaPipeline:
self.pixels_x = pixels_x self.pixels_x = pixels_x
self.pixels_y = pixels_y self.pixels_y = pixels_y
self.memory_cells = memory_cells self.memory_cells = memory_cells
self.constant_memory_cells = 0
self.pulse_filter = pulse_filter self.pulse_filter = pulse_filter
self.output_shape = (self.pixels_x, self.pixels_y, self.pulse_filter.size) self.output_shape = (self.pixels_x, self.pixels_y, self.pulse_filter.size)
self.map_shape = (self.pixels_x, self.pixels_y, self.memory_cells) self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
# preview will only be single memory cell # preview will only be single memory cell
self.preview_shape = self.output_shape[:-1] self.preview_shape = self.output_shape[:-1]
self.input_data_dtype = input_data_dtype
self.output_data_dtype = output_data_dtype
kernel_source = self._kernel_template.render( self._init_kernels()
{
"pixels_x": self.pixels_x,
"pixels_y": self.pixels_y,
"memory_cells": self.memory_cells,
"input_data_dtype": utils.numpy_dtype_to_c_type_str[input_data_dtype],
"output_data_dtype": utils.numpy_dtype_to_c_type_str[output_data_dtype],
"pulse_filter": pulse_filter,
}
)
self.source_module = pycuda.compiler.SourceModule(
kernel_source, no_extern_c=True
)
self.reshaping_kernel = self.source_module.get_function("reshape_4_3")
self.correction_kernel = self.source_module.get_function("correct")
self.casting_kernel = self.source_module.get_function("only_cast")
self.preview_slice_raw_kernel = self.source_module.get_function(
"cell_slice_preview_raw"
)
self.preview_slice_corrected_kernel = self.source_module.get_function(
"cell_slice_preview_corrected"
)
self.preview_stat_raw_kernel = self.source_module.get_function(
"cell_stat_preview_raw"
)
self.preview_stat_corrected_kernel = self.source_module.get_function(
"cell_stat_preview_corrected"
)
self.frame_sum_kernel = self.source_module.get_function("sum_frames")
self.offset_map = pycuda.gpuarray.zeros(self.map_shape, dtype=np.float32) self.offset_map = pycuda.gpuarray.empty(self.map_shape, dtype=np.float32)
# reuse output arrays # reuse output arrays
self.gpu_result = pycuda.gpuarray.empty( self.gpu_result = pycuda.gpuarray.empty(
...@@ -105,6 +80,51 @@ class PyCudaPipeline: ...@@ -105,6 +80,51 @@ class PyCudaPipeline:
self.update_block_size(full_block=(1, 1, 64), preview_block=(1, 64, 1)) self.update_block_size(full_block=(1, 1, 64), preview_block=(1, 64, 1))
def load_constants(self, offset_map_host):
constant_memory_cells = offset_map_host.shape[-1]
if constant_memory_cells != self.constant_memory_cells:
self.constant_memory_cells = constant_memory_cells
self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
self.offset_map = pycuda.gpuarray.empty(self.map_shape, dtype=np.float32)
self._init_kernels()
self.offset_map.set(offset_map_host)
def _init_kernels(self):
kernel_source = self._kernel_template.render(
{
"pixels_x": self.pixels_x,
"pixels_y": self.pixels_y,
"memory_cells": self.memory_cells,
"constant_memory_cells": self.constant_memory_cells,
"input_data_dtype": utils.numpy_dtype_to_c_type_str[
self.input_data_dtype
],
"output_data_dtype": utils.numpy_dtype_to_c_type_str[
self.output_data_dtype
],
"pulse_filter": self.pulse_filter,
}
)
self.source_module = pycuda.compiler.SourceModule(
kernel_source, no_extern_c=True
)
self.reshaping_kernel = self.source_module.get_function("reshape_4_3")
self.correction_kernel = self.source_module.get_function("correct")
self.casting_kernel = self.source_module.get_function("only_cast")
self.preview_slice_raw_kernel = self.source_module.get_function(
"cell_slice_preview_raw"
)
self.preview_slice_corrected_kernel = self.source_module.get_function(
"cell_slice_preview_corrected"
)
self.preview_stat_raw_kernel = self.source_module.get_function(
"cell_stat_preview_raw"
)
self.preview_stat_corrected_kernel = self.source_module.get_function(
"cell_stat_preview_corrected"
)
self.frame_sum_kernel = self.source_module.get_function("sum_frames")
def update_block_size(self, full_block=None, preview_block=None): def update_block_size(self, full_block=None, preview_block=None):
"""Execution is scheduled with 3d "blocks" of CUDA threads, tuning can """Execution is scheduled with 3d "blocks" of CUDA threads, tuning can
affect performance affect performance
...@@ -161,7 +181,6 @@ class PyCudaPipeline: ...@@ -161,7 +181,6 @@ class PyCudaPipeline:
Will return string encoded handle to shared memory output buffer and Will return string encoded handle to shared memory output buffer and
(view of) said buffer as an ndarray. Keep in mind that the output (view of) said buffer as an ndarray. Keep in mind that the output
buffers will get overwritten eventually (circular buffer). buffers will get overwritten eventually (circular buffer).
""" """
self.correction_kernel( self.correction_kernel(
data, data,
...@@ -201,17 +220,20 @@ class PyCudaPipeline: ...@@ -201,17 +220,20 @@ class PyCudaPipeline:
return handle, output_buffer return handle, output_buffer
def compute_preview( def compute_preview(
self, raw_data, cell_to_preview, has_just_corrected=True, verify=False self,
raw_data,
preview_index,
reuse_corrected=True,
cell_table=None,
): ):
"""Generate single slice or reduction preview of raw and corrected data """Generate single slice or reduction preview of raw and corrected data
Special values of cell_to_preview are -1 for max, -2 for mean, -3 for Special values of preview_index are -1 for max, -2 for mean, -3 for
sum, and -4 for stdev (across cells). sum, and -4 for stdev (across cells).
Note that cell_to_preview is taken from data without checking cell Note that preview_index is taken from data without checking cell table.
table, so if a pulse filter not contiguous from 0 has been applied Caller has to figure out which index along memory cell dimension they
first, the resulting cell will be offset. Cell table is only used to actually want to preview.
get the correct slice of the correction map.
raw_data should be a gpuarray raw_data should be a gpuarray
...@@ -220,35 +242,57 @@ class PyCudaPipeline: ...@@ -220,35 +242,57 @@ class PyCudaPipeline:
""" """
if cell_to_preview < -4: if preview_index < -4:
raise ValueError(f"No statistic with code {cell_to_preview} defined") raise ValueError(f"No statistic with code {preview_index} defined")
elif cell_to_preview >= self.memory_cells: elif preview_index >= self.memory_cells:
raise ValueError(f"Memory cell index {cell_to_preview} out of range") raise ValueError(f"Memory cell index {preview_index} out of range")
if not reuse_corrected:
# if we didn't already correct, need to do so to get corrected data in buffer
if self.offset_map.size == 0 or cell_table is None:
self.casting_kernel(
raw_data,
self.gpu_result,
block=self.full_block,
grid=self.full_grid,
)
if self.offset_map.size == 0:
print(
"Warning: no offset map loaded, corrected preview "
"will be not actually have correction applied."
)
if cell_table is None:
print(
"Warning: missing parameter cell_table for applying "
"correction for preview."
)
else:
self.correction_kernel(
raw_data,
cell_table,
self.offset_map,
self.gpu_result,
block=self.full_block,
grid=self.full_grid,
)
# TODO: lift this restriction.
assert has_just_corrected
# TODO: enum around reduction type # TODO: enum around reduction type
if cell_to_preview >= 0: if preview_index >= 0:
self.preview_slice_raw_kernel( self.preview_slice_raw_kernel(
raw_data, raw_data,
np.int16(cell_to_preview), np.int16(preview_index),
self.gpu_preview_raw, self.gpu_preview_raw,
block=self.preview_block, block=self.preview_block,
grid=self.preview_grid, grid=self.preview_grid,
) )
self.preview_slice_corrected_kernel( self.preview_slice_corrected_kernel(
self.gpu_result, self.gpu_result,
np.int16(cell_to_preview), np.int16(preview_index),
self.gpu_preview_corrected, self.gpu_preview_corrected,
block=self.preview_block, block=self.preview_block,
grid=self.preview_grid, grid=self.preview_grid,
) )
if verify: elif preview_index == -1:
assert np.allclose(
self.gpu_preview_raw.get(),
raw_data.get()[..., cell_to_preview],
)
elif cell_to_preview == -1:
# TODO: select argmax independently for raw and corrected? # TODO: select argmax independently for raw and corrected?
# TODO: send frame sums somewhere to compute global max frame # TODO: send frame sums somewhere to compute global max frame
self.frame_sum_kernel( self.frame_sum_kernel(
...@@ -272,38 +316,21 @@ class PyCudaPipeline: ...@@ -272,38 +316,21 @@ class PyCudaPipeline:
block=self.preview_block, block=self.preview_block,
grid=self.preview_grid, grid=self.preview_grid,
) )
if verify: elif preview_index in (-2, -3, -4):
assert np.allclose(
self.gpu_preview_raw.get(),
raw_data.get()[
...,
np.argmax(
np.sum(raw_data.get(), axis=(0, 1), dtype=np.float32)
),
],
)
elif cell_to_preview in (-2, -3, -4):
self.preview_stat_raw_kernel( self.preview_stat_raw_kernel(
raw_data, # this is input_data_dtype raw_data, # this is input_data_dtype
np.int16(cell_to_preview), np.int16(preview_index),
self.gpu_preview_raw, self.gpu_preview_raw,
block=self.preview_block, block=self.preview_block,
grid=self.preview_grid, grid=self.preview_grid,
) )
self.preview_stat_corrected_kernel( self.preview_stat_corrected_kernel(
self.gpu_result, # this is output_data_dtype self.gpu_result, # this is output_data_dtype
np.int16(cell_to_preview), np.int16(preview_index),
self.gpu_preview_corrected, self.gpu_preview_corrected,
block=self.preview_block, block=self.preview_block,
grid=self.preview_grid, grid=self.preview_grid,
) )
if verify:
assert np.allclose(
self.gpu_preview_raw.get(),
{-2: np.mean, -3: np.sum, -4: np.std}[cell_to_preview](
raw_data.get(), axis=2
),
)
self.gpu_preview_raw.get(ary=self.preview_raw) self.gpu_preview_raw.get(ary=self.preview_raw)
self.gpu_preview_corrected.get(ary=self.preview_corrected) self.gpu_preview_corrected.get(ary=self.preview_corrected)
return self.preview_raw, self.preview_corrected return self.preview_raw, self.preview_corrected
...@@ -55,8 +55,8 @@ extern "C" { ...@@ -55,8 +55,8 @@ extern "C" {
const size_t Y = {{pixels_y}}; const size_t Y = {{pixels_y}};
// reshaped and output data have pulse filter length memory cells dim // reshaped and output data have pulse filter length memory cells dim
const size_t filtered_memory_cells = {{pulse_filter|length}}; const size_t filtered_memory_cells = {{pulse_filter|length}};
// but correction map has "full size" // but correction map has some number which may even exceed input data's (due to veto pattern)
const size_t full_memory_cells = {{memory_cells}}; const size_t map_memory_cells = {{constant_memory_cells}};
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y; const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -71,20 +71,27 @@ extern "C" { ...@@ -71,20 +71,27 @@ extern "C" {
const size_t data_stride_1 = filtered_memory_cells * data_stride_2; const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1; const size_t data_stride_0 = Y * data_stride_1;
const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2; const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
const float raw = (float)data[data_index];
const size_t map_stride_2 = 1; const size_t map_stride_2 = 1;
const size_t map_stride_1 = full_memory_cells * map_stride_2; const size_t map_stride_1 = map_memory_cells * map_stride_2;
const size_t map_stride_0 = Y * map_stride_1; const size_t map_stride_0 = Y * map_stride_1;
const size_t map_cell = cell_table[k]; const size_t map_cell = cell_table[k];
const size_t map_index = i * map_stride_0 + j * map_stride_1 + map_cell * map_stride_2; if (map_cell < map_memory_cells) {
const size_t map_index = i * map_stride_0 + j * map_stride_1 + map_cell * map_stride_2;
const float raw = (float)data[data_index]; const float corrected = raw - offset_map[map_index];
const float corrected = raw - offset_map[map_index]; {% if output_data_dtype == "half" %}
{% if output_data_dtype == "half" %} output[data_index] = __float2half(corrected);
output[data_index] = __float2half(corrected); {% else %}
{% else %} output[data_index] = ({{output_data_dtype}})corrected;
output[data_index] = ({{output_data_dtype}})corrected; {% endif %}
{% endif %} } else {
{% if output_data_dtype == "half" %}
output[data_index] = __float2half(raw);
{% else %}
output[data_index] = ({{output_data_dtype}})raw;
{% endif %}
}
} }
/* /*
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment