Skip to content
Snippets Groups Projects
Commit 43b679ab authored by David Hammer's avatar David Hammer
Browse files

WIP: add handling of fixed gain mode

parent 4d38ca5b
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
...@@ -15,7 +15,7 @@ from karabo.common.states import State ...@@ -15,7 +15,7 @@ from karabo.common.states import State
from . import utils from . import utils
from ._version import version as deviceVersion from ._version import version as deviceVersion
from .agipd_gpu import AgipdGpuRunner, BadPixelValues, CorrectionFlags from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags
from .base_correction import BaseCorrection from .base_correction import BaseCorrection
...@@ -111,6 +111,15 @@ class AgipdCorrection(BaseCorrection): ...@@ -111,6 +111,15 @@ class AgipdCorrection(BaseCorrection):
.reconfigurable() .reconfigurable()
.commit(), .commit(),
) )
(
STRING_ELEMENT(expected)
.key("gainMode")
.displayedName("Gain mode")
.assignmentOptional()
.defaultValue("ADAPTIVE_GAIN")
.options("ADAPTIVE_GAIN,FIXED_HIGH_GAIN,FIXED_MEDIUM_GAIN,FIXED_LOW_GAIN")
.commit()
)
# TODO: hook this up to actual correction done # TODO: hook this up to actual correction done
bad_pixel_selection_schema = Schema() bad_pixel_selection_schema = Schema()
( (
...@@ -157,6 +166,10 @@ class AgipdCorrection(BaseCorrection): ...@@ -157,6 +166,10 @@ class AgipdCorrection(BaseCorrection):
) )
def __init__(self, config): def __init__(self, config):
# TODO: different gpu runner for fixed gain mode
self.gain_mode = AgipdGainMode[config.get("gainMode")]
self._gpu_runner_init_args = {"gain_mode": self.gain_mode}
super().__init__(config) super().__init__(config)
output_axis_order = config.get("dataFormat.outputAxisOrder") output_axis_order = config.get("dataFormat.outputAxisOrder")
if output_axis_order == "pixels-fast": if output_axis_order == "pixels-fast":
...@@ -297,6 +310,9 @@ class AgipdCorrection(BaseCorrection): ...@@ -297,6 +310,9 @@ class AgipdCorrection(BaseCorrection):
def _load_constant_to_gpu(self, constant_name, constant_data): def _load_constant_to_gpu(self, constant_name, constant_data):
if constant_name == "ThresholdsDark": if constant_name == "ThresholdsDark":
if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN:
self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode")
return
self.gpu_runner.load_thresholds(constant_data) self.gpu_runner.load_thresholds(constant_data)
# TODO: encode correction / constant dependencies in a clever way # TODO: encode correction / constant dependencies in a clever way
if not self.get("corrections.available.thresholding"): if not self.get("corrections.available.thresholding"):
...@@ -344,16 +360,12 @@ class AgipdCorrection(BaseCorrection): ...@@ -344,16 +360,12 @@ class AgipdCorrection(BaseCorrection):
assert np.max(new_filter) < self.get("dataFormat.memoryCells") assert np.max(new_filter) < self.get("dataFormat.memoryCells")
self.pulse_filter = new_filter self.pulse_filter = new_filter
def preReconfigure(self, config): def postReconfigure(self):
super().preReconfigure(config) super().postReconfigure()
if config.has("corrections.overrideMdAdditionalOffset"): if self.get("corrections.overrideMdAdditionalOffset"):
if config.get("corrections.overrideMdAdditionalOffset"): self._override_md_additional_offset = self.get("corrections.mdAdditionalOffset")
md_additional_offset = self.get("corrections.mdAdditionalOffset") self.gpu_runner.md_additional_offset_gpu.fill(
if config.has("corrections.mdAdditionalOffset"): self._override_md_additional_offset
md_additional_offset = config.get("corrections.mdAdditionalOffset") )
self._override_md_additional_offset = md_additional_offset else:
self.gpu_runner.md_additional_offset_gpu.fill( self._override_md_additional_offset = None
self._override_md_additional_offset
)
else:
self._override_md_additional_offset = None
...@@ -16,6 +16,14 @@ class CorrectionFlags(enum.IntFlag): ...@@ -16,6 +16,14 @@ class CorrectionFlags(enum.IntFlag):
BPMASK = 32 BPMASK = 32
# from pycalibration's enum.py
class AgipdGainMode(enum.IntEnum):
ADAPTIVE_GAIN = 0
FIXED_HIGH_GAIN = 1
FIXED_MEDIUM_GAIN = 2
FIXED_LOW_GAIN = 3
class AgipdGpuRunner(base_gpu.BaseGpuRunner): class AgipdGpuRunner(base_gpu.BaseGpuRunner):
_kernel_source_filename = "agipd_gpu_kernels.cpp" _kernel_source_filename = "agipd_gpu_kernels.cpp"
...@@ -29,7 +37,13 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): ...@@ -29,7 +37,13 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
input_data_dtype=np.uint16, input_data_dtype=np.uint16,
output_data_dtype=np.float32, output_data_dtype=np.float32,
badpixel_mask_value=np.float32(np.nan), badpixel_mask_value=np.float32(np.nan),
gain_mode=AgipdGainMode.ADAPTIVE_GAIN,
): ):
self.gain_mode = gain_mode
if self.gain_mode is AgipdGainMode.ADAPTIVE_GAIN:
self.default_gain = np.uint8(gain_mode)
else:
self.default_gain = np.uint8(gain_mode - 1)
self.input_shape = (memory_cells, 2, pixels_x, pixels_y) self.input_shape = (memory_cells, 2, pixels_x, pixels_y)
self.processed_shape = (memory_cells, pixels_x, pixels_y) self.processed_shape = (memory_cells, pixels_x, pixels_y)
super().__init__( super().__init__(
...@@ -162,6 +176,10 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): ...@@ -162,6 +176,10 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
def correct(self, flags): def correct(self, flags):
if flags & CorrectionFlags.BLSHIFT: if flags & CorrectionFlags.BLSHIFT:
raise NotImplementedError("Baseline shift not implemented yet") raise NotImplementedError("Baseline shift not implemented yet")
if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN and (
flags & CorrectionFlags.THRESHOLD
):
raise ValueError("Cannot do gain thresholding in fixed gain mode")
self.correction_kernel( self.correction_kernel(
self.full_grid, self.full_grid,
self.full_block, self.full_block,
...@@ -169,6 +187,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): ...@@ -169,6 +187,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner):
self.input_data_gpu, self.input_data_gpu,
self.cell_table_gpu, self.cell_table_gpu,
np.uint8(flags), np.uint8(flags),
self.default_gain,
self.gain_thresholds_gpu, self.gain_thresholds_gpu,
self.offset_map_gpu, self.offset_map_gpu,
self.rel_gain_pc_map_gpu, self.rel_gain_pc_map_gpu,
......
...@@ -12,14 +12,16 @@ extern "C" { ...@@ -12,14 +12,16 @@ extern "C" {
*/ */
__global__ void correct(const {{input_data_dtype}}* data, __global__ void correct(const {{input_data_dtype}}* data,
const unsigned short* cell_table, const unsigned short* cell_table,
const unsigned char corr_flags, const unsigned char corr_flags,
// default_gain can be 0, 1, or 2, and is relevant for fixed gain mode (no THRESHOLD)
const unsigned char default_gain,
const float* threshold_map, const float* threshold_map,
const float* offset_map, const float* offset_map,
const float* rel_gain_pc_map, const float* rel_gain_pc_map,
const float* md_additional_offset, const float* md_additional_offset,
const float* rel_gain_xray_map, const float* rel_gain_xray_map,
const unsigned int* bad_pixel_map, const unsigned int* bad_pixel_map,
const float bad_pixel_mask_value, const float bad_pixel_mask_value,
unsigned char* gain_map, unsigned char* gain_map,
{{output_data_dtype}}* output) { {{output_data_dtype}}* output) {
const size_t X = {{pixels_x}}; const size_t X = {{pixels_x}};
...@@ -77,7 +79,7 @@ extern "C" { ...@@ -77,7 +79,7 @@ extern "C" {
const size_t map_cell = cell_table[cell]; const size_t map_cell = cell_table[cell];
if (map_cell < map_cells) { if (map_cell < map_cells) {
unsigned char gain = 0; unsigned char gain = default_gain;
if (corr_flags & THRESHOLD) { if (corr_flags & THRESHOLD) {
const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold + const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold +
map_cell * threshold_map_stride_cell + map_cell * threshold_map_stride_cell +
......
...@@ -32,6 +32,8 @@ from . import shmem_utils, utils ...@@ -32,6 +32,8 @@ from . import shmem_utils, utils
class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
_correction_flag_class = None # subclass must override this with some enum class _correction_flag_class = None # subclass must override this with some enum class
_gpu_runner_class = None # subclass must set this
_gpu_runner_init_args = {} # subclass can set this (TODO: remove, design better)
_schema_cache_slots = { _schema_cache_slots = {
"doAnything", "doAnything",
"dataFormat.memoryCells", "dataFormat.memoryCells",
...@@ -387,6 +389,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -387,6 +389,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype")) self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype"))
self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype"))
self.gpu_runner = None # must call _update_shapes() in subclass init
self._correction_flag_enabled = self._correction_flag_class.NONE self._correction_flag_enabled = self._correction_flag_class.NONE
self._correction_flag_preview = self._correction_flag_class.NONE self._correction_flag_preview = self._correction_flag_class.NONE
...@@ -590,6 +593,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice): ...@@ -590,6 +593,7 @@ class BaseCorrection(calibrationBase.CalibrationReceiverBaseDevice):
output_transpose=self._output_transpose, output_transpose=self._output_transpose,
input_data_dtype=self.input_data_dtype, input_data_dtype=self.input_data_dtype,
output_data_dtype=self.output_data_dtype, output_data_dtype=self.output_data_dtype,
**self._gpu_runner_init_args,
) )
for constant_name, constant_data in self._cached_constants.items(): for constant_name, constant_data in self._cached_constants.items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment