Skip to content
Snippets Groups Projects
Commit b2869ca7 authored by David Hammer's avatar David Hammer
Browse files

Merge branch 'gh2-add-gain-streak' into 'master'

GH2 updates after initial HIREX test

See merge request !91
parents b8a7b338 6f3fa272
No related branches found
No related tags found
1 merge request!91GH2 updates after initial HIREX test
......@@ -48,6 +48,25 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
.assignmentOptional()
.defaultValue("")
.commit(),
STRING_ELEMENT(expected)
.key("assemblyMode")
.displayedName("Assembly mode")
.description(
"Previews for 25 μm GOTTHARD-II are generally generated by "
"interleaving the preview outputs of the two constituent 50 μm "
"modules. However, the frame sum previews is temporal (sum across all "
"pixels per frame), so this preview should be the sum of the two "
"constituent previews. Additionally, previews are either 1D (regular "
"ones) or 2D (streak previews) and the latter need special image data "
"wrapping. This 'assembler' can therefore either interleave (1D or 2D) "
"or sum - 'auto' means it will guess which to do based on the primary "
"input source name."
)
.options("auto,interleave1d,interleave2d,sum")
.assignmentOptional()
.defaultValue("auto")
.commit(),
)
def initialization(self):
......@@ -59,6 +78,36 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
self._primary_source, self._secondary_source = [
row["source"].partition("@")[0] for row in self.get("sources")
]
# figure out assembly mode and set handler
if self.get("assemblyMode") == "auto":
_, _, source_output = self._primary_source.partition(":")
if source_output.lower().endswith("streak"):
self.set("assemblyMode", "interleave2d")
elif source_output.lower().endswith("sums"):
self.set("assemblyMode", "sum")
else:
self.set("assemblyMode", "interleave1d")
mode = self.get("assemblyMode")
if mode == "interleave1d":
self._do_the_assembly = self._interleave_1d
elif mode == "interleave2d":
self._do_the_assembly = self._interleave_2d
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=True)
)
.commit(),
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
else:
self._do_the_assembly = self._sum_1d
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._interleaving_buffer = np.ma.empty(0, dtype=np.float32)
self._wrap_in_imagedata = False
......@@ -91,6 +140,14 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
)
def on_matched_data(self, train_id, sources):
if (
missing_sources := {self._primary_source, self._secondary_source}
- sources.keys()
):
self.log.WARN(
f"Missing preview source(s): {missing_sources}, skipping train"
)
return
for (data, _) in sources.values():
self._shmem_handler.dereference_shmem_handles(data)
......@@ -108,83 +165,87 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
mask_1 = False
mask_2 = False
# streak preview is in and should be put back into ImageData
wrap_in_imagedata = isinstance(data_1, ImageData)
if wrap_in_imagedata:
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
)
output_hash = self._do_the_assembly(data_1, mask_1, data_2, mask_2)
self.output.write(output_hash, meta, copyAllData=False)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _interleave_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# now to figure out the interleaving
axis = 0 if data_1.ndim == 1 else 1
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, axis)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 0)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 0)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 0)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
if self._wrap_in_imagedata != wrap_in_imagedata:
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=wrap_in_imagedata)
)
.commit(),
return Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
)
def _interleave_2d(self, data_1, mask_1, data_2, mask_2):
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 1)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
self._wrap_in_imagedata = wrap_in_imagedata
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 1)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 1)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
return Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
)
if wrap_in_imagedata:
self.output.write(
Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
),
meta,
copyAllData=False,
)
else:
self.output.write(
Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
),
meta,
copyAllData=False,
)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _sum_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# don't bother with self._interleaving_buffer
res = image_1 + image_2
return Hash(
"image.data",
res.data,
"image.mask",
res.mask,
)
......@@ -84,7 +84,7 @@ class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
@property
def preview_data_views(self):
return (self.input_data, self.processed_data)
return (self.input_data, self.input_gain_stage, self.processed_data)
def load_constant(self, constant_type, data):
if constant_type is Constants.LUTGotthard2:
......@@ -261,7 +261,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
_cell_table_path = "data.memoryCell"
_pulse_table_path = None
_warn_memory_cell_range = False # for now, receiver always writes 255
_preview_outputs = ["outputStreak"]
_preview_outputs = ["outputStreak", "outputGainStreak"]
_cuda_pin_buffers = False
@staticmethod
......@@ -302,7 +302,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
base_correction.add_preview_outputs(
expected, Gotthard2Correction._preview_outputs
)
for channel in ("outputRaw", "outputCorrected", "outputFrameSums"):
for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
# add these "manually" as the automated bits wrap ImageData
(
OUTPUT_CHANNEL(expected)
......@@ -425,6 +425,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
warn(preview_warning)
(
preview_raw,
preview_gain,
preview_corrected,
) = self.kernel_runner.compute_previews(preview_slice_index)
......@@ -441,6 +442,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
for channel, data in (
("outputRaw", preview_raw),
("outputGain", preview_gain),
("outputCorrected", preview_corrected),
("outputFrameSums", frame_sums),
):
......@@ -454,7 +456,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
),
timestamp=timestamp,
)
self._preview_friend.write_outputs(metadata, buffer_array)
self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
def _load_constant_to_runner(self, constant, constant_data):
self.kernel_runner.load_constant(constant, constant_data)
......@@ -2,6 +2,8 @@
# cython: cdivision=True
# cython: wrapararound=False
from libc.math cimport isinf, isnan
# TODO: get these automatically from enum definition
cdef unsigned char NONE = 0
cdef unsigned char LUT = 1
......@@ -47,4 +49,7 @@ def correct(
if (flags & GAIN):
res /= gain_map[gain, cell, x]
if isnan(res) or isinf(res):
res = badpixel_fill_value
output[frame, x] = res
......@@ -117,7 +117,7 @@ class PreviewFriend:
]
self.reconfigure(device._parameters)
def write_outputs(self, timestamp, *datas, inplace=True, source=None):
def write_outputs(self, timestamp, *datas, inplace=True):
"""Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
ImageData) and writes to output channels. Make sure datas length matches number
of channels!"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment