Skip to content
Snippets Groups Projects
Commit aeb03f6a authored by David Hammer's avatar David Hammer
Browse files

Add override to ignore strange input axis order

parent cd39e039
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
......@@ -35,16 +35,26 @@ class AgipdCorrection(BaseCorrection):
_gpu_runner_class = AgipdGpuRunner
_calcat_friend_class = AgipdCalcatFriend
_constant_enum_class = AgipdConstants
_managed_keys = BaseCorrection._managed_keys[:]
_managed_keys = BaseCorrection._managed_keys[:] + ["overrideInputAxisOrder"]
# this is just extending (not mandatory)
_schema_cache_fields = BaseCorrection._schema_cache_fields | {"sendGainMap"}
_schema_cache_fields = BaseCorrection._schema_cache_fields | {
"sendGainMap",
"overrideInputAxisOrder",
}
@staticmethod
def expectedParameters(expected):
super(AgipdCorrection, AgipdCorrection).expectedParameters(expected)
expected.setDefaultValue("dataFormat.memoryCells", 352)
(
BOOL_ELEMENT(expected)
.key("overrideInputAxisOrder")
.displayedName("Override input axis order")
.assignmentOptional()
.defaultValue(False)
.reconfigurable()
.commit(),
STRING_ELEMENT(expected)
.key("gainMode")
.displayedName("Gain mode")
......@@ -188,10 +198,10 @@ class AgipdCorrection(BaseCorrection):
@property
def input_data_shape(self):
return (
self.get("dataFormat.memoryCells"),
self._schema_cache["dataFormat.memoryCells"],
2,
self.get("dataFormat.pixelsX"),
self.get("dataFormat.pixelsY"),
self._schema_cache["dataFormat.pixelsX"],
self._schema_cache["dataFormat.pixelsY"],
)
def __init__(self, config):
......@@ -269,6 +279,14 @@ class AgipdCorrection(BaseCorrection):
and self._schema_cache["preview.enable"]
)
if self._schema_cache["overrideInputAxisOrder"]:
expected_shape = self.input_data_shape
if expected_shape != image_data.shape:
self.log.INFO(
f"Overriding input data order from {image_data.shape} to {expected_shape}"
)
image_data.shape = expected_shape
with self._buffer_lock:
# cell_table = cell_table[self.pulse_filter]
pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter]
......@@ -280,7 +298,12 @@ class AgipdCorrection(BaseCorrection):
"corrected."
)
self.gpu_runner.load_data(image_data)
try:
self.gpu_runner.load_data(image_data)
except ValueError as e:
self.log_status_warn(f"Failed to load data: {e}")
return
buffer_handle, buffer_array = self._shmem_buffer.next_slot()
self.gpu_runner.load_cell_table(cell_table)
self.gpu_runner.correct(self._correction_flag_enabled)
......
......@@ -57,8 +57,8 @@ class BaseCorrection(PythonDevice):
] # subclass must extend this and put it in schema
_schema_cache_fields = {
"doAnything",
"dataFormat.memoryCells",
"constantParameters.memoryCells",
"dataFormat.memoryCells",
"dataFormat.pixelsX",
"dataFormat.pixelsY",
"dataFormat.outputAxisOrder",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment