Skip to content
Snippets Groups Projects
Commit 6f3fa272 authored by David Hammer's avatar David Hammer
Browse files

GH2 updates after initial HIREX test

parent b8a7b338
No related branches found
No related tags found
1 merge request!91GH2 updates after initial HIREX test
......@@ -48,6 +48,25 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
.assignmentOptional()
.defaultValue("")
.commit(),
STRING_ELEMENT(expected)
.key("assemblyMode")
.displayedName("Assembly mode")
.description(
"Previews for 25 μm GOTTHARD-II are generally generated by "
"interleaving the preview outputs of the two constituent 50 μm "
"modules. However, the frame sum previews is temporal (sum across all "
"pixels per frame), so this preview should be the sum of the two "
"constituent previews. Additionally, previews are either 1D (regular "
"ones) or 2D (streak previews) and the latter need special image data "
"wrapping. This 'assembler' can therefore either interleave (1D or 2D) "
"or sum - 'auto' means it will guess which to do based on the primary "
"input source name."
)
.options("auto,interleave1d,interleave2d,sum")
.assignmentOptional()
.defaultValue("auto")
.commit(),
)
def initialization(self):
......@@ -59,6 +78,36 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
self._primary_source, self._secondary_source = [
row["source"].partition("@")[0] for row in self.get("sources")
]
# figure out assembly mode and set handler
if self.get("assemblyMode") == "auto":
_, _, source_output = self._primary_source.partition(":")
if source_output.lower().endswith("streak"):
self.set("assemblyMode", "interleave2d")
elif source_output.lower().endswith("sums"):
self.set("assemblyMode", "sum")
else:
self.set("assemblyMode", "interleave1d")
mode = self.get("assemblyMode")
if mode == "interleave1d":
self._do_the_assembly = self._interleave_1d
elif mode == "interleave2d":
self._do_the_assembly = self._interleave_2d
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=True)
)
.commit(),
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
else:
self._do_the_assembly = self._sum_1d
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._interleaving_buffer = np.ma.empty(0, dtype=np.float32)
self._wrap_in_imagedata = False
......@@ -91,6 +140,14 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
)
def on_matched_data(self, train_id, sources):
if (
missing_sources := {self._primary_source, self._secondary_source}
- sources.keys()
):
self.log.WARN(
f"Missing preview source(s): {missing_sources}, skipping train"
)
return
for (data, _) in sources.values():
self._shmem_handler.dereference_shmem_handles(data)
......@@ -108,83 +165,87 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
mask_1 = False
mask_2 = False
# streak preview is in and should be put back into ImageData
wrap_in_imagedata = isinstance(data_1, ImageData)
if wrap_in_imagedata:
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
)
output_hash = self._do_the_assembly(data_1, mask_1, data_2, mask_2)
self.output.write(output_hash, meta, copyAllData=False)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _interleave_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# now to figure out the interleaving
axis = 0 if data_1.ndim == 1 else 1
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, axis)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 0)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 0)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 0)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
if self._wrap_in_imagedata != wrap_in_imagedata:
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=wrap_in_imagedata)
)
.commit(),
return Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
)
def _interleave_2d(self, data_1, mask_1, data_2, mask_2):
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 1)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
self._wrap_in_imagedata = wrap_in_imagedata
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 1)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 1)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
return Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
)
if wrap_in_imagedata:
self.output.write(
Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
),
meta,
copyAllData=False,
)
else:
self.output.write(
Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
),
meta,
copyAllData=False,
)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _sum_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# don't bother with self._interleaving_buffer
res = image_1 + image_2
return Hash(
"image.data",
res.data,
"image.mask",
res.mask,
)
......@@ -84,7 +84,7 @@ class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
@property
def preview_data_views(self):
return (self.input_data, self.processed_data)
return (self.input_data, self.input_gain_stage, self.processed_data)
def load_constant(self, constant_type, data):
if constant_type is Constants.LUTGotthard2:
......@@ -261,7 +261,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
_cell_table_path = "data.memoryCell"
_pulse_table_path = None
_warn_memory_cell_range = False # for now, receiver always writes 255
_preview_outputs = ["outputStreak"]
_preview_outputs = ["outputStreak", "outputGainStreak"]
_cuda_pin_buffers = False
@staticmethod
......@@ -302,7 +302,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
base_correction.add_preview_outputs(
expected, Gotthard2Correction._preview_outputs
)
for channel in ("outputRaw", "outputCorrected", "outputFrameSums"):
for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
# add these "manually" as the automated bits wrap ImageData
(
OUTPUT_CHANNEL(expected)
......@@ -425,6 +425,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
warn(preview_warning)
(
preview_raw,
preview_gain,
preview_corrected,
) = self.kernel_runner.compute_previews(preview_slice_index)
......@@ -441,6 +442,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
for channel, data in (
("outputRaw", preview_raw),
("outputGain", preview_gain),
("outputCorrected", preview_corrected),
("outputFrameSums", frame_sums),
):
......@@ -454,7 +456,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
),
timestamp=timestamp,
)
self._preview_friend.write_outputs(metadata, buffer_array)
self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
def _load_constant_to_runner(self, constant, constant_data):
self.kernel_runner.load_constant(constant, constant_data)
......@@ -2,6 +2,8 @@
# cython: cdivision=True
# cython: wrapararound=False
from libc.math cimport isinf, isnan
# TODO: get these automatically from enum definition
cdef unsigned char NONE = 0
cdef unsigned char LUT = 1
......@@ -47,4 +49,7 @@ def correct(
if (flags & GAIN):
res /= gain_map[gain, cell, x]
if isnan(res) or isinf(res):
res = badpixel_fill_value
output[frame, x] = res
......@@ -117,7 +117,7 @@ class PreviewFriend:
]
self.reconfigure(device._parameters)
def write_outputs(self, timestamp, *datas, inplace=True, source=None):
def write_outputs(self, timestamp, *datas, inplace=True):
"""Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
ImageData) and writes to output channels. Make sure datas length matches number
of channels!"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment