Skip to content
Snippets Groups Projects
Commit b2869ca7 authored by David Hammer's avatar David Hammer
Browse files

Merge branch 'gh2-add-gain-streak' into 'master'

GH2 updates after initial HIREX test

See merge request !91
parents b8a7b338 6f3fa272
No related branches found
No related tags found
1 merge request!91GH2 updates after initial HIREX test
...@@ -48,6 +48,25 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher): ...@@ -48,6 +48,25 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
.assignmentOptional() .assignmentOptional()
.defaultValue("") .defaultValue("")
.commit(), .commit(),
STRING_ELEMENT(expected)
.key("assemblyMode")
.displayedName("Assembly mode")
.description(
"Previews for 25 μm GOTTHARD-II are generally generated by "
"interleaving the preview outputs of the two constituent 50 μm "
"modules. However, the frame sum previews is temporal (sum across all "
"pixels per frame), so this preview should be the sum of the two "
"constituent previews. Additionally, previews are either 1D (regular "
"ones) or 2D (streak previews) and the latter need special image data "
"wrapping. This 'assembler' can therefore either interleave (1D or 2D) "
"or sum - 'auto' means it will guess which to do based on the primary "
"input source name."
)
.options("auto,interleave1d,interleave2d,sum")
.assignmentOptional()
.defaultValue("auto")
.commit(),
) )
def initialization(self): def initialization(self):
...@@ -59,6 +78,36 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher): ...@@ -59,6 +78,36 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
self._primary_source, self._secondary_source = [ self._primary_source, self._secondary_source = [
row["source"].partition("@")[0] for row in self.get("sources") row["source"].partition("@")[0] for row in self.get("sources")
] ]
# figure out assembly mode and set handler
if self.get("assemblyMode") == "auto":
_, _, source_output = self._primary_source.partition(":")
if source_output.lower().endswith("streak"):
self.set("assemblyMode", "interleave2d")
elif source_output.lower().endswith("sums"):
self.set("assemblyMode", "sum")
else:
self.set("assemblyMode", "interleave1d")
mode = self.get("assemblyMode")
if mode == "interleave1d":
self._do_the_assembly = self._interleave_1d
elif mode == "interleave2d":
self._do_the_assembly = self._interleave_2d
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=True)
)
.commit(),
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
else:
self._do_the_assembly = self._sum_1d
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._interleaving_buffer = np.ma.empty(0, dtype=np.float32) self._interleaving_buffer = np.ma.empty(0, dtype=np.float32)
self._wrap_in_imagedata = False self._wrap_in_imagedata = False
...@@ -91,6 +140,14 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher): ...@@ -91,6 +140,14 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
) )
def on_matched_data(self, train_id, sources): def on_matched_data(self, train_id, sources):
if (
missing_sources := {self._primary_source, self._secondary_source}
- sources.keys()
):
self.log.WARN(
f"Missing preview source(s): {missing_sources}, skipping train"
)
return
for (data, _) in sources.values(): for (data, _) in sources.values():
self._shmem_handler.dereference_shmem_handles(data) self._shmem_handler.dereference_shmem_handles(data)
...@@ -108,83 +165,87 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher): ...@@ -108,83 +165,87 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
mask_1 = False mask_1 = False
mask_2 = False mask_2 = False
# streak preview is in and should be put back into ImageData meta = ChannelMetaData(
wrap_in_imagedata = isinstance(data_1, ImageData) f"{self.getInstanceId()}:output",
if wrap_in_imagedata: Timestamp(Epochstamp(), Trainstamp(train_id)),
data_1 = data_1.getData() )
data_2 = data_2.getData() output_hash = self._do_the_assembly(data_1, mask_1, data_2, mask_2)
mask_1 = mask_1.getData() self.output.write(output_hash, meta, copyAllData=False)
mask_2 = mask_2.getData() self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _interleave_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1) image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2) image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# now to figure out the interleaving # now to figure out the interleaving
axis = 0 if data_1.ndim == 1 else 1 out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 0)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, axis)
if self._interleaving_buffer.shape != out_shape: if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array( self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32), np.empty(shape=out_shape, dtype=np.float32),
mask=False, mask=False,
) )
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], axis) utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 0)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], axis) utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 0)
# TODO: replace this part with preview friend # TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data) self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
if self._wrap_in_imagedata != wrap_in_imagedata: return Hash(
# we may need to re-inject output channel to satisfy GUI :D "image.data",
schema_update = Schema() self._interleaving_buffer.data,
( "image.mask",
OUTPUT_CHANNEL(schema_update) self._interleaving_buffer.mask,
.key("output") )
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=wrap_in_imagedata) def _interleave_2d(self, data_1, mask_1, data_2, mask_2):
) data_1 = data_1.getData()
.commit(), data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 1)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
) )
self.updateSchema(schema_update) utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 1)
self.output = self._ss.getOutputChannel("output") utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 1)
self._wrap_in_imagedata = wrap_in_imagedata # TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
meta = ChannelMetaData( return Hash(
f"{self.getInstanceId()}:output", "image.data",
Timestamp(Epochstamp(), Trainstamp(train_id)), ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
) )
if wrap_in_imagedata:
self.output.write(
Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
),
meta,
copyAllData=False,
)
else:
self.output.write(
Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
),
meta,
copyAllData=False,
)
self.output.update(safeNDArray=True)
self.info["sent"] += 1 def _sum_1d(self, data_1, mask_1, data_2, mask_2):
self.info["trainId"] = train_id image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
self.rate_out.update() image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# don't bother with self._interleaving_buffer
res = image_1 + image_2
return Hash(
"image.data",
res.data,
"image.mask",
res.mask,
)
...@@ -84,7 +84,7 @@ class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner): ...@@ -84,7 +84,7 @@ class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
@property @property
def preview_data_views(self): def preview_data_views(self):
return (self.input_data, self.processed_data) return (self.input_data, self.input_gain_stage, self.processed_data)
def load_constant(self, constant_type, data): def load_constant(self, constant_type, data):
if constant_type is Constants.LUTGotthard2: if constant_type is Constants.LUTGotthard2:
...@@ -261,7 +261,7 @@ class Gotthard2Correction(base_correction.BaseCorrection): ...@@ -261,7 +261,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
_cell_table_path = "data.memoryCell" _cell_table_path = "data.memoryCell"
_pulse_table_path = None _pulse_table_path = None
_warn_memory_cell_range = False # for now, receiver always writes 255 _warn_memory_cell_range = False # for now, receiver always writes 255
_preview_outputs = ["outputStreak"] _preview_outputs = ["outputStreak", "outputGainStreak"]
_cuda_pin_buffers = False _cuda_pin_buffers = False
@staticmethod @staticmethod
...@@ -302,7 +302,7 @@ class Gotthard2Correction(base_correction.BaseCorrection): ...@@ -302,7 +302,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
base_correction.add_preview_outputs( base_correction.add_preview_outputs(
expected, Gotthard2Correction._preview_outputs expected, Gotthard2Correction._preview_outputs
) )
for channel in ("outputRaw", "outputCorrected", "outputFrameSums"): for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
# add these "manually" as the automated bits wrap ImageData # add these "manually" as the automated bits wrap ImageData
( (
OUTPUT_CHANNEL(expected) OUTPUT_CHANNEL(expected)
...@@ -425,6 +425,7 @@ class Gotthard2Correction(base_correction.BaseCorrection): ...@@ -425,6 +425,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
warn(preview_warning) warn(preview_warning)
( (
preview_raw, preview_raw,
preview_gain,
preview_corrected, preview_corrected,
) = self.kernel_runner.compute_previews(preview_slice_index) ) = self.kernel_runner.compute_previews(preview_slice_index)
...@@ -441,6 +442,7 @@ class Gotthard2Correction(base_correction.BaseCorrection): ...@@ -441,6 +442,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp")) timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
for channel, data in ( for channel, data in (
("outputRaw", preview_raw), ("outputRaw", preview_raw),
("outputGain", preview_gain),
("outputCorrected", preview_corrected), ("outputCorrected", preview_corrected),
("outputFrameSums", frame_sums), ("outputFrameSums", frame_sums),
): ):
...@@ -454,7 +456,7 @@ class Gotthard2Correction(base_correction.BaseCorrection): ...@@ -454,7 +456,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
), ),
timestamp=timestamp, timestamp=timestamp,
) )
self._preview_friend.write_outputs(metadata, buffer_array) self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
def _load_constant_to_runner(self, constant, constant_data): def _load_constant_to_runner(self, constant, constant_data):
self.kernel_runner.load_constant(constant, constant_data) self.kernel_runner.load_constant(constant, constant_data)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# cython: cdivision=True # cython: cdivision=True
# cython: wrapararound=False # cython: wrapararound=False
from libc.math cimport isinf, isnan
# TODO: get these automatically from enum definition # TODO: get these automatically from enum definition
cdef unsigned char NONE = 0 cdef unsigned char NONE = 0
cdef unsigned char LUT = 1 cdef unsigned char LUT = 1
...@@ -47,4 +49,7 @@ def correct( ...@@ -47,4 +49,7 @@ def correct(
if (flags & GAIN): if (flags & GAIN):
res /= gain_map[gain, cell, x] res /= gain_map[gain, cell, x]
if isnan(res) or isinf(res):
res = badpixel_fill_value
output[frame, x] = res output[frame, x] = res
...@@ -117,7 +117,7 @@ class PreviewFriend: ...@@ -117,7 +117,7 @@ class PreviewFriend:
] ]
self.reconfigure(device._parameters) self.reconfigure(device._parameters)
def write_outputs(self, timestamp, *datas, inplace=True, source=None): def write_outputs(self, timestamp, *datas, inplace=True):
"""Applies GUI-friendly preview settings (replace NaN, downsample, wrap as """Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
ImageData) and writes to output channels. Make sure datas length matches number ImageData) and writes to output channels. Make sure datas length matches number
of channels!""" of channels!"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment