Skip to content
Snippets Groups Projects
Commit ce6ff152 authored by David Hammer's avatar David Hammer
Browse files

Tidying up, improving documentation for MR

parent 7b5a74b5
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
...@@ -10,7 +10,7 @@ from karabo.bound import ( ...@@ -10,7 +10,7 @@ from karabo.bound import (
) )
from karabo.common.states import State from karabo.common.states import State
from . import shmem_utils, utils from . import utils
from ._version import version as deviceVersion from ._version import version as deviceVersion
from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags
from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema
...@@ -28,7 +28,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -28,7 +28,7 @@ class AgipdCorrection(BaseCorrection):
("relGainXray", CorrectionFlags.REL_GAIN_XRAY), ("relGainXray", CorrectionFlags.REL_GAIN_XRAY),
("badPixels", CorrectionFlags.BPMASK), ("badPixels", CorrectionFlags.BPMASK),
) )
_gpu_runner_class = AgipdGpuRunner _kernel_runner_class = AgipdGpuRunner
_calcat_friend_class = AgipdCalcatFriend _calcat_friend_class = AgipdCalcatFriend
_constant_enum_class = AgipdConstants _constant_enum_class = AgipdConstants
...@@ -181,12 +181,12 @@ class AgipdCorrection(BaseCorrection): ...@@ -181,12 +181,12 @@ class AgipdCorrection(BaseCorrection):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
# TODO: different gpu runner for fixed gain mode # TODO: consider different gpu runner for fixed gain mode
self.gain_mode = AgipdGainMode[config.get("gainMode")] self.gain_mode = AgipdGainMode[config.get("gainMode")]
self.bad_pixel_mask_value = eval( self.bad_pixel_mask_value = eval(
config.get("corrections.badPixels.maskingValue") config.get("corrections.badPixels.maskingValue")
) )
self._gpu_runner_init_args = { self._kernel_runner_init_args = {
"gain_mode": self.gain_mode, "gain_mode": self.gain_mode,
"bad_pixel_mask_value": self.bad_pixel_mask_value, "bad_pixel_mask_value": self.bad_pixel_mask_value,
"g_gain_value": config.get("corrections.relGainXray.gGainValue"), "g_gain_value": config.get("corrections.relGainXray.gGainValue"),
...@@ -235,7 +235,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -235,7 +235,7 @@ class AgipdCorrection(BaseCorrection):
return return
try: try:
self.gpu_runner.load_data(image_data) self.kernel_runner.load_data(image_data)
except ValueError as e: except ValueError as e:
self.log_status_warn(f"Failed to load data: {e}") self.log_status_warn(f"Failed to load data: {e}")
return return
...@@ -243,16 +243,16 @@ class AgipdCorrection(BaseCorrection): ...@@ -243,16 +243,16 @@ class AgipdCorrection(BaseCorrection):
self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
buffer_handle, buffer_array = self._shmem_buffer.next_slot() buffer_handle, buffer_array = self._shmem_buffer.next_slot()
self.gpu_runner.load_cell_table(cell_table) self.kernel_runner.load_cell_table(cell_table)
self.gpu_runner.correct(self._correction_flag_enabled) self.kernel_runner.correct(self._correction_flag_enabled)
self.gpu_runner.reshape( self.kernel_runner.reshape(
output_order=self._schema_cache["dataFormat.outputAxisOrder"], output_order=self._schema_cache["dataFormat.outputAxisOrder"],
out=buffer_array, out=buffer_array,
) )
# after reshape, data for dataOutput is now safe in its own buffer # after reshape, data for dataOutput is now safe in its own buffer
if do_generate_preview: if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview: if self._correction_flag_enabled != self._correction_flag_preview:
self.gpu_runner.correct(self._correction_flag_preview) self.kernel_runner.correct(self._correction_flag_preview)
( (
preview_slice_index, preview_slice_index,
preview_cell, preview_cell,
...@@ -269,7 +269,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -269,7 +269,7 @@ class AgipdCorrection(BaseCorrection):
preview_corrected, preview_corrected,
preview_raw_gain, preview_raw_gain,
preview_gain_map, preview_gain_map,
) = self.gpu_runner.compute_previews(preview_slice_index) ) = self.kernel_runner.compute_previews(preview_slice_index)
# reusing input data hash for sending # reusing input data hash for sending
data_hash.set("image.data", buffer_handle) data_hash.set("image.data", buffer_handle)
...@@ -291,30 +291,30 @@ class AgipdCorrection(BaseCorrection): ...@@ -291,30 +291,30 @@ class AgipdCorrection(BaseCorrection):
source, source,
) )
def _load_constant_to_gpu(self, constant, constant_data): def _load_constant_to_runner(self, constant, constant_data):
# TODO: encode correction / constant dependencies in a clever way # TODO: encode correction / constant dependencies in a clever way
if constant is AgipdConstants.ThresholdsDark: if constant is AgipdConstants.ThresholdsDark:
field_name = "thresholding" # TODO: (reverse) mapping, DRY field_name = "thresholding" # TODO: (reverse) mapping, DRY
if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN:
self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode") self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode")
return return
self.gpu_runner.load_thresholds(constant_data) self.kernel_runner.load_thresholds(constant_data)
elif constant is AgipdConstants.Offset: elif constant is AgipdConstants.Offset:
field_name = "offset" field_name = "offset"
self.gpu_runner.load_offset_map(constant_data) self.kernel_runner.load_offset_map(constant_data)
elif constant is AgipdConstants.SlopesPC: elif constant is AgipdConstants.SlopesPC:
field_name = "relGainPc" field_name = "relGainPc"
self.gpu_runner.load_rel_gain_pc_map(constant_data) self.kernel_runner.load_rel_gain_pc_map(constant_data)
if self._override_md_additional_offset is not None: if self._override_md_additional_offset is not None:
self.gpu_runner.md_additional_offset_gpu.fill( self.kernel_runner.md_additional_offset_gpu.fill(
self._override_md_additional_offset self._override_md_additional_offset
) )
elif constant is AgipdConstants.SlopesFF: elif constant is AgipdConstants.SlopesFF:
field_name = "relGainXray" field_name = "relGainXray"
self.gpu_runner.load_rel_gain_ff_map(constant_data) self.kernel_runner.load_rel_gain_ff_map(constant_data)
elif "BadPixels" in constant.name: elif "BadPixels" in constant.name:
field_name = "badPixels" field_name = "badPixels"
self.gpu_runner.load_bad_pixels_map( self.kernel_runner.load_bad_pixels_map(
constant_data, override_flags_to_use=self._override_bad_pixel_flags constant_data, override_flags_to_use=self._override_bad_pixel_flags
) )
...@@ -340,7 +340,7 @@ class AgipdCorrection(BaseCorrection): ...@@ -340,7 +340,7 @@ class AgipdCorrection(BaseCorrection):
self._override_md_additional_offset = self.get( self._override_md_additional_offset = self.get(
"corrections.relGainPc.mdAdditionalOffset" "corrections.relGainPc.mdAdditionalOffset"
) )
self.gpu_runner.override_md_additional_offset( self.kernel_runner.override_md_additional_offset(
self._override_md_additional_offset self._override_md_additional_offset
) )
else: else:
...@@ -352,10 +352,10 @@ class AgipdCorrection(BaseCorrection): ...@@ -352,10 +352,10 @@ class AgipdCorrection(BaseCorrection):
update = self._prereconfigure_update_hash update = self._prereconfigure_update_hash
if update.has("corrections.relGainXray.gGainValue"): if update.has("corrections.relGainXray.gGainValue"):
self.gpu_runner.set_g_gain_value( self.kernel_runner.set_g_gain_value(
self.get("corrections.relGainXray.gGainValue") self.get("corrections.relGainXray.gGainValue")
) )
self._gpu_runner_init_args["g_gain_value"] = self.get( self._kernel_runner_init_args["g_gain_value"] = self.get(
"corrections.relGainXray.gGainValue" "corrections.relGainXray.gGainValue"
) )
...@@ -363,8 +363,8 @@ class AgipdCorrection(BaseCorrection): ...@@ -363,8 +363,8 @@ class AgipdCorrection(BaseCorrection):
self.bad_pixel_mask_value = eval( self.bad_pixel_mask_value = eval(
self.get("corrections.badPixels.maskingValue") self.get("corrections.badPixels.maskingValue")
) )
self.gpu_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value)
self._gpu_runner_init_args[ self._kernel_runner_init_args[
"bad_pixel_mask_value" "bad_pixel_mask_value"
] = self.bad_pixel_mask_value ] = self.bad_pixel_mask_value
...@@ -388,8 +388,8 @@ class AgipdCorrection(BaseCorrection): ...@@ -388,8 +388,8 @@ class AgipdCorrection(BaseCorrection):
data, data,
) in self.calcat_friend.cached_constants.items(): ) in self.calcat_friend.cached_constants.items():
if "BadPixels" in constant.name: if "BadPixels" in constant.name:
self._load_constant_to_gpu(constant, data) self._load_constant_to_runner(constant, data)
self._update_bad_pixel_selection() self._update_bad_pixel_selection()
self.gpu_runner.override_bad_pixel_flags_to_use( self.kernel_runner.override_bad_pixel_flags_to_use(
self._override_bad_pixel_flags self._override_bad_pixel_flags
) )
...@@ -14,7 +14,7 @@ class DsscCorrection(BaseCorrection): ...@@ -14,7 +14,7 @@ class DsscCorrection(BaseCorrection):
# subclass *must* set these attributes # subclass *must* set these attributes
_correction_flag_class = CorrectionFlags _correction_flag_class = CorrectionFlags
_correction_field_names = (("offset", CorrectionFlags.OFFSET),) _correction_field_names = (("offset", CorrectionFlags.OFFSET),)
_gpu_runner_class = DsscGpuRunner _kernel_runner_class = DsscGpuRunner
_calcat_friend_class = DsscCalcatFriend _calcat_friend_class = DsscCalcatFriend
_constant_enum_class = DsscConstants _constant_enum_class = DsscConstants
_managed_keys = BaseCorrection._managed_keys.copy() _managed_keys = BaseCorrection._managed_keys.copy()
...@@ -74,7 +74,7 @@ class DsscCorrection(BaseCorrection): ...@@ -74,7 +74,7 @@ class DsscCorrection(BaseCorrection):
return return
try: try:
self.gpu_runner.load_data(image_data) self.kernel_runner.load_data(image_data)
except ValueError as e: except ValueError as e:
self.log_status_warn(f"Failed to load data: {e}") self.log_status_warn(f"Failed to load data: {e}")
return return
...@@ -82,15 +82,15 @@ class DsscCorrection(BaseCorrection): ...@@ -82,15 +82,15 @@ class DsscCorrection(BaseCorrection):
self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
buffer_handle, buffer_array = self._shmem_buffer.next_slot() buffer_handle, buffer_array = self._shmem_buffer.next_slot()
self.gpu_runner.load_cell_table(cell_table) self.kernel_runner.load_cell_table(cell_table)
self.gpu_runner.correct(self._correction_flag_enabled) self.kernel_runner.correct(self._correction_flag_enabled)
self.gpu_runner.reshape( self.kernel_runner.reshape(
output_order=self._schema_cache["dataFormat.outputAxisOrder"], output_order=self._schema_cache["dataFormat.outputAxisOrder"],
out=buffer_array, out=buffer_array,
) )
if do_generate_preview: if do_generate_preview:
if self._correction_flag_enabled != self._correction_flag_preview: if self._correction_flag_enabled != self._correction_flag_preview:
self.gpu_runner.correct(self._correction_flag_preview) self.kernel_runner.correct(self._correction_flag_preview)
( (
preview_slice_index, preview_slice_index,
preview_cell, preview_cell,
...@@ -102,7 +102,7 @@ class DsscCorrection(BaseCorrection): ...@@ -102,7 +102,7 @@ class DsscCorrection(BaseCorrection):
pulse_table, pulse_table,
warn_func=self.log_status_warn, warn_func=self.log_status_warn,
) )
preview_raw, preview_corrected = self.gpu_runner.compute_previews( preview_raw, preview_corrected = self.kernel_runner.compute_previews(
preview_slice_index, preview_slice_index,
) )
...@@ -121,9 +121,9 @@ class DsscCorrection(BaseCorrection): ...@@ -121,9 +121,9 @@ class DsscCorrection(BaseCorrection):
source, source,
) )
def _load_constant_to_gpu(self, constant, constant_data): def _load_constant_to_runner(self, constant, constant_data):
assert constant is DsscConstants.Offset assert constant is DsscConstants.Offset
self.gpu_runner.load_offset_map(constant_data) self.kernel_runner.load_offset_map(constant_data)
if not self.get("corrections.offset.available"): if not self.get("corrections.offset.available"):
self.set("corrections.offset.available", True) self.set("corrections.offset.available", True)
......
This diff is collapsed.
...@@ -103,13 +103,6 @@ def _add_status_schema_from_enum(schema, prefix, enum_class): ...@@ -103,13 +103,6 @@ def _add_status_schema_from_enum(schema, prefix, enum_class):
) )
class DetectorStandin(typing.NamedTuple):
detector_name: str
modno_to_source: dict
frames_per_train: int
module_shape: tuple
class OperatingConditions(dict): class OperatingConditions(dict):
# TODO: support deviation? # TODO: support deviation?
def encode(self): def encode(self):
...@@ -132,6 +125,16 @@ class OperatingConditions(dict): ...@@ -132,6 +125,16 @@ class OperatingConditions(dict):
class BaseCalcatFriend: class BaseCalcatFriend:
"""Base class for CalCat friends - handles interacting with CalCat for the device
A CalCat friend uses the device schema to build up parameters for CalCat queries.
It focuses on two nodes (added by static method add_schema): param_prefix and
status_prefix. The former is primarily used to get parameters which are (via
condition methods - see for example dark_condition of DsscCalcatFriend) used
to look for constants. The latter is primarily used to give user information
about what was found.
"""
_constant_enum_class = None # subclass should set _constant_enum_class = None # subclass should set
_constants_need_conditions = None # subclass should set _constants_need_conditions = None # subclass should set
...@@ -272,6 +275,7 @@ class BaseCalcatFriend: ...@@ -272,6 +275,7 @@ class BaseCalcatFriend:
self.status_prefix = status_prefix self.status_prefix = status_prefix
self.cached_constants = {} self.cached_constants = {}
self.cached_constants_lock = threading.Lock() self.cached_constants_lock = threading.Lock()
# api lock used to force queries to be sequential (SSL issue on ONC)
self.api_lock = threading.Lock() self.api_lock = threading.Lock()
if not secrets_fn.is_file(): if not secrets_fn.is_file():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment