Skip to content
Snippets Groups Projects
Commit fc3b3be2 authored by David Hammer's avatar David Hammer
Browse files

Fill in more mandatory methods

parent 1888eb3d
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!6Draft: add Jungfrau correction device
import enum import enum
import timeit
import cupy import cupy
import numpy as np import numpy as np
...@@ -7,15 +6,17 @@ from karabo.bound import ( ...@@ -7,15 +6,17 @@ from karabo.bound import (
DOUBLE_ELEMENT, DOUBLE_ELEMENT,
KARABO_CLASSINFO, KARABO_CLASSINFO,
OVERWRITE_ELEMENT, OVERWRITE_ELEMENT,
STRING_ELEMENT,
VECTOR_STRING_ELEMENT, VECTOR_STRING_ELEMENT,
) )
from karabo.common.states import State
from . import base_gpu, calcat_utils, utils from . import base_gpu, calcat_utils, utils
from ._version import version as deviceVersion from ._version import version as deviceVersion
from .base_correction import BaseCorrection, add_correction_step_schema from .base_correction import BaseCorrection, add_correction_step_schema
_pretend_pulse_table = np.arange(16, dtype=np.uint8)
class JungfrauConstants(enum.Enum): class JungfrauConstants(enum.Enum):
Offset10Hz = enum.auto() Offset10Hz = enum.auto()
BadPixelsDark10Hz = enum.auto() BadPixelsDark10Hz = enum.auto()
...@@ -53,6 +54,7 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner): ...@@ -53,6 +54,7 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
output_data_dtype=cupy.float32, output_data_dtype=cupy.float32,
bad_pixel_mask_value=cupy.nan, bad_pixel_mask_value=cupy.nan,
burst_mode=False, burst_mode=False,
gain_mode=JungfrauGainMode.dynamicgain,
): ):
self.burst_mode = burst_mode self.burst_mode = burst_mode
self.input_shape = (memory_cells, pixels_y, pixels_x) self.input_shape = (memory_cells, pixels_y, pixels_x)
...@@ -65,6 +67,9 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner): ...@@ -65,6 +67,9 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
input_data_dtype, input_data_dtype,
output_data_dtype, output_data_dtype,
) )
# TODO: avoid superclass creating cell table with wrong dtype first
self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=cupy.uint8)
self.input_gain_map_gpu = cupy.empty(self.input_shape, dtype=cupy.uint8)
self.map_shape = self.input_shape + (3,) self.map_shape = self.input_shape + (3,)
# is jungfrau stuff gain mapped? # is jungfrau stuff gain mapped?
self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32)
...@@ -84,9 +89,41 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner): ...@@ -84,9 +89,41 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner):
"burst_mode": self.burst_mode, "burst_mode": self.burst_mode,
} }
) )
print(kernel_source) for i, line in enumerate(kernel_source.split("\n")):
print(f"{i}: {line}")
self.source_module = cupy.RawModule(code=kernel_source) self.source_module = cupy.RawModule(code=kernel_source)
print("Got raw module")
self.correction_kernel = self.source_module.get_function("correct") self.correction_kernel = self.source_module.get_function("correct")
print("Got kernel")
def _get_raw_for_preview(self):
return self.input_data_gpu.transpose(0, 2, 1)
def _get_corrected_for_preview(self):
return self.processed_data_gpu.transpose(0, 2, 1)
def load_data(self, image_data, input_gain_map, cell_table):
"""Experiment: loading all three in one function as they are tied"""
self.input_data_gpu.set(image_data)
self.input_gain_map_gpu.set(input_gain_map)
self.cell_table_gpu.set(cell_table)
def correct(self, flags):
self.correction_kernel(
self.full_grid,
self.full_block,
(
self.input_data_gpu,
self.input_gain_map_gpu,
self.cell_table_gpu,
cupy.uint8(flags),
self.offset_map_gpu,
self.rel_gain_map_gpu,
self.bad_pixel_map_gpu,
self.bad_pixel_mask_value,
self.processed_data_gpu,
)
)
class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend): class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend):
...@@ -187,6 +224,10 @@ class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend): ...@@ -187,6 +224,10 @@ class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend):
res["Integration Time"] = self._get_param("integrationTime") res["Integration Time"] = self._get_param("integrationTime")
res["Sensor Temperature"] = self._get_param("sensorTemperature") res["Sensor Temperature"] = self._get_param("sensorTemperature")
res["Gain Setting"] = self._get_param("gainSetting") res["Gain Setting"] = self._get_param("gainSetting")
gain_mode = JungfrauGainMode[self._get_param("gainMode")]
if gain_mode is not JungfrauGainMode.dynamicgain:
# TODO: figure out what to set
res["Gain mode"] = 1
return res return res
...@@ -207,6 +248,16 @@ class JungfrauCorrection(BaseCorrection): ...@@ -207,6 +248,16 @@ class JungfrauCorrection(BaseCorrection):
def expectedParameters(expected): def expectedParameters(expected):
super(JungfrauCorrection, JungfrauCorrection).expectedParameters(expected) super(JungfrauCorrection, JungfrauCorrection).expectedParameters(expected)
( (
OVERWRITE_ELEMENT(expected)
.key("dataFormat.pixelsX")
.setNewDefaultValue(1024)
.commit(),
OVERWRITE_ELEMENT(expected)
.key("dataFormat.pixelsY")
.setNewDefaultValue(512)
.commit(),
OVERWRITE_ELEMENT(expected) OVERWRITE_ELEMENT(expected)
.key("dataFormat.memoryCells") .key("dataFormat.memoryCells")
.setNewDefaultValue(1) .setNewDefaultValue(1)
...@@ -238,15 +289,27 @@ class JungfrauCorrection(BaseCorrection): ...@@ -238,15 +289,27 @@ class JungfrauCorrection(BaseCorrection):
def input_data_shape(self): def input_data_shape(self):
return ( return (
self._schema_cache["dataFormat.memoryCells"], self._schema_cache["dataFormat.memoryCells"],
self._schema_cache["dataFormat.pixelsX"],
self._schema_cache["dataFormat.pixelsY"], self._schema_cache["dataFormat.pixelsY"],
self._schema_cache["dataFormat.pixelsX"],
) )
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
# TODO: gain mode as constant parameter and / or device configuration # TODO: gain mode as constant parameter and / or device configuration
self.gain_mode = JungfrauGainMode[config.get("constantParameters.gainMode")] self.gain_mode = JungfrauGainMode[config.get("constantParameters.gainMode")]
# TODO: rest of this
try:
self.bad_pixel_mask_value = np.float32(
config.get("corrections.badPixels.maskingValue")
)
except ValueError:
self.bad_pixel_mask_value = np.float32("nan")
self._kernel_runner_init_args = {
"gain_mode": self.gain_mode,
"bad_pixel_mask_value": self.bad_pixel_mask_value,
"burst_mode": False, # TODO
}
def process_data( def process_data(
self, self,
...@@ -258,5 +321,66 @@ class JungfrauCorrection(BaseCorrection): ...@@ -258,5 +321,66 @@ class JungfrauCorrection(BaseCorrection):
cell_table, cell_table,
do_generate_preview, do_generate_preview,
): ):
# TODO if self._frame_filter is not None:
... try:
cell_table = cell_table[self._frame_filter]
image_data = image_data[self._frame_filter]
except IndexError:
self.log_status_warn(
"Failed to apply frame filter, please check that it is valid!"
)
return
try:
self.kernel_runner.load_data(
image_data, data_hash.get("data.gain"), cell_table
)
except ValueError as e:
self.log_status_warn(f"Failed to load data: {e}")
return
except Exception as e:
self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
buffer_handle, buffer_array = self._shmem_buffer.next_slot()
self.kernel_runner.correct(self._correction_flag_enabled)
self.kernel_runner.reshape(
output_order=self._schema_cache["dataFormat.outputAxisOrder"],
out=buffer_array,
)
if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview:
self.kernel_runner.correct(self._correction_flag_preview)
(
preview_slice_index,
preview_cell,
preview_pulse,
) = utils.pick_frame_index(
self._schema_cache["preview.selectionMode"],
self._schema_cache["preview.index"],
cell_table,
_pretend_pulse_table,
warn_func=self.log_status_warn,
)
preview_raw, preview_corrected = self.kernel_runner.compute_previews(
preview_slice_index
)
# reusing input data hash for sending
data_hash.set(self._image_data_path, buffer_handle)
data_hash.set("calngShmemPaths", [self._image_data_path])
data_hash.set(self._cell_table_path, cell_table)
data_hash.set("image.pulseId", pulse_table[:, np.newaxis])
self._write_output(data_hash, metadata)
if do_generate_preview:
self._write_combiner_previews(
(
("preview.outputRaw", preview_raw),
("preview.outputCorrected", preview_corrected),
),
train_id,
source,
)
...@@ -3,18 +3,13 @@ ...@@ -3,18 +3,13 @@
{{corr_enum}} {{corr_enum}}
extern "C" { extern "C" {
/*
TODO
Shape of input data: memory cell, y, x
Shape of offset constant: x, y, memory cell
*/
__global__ void correct(const {{input_data_dtype}}* data, // shape: memory cell, y, x __global__ void correct(const {{input_data_dtype}}* data, // shape: memory cell, y, x
const unsigned char* gain_stage, // same shape const unsigned char* gain_stage, // same shape
const unsigned char* cell_table, const unsigned char* cell_table,
const unsigned char corr_flags, const unsigned char corr_flags,
const float* offset_map, const float* offset_map,
const float* rel_gain_map, const float* rel_gain_map,
const unsigned int bad_pixel_map, const unsigned int* bad_pixel_map,
const float bad_pixel_mask_value, const float bad_pixel_mask_value,
{{output_data_dtype}}* output) { {{output_data_dtype}}* output) {
const size_t X = {{pixels_x}}; const size_t X = {{pixels_x}};
...@@ -76,8 +71,8 @@ extern "C" { ...@@ -76,8 +71,8 @@ extern "C" {
if (corr_flags & OFFSET) { if (corr_flags & OFFSET) {
res -= offset_map[map_index]; res -= offset_map[map_index];
} }
if (corr_flags & GAIN) { if (corr_flags & REL_GAIN) {
res /= gain_map[map_index]; res /= rel_gain_map[map_index];
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment