Skip to content
Snippets Groups Projects
Commit c0bffab0 authored by David Hammer's avatar David Hammer
Browse files

Prototype Cython-based GOTTHARD2 correction device

parent 319a0a11
No related branches found
No related tags found
1 merge request!12Snapshot: field test deployed version as of end of run 202201
#!/usr/bin/env python
from os.path import dirname, join, realpath
from setuptools import setup, find_packages
from setuptools import setup, find_packages, Extension
from Cython.Build import cythonize
from karabo.packaging.versioning import device_scm_version
......@@ -26,6 +27,7 @@ setup(name='calng',
'karabo.bound_device': [
'AgipdCorrection = calng.AgipdCorrection:AgipdCorrection',
'DsscCorrection = calng.DsscCorrection:DsscCorrection',
'Gotthard2Correction = calng.Gotthard2Correction:Gotthard2Correction',
'JungfrauCorrection = calng.JungfrauCorrection:JungfrauCorrection',
'LpdCorrection = calng.LpdCorrection:LpdCorrection',
'ShmemToZMQ = calng.ShmemToZMQ:ShmemToZMQ',
......@@ -43,4 +45,12 @@ setup(name='calng',
},
package_data={'': ['kernels/*']},
requires=[],
ext_modules=cythonize(
[
Extension(
'calng.kernels.gotthard2_cython',
['src/calng/kernels/gotthard2_cpu.pyx']
)
]
),
)
import enum
import numpy as np
from karabo.bound import (
IMAGEDATA_ELEMENT,
KARABO_CLASSINFO,
NODE_ELEMENT,
OUTPUT_CHANNEL,
OVERWRITE_ELEMENT,
UINT64_ELEMENT,
VECTOR_STRING_ELEMENT,
ImageData,
Schema,
)
from . import base_calcat, utils
from ._version import version as deviceVersion
from .base_correction import BaseCorrection, add_correction_step_schema
_pretend_pulse_table = np.arange(2720, dtype=np.uint8)
streak_preview_schema = Schema()
(
NODE_ELEMENT(streak_preview_schema).key("image").commit(),
IMAGEDATA_ELEMENT(streak_preview_schema).key("image.data").commit(),
UINT64_ELEMENT(streak_preview_schema).key("trainId").readOnly().commit(),
)
class Gotthard2Constants(enum.Enum):
Lut = enum.auto()
Offset = enum.auto()
Gain = enum.auto()
class CorrectionFlags(enum.IntFlag):
NONE = 0
LUT = 1
OFFSET = 2
GAIN = 4
class Gotthard2CpuRunner:
def __init__(
self,
pixels_x,
pixels_y,
memory_cells,
constant_memory_cells,
input_data_dtype=np.uint16,
output_data_dtype=np.float32,
bad_pixel_mask_value=np.nan,
):
from .kernels import gotthard2_cython # TODO
self.correction_kernel = gotthard2_cython.correct
self.pixels_x = pixels_x
self.memory_cells = memory_cells
self.constant_memory_cells = constant_memory_cells
self.input_shape = (memory_cells, pixels_x)
self.processed_shape = self.input_shape
# model: 2 buffers (corresponding to actual memory cells), 2720 frames
# lut maps from uint12 to uint10 values
self.lut_shape = (2, 4096, pixels_x)
self.map_shape = (3, self.constant_memory_cells, self.pixels_x)
self.lut = np.empty(self.lut_shape, dtype=np.uint16)
self.offset_map = np.empty(self.map_shape, dtype=np.float32)
self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32)
self.flush_buffers()
self.input_data = None # will just point to data coming in
self.input_gain_stage = None # will just point to data coming in
self.processed_data = None # will just point to buffer we're given
self.preview_buffer_getters = [
self._get_raw_for_preview,
self._get_corrected_for_preview,
]
def _get_raw_for_preview(self):
return self.input_data
def _get_corrected_for_preview(self):
return self.processed_data
def load_data(self, image_data, input_gain_stage):
"""Experiment: loading all three in one function as they are tied"""
self.input_data = image_data.astype(np.uint16, copy=False)
self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False)
def flush_buffers(self):
default_lut = (
np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12
).astype(np.uint16)
self.lut[:] = np.stack([np.stack([default_lut] * 2)] * self.pixels_x, axis=2)
self.offset_map.fill(0)
self.rel_gain_map.fill(1)
def correct(self, flags, out=None):
if out is None:
out = np.empty(self.processed_shape, dtype=np.float32)
self.correction_kernel(
self.input_data,
self.input_gain_stage,
np.uint8(flags),
self.lut,
self.offset_map,
self.rel_gain_map,
out,
)
self.processed_data = out
return out
def compute_previews(self, preview_index):
"""See BaseGpuRunner.compute_previews"""
if preview_index < -4:
raise ValueError(f"No statistic with code {preview_index} defined")
elif preview_index >= self.memory_cells:
raise ValueError(f"Memory cell index {preview_index} out of range")
# TODO: enum around reduction type
return tuple(
self._compute_a_preview(image_data=getter(), preview_index=preview_index)
for getter in self.preview_buffer_getters
)
def _compute_a_preview(self, image_data, preview_index):
if preview_index >= 0:
return image_data[preview_index].astype(np.float32, copy=False)
elif preview_index == -1:
return np.nanmax(image_data, axis=0).astype(np.float32, copy=False)
elif preview_index in (-2, -3, -4):
stat_fun = {
-2: np.nanmean,
-3: np.nansum,
-4: np.nanstd,
}[preview_index]
return stat_fun(image_data, axis=0, dtype=np.float32)
class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend):
_constant_enum_class = Gotthard2Constants
def __init__(self, device, *args, **kwargs):
super().__init__(device, *args, **kwargs)
self._constants_need_conditions = {} # TODO
@staticmethod
def add_schema(
schema,
managed_keys,
param_prefix="constantParameters",
status_prefix="foundConstants",
):
super(Gotthard2CalcatFriend, Gotthard2CalcatFriend).add_schema(
schema, managed_keys, "gotthard-Type", param_prefix, status_prefix
)
# set some defaults for common parameters
(
OVERWRITE_ELEMENT(schema)
.key(f"{param_prefix}.pixelsX")
.setNewDefaultValue(1280)
.commit(),
OVERWRITE_ELEMENT(schema)
.key(f"{param_prefix}.pixelsY")
.setNewDefaultValue(1)
.commit(),
OVERWRITE_ELEMENT(schema)
.key(f"{param_prefix}.memoryCells")
.setNewDefaultValue(2)
.commit(),
)
base_calcat.add_status_schema_from_enum(
schema, status_prefix, Gotthard2Constants
)
@KARABO_CLASSINFO("Gotthard2Correction", deviceVersion)
class Gotthard2Correction(BaseCorrection):
_correction_flag_class = CorrectionFlags
_correction_field_names = (
("lut", CorrectionFlags.LUT),
("offset", CorrectionFlags.OFFSET),
("gain", CorrectionFlags.GAIN),
)
_kernel_runner_class = Gotthard2CpuRunner
_calcat_friend_class = Gotthard2CalcatFriend
_constant_enum_class = Gotthard2Constants
_managed_keys = BaseCorrection._managed_keys.copy()
_image_data_path = "data.adc"
_cell_table_path = "data.memoryCell"
@staticmethod
def expectedParameters(expected):
super(Gotthard2Correction, Gotthard2Correction).expectedParameters(expected)
(
OVERWRITE_ELEMENT(expected)
.key("dataFormat.pixelsX")
.setNewDefaultValue(1280)
.commit(),
OVERWRITE_ELEMENT(expected)
.key("dataFormat.pixelsY")
.setNewDefaultValue(1)
.commit(),
OVERWRITE_ELEMENT(expected)
.key("dataFormat.memoryCells")
.setNewDefaultValue(2720) # note: actually just frames...
.commit(),
OVERWRITE_ELEMENT(expected)
.key("preview.selectionMode")
.setNewDefaultValue("frame")
.commit(),
)
(
OUTPUT_CHANNEL(expected)
.key("preview.outputStreak")
.dataSchema(streak_preview_schema)
.commit(),
)
Gotthard2CalcatFriend.add_schema(expected, Gotthard2Correction._managed_keys)
add_correction_step_schema(
expected,
Gotthard2Correction._managed_keys,
Gotthard2Correction._correction_field_names,
)
# mandatory: manager needs this in schema
(
VECTOR_STRING_ELEMENT(expected)
.key("managedKeys")
.assignmentOptional()
.defaultValue(list(Gotthard2Correction._managed_keys))
.commit()
)
@property
def input_data_shape(self):
return (
self.unsafe_get("dataFormat.memoryCells"),
self.unsafe_get("dataFormat.pixelsX"),
)
@property
def output_data_shape(self):
return (
self.unsafe_get("dataFormat.memoryCells"),
self.unsafe_get("dataFormat.pixelsX"),
)
def __init__(self, config):
super().__init__(config)
# TODO: gain mode as constant parameter and / or device configuration
try:
self.bad_pixel_mask_value = np.float32(
config.get("corrections.badPixels.maskingValue")
)
except ValueError:
self.bad_pixel_mask_value = np.float32("nan")
self._kernel_runner_init_args = {
"bad_pixel_mask_value": self.bad_pixel_mask_value,
}
def process_data(
self,
data_hash,
metadata,
source,
train_id,
image_data,
cell_table,
do_generate_preview,
):
# cell table currently not used for GOTTHARD2 (assume alternating)
try:
self.kernel_runner.load_data(
image_data, data_hash.get("data.gain")
)
except Exception as e:
self.log_status_warn(f"Unknown exception when loading data: {e}")
buffer_handle, buffer_array = self._shmem_buffer.next_slot()
self.kernel_runner.correct(self._correction_flag_enabled, out=buffer_array)
if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview:
self.kernel_runner.correct(self._correction_flag_preview)
(
preview_slice_index,
preview_cell,
preview_pulse,
) = utils.pick_frame_index(
self.unsafe_get("preview.selectionMode"),
self.unsafe_get("preview.index"),
cell_table,
_pretend_pulse_table,
warn_func=self.log_status_warn,
)
(
preview_raw,
preview_corrected,
) = self.kernel_runner.compute_previews(preview_slice_index)
# reusing input data hash for sending
data_hash.set(self._image_data_path, buffer_handle)
data_hash.set("calngShmemPaths", [self._image_data_path])
self._write_output(data_hash, metadata)
if do_generate_preview:
self._write_preview_outputs(
(
("preview.outputRaw", preview_raw),
("preview.outputCorrected", preview_corrected),
),
metadata,
)
def _load_constant_to_runner(self, constant, constant_data):
if constant is Gotthard2Constants.Lut:
self.kernel_runner.lut[:] = constant_data.astype(np.uint16, copy=False)
if not self.get("corrections.lut.available"):
self.set("corrections.lut.available", True)
elif constant is Gotthard2Constants.Offset:
self.kernel_runner.offset_map[:] = constant_data.astype(
np.float32, copy=False
)
if not self.get("corrections.offset.available"):
self.set("corrections.offset.available", True)
elif constant is Gotthard2Constants.Gain:
self.kernel_runner.rel_gain_map[:] = constant_data.astype(
np.float32, copy=False
)
if not self.get("corrections.gain.available"):
self.set("corrections.gain.available", True)
self._update_correction_flags()
self.log_status_info(f"Done loading {constant.name}")
# cython: boundscheck=False
# cython: cdivision=True
# cython: wrapararound=False
# TODO: get these automatically from enum definition
cdef unsigned char NONE = 0
cdef unsigned char LUT = 1
cdef unsigned char OFFSET = 2
cdef unsigned char GAIN = 4
def correct(
unsigned short[:, :] raw_data,
unsigned char[:, :] raw_gain,
unsigned char flags,
unsigned short[:, :, :] lut,
float[:, :, :] offset_map,
float[:, :, :] gain_map,
float[:, :] output,
):
cdef unsigned frame, x
cdef unsigned short cell, raw_value, looked_up_value
cdef float res
cdef unsigned char gain
for frame in range(raw_data.shape[0]):
for x in range(raw_data.shape[1]):
cell = frame % 2
gain = raw_gain[frame, x]
if gain == 3:
gain = 2
raw_value = raw_data[frame, x]
if (flags & LUT):
res = <float>lut[cell, raw_value, x]
else:
res = <float>raw_value
if (flags & OFFSET):
res -= offset_map[gain, cell, x]
if (flags & GAIN):
res /= gain_map[gain, cell, x]
output[frame, x] = res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment