Skip to content
Snippets Groups Projects
Commit 39ef1be4 authored by David Hammer's avatar David Hammer
Browse files

Reimplementing frame filter

parent 3de68767
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
......@@ -240,8 +240,17 @@ class AgipdCorrection(BaseCorrection):
compute preview, write data output, and optionally write preview outputs."""
# original shape: memory_cell, data/raw_gain, x, y
# TODO: add pulse filter back in
pulse_table = np.squeeze(data_hash.get("image.pulseId"))
if self._frame_filter is not None:
try:
cell_table = cell_table[self._frame_filter]
pulse_table = pulse_table[self._frame_filter]
image_data = image_data[self._frame_filter]
except IndexError:
self.log_status_warn(
"Failed to apply frame filter, please check that it is valid!"
)
return
try:
self.gpu_runner.load_data(image_data)
......@@ -362,19 +371,8 @@ class AgipdCorrection(BaseCorrection):
self._update_correction_flags()
self.log_status_info(f"Done loading {constant.name} to GPU")
def _update_pulse_filter(self, filter_string):
"""Called whenever the pulse filter changes, typically followed by
_update_shapes"""
if filter_string.strip() == "":
new_filter = np.arange(self.get("dataFormat.memoryCells"), dtype=np.uint16)
else:
new_filter = np.array(eval(filter_string), dtype=np.uint16)
assert np.max(new_filter) < self.get("dataFormat.memoryCells")
self.pulse_filter = new_filter
def _update_shapes(self):
super()._update_shapes()
def _update_buffers(self):
super()._update_buffers()
# TODO: pack four pixels per byte
if self._schema_cache["sendGainMap"]:
if self._shmem_buffer_gain_map is None:
......
......@@ -61,8 +61,17 @@ class DsscCorrection(BaseCorrection):
cell_table,
do_generate_preview,
):
# cell_table = cell_table[self.pulse_filter]
pulse_table = np.squeeze(data_hash.get("image.pulseId")) # [self.pulse_filter]
pulse_table = np.squeeze(data_hash.get("image.pulseId"))
if self._frame_filter is not None:
try:
cell_table = cell_table[self._frame_filter]
pulse_table = pulse_table[self._frame_filter]
image_data = image_data[self._frame_filter]
except IndexError:
self.log_status_warn(
"Failed to apply frame filter, please check that it is valid!"
)
return
try:
self.gpu_runner.load_data(image_data)
......@@ -112,17 +121,6 @@ class DsscCorrection(BaseCorrection):
source,
)
def _update_pulse_filter(self, filter_string):
"""Called whenever the pulse filter changes, typically followed by
_update_shapes"""
if filter_string.strip() == "":
new_filter = np.arange(self.get("dataFormat.memoryCells"), dtype=np.uint16)
else:
new_filter = np.array(eval(filter_string), dtype=np.uint16)
assert np.max(new_filter) < self.get("dataFormat.memoryCells")
self.pulse_filter = new_filter
def _load_constant_to_gpu(self, constant, constant_data):
assert constant is DsscConstants.Offset
self.gpu_runner.load_offset_map(constant_data)
......
......@@ -51,6 +51,7 @@ class BaseCorrection(PythonDevice):
"dataFormat.outputAxisOrder",
"dataFormat.outputImageDtype",
"dataFormat.overrideInputAxisOrder",
"frameFilter",
"preview.enable",
"preview.index",
"preview.selectionMode",
......@@ -60,6 +61,7 @@ class BaseCorrection(PythonDevice):
_schema_cache_fields = {
"doAnything",
"constantParameters.memoryCells",
"dataFormat.filteredFrames",
"dataFormat.memoryCells",
"dataFormat.pixelsX",
"dataFormat.pixelsY",
......@@ -85,7 +87,7 @@ class BaseCorrection(PythonDevice):
axis_lengths = {
"x": self._schema_cache["dataFormat.pixelsX"],
"y": self._schema_cache["dataFormat.pixelsY"],
"c": self._schema_cache["dataFormat.memoryCells"],
"c": self._schema_cache["dataFormat.filteredFrames"],
}
return tuple(
axis_lengths[axis]
......@@ -200,16 +202,20 @@ class BaseCorrection(PythonDevice):
.defaultValue([])
.commit(),
STRING_ELEMENT(expected)
.key("pulseFilter")
.displayedName("[disabled] Pulse filter")
.key("frameFilter")
.displayedName("Frame filter")
.description(
"Filter pulses: will be evaluated as array of indices to keep from "
"data. Can be anything which can be turned into numpy uint16 array. "
"Numpy is available as np. Take care not to include duplicates. If "
"empty, will not filter at all."
"The frame filter - if set - slices the input data. Frames not in the "
"filter will be discarded before any processing happens (will not get "
"to dataOutput or preview. Note that this filter goes by frame index "
"rather than cell ID or pulse ID; set accordingly (and only if you "
"know what you are doing as it can break everything). Filter is "
"evaluated into numpy uint16 array; a valid filter could be "
"'np.arange(0, 352, 2)'."
)
.readOnly()
.initialValue("")
.assignmentOptional()
.defaultValue("")
.reconfigurable()
.commit(),
UINT32_ELEMENT(expected)
.key("outputShmemBufferSize")
......@@ -293,6 +299,13 @@ class BaseCorrection(PythonDevice):
.assignmentOptional()
.noDefaultValue()
.commit(),
UINT32_ELEMENT(expected)
.key("dataFormat.filteredFrames")
.displayedName("Frames after filter")
.description("Number of frames left after applying frame filter")
.readOnly()
.initialValue(0)
.commit(),
STRING_ELEMENT(expected)
.key("dataFormat.outputAxisOrder")
.displayedName("Output axis order")
......@@ -471,8 +484,9 @@ class BaseCorrection(PythonDevice):
self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype"))
self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype"))
self.gpu_runner = None # must call _update_shapes()
self.registerInitialFunction(self._update_shapes)
self.gpu_runner = None # must call _update_buffers()
self.registerInitialFunction(self._update_frame_filter)
self.registerInitialFunction(self._update_buffers)
self.calcat_friend = self._calcat_friend_class(
self, pathlib.Path.cwd() / "calibration-client-secrets.json"
......@@ -482,7 +496,6 @@ class BaseCorrection(PythonDevice):
self._correction_flag_preview = self._correction_flag_class.NONE
self._shmem_buffer = None
self._has_updated_shapes = False
self._processing_time_ema = utils.ExponentialMovingAverage(alpha=0.3)
self._rate_tracker = utils.WindowRateTracker()
......@@ -557,6 +570,9 @@ class BaseCorrection(PythonDevice):
self._schema_cache[path] = update.get(path)
# TODO: pulse filter (after reimplementing)
if update.has("frameFilter"):
with self._buffer_lock:
self._update_frame_filter()
if any(
update.has(shape_param)
for shape_param in (
......@@ -564,9 +580,11 @@ class BaseCorrection(PythonDevice):
"dataFormat.pixelsY",
"dataFormat.memoryCells",
"constantParameters.memoryCells",
"frameFilter",
)
):
self._update_shapes()
with self._buffer_lock:
self._update_buffers()
# TODO: only call this if they are changed (is cheap, though)
self._update_correction_flags()
......@@ -645,7 +663,6 @@ class BaseCorrection(PythonDevice):
channel.update()
def _write_combiner_previews(self, channel_data_pairs, train_id, source):
# TODO: take into account updated pulse table after pulse filter
# TODO: send as ImageData (requires updated assembler)
# TODO: allow sending *all* frames for commissioning (request: Jola)
preview_hash = Hash()
......@@ -679,12 +696,27 @@ class BaseCorrection(PythonDevice):
self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}")
self.log.DEBUG(f"Corrections for preview: {str(preview)}")
def _update_shapes(self):
def _update_frame_filter(self):
filter_string = self.get("frameFilter")
if filter_string.strip() == "":
self._frame_filter = None
self.set("dataFormat.filteredFrames", self.get("dataFormat.memoryCells"))
else:
self._frame_filter = np.array(eval(filter_string), dtype=np.uint16)
self.set("dataFormat.filteredFrames", self._frame_filter.size)
if self._frame_filter.min() < 0 or self._frame_filter.max() >= self.get(
"dataFormat.memoryCells"
):
self.log_status_warn("Invalid frame filter set, expect exceptions!")
def _update_buffers(self):
"""(Re)initialize buffers according to expected data shapes"""
self.log.INFO("Updating shapes")
self.log.INFO("Updating buffers according to data shapes")
# reflect the axis reordering in the expected output shape
self.set("dataFormat.inputDataShape", list(self.input_data_shape))
self.set("dataFormat.outputDataShape", list(self.output_data_shape))
self.log.INFO(f"Input shape: {self.input_data_shape}")
self.log.INFO(f"Output shape: {self.output_data_shape}")
if self._shmem_buffer is None:
shmem_buffer_name = self.getInstanceId() + ":dataOutput"
......@@ -704,7 +736,7 @@ class BaseCorrection(PythonDevice):
self.gpu_runner = self._gpu_runner_class(
self.get("dataFormat.pixelsX"),
self.get("dataFormat.pixelsY"),
self.get("dataFormat.memoryCells"),
self.get("dataFormat.filteredFrames"),
int(self.get("constantParameters.memoryCells")),
input_data_dtype=self.input_data_dtype,
output_data_dtype=self.output_data_dtype,
......@@ -719,8 +751,6 @@ class BaseCorrection(PythonDevice):
self.log_status_info(f"Reload constant {constant}")
self._load_constant_to_gpu(constant, data)
self._has_updated_shapes = True
def input_handler(self, input_channel):
"""Main handler for data input: Do a few simple checks to determine whether to
even try processing. If yes, will pass data and information to subclass'
......@@ -779,8 +809,8 @@ class BaseCorrection(PythonDevice):
)
self.set("dataFormat.memoryCells", image_data.shape[0])
with self._buffer_lock:
# TODO: pulse filter update after reimplementation
self._update_shapes()
# TODO: re-validate frame filter against new number of cells
self._update_buffers()
# DataAggregator typically tells us the wrong axis order
if self._schema_cache["dataFormat.overrideInputAxisOrder"]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment