Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • calibration/calng
1 result
Show changes
Commits on Source (12)
TrainMatcher, 2.4.3
TrainMatcher, 2.4.4
calibrationClient, 11.3.0
calibration/geometryDevices, 0.0.1
calibration/calngUtils, 0.0.1
calibration/calngUtils, 0.0.3
......@@ -21,5 +21,5 @@ xarray:
extra-geom:
$(PYPI) extra_geom==1.11.0
calng: cupy jinja2 h5py extra-geom
calng: cupy jinja2 h5py extra-geom xarray
pip install --upgrade .
......@@ -39,6 +39,7 @@ setup(
"DetectorAssembler = calng.DetectorAssembler:DetectorAssembler",
"Gotthard2Assembler = calng.Gotthard2Assembler:Gotthard2Assembler",
"LpdminiSplitter = calng.LpdminiSplitter:LpdminiSplitter",
"SaturationWarningAggregator = calng.SaturationWarningAggregator:SaturationWarningAggregator", # noqa
],
"karabo.middlelayer_device": [
"CalibrationManager = calng.CalibrationManager:CalibrationManager",
......@@ -51,7 +52,8 @@ setup(
"IntegratedIntensity = calng.correction_addons.integrated_intensity:IntegratedIntensity", # noqa
"LitPixelCounter = calng.correction_addons.litpixel_counter:LitPixelCounter [agipd]", # noqa
"Peakfinder9 = calng.correction_addons.peakfinder9:Peakfinder9", # noqa
"RandomFrames = calng.correction_addons.random_frames:RandomFrames", # noqa
"RandomFrames = calng.correction_addons.random_frames:RandomFrames",
"SaturationMonitor = calng.correction_addons.saturation_monitor:SaturationMonitor",# noqa
],
"calng.arbiter_kernel": [
"BooleanCombination = calng.arbiter_kernels.boolean_ops:BooleanCombination", # noqa
......
......@@ -18,6 +18,7 @@ import re
from tornado.httpclient import AsyncHTTPClient, HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
from calngUtils import scene_utils
from karabo.middlelayer import (
KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable,
Slot, Node, Type, Schema, ProxyFactory,
......@@ -306,7 +307,7 @@ class CalibrationManager(DeviceClientBase, Device):
prefix = name[len('browse_schema:'):]
else:
prefix = 'managedKeys'
scene_data = scenes.recursive_subschema_scene(
scene_data = scene_utils.recursive_subschema_scene(
self.deviceId,
self.getDeviceSchema(),
prefix,
......@@ -465,14 +466,6 @@ class CalibrationManager(DeviceClientBase, Device):
defaultValue=[],
accessMode=AccessMode.READONLY)
@Slot(
displayedName='Discover managed devices',
description='',
allowedStates=[State.ACTIVE])
async def discoverManagedDevices(self):
self.state = State.CHANGING
background(self._discover_managed_devices())
@Slot(
displayedName='Apply managed values',
description='Set all managed keys to the values currently active on '
......@@ -581,7 +574,6 @@ class CalibrationManager(DeviceClientBase, Device):
if info['type'] == 'device':
self._check_new_device(instance_id, info['serverId'],
info['classId'])
self._update_managed_devices()
@slot
def slotInstanceGone(self, instance_id, info):
......@@ -590,12 +582,7 @@ class CalibrationManager(DeviceClientBase, Device):
super().slotInstanceGone(instance_id, info)
if info['type'] == 'device':
self._daq_device_ids.discard(instance_id)
self._domain_device_ids.discard(instance_id)
self._managed_device_ids.discard(instance_id)
self._correction_device_ids.discard(instance_id)
self._assembler_device_ids.discard(instance_id)
self._update_managed_devices()
self._remove_managed_device(instance_id)
async def _async_init(self):
# Set-up Tornado.
......@@ -651,6 +638,16 @@ class CalibrationManager(DeviceClientBase, Device):
# This device is also an assembler
self._assembler_device_ids.add(device_id)
self._update_managed_devices()
def _remove_managed_device(self, device_id):
self._daq_device_ids.discard(device_id)
self._domain_device_ids.discard(device_id)
self._managed_device_ids.discard(device_id)
self._correction_device_ids.discard(device_id)
self._assembler_device_ids.discard(device_id)
self._update_managed_devices()
async def _check_topology(self):
for i in range(10):
try:
......@@ -673,34 +670,30 @@ class CalibrationManager(DeviceClientBase, Device):
self._check_new_device(device_id, devices[device_id, 'serverId'],
devices[device_id, 'classId'])
self._update_managed_devices(True)
async def _discover_managed_devices(self):
self._daq_device_ids.clear()
self._domain_device_ids.clear()
self._managed_device_ids.clear()
self._correction_device_ids.clear()
self._assembler_device_ids.clear()
async def _ping_managed_devices(self):
device_ids = sorted(self._managed_device_ids)
await self._check_topology()
device_infos = await gather(
*[wait_for(getInstanceInfo(device_id), timeout=5)
for device_id in device_ids],
return_exceptions=True)
self.state = State.ACTIVE
for device_id, device_info in zip(device_ids, device_infos):
if isinstance(device_info, AsyncTimeoutError):
self._remove_managed_device(device_id)
async def _delayed_managed_devices_update(self):
await sleep(1.0) # Throttle updates to at most once a second.
self.managedDevices = sorted(self._managed_device_ids)
self._managed_devices_updater = None # Clear task again.
def _update_managed_devices(self, forced=False):
def _update_managed_devices(self):
if self._managed_devices_updater is not None:
# Update already in progress, ignore.
return
if forced or len(self._managed_device_ids) != len(self.managedDevices):
# Trigger an update either if forced or the number of
# devices changed.
self._managed_devices_updater = background(
self._delayed_managed_devices_update())
self._managed_devices_updater = background(
self._delayed_managed_devices_update())
async def _get_shared_keys(self, device_ids, keys):
"""Find the most common property values on devices."""
......@@ -1073,7 +1066,7 @@ class CalibrationManager(DeviceClientBase, Device):
# need are in there, and obtain their API names.
for host, req_names in hosts.items():
# Retrieve hostname for nice error messages.
hostname = urlparse(self._server_hosts[name]).hostname
hostname = urlparse(host).hostname
try:
reply = await to_asyncio_future(self._http_client.fetch(
......@@ -1277,6 +1270,9 @@ class CalibrationManager(DeviceClientBase, Device):
'listed in the device servers '
'configuration')
# Ping all known managed devices to make sure they're alive.
await self._ping_managed_devices()
# Instantiate modules.
modules_by_group = defaultdict(list)
correct_device_id_by_module = {}
......
......@@ -48,6 +48,25 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
.assignmentOptional()
.defaultValue("")
.commit(),
STRING_ELEMENT(expected)
.key("assemblyMode")
.displayedName("Assembly mode")
.description(
"Previews for 25 μm GOTTHARD-II are generally generated by "
"interleaving the preview outputs of the two constituent 50 μm "
"modules. However, the frame sum previews is temporal (sum across all "
"pixels per frame), so this preview should be the sum of the two "
"constituent previews. Additionally, previews are either 1D (regular "
"ones) or 2D (streak previews) and the latter need special image data "
"wrapping. This 'assembler' can therefore either interleave (1D or 2D) "
"or sum - 'auto' means it will guess which to do based on the primary "
"input source name."
)
.options("auto,interleave1d,interleave2d,sum")
.assignmentOptional()
.defaultValue("auto")
.commit(),
)
def initialization(self):
......@@ -59,6 +78,36 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
self._primary_source, self._secondary_source = [
row["source"].partition("@")[0] for row in self.get("sources")
]
# figure out assembly mode and set handler
if self.get("assemblyMode") == "auto":
_, _, source_output = self._primary_source.partition(":")
if source_output.lower().endswith("streak"):
self.set("assemblyMode", "interleave2d")
elif source_output.lower().endswith("sums"):
self.set("assemblyMode", "sum")
else:
self.set("assemblyMode", "interleave1d")
mode = self.get("assemblyMode")
if mode == "interleave1d":
self._do_the_assembly = self._interleave_1d
elif mode == "interleave2d":
self._do_the_assembly = self._interleave_2d
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=True)
)
.commit(),
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
else:
self._do_the_assembly = self._sum_1d
self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver()
self._interleaving_buffer = np.ma.empty(0, dtype=np.float32)
self._wrap_in_imagedata = False
......@@ -91,6 +140,14 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
)
def on_matched_data(self, train_id, sources):
if (
missing_sources := {self._primary_source, self._secondary_source}
- sources.keys()
):
self.log.WARN(
f"Missing preview source(s): {missing_sources}, skipping train"
)
return
for (data, _) in sources.values():
self._shmem_handler.dereference_shmem_handles(data)
......@@ -108,83 +165,87 @@ class Gotthard2Assembler(TrainMatcher.TrainMatcher):
mask_1 = False
mask_2 = False
# streak preview is in and should be put back into ImageData
wrap_in_imagedata = isinstance(data_1, ImageData)
if wrap_in_imagedata:
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
)
output_hash = self._do_the_assembly(data_1, mask_1, data_2, mask_2)
self.output.write(output_hash, meta, copyAllData=False)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _interleave_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# now to figure out the interleaving
axis = 0 if data_1.ndim == 1 else 1
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, axis)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 0)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], axis)
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 0)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 0)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
if self._wrap_in_imagedata != wrap_in_imagedata:
# we may need to re-inject output channel to satisfy GUI :D
schema_update = Schema()
(
OUTPUT_CHANNEL(schema_update)
.key("output")
.dataSchema(
schemas.preview_schema(wrap_image_in_imagedata=wrap_in_imagedata)
)
.commit(),
return Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
)
def _interleave_2d(self, data_1, mask_1, data_2, mask_2):
data_1 = data_1.getData()
data_2 = data_2.getData()
mask_1 = mask_1.getData()
mask_2 = mask_2.getData()
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
out_shape = utils.interleaving_buffer_shape(data_1.shape, 2, 1)
if self._interleaving_buffer.shape != out_shape:
self._interleaving_buffer = np.ma.masked_array(
np.empty(shape=out_shape, dtype=np.float32),
mask=False,
)
self.updateSchema(schema_update)
self.output = self._ss.getOutputChannel("output")
self._wrap_in_imagedata = wrap_in_imagedata
utils.set_on_axis(self._interleaving_buffer, image_1, np.index_exp[0::2], 1)
utils.set_on_axis(self._interleaving_buffer, image_2, np.index_exp[1::2], 1)
# TODO: replace this part with preview friend
self._interleaving_buffer.mask |= ~np.isfinite(self._interleaving_buffer.data)
meta = ChannelMetaData(
f"{self.getInstanceId()}:output",
Timestamp(Epochstamp(), Trainstamp(train_id)),
return Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
)
if wrap_in_imagedata:
self.output.write(
Hash(
"image.data",
ImageData(
self._interleaving_buffer.data,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
"image.mask",
ImageData(
self._interleaving_buffer.mask,
Dims(*self._interleaving_buffer.shape),
Encoding.GRAY,
bitsPerPixel=32,
),
),
meta,
copyAllData=False,
)
else:
self.output.write(
Hash(
"image.data",
self._interleaving_buffer.data,
"image.mask",
self._interleaving_buffer.mask,
),
meta,
copyAllData=False,
)
self.output.update(safeNDArray=True)
self.info["sent"] += 1
self.info["trainId"] = train_id
self.rate_out.update()
def _sum_1d(self, data_1, mask_1, data_2, mask_2):
image_1 = np.ma.masked_array(data=data_1, mask=mask_1)
image_2 = np.ma.masked_array(data=data_2, mask=mask_2)
# don't bother with self._interleaving_buffer
res = image_1 + image_2
return Hash(
"image.data",
res.data,
"image.mask",
res.mask,
)
import numpy as np
import xarray as xr
from karabo.bound import (
BOOL_ELEMENT,
FLOAT_ELEMENT,
KARABO_CLASSINFO,
NODE_ELEMENT,
OVERWRITE_ELEMENT,
UINT32_ELEMENT,
UINT64_ELEMENT,
Epochstamp,
Hash,
ImageData,
Timestamp,
Trainstamp,
)
from .DetectorAssembler import DetectorAssembler
from ._version import version as deviceVersion
@KARABO_CLASSINFO("SaturationWarningAggregator", deviceVersion)
class SaturationWarningAggregator(DetectorAssembler):
@staticmethod
def expectedParameters(expected):
(
OVERWRITE_ELEMENT(expected)
.key("imageDataPath")
.setNewDefaultValue("saturationMonitor.maxImage")
.commit(),
OVERWRITE_ELEMENT(expected)
.key("imageMaskPath")
.setNewDefaultValue("")
.commit(),
# The reason for the node is compatibility with the saturation
# monitor add-on in calng. That way a MDL device can use this
# aggregator, the add-on or the SaturationMonitor from the
# ImageProcessor package with the same code
NODE_ELEMENT(expected)
.key("saturationMonitor")
.commit(),
BOOL_ELEMENT(expected)
.key("saturationMonitor.warning")
.readOnly()
.initialValue(False)
.commit(),
BOOL_ELEMENT(expected)
.key("saturationMonitor.alarm")
.readOnly()
.initialValue(False)
.commit(),
UINT32_ELEMENT(expected)
.key("saturationMonitor.warnCount")
.description(
"Total number of pixels above warning threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
UINT32_ELEMENT(expected)
.key("saturationMonitor.alarmCount")
.description(
"Total number of pixels above alarm threshold. Each pixel is "
"only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
UINT64_ELEMENT(expected)
.key("saturationMonitor.trainId")
.description(
"Total number of pixels above alarm threshold. Each pixel is "
"only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells)."
)
.readOnly()
.initialValue(0)
.commit(),
FLOAT_ELEMENT(expected)
.key("saturationMonitor.maxValue")
.description("Max pixel value in latest train with warning or alarm.")
.readOnly()
.initialValue(0)
.commit(),
)
def on_matched_data(self, tid, sources):
my_timestamp = Timestamp(Epochstamp(), Trainstamp(tid))
warn = False
alarm = False
warn_count = 0
alarm_count = 0
max_value = 0
processed_any = False
# TODO: look at image, too (DetectorAssembler functionality needed)
image_datas, module_indices = [], []
for source, (data, timestamp) in sources.items():
if not data.has("saturationMonitor"):
continue
processed_any = True
warn = warn or data["saturationMonitor.warning"]
alarm = alarm or data["saturationMonitor.alarm"]
warn_count += data["saturationMonitor.warnCount"]
alarm_count += data["saturationMonitor.alarmCount"]
if data["saturationMonitor.maxValue"] > max_value:
max_value = data["saturationMonitor.maxValue"]
# this part like DetectorAssembler
image_data = data[self._image_data_path]
if isinstance(image_data, ImageData):
image_data = image_data.getData()
# note: not using image mask, let preview handler zero out NaN
image_datas.append(image_data.astype(np.float32, copy=False))
module_indices.append(self._source_to_index[source])
if not processed_any:
self.log.WARN(
"No sources in match had 'saturationMonitor', "
"please check upstream configuration!"
)
# handle warning property updates
update = Hash()
if warn != self.get("saturationMonitor.warning"):
update["saturationMonitor.warning"] = warn
if alarm != self.get("saturationMonitor.alarm"):
update["saturationMonitor.alarm"] = alarm
if warn or alarm:
update["saturationMonitor.warnCount"] = warn_count
update["saturationMonitor.alarmCount"] = alarm_count
update["saturationMonitor.trainId"] = tid
update["saturationMonitor.maxValue"] = max_value
if update:
self.set(update)
# handle maybe sending preview
if self._geometry is None:
self.log.WARN("Have not received a geometry yet, cannot show preview")
elif warn or alarm:
dims = ["module", "slow_scan", "fast_scan"]
coords = {"module": module_indices}
assembled_data, _ = self._geometry.position_modules(
xr.DataArray(
data=image_datas,
dims=dims,
coords=coords,
)
)
self._preview_friend.write_outputs(my_timestamp, assembled_data)
self.info["sent"] += 1
self.info["trainId"] = tid
self.rate_out.update()
......@@ -11,7 +11,14 @@ from timeit import default_timer
import dateutil.parser
import numpy as np
from geometryDevices import utils as geom_utils
from calngUtils import device as device_utils, misc, shmem_utils, timing, trackers
from calngUtils import (
device as device_utils,
misc,
scene_utils,
shmem_utils,
timing,
trackers,
)
from karabo.bound import (
BOOL_ELEMENT,
DOUBLE_ELEMENT,
......@@ -954,7 +961,7 @@ class BaseCorrection(PythonDevice):
prefix = name[len("browse_schema:") :]
else:
prefix = "managed"
payload["data"] = scenes.recursive_subschema_scene(
payload["data"] = scene_utils.recursive_subschema_scene(
self.getInstanceId(),
self.getFullSchema(),
prefix,
......
import numpy as np
from karabo.bound import (
BOOL_ELEMENT,
FLOAT_ELEMENT,
IMAGEDATA_ELEMENT,
NODE_ELEMENT,
UINT32_ELEMENT,
UINT64_ELEMENT,
Dims,
Encoding,
ImageData,
)
from .base_addon import BaseCorrectionAddon
def maybe_get(a):
# TODO: proper check for cupy
if hasattr(a, "get"):
return a.get()
return a
class SaturationMonitor(BaseCorrectionAddon):
def __init__(self, config):
global cupy
import cupy
self._alarmThreshold = config["alarmThreshold"]
self._warnThreshold = config["warnThreshold"]
self._alarmMaxCount = config["alarmMaxCount"]
self._warnMaxCount = config["warnMaxCount"]
self._frameAxis = config["frameAxis"]
def reconfigure(self, changed_config):
if changed_config.has("alarmThreshold"):
self._alarmThreshold = changed_config["alarmThreshold"]
if changed_config.has("warnThreshold"):
self._warnThreshold = changed_config["warnThreshold"]
if changed_config.has("alarmMaxCount"):
self._alarmMaxCount = changed_config["alarmMaxCount"]
if changed_config.has("warnMaxCount"):
self._warnMaxCount = changed_config["warnMaxCount"]
if changed_config.has("frameAxis"):
self._frameAxis = changed_config["frameAxis"]
@staticmethod
def extend_output_schema(schema):
(
NODE_ELEMENT(schema)
.key("saturationMonitor")
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.warning")
.readOnly()
.commit(),
BOOL_ELEMENT(schema)
.key("saturationMonitor.alarm")
.readOnly()
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.warnCount")
.description(
"Total number of pixels above warning threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.commit(),
UINT32_ELEMENT(schema)
.key("saturationMonitor.alarmCount")
.description(
"Total number of pixels above alarm threshold. Each pixel "
"is only counted once even if it exceeds the threshold in "
"multiple frames (/ memory cells = given axis)."
)
.readOnly()
.commit(),
FLOAT_ELEMENT(schema)
.key("saturationMonitor.maxValue")
.readOnly()
.commit(),
# TODO: switch to image data
IMAGEDATA_ELEMENT(schema)
.key("saturationMonitor.alarmImage")
.commit(),
)
@staticmethod
def extend_device_schema(schema, prefix):
(
FLOAT_ELEMENT(schema)
.key(f"{prefix}.alarmThreshold")
.description("Alarm threshold per pixel.")
.tags("managed")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
FLOAT_ELEMENT(schema)
.key(f"{prefix}.warnThreshold")
.description("Warning threshold per pixel.")
.tags("managed")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
UINT64_ELEMENT(schema)
.key(f"{prefix}.alarmMaxCount")
.description("Maximum number of pixel above alarm threshold.")
.tags("managed")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
UINT64_ELEMENT(schema)
.key(f"{prefix}.warnMaxCount")
.description("Maximum number of pixel above warn threshold.")
.tags("managed")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
UINT64_ELEMENT(schema)
.key(f"{prefix}.frameAxis")
.displayedName('Multi-frame axis')
.description("Axis for frames. Used to take the max over this axis.")
.tags("managed")
.assignmentOptional()
.defaultValue(0)
.reconfigurable()
.commit(),
)
def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
# only take the max if data has frames -> more than 2 dimensions
if processed_data.ndim > 2:
max_image = np.nanmax(processed_data, axis=self._frameAxis)
else:
max_image = processed_data
nb_pix_warning = int(np.nansum(max_image > self._warnThreshold))
nb_pix_alarm = int(np.nansum(max_image > self._alarmThreshold))
output_hash["saturationMonitor.warning"] = nb_pix_warning > self._warnMaxCount
output_hash["saturationMonitor.alarm"] = nb_pix_alarm > self._alarmMaxCount
output_hash["saturationMonitor.warnCount"] = nb_pix_warning
output_hash["saturationMonitor.alarmCount"] = nb_pix_alarm
output_hash["saturationMonitor.maxValue"] = float(np.nanmax(max_image))
max_image[max_image <= self._warnThreshold] = 0
output_hash["saturationMonitor.alarmImage"] = ImageData(
maybe_get(max_image), Dims(*max_image.shape), Encoding.GRAY, bitsPerPixel=32
)
......@@ -84,7 +84,7 @@ class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner):
@property
def preview_data_views(self):
return (self.input_data, self.processed_data)
return (self.input_data, self.input_gain_stage, self.processed_data)
def load_constant(self, constant_type, data):
if constant_type is Constants.LUTGotthard2:
......@@ -261,7 +261,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
_cell_table_path = "data.memoryCell"
_pulse_table_path = None
_warn_memory_cell_range = False # for now, receiver always writes 255
_preview_outputs = ["outputStreak"]
_preview_outputs = ["outputStreak", "outputGainStreak"]
_cuda_pin_buffers = False
@staticmethod
......@@ -302,7 +302,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
base_correction.add_preview_outputs(
expected, Gotthard2Correction._preview_outputs
)
for channel in ("outputRaw", "outputCorrected", "outputFrameSums"):
for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"):
# add these "manually" as the automated bits wrap ImageData
(
OUTPUT_CHANNEL(expected)
......@@ -425,6 +425,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
warn(preview_warning)
(
preview_raw,
preview_gain,
preview_corrected,
) = self.kernel_runner.compute_previews(preview_slice_index)
......@@ -441,6 +442,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp"))
for channel, data in (
("outputRaw", preview_raw),
("outputGain", preview_gain),
("outputCorrected", preview_corrected),
("outputFrameSums", frame_sums),
):
......@@ -454,7 +456,7 @@ class Gotthard2Correction(base_correction.BaseCorrection):
),
timestamp=timestamp,
)
self._preview_friend.write_outputs(metadata, buffer_array)
self._preview_friend.write_outputs(metadata, buffer_array, gain_map)
def _load_constant_to_runner(self, constant, constant_data):
self.kernel_runner.load_constant(constant, constant_data)
......@@ -2,6 +2,8 @@
# cython: cdivision=True
# cython: wrapararound=False
from libc.math cimport isinf, isnan
# TODO: get these automatically from enum definition
cdef unsigned char NONE = 0
cdef unsigned char LUT = 1
......@@ -47,4 +49,7 @@ def correct(
if (flags & GAIN):
res /= gain_map[gain, cell, x]
if isnan(res) or isinf(res):
res = badpixel_fill_value
output[frame, x] = res
......@@ -117,7 +117,7 @@ class PreviewFriend:
]
self.reconfigure(device._parameters)
def write_outputs(self, timestamp, *datas, inplace=True, source=None):
def write_outputs(self, timestamp, *datas, inplace=True):
"""Applies GUI-friendly preview settings (replace NaN, downsample, wrap as
ImageData) and writes to output channels. Make sure datas length matches number
of channels!"""
......