diff --git a/DEPENDS b/DEPENDS index 6dbf7d2159b1530f96346ea641ac0f5116aacd47..53d4298919b3f4496601016e71e66b28d9ae9d4f 100644 --- a/DEPENDS +++ b/DEPENDS @@ -1,4 +1,4 @@ TrainMatcher, 2.4.5 calibrationClient, 11.3.0 calibration/geometryDevices, 0.0.2 -calibration/calngUtils, 0.0.3 +calibration/calngUtils, 0.0.4 diff --git a/docs/devices.md b/docs/devices.md index bbdaa9b88e13fd1b68c96996869217ce82327f70..242caf3a832a3a10f6ddef6b47c4563ed2c5d348 100644 --- a/docs/devices.md +++ b/docs/devices.md @@ -50,15 +50,13 @@ Therefore, settings related to how this is done should typically be managed via The parameters seen in the "Preview" box on the manager overview scene control: - How to select which frame to preview (for non-negative indices) - - See [Index selection mode](schemas/BaseCorrection.md#preview.selectionMode) - Frame can be extracted directly from image ndarray (`frame` mode), by looking up corresponding frame in cell table (`cell` mode), or by looking up in pulse table (`pulse` mode) - For XTDF detectors, the mapping tables are typically `image.cellId` and `image.pulseId`. Which selection mode makes sense depends on detector, veto pattern, and experimental setup. - If the specified cell or pulse is not found in the respective table, a warning is issued and the first frame is sent instead. - Which frame or statistic to send - - See [Index for preview](schemas/BaseCorrection.md#preview.index) - If the index is non-negative, a single frame is sliced from the image data. - How this frame is found depends on the the [index selection mode](schemas/BaseCorrection.md#preview.selectionMode). + How this frame is found depends on the the index selection mode (previous point). - If the index is negative, a statistic is computed across all frames for a summary preview. Note that this is done individually per pixel and that `NaN` values are ignored. Options are: diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py index 8de76153d77e6b7ebf8386bbba8867fb7e81a532..4ae0ed05eee7663944351f709690c56b975a2458 100644 --- a/src/calng/CalibrationManager.py +++ b/src/calng/CalibrationManager.py @@ -18,7 +18,7 @@ import re from tornado.httpclient import AsyncHTTPClient, HTTPError from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future -from calngUtils import scene_utils +from calngUtils.scene_utils import recursive_subschema_scene from karabo.middlelayer import ( KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable, Slot, Node, Type, Schema, ProxyFactory, @@ -158,8 +158,8 @@ class PreviewLayerRow(Configurable): defaultValue='', accessMode=AccessMode.RECONFIGURABLE) - outputPipeline = String( - displayedName='Output pipeline', + previewName = String( + displayedName='Preview name', defaultValue='', accessMode=AccessMode.RECONFIGURABLE) @@ -307,7 +307,7 @@ class CalibrationManager(DeviceClientBase, Device): prefix = name[len('browse_schema:'):] else: prefix = 'managedKeys' - scene_data = scene_utils.recursive_subschema_scene( + scene_data = recursive_subschema_scene( self.deviceId, self.getDeviceSchema(), prefix, @@ -382,12 +382,6 @@ class CalibrationManager(DeviceClientBase, Device): accessMode=AccessMode.RECONFIGURABLE, assignment=Assignment.MANDATORY) - previewLayers = VectorHash( - displayedName='Preview layers', - rows=PreviewLayerRow, - accessMode=AccessMode.RECONFIGURABLE, - assignment=Assignment.MANDATORY) - @VectorHash( displayedName='Device servers', description='WARNING: It is strongly recommended to perform a full ' @@ -402,6 +396,21 @@ class CalibrationManager(DeviceClientBase, Device): # Switch to UNKNOWN state to suggest operator to restart pipeline. self.state = State.UNKNOWN + previewServer = String( + displayedName='Preview device server', + description='The server (from "Device servers" list) to use for ' + 'preview assemblers', + accessMode=AccessMode.RECONFIGURABLE, + assignment=Assignment.MANDATORY) + + previewLayers = VectorString( + displayedName='Preview layers', + description='List of previews (like raw, corrected, and anything else ' + 'detector-specific) to show. Automatically populated based ' + 'on correction device schema, just listed for debugging.', + defaultValue=[], + accessMode=AccessMode.READONLY) + geometryDevice = String( displayedName='Geometry device', description='Device ID for a geometry device defining the detector ' @@ -598,6 +607,10 @@ class CalibrationManager(DeviceClientBase, Device): # Inject schema for configuration of managed devices. await self._inject_managed_keys() + # Populate preview layers from correction device schema + self.previewLayers = list(self._correction_device_schema + .hash["preview"].getKeys()) + # Populate the device ID sets with what's out there right now. await self._check_topology() @@ -1262,12 +1275,9 @@ class CalibrationManager(DeviceClientBase, Device): # Servers by group and layer. server_by_group = {group: server for group, server, _, _, _, in self.moduleGroups.value} - server_by_layer = {layer: server for layer, _, server - in self.previewLayers.value} - - all_req_servers = set(server_by_group.values()).union( - server_by_layer.values()) + preview_server = self.previewServer.value + all_req_servers = set(server_by_group.values()) | {preview_server} if all_req_servers != up_servers: return self._set_error('One or more device servers are not ' 'listed in the device servers ' @@ -1352,9 +1362,9 @@ class CalibrationManager(DeviceClientBase, Device): )) # Instantiate preview layer assemblers. - for layer, output_pipeline, server in self.previewLayers.value: + for preview_name in self.previewLayers.value: sources = [Hash('select', True, - 'source', f'{device_id}:{output_pipeline}') + 'source', f'{device_id}:preview.{preview_name}.output') for (_, device_id) in correct_device_id_by_module.items()] config = Hash('sources', sources, @@ -1367,9 +1377,10 @@ class CalibrationManager(DeviceClientBase, Device): config[remote_key] = value.value awaitables.append(self._instantiate_device( - server, + preview_server, self._class_ids['assembler'], - self._device_id_templates['assembler'].format(layer=layer), + self._device_id_templates['assembler'].format( + layer=preview_name.upper()), config )) diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py index 1f3fb09d0b6c795f4c14903f5e5ae227c56c49a3..b4f6c06d94be9aa7b1ceea21673951818068ed77 100644 --- a/src/calng/DetectorAssembler.py +++ b/src/calng/DetectorAssembler.py @@ -27,7 +27,8 @@ from karabo.bound import ( ) from TrainMatcher import TrainMatcher -from . import preview_utils, scenes, schemas +from . import scenes, schemas +from .preview_utils import PreviewFriend, PreviewSpec from ._version import version as deviceVersion @@ -73,8 +74,12 @@ def module_index_schema(): @device_utils.with_unsafe_get @KARABO_CLASSINFO("DetectorAssembler", deviceVersion) class DetectorAssembler(TrainMatcher.TrainMatcher): - @staticmethod - def expectedParameters(expected): + _preview_outputs = [ + PreviewSpec("assembled", frame_reduction=False, flip_ss=True, flip_fs=True) + ] + + @classmethod + def expectedParameters(cls, expected): ( OVERWRITE_ELEMENT(expected) .key("availableScenes") @@ -178,18 +183,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): .reconfigurable() .commit(), ) - preview_utils.PreviewFriend.add_schema(expected, "preview", create_node=True) - ( - OVERWRITE_ELEMENT(expected) - .key("preview.flipSS") - .setNewDefaultValue(True) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.flipFS") - .setNewDefaultValue(True) - .commit(), - ) + PreviewFriend.add_schema(expected, cls._preview_outputs) def __init__(self, conf): super().__init__(conf) @@ -204,7 +198,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): def initialization(self): super().initialization() - self._preview_friend = preview_utils.PreviewFriend(self) + self._preview_friend = PreviewFriend(self, self._preview_outputs, "preview") self._image_data_path = self.get("imageDataPath") self._image_mask_path = self.get("imageMaskPath") self._geometry = None @@ -469,8 +463,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self.zmq_output.update() self._preview_friend.write_outputs( - my_timestamp, - np.ma.masked_array(data=assembled_data, mask=assembled_mask), + np.ma.masked_array(data=assembled_data, mask=assembled_mask) ) self._processing_time_tracker.update(default_timer() - ts_start) @@ -495,7 +488,8 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): ): self._need_to_update_source_index_mapping = True - self._preview_friend.reconfigure(conf) + if conf.has("preview"): + self._preview_friend.reconfigure(conf["preview"]) def postReconfigure(self): if self._need_to_update_source_index_mapping: diff --git a/src/calng/LpdminiSplitter.py b/src/calng/LpdminiSplitter.py index 5c32d679e305c8586d40e93457721789602586c5..d2fee86678176a520cab2fe521940942734cc710 100644 --- a/src/calng/LpdminiSplitter.py +++ b/src/calng/LpdminiSplitter.py @@ -105,7 +105,7 @@ class LpdminiSplitter(PythonDevice): ( OUTPUT_CHANNEL(expected) .key(channel_name) - .dataSchema(schemas.xtdf_output_schema()) + .dataSchema(schemas.xtdf_output_schema(use_shmem_handles=True)) .commit(), ) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 414b334211ecda2b7dfaa08b83c8a1477a0d7f35..a11f9b969a696211cc95742115b93dc76139eeef 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -1,7 +1,7 @@ import concurrent.futures -import enum +import contextlib import functools -import gc +import itertools import math import pathlib import threading @@ -13,17 +13,16 @@ import numpy as np from geometryDevices import utils as geom_utils from calngUtils import ( device as device_utils, - misc, scene_utils, shmem_utils, timing, trackers, ) +from calngUtils.misc import ChainHash from karabo.bound import ( BOOL_ELEMENT, DOUBLE_ELEMENT, INPUT_CHANNEL, - INT32_ELEMENT, KARABO_CLASSINFO, NODE_ELEMENT, OUTPUT_CHANNEL, @@ -45,96 +44,33 @@ from karabo.bound import ( ) from karabo.common.api import KARABO_SCHEMA_DISPLAY_TYPE_SCENES as DT_SCENES -from . import preview_utils, schemas, scenes, utils +from . import scenes +from .preview_utils import PreviewFriend +from .utils import StateContext, WarningContextSystem, WarningLampType, subset_of_hash from ._version import version as deviceVersion PROCESSING_STATE_TIMEOUT = 10 -class FramefilterSpecType(enum.Enum): - NONE = "none" - RANGE = "range" - COMMASEPARATED = "commaseparated" - - -class WarningLampType(enum.Enum): - FRAME_FILTER = enum.auto() - MEMORY_CELL_RANGE = enum.auto() - CONSTANT_OPERATING_PARAMETERS = enum.auto() - PREVIEW_SETTINGS = enum.auto() - CORRECTION_RUNNER = enum.auto() - OUTPUT_BUFFER = enum.auto() - GPU_MEMORY = enum.auto() - CALCAT_CONNECTION = enum.auto() - EMPTY_HASH = enum.auto() - MISC_INPUT_DATA = enum.auto() - TRAIN_ID = enum.auto() - TIMESERVER_CONNECTION = enum.auto() +class TrainFromTheFutureException(BaseException): + pass +@device_utils.with_config_overlay @device_utils.with_unsafe_get @KARABO_CLASSINFO("BaseCorrection", deviceVersion) class BaseCorrection(PythonDevice): _available_addons = [] # classes, filled by add_addon_nodes using entry_points _base_output_schema = None # subclass must set constructor _constant_enum_class = None # subclass must set - _correction_flag_class = None # subclass must set (ex.: dssc_gpu.CorrectionFlags) _correction_steps = None # subclass must set _kernel_runner_class = None # subclass must set (ex.: dssc_gpu.DsscGpuRunner) - _kernel_runner_init_args = {} # optional extra args for runner _image_data_path = "image.data" # customize for *some* subclasses _cell_table_path = "image.cellId" _pulse_table_path = "image.pulseId" _warn_memory_cell_range = True # can be disabled for some detectors _cuda_pin_buffers = False - - def _load_constant_to_runner(self, constant, constant_data): - """Subclass must define how to process constants into correction maps and store - into appropriate buffers in (GPU or main) memory.""" - # note: aim to refactor this away as kernel runner handles loading logic - raise NotImplementedError() - - def _successfully_loaded_constant_to_runner(self, constant): - for field_name in self._constant_to_correction_names[constant]: - key = f"corrections.{field_name}.available" - if not self.unsafe_get(key): - self.set(key, True) - - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name} to runner") - - @property - def input_data_shape(self): - """Subclass must define expected input data shape in terms of dataFormat.{ - memoryCells,pixelsX,pixelsY} and any other axes.""" - raise NotImplementedError() - - @property - def output_data_shape(self): - """Shape of corrected image data sent on dataOutput. Depends on data format - parameters (optionally including frame filter).""" - axis_lengths = { - "x": self.unsafe_get("dataFormat.pixelsX"), - "y": self.unsafe_get("dataFormat.pixelsY"), - "f": self.unsafe_get("dataFormat.filteredFrames"), - } - return tuple( - axis_lengths[axis] for axis in self.unsafe_get("dataFormat.outputAxisOrder") - ) - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - ): - """Subclass must define data processing (presumably using the kernel runner). - Will be called by input_handler, which will take care of some common checks and - extracting the parameters given to process_data.""" - raise NotImplementedError() + _use_shmem_handles = True @staticmethod def expectedParameters(expected): @@ -171,57 +107,6 @@ class BaseCorrection(PythonDevice): .defaultValue([]) .commit(), - NODE_ELEMENT(expected) - .key("frameFilter") - .displayedName("Frame filter") - .description( - "The frame filter - if set - slices the input data. Frames not in the " - "filter will be discarded before any processing happens and will not " - "get to dataOutput or preview. Note that this filter goes by frame " - "index rather than cell ID or pulse ID; set accordingly. Handle with " - "care - an invalid filter can prevent all processing. How the filter " - "is specified depends on frameFilter.type. See frameFilter.current to " - "inspect the currently set frame filter array (if any)." - ) - .commit(), - - STRING_ELEMENT(expected) - .key("frameFilter.type") - .tags("managed") - .displayedName("Type") - .description( - "Controls how frameFilter.spec is used. The default value of 'none' " - "means that no filter is set (regardless of frameFilter.spec). " - "'arange' allows between one and three integers separated by ',' which " - "are parsed and passed directly to numpy.arange. 'commaseparated' " - "reads a list of integers separated by commas." - ) - .options(",".join(spectype.value for spectype in FramefilterSpecType)) - .assignmentOptional() - .defaultValue("none") - .reconfigurable() - .commit(), - - STRING_ELEMENT(expected) - .key("frameFilter.spec") - .tags("managed") - .displayedName("Specification") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - VECTOR_UINT32_ELEMENT(expected) - .key("frameFilter.current") - .displayedName("Current filter") - .description( - "This read-only value is used to display the contents of the current " - "frame filter. An empty array means no filtering is done." - ) - .readOnly() - .initialValue([]) - .commit(), - UINT32_ELEMENT(expected) .key("outputShmemBufferSize") .tags("managed") @@ -285,14 +170,6 @@ class BaseCorrection(PythonDevice): .displayedName("Data format (in/out)") .commit(), - STRING_ELEMENT(expected) - .key("dataFormat.inputImageDtype") - .displayedName("Input image data dtype") - .description("The (numpy) dtype to expect for incoming image data.") - .readOnly() - .initialValue("uint16") - .commit(), - STRING_ELEMENT(expected) .key("dataFormat.outputImageDtype") .tags("managed") @@ -304,41 +181,16 @@ class BaseCorrection(PythonDevice): "causes truncation rather than rounding." ) # TODO: consider adding rounding / binning for integer output - .options("float16,float32,uint16") + .options("float32,float16,uint16") .assignmentOptional() .defaultValue("float32") .reconfigurable() .commit(), - # important: determines shape of data as going into correction UINT32_ELEMENT(expected) - .key("dataFormat.pixelsX") - .displayedName("Pixels x") - .description("Number of pixels of image data along X axis") - .assignmentOptional() - .defaultValue(512) - .commit(), - - UINT32_ELEMENT(expected) - .key("dataFormat.pixelsY") - .displayedName("Pixels y") - .description("Number of pixels of image data along Y axis") - .assignmentOptional() - .defaultValue(128) - .commit(), - - UINT32_ELEMENT(expected) - .key("dataFormat.frames") + .key("dataFormat.inputFrames") .displayedName("Frames") - .description("Number of image frames per train in incoming data") - .assignmentOptional() - .defaultValue(1) # subclass will want to set a default value - .commit(), - - UINT32_ELEMENT(expected) - .key("dataFormat.filteredFrames") - .displayedName("Frames after filter") - .description("Number of frames left after applying frame filter") + .description("Number of frames on input") .readOnly() .initialValue(0) .commit(), @@ -349,13 +201,15 @@ class BaseCorrection(PythonDevice): .displayedName("Output axis order") .description( "Axes of main data output can be reordered after correction. Axis " - "order is specified as string consisting of 'x', 'y', and 'f', with " - "the latter indicating the image frame. The default value of 'fxy' " - "puts pixels on the fast axes." + "order is specified as string consisting of 'f' (frames), 'ss' (slow " + "scan), and 'ff' (fast scan). Anything but the default f-ss-fs order " + "implies reordering - this axis ordering is based on the data we get " + "from the detector (/ receiver), so how ss and fs maps to what may be " + "considered x and y is detector-dependent." ) - .options("fxy,fyx,xfy,xyf,yfx,yxf") + .options("f-ss-fs,f-fs-ss,ss-f-fs,ss-fs-f,fs-f-ss,fs-ss-f") .assignmentOptional() - .defaultValue("fxy") + .defaultValue("f-ss-fs") .reconfigurable() .commit(), @@ -363,9 +217,9 @@ class BaseCorrection(PythonDevice): .key("dataFormat.inputDataShape") .displayedName("Input data shape") .description( - "Image data shape in incoming data (from reader / DAQ). This value is " - "computed from pixelsX, pixelsY, and frames - this field just " - "shows what is currently expected." + "Image data shape in incoming data (from reader / DAQ). Updated based " + "on latest train processed. Note that axis order may look wonky due to " + "DAQ quirk." ) .readOnly() .initialValue([]) @@ -375,9 +229,10 @@ class BaseCorrection(PythonDevice): .key("dataFormat.outputDataShape") .displayedName("Output data shape") .description( - "Image data shape for data output from this device. This value is " - "computed from pixelsX, pixelsY, and the size of the frame filter - " - "this field just shows what is currently expected." + "Image data shape for data output from this device. Takes into account " + "axis reordering, if applicable. Even without that, " + "workarounds.overrideInputAxisOrder likely makes this differ from " + "dataFormat.inputDataShape." ) .readOnly() .initialValue([]) @@ -472,50 +327,6 @@ class BaseCorrection(PythonDevice): .commit(), ) - ( - NODE_ELEMENT(expected) - .key("preview") - .displayedName("Preview") - .commit(), - - INT32_ELEMENT(expected) - .key("preview.index") - .tags("managed") - .displayedName("Index (or stat) for preview") - .description( - "If this value is ≥ 0, the corresponding index (frame, cell, or pulse) " - "will be sliced for the preview output. If this value is < 0, preview " - "will be one of the following stats: -1: max, -2: mean, -3: sum, -4: " - "stdev. These stats are computed across frames, ignoring NaN values." - ) - .assignmentOptional() - .defaultValue(0) - .minInc(-4) - .reconfigurable() - .commit(), - - STRING_ELEMENT(expected) - .key("preview.selectionMode") - .tags("managed") - .displayedName("Index selection mode") - .description( - "The value of preview.index can be used in multiple ways, controlled " - "by this value. If this is set to 'frame', preview.index is sliced " - "directly from data. If 'cell' (or 'pulse') is selected, I will look " - "at cell (or pulse) table for the requested cell (or pulse ID). " - "Special (stat) index values <0 are not affected by this." - ) - .options( - ",".join( - spectype.value for spectype in utils.PreviewIndexSelectionMode - ) - ) - .assignmentOptional() - .defaultValue("frame") - .reconfigurable() - .commit(), - ) - # just measurements and counters to display ( UINT64_ELEMENT(expected) @@ -618,24 +429,8 @@ class BaseCorrection(PythonDevice): config["dataOutput.hostname"] = ib_ip super().__init__(config) - self.output_data_dtype = np.dtype(config["dataFormat.outputImageDtype"]) - self.sources = set(config.get("fastSources")) - self.kernel_runner = None # must call _update_buffers to initialize - self._shmem_buffer = None # ditto - self._use_shmem_handles = config.get("useShmemHandles") - self._preview_friend = None - - self._correction_flag_enabled = self._correction_flag_class.NONE - self._correction_flag_preview = self._correction_flag_class.NONE - self._correction_applied_hash = Hash() - # note: does not handle one constant enabling multiple correction steps - # (do we need to generalize "decide if this correction step is available"?) - self._constant_to_correction_names = {} - for (name, _, constants) in self._correction_steps: - for constant in constants: - self._constant_to_correction_names.setdefault(constant, set()).add(name) self._buffer_lock = threading.Lock() self._last_processing_started = 0 # used for processing time and timeout self._last_train_id_processed = 0 # used to keep track (and as fallback) @@ -657,43 +452,36 @@ class BaseCorrection(PythonDevice): constant_data = getattr(self.calcat_friend, friend_fun)( constant ) - self._load_constant_to_runner(constant, constant_data) + self.kernel_runner.load_constant(constant, constant_data) except Exception as e: warn(f"Failed to load {constant}: {e}") - else: - self._successfully_loaded_constant_to_runner(constant) # note: consider if multi-constant buffers need special care self._lock_and_update(aux) - for constant in self._constant_enum_class: - self.KARABO_SLOT( - functools.partial( - constant_override_fun, - friend_fun="get_constant_version", - constant=constant, - preserve_fields=set(), + for constant, (friend_fun, preserved_fields, slot_name) in itertools.product( + self._constant_enum_class, + ( + ("get_constant_version", set(), "loadMostRecent"), + ( + "get_constant_from_constant_version_id", + {"constantVersionId"}, + "overrideConstantFromVersion", ), - slotName=f"foundConstants_{constant.name}_loadMostRecent", - numArgs=0, - ) - self.KARABO_SLOT( - functools.partial( - constant_override_fun, - friend_fun="get_constant_from_constant_version_id", - constant=constant, - preserve_fields={"constantVersionId"}, + ( + "get_constant_from_file", + {"dataFilePath", "dataSetName"}, + "overrideConstantFromFile", ), - slotName=f"foundConstants_{constant.name}_overrideConstantFromVersion", - numArgs=0, - ) + ), + ): self.KARABO_SLOT( functools.partial( constant_override_fun, - friend_fun="get_constant_from_file", - constant=constant, - preserve_fields={"dataFilePath", "dataSetName"}, + friend_fun, + constant, + preserved_fields, ), - slotName=f"foundConstants_{constant.name}_overrideConstantFromFile", + slotName=f"foundConstants_{constant.name}_{slot_name}", numArgs=0, ) @@ -704,30 +492,27 @@ class BaseCorrection(PythonDevice): def _initialization(self): self.updateState(State.INIT) - self.warning_context = utils.WarningContextSystem( + self.warning_context = WarningContextSystem( self, on_success={ f"foundConstants.{constant.name}.state": "ON" for constant in self._constant_enum_class }, ) - self._preview_friend = preview_utils.PreviewFriend( - self, - output_channels=self._preview_outputs, + self.log.DEBUG("Opening shmem buffer") + self._shmem_buffer = shmem_utils.ShmemCircularBuffer( + self.get("outputShmemBufferSize") * 2**30, + # TODO: have it just allocate memory, then set_shape later per train + (1,), + np.float32, + self.getInstanceId() + ":dataOutput", ) - self["availableScenes"] = self["availableScenes"] + [ - f"preview:{channel}" for channel in self._preview_outputs - ] - - self._geometry = None - if self.get("geometryDevice"): - self.signalSlotable.connect( - self.get("geometryDevice"), - "signalNewGeometry", - "", # slot device ID (default: self) - "slotReceiveGeometry", - ) + if self._cuda_pin_buffers: + self.log.DEBUG("Trying to pin the shmem buffer memory") + self._shmem_buffer.cuda_pin() + self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver() + # CalCat friend comes before kernel runner with self.warning_context( "deviceInternalsState", WarningLampType.CALCAT_CONNECTION ) as warn: @@ -739,18 +524,21 @@ class BaseCorrection(PythonDevice): warn(f"Failed to connect to CalCat: {e}") # TODO: add raw fallback mode if CalCat fails (raw data still useful) return + self.kernel_runner = self._kernel_runner_class(self) + self.preview_friend = PreviewFriend(self, self._preview_outputs) + # TODO: can be static OVERWRITE_ELEMENT in expectedParameters + self["availableScenes"] = self["availableScenes"] + [ + f"preview:{spec.name}" for spec in self._preview_outputs + ] - with self.warning_context( - "processingState", WarningLampType.FRAME_FILTER - ) as warn: - try: - self._frame_filter = _parse_frame_filter(self._parameters) - except (ValueError, TypeError): - warn("Failed to parse initial frame filter, will not use") - self._frame_filter = None - # TODO: decide how persistent initial warning should be - self._update_correction_flags() - self._update_frame_filter() + self.geometry = None + if self.get("geometryDevice"): + self.signalSlotable.connect( + self.get("geometryDevice"), + "signalNewGeometry", + "", # slot device ID (default: self) + "slotReceiveGeometry", + ) self._buffered_status_update = Hash() self._processing_time_tracker = trackers.ExponentialMovingAverage( @@ -766,7 +554,7 @@ class BaseCorrection(PythonDevice): ) self._train_ratio_tracker = trackers.TrainRatioTracker() - self.KARABO_ON_INPUT("dataInput", self.input_handler) + self.KARABO_ON_DATA("dataInput", self.input_handler) self.KARABO_ON_EOS("dataInput", self.handle_eos) self._enabled_addons = [ @@ -776,9 +564,14 @@ class BaseCorrection(PythonDevice): ] for addon in self._enabled_addons: addon._device = self - if self._enabled_addons: + if ( + (self.get("useShmemHandles") != self._use_shmem_handles) + or self._enabled_addons + ): schema_override = Schema() - output_schema_override = self._base_output_schema + output_schema_override = self._base_output_schema( + use_shmem_handles=self.get("useShmemHandles") + ) for addon in self._enabled_addons: addon.extend_output_schema(output_schema_override) ( @@ -794,7 +587,7 @@ class BaseCorrection(PythonDevice): self.updateState(State.ON) def __del__(self): - del self._shmem_buffer + del self.shmem_buffer super().__del__() def preReconfigure(self, config): @@ -812,64 +605,31 @@ class BaseCorrection(PythonDevice): timestamp = dateutil.parser.isoparse(ts_string) config.set(ts_path, timestamp.isoformat()) - if config.has("constantParameters.deviceMappingSnapshotAt"): - self.calcat_friend.flush_pdu_mapping() + with self.config_overlay(config): + if config.has("constantParameters.deviceMappingSnapshotAt"): + self.calcat_friend.flush_pdu_mapping() - # update device based on changes - if config.has("frameFilter"): - self._frame_filter = _parse_frame_filter( - misc.ChainHash(config, self._parameters) - ) - - self._prereconfigure_update_hash = config - - def postReconfigure(self): - if not hasattr(self, "_prereconfigure_update_hash"): - self.log_status_warn("postReconfigure without knowing update hash") - return - - update = self._prereconfigure_update_hash - - if update.has("frameFilter"): - self._lock_and_update(self._update_frame_filter) - elif any( - update.has(shape_param) - for shape_param in ( - "dataFormat.pixelsX", - "dataFormat.pixelsY", - "dataFormat.outputImageDtype", - "dataFormat.outputAxisOrder", - "dataFormat.frames", - "constantParameters.memoryCells", - "frameFilter", - ) - ): - self._lock_and_update(self._update_buffers) - - if any( - ( - update.has(f"corrections.{field_name}.enable") - or update.has(f"corrections.{field_name}.preview") - ) - for field_name, *_ in self._correction_steps - ): - self._update_correction_flags() + if ( + runner_update := subset_of_hash( + config, "corrections", "dataFormat", "constantParameters" + ) + ): + self.kernel_runner.reconfigure(runner_update) - # TODO: only send subhash of reconfiguration (like with addons) - if update.has("preview"): - self._preview_friend.reconfigure(update) + if config.has("preview"): + self.preview_friend.reconfigure(config["preview"]) - if update.has("addons"): - # note: can avoid iterating, but it's not that costly - for addon in self._enabled_addons: - full_path = f"addons.{addon.__class__.__name__}" - if update.has(full_path): - addon.reconfigure(update[full_path]) + if config.has("addons"): + # note: can avoid iterating, but it's not that costly + for addon in self._enabled_addons: + full_path = f"addons.{addon.__class__.__name__}" + if config.has(full_path): + addon.reconfigure(config[full_path]) def _lock_and_update(self, method, background=True): # TODO: securely handle errors (postReconfigure may succeed, device state not) def runner(): - with self._buffer_lock, utils.StateContext(self, State.CHANGING): + with self._buffer_lock, StateContext(self, State.CHANGING): method() if background: @@ -896,11 +656,9 @@ class BaseCorrection(PythonDevice): ) as warn: try: constant_data = future.result() - self._load_constant_to_runner(constant, constant_data) + self.kernel_runner.load_constant(constant, constant_data) except Exception as e: warn(f"Failed to load {constant}: {e}") - else: - self._successfully_loaded_constant_to_runner(constant) self._lock_and_update(aux) def flush_constants(self, constants=None, preserve_fields=None): @@ -930,7 +688,6 @@ class BaseCorrection(PythonDevice): self.set(f"corrections.{field_name}.available", False) self.calcat_friend.flush_constants(constants, preserve_fields) self._reload_constants_from_cache_to_runner(constants) - self._update_correction_flags() def log_status_info(self, msg): self.log.INFO(msg) @@ -954,7 +711,7 @@ class BaseCorrection(PythonDevice): payload["data"] = scenes.correction_device_preview( device_id=self.getInstanceId(), schema=self.getFullSchema(), - preview_channel=channel_name, + name=channel_name, ) elif name.startswith("browse_schema"): if ":" in name: @@ -983,10 +740,35 @@ class BaseCorrection(PythonDevice): def slotReceiveGeometry(self, device_id, serialized_geometry): self.log.INFO(f"Received geometry from {device_id}") try: - self._geometry = geom_utils.deserialize_geometry(serialized_geometry) + self.geometry = geom_utils.deserialize_geometry(serialized_geometry) except Exception as e: self.log.WARN(f"Failed to deserialize geometry; {e}") + def _get_data_from_hash(self, data_hash): + """Will get image data, cell table, pulse table, and list of other arrays from + the data hash. Assumes XTDF (image.data, image.cellId, image.pulseId, and no + other data), non-XTDF subclass can override. Cell and pulse table can be + None.""" + + image_data = data_hash.get(self._image_data_path) + cell_table = data_hash[self._cell_table_path].ravel() + pulse_table = data_hash[self._pulse_table_path].ravel() + num_frames = cell_table.size + # DataAggregator typically tells us the wrong axis order + if self.unsafe_get("workarounds.overrideInputAxisOrder"): + # TODO: check that frames are always axis 0 + expected_shape = self.kernel_runner.expected_input_shape( + num_frames + ) + if expected_shape != image_data.shape: + image_data.shape = expected_shape + return ( + num_frames, + image_data, # not ndarray (runner can do that) + cell_table, + pulse_table, + ) + def _write_output(self, data, old_metadata): """For dataOutput: reusing incoming data hash and setting source and timestamp to be same as input""" @@ -994,103 +776,11 @@ class BaseCorrection(PythonDevice): old_metadata.get("source"), Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), ) - data["corrections"] = self._correction_applied_hash channel = self.signalSlotable.getOutputChannel("dataOutput") channel.write(data, metadata, copyAllData=False) channel.update(safeNDArray=True) - def _update_correction_flags(self): - """Based on constants loaded and settings, update bit mask flags for kernel""" - enabled = self._correction_flag_class.NONE - output = Hash() # for informing downstream receivers what was applied - preview = self._correction_flag_class.NONE - for field_name, flag, constants in self._correction_steps: - output[field_name] = False - if self.get(f"corrections.{field_name}.available"): - if self.get(f"corrections.{field_name}.enable"): - enabled |= flag - output[field_name] = True - if self.get(f"corrections.{field_name}.preview"): - preview |= flag - self._correction_flag_enabled = enabled - self._correction_flag_preview = preview - self._correction_applied_hash = output - self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}") - self.log.DEBUG(f"Corrections for preview: {str(preview)}") - - def _update_frame_filter(self, update_buffers=True): - """Parse frameFilter string (if set) and update cached filter array. May update - dataFormat.filteredFrames - will therefore by default call _update_buffers - afterwards.""" - # TODO: add some validation to preReconfigure - self.log.DEBUG("Updating frame filter") - - if self._frame_filter is None: - self.set("dataFormat.filteredFrames", self.get("dataFormat.frames")) - self.set("frameFilter.current", []) - else: - self.set("dataFormat.filteredFrames", self._frame_filter.size) - self.set("frameFilter.current", list(map(int, self._frame_filter))) - - with self.warning_context( - "deviceInternalsState", WarningLampType.FRAME_FILTER - ) as warn: - if self._frame_filter is not None and ( - self._frame_filter.min() < 0 - or self._frame_filter.max() >= self.get("dataFormat.frames") - ): - warn("Invalid frame filter set, expect exceptions!") - # note: dataFormat.frames could change from detector and - # that could make this warning invalid - - if update_buffers: - self._update_buffers() - - def _update_buffers(self): - """(Re)initialize buffers / kernel runner according to expected data shapes""" - self.log.INFO("Updating buffers according to data shapes") - # reflect the axis reordering in the expected output shape - self.set("dataFormat.inputDataShape", list(self.input_data_shape)) - self.set("dataFormat.outputDataShape", list(self.output_data_shape)) - self.log.INFO(f"Input shape: {self.input_data_shape}") - self.log.INFO(f"Output shape: {self.output_data_shape}") - - if self._shmem_buffer is None: - shmem_buffer_name = self.getInstanceId() + ":dataOutput" - memory_budget = self.get("outputShmemBufferSize") * 2**30 - self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") - self._shmem_buffer = shmem_utils.ShmemCircularBuffer( - memory_budget, - self.output_data_shape, - self.output_data_dtype, - shmem_buffer_name, - ) - self._shmem_receiver = shmem_utils.ShmemCircularBufferReceiver() - if self._cuda_pin_buffers: - self.log.INFO("Trying to pin the shmem buffer memory") - self._shmem_buffer.cuda_pin() - self.log.INFO("Done, shmem buffer is ready") - else: - self._shmem_buffer.change_shape(self.output_data_shape) - - # give CuPy a chance to at least start memory cleanup before this - if self.kernel_runner is not None: - del self.kernel_runner - self.kernel_runner = None - gc.collect() - - self.kernel_runner = self._kernel_runner_class( - self.get("dataFormat.pixelsX"), - self.get("dataFormat.pixelsY"), - self.get("dataFormat.filteredFrames"), - int(self.get("constantParameters.memoryCells")), - output_data_dtype=self.output_data_dtype, - **self._kernel_runner_init_args, - ) - - self._reload_constants_from_cache_to_runner() - def _reload_constants_from_cache_to_runner(self, constants=None): # TODO: gracefully handle change in constantParameters.memoryCells if constants is None: @@ -1110,13 +800,11 @@ class BaseCorrection(PythonDevice): ) as warn: try: self.log_status_info(f"Reload constant {constant.name}") - self._load_constant_to_runner(constant, data) + self.kernel_runner.load_constant(constant, data) except Exception as e: warn(f"Failed to reload {constant.name}: {e}") - else: - self._successfully_loaded_constant_to_runner(constant) - def input_handler(self, input_channel): + def input_handler(self, data_hash, metadata): """Main handler for data input: Do a few simple checks to determine whether to even try processing. If yes, will pass data and information to process_data method provided by subclass.""" @@ -1134,161 +822,179 @@ class BaseCorrection(PythonDevice): warn("Received data before correction device was ready") return - all_metadata = input_channel.getMetaData() - for input_index in range(input_channel.size()): - with self.warning_context( - "inputDataState", WarningLampType.MISC_INPUT_DATA - ) as warn: - data_hash = input_channel.read(input_index) - metadata = all_metadata[input_index] - source = metadata.get("source") + with self.warning_context( + "inputDataState", WarningLampType.MISC_INPUT_DATA + ) as warn: + source = metadata.get("source") - if source not in self.sources: - continue - elif not data_hash.has(self._image_data_path): - warn("Ignoring hash without image node") - continue + if source not in self.sources: + return + elif not data_hash.has(self._image_data_path): + warn("Ignoring hash without image node") + return - with self.warning_context( - "inputDataState", - WarningLampType.EMPTY_HASH, - ) as warn: - self._shmem_receiver.dereference_shmem_handles(data_hash) - try: - image_data = np.asarray(data_hash.get(self._image_data_path)) - cell_table = ( - np.array( - # explicit copy to avoid mysterious segfault - data_hash.get(self._cell_table_path), copy=True - ).ravel() - if self._cell_table_path is not None - else None - ) - pulse_table = ( - np.array( - # explicit copy to avoid mysterious segfault - data_hash.get(self._pulse_table_path), copy=True - ).ravel() - if self._pulse_table_path is not None - else None - ) - except RuntimeError as err: - warn( - "Failed to load image data; " - f"probably empty hash from DAQ: {err}" - ) - self._maybe_update_cell_and_pulse_tables(None, None) - continue + with self.warning_context( + "inputDataState", + WarningLampType.EMPTY_HASH, + ) as warn: + self._shmem_receiver.dereference_shmem_handles(data_hash) + try: + ( + num_frames, + image_data, + cell_table, + pulse_table, + *additional_data + ) = self._get_data_from_hash(data_hash) + except RuntimeError as err: + warn( + "Failed to load data from hash, probably empy hash from DAQ. " + f"(Is detector sending data?)\n{err}" + ) + self._maybe_update_cell_and_pulse_tables(None, None) + return - # no more common reasons to skip input, so go to processing - self._last_processing_started = default_timer() - if state is State.ON: - self.updateState(State.PROCESSING) - self.log_status_info("Processing data") + # no more common reasons to skip input, so go to processing + self._last_processing_started = default_timer() + if state is State.ON: + self.updateState(State.PROCESSING) + self.log_status_info("Processing data") - self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table) - timestamp = Timestamp.fromHashAttributes( - metadata.getAttributes("timestamp") - ) - train_id = timestamp.getTrainId() + self._maybe_update_cell_and_pulse_tables(cell_table, pulse_table) + timestamp = Timestamp.fromHashAttributes( + metadata.getAttributes("timestamp") + ) - # check time server connection - with self.warning_context( - "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION - ) as warn: - my_timestamp = self.getActualTimestamp() - my_train_id = my_timestamp.getTrainId() - self._input_delay_tracker.update( - (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000 - ) - self._buffered_status_update.set( - "performance.inputDelay", self._input_delay_tracker.get() - ) - if my_train_id == 0: - my_train_id = self._last_train_id_processed + 1 - warn( - "Failed to get current train ID, using previously seen train " - "ID for future train thresholding - if this persists, check " - "connection to timeserver." - ) + self._check_train_id_and_time(timestamp) + + if self._warn_memory_cell_range: with self.warning_context( - "inputDataState", WarningLampType.TRAIN_ID + "processingState", WarningLampType.MEMORY_CELL_RANGE ) as warn: - if train_id > ( - my_train_id - + self.unsafe_get("workarounds.trainFromFutureThreshold") + if ( + self.unsafe_get("constantParameters.memoryCells") + <= cell_table.max() ): - warn( - f"Suspecting train from the future: 'now' is {my_train_id}, " - f"received train ID {train_id}, dropping data" - ) - continue - try: - self._train_ratio_tracker.update(train_id) - self._buffered_status_update.set( - "performance.ratioOfRecentTrainsReceived", - self._train_ratio_tracker.get( - current_train=my_train_id, - expected_delay=math.ceil( - self.unsafe_get("performance.inputDelay") / 100 - ), - ) - if my_train_id != 0 - else self._train_ratio_tracker.get(), - ) - except trackers.NonMonotonicTrainIdWarning as ex: - warn( - f"Train ratio tracker noticed issue with train ID: {ex}\n" - f"For the record, I think now is: {my_train_id}" - ) - self._train_ratio_tracker.reset() - self._train_ratio_tracker.update(train_id) + warn("Input cell IDs out of range of constants") - if self._warn_memory_cell_range: - with self.warning_context( - "processingState", WarningLampType.MEMORY_CELL_RANGE - ) as warn: - if ( - self.unsafe_get("constantParameters.memoryCells") - <= cell_table.max() - ): - warn("Input cell IDs out of range of constants") - - # TODO: avoid branching multiple times on cell_table - if cell_table is not None and cell_table.size != self.unsafe_get( - "dataFormat.frames" - ): - self.log_status_info( - f"Updating new input shape to account for {cell_table.size} cells" - ) - self.set("dataFormat.frames", cell_table.size) - self._lock_and_update(self._update_frame_filter, background=False) - - # DataAggregator typically tells us the wrong axis order - if self.unsafe_get("workarounds.overrideInputAxisOrder"): - expected_shape = self.input_data_shape - if expected_shape != image_data.shape: - image_data.shape = expected_shape - - with self._buffer_lock: - self.process_data( - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, + # TODO: change ShmemCircularBuffer API so we don't have to poke like this + if ( + (output_shape := self.kernel_runner.expected_output_shape(num_frames)) + != self._shmem_buffer._buffer_ary.shape[1:] + or self.kernel_runner._output_dtype != self._shmem_buffer._buffer_ary.dtype + ): + self.set("dataFormat.outputDataShape", list(output_shape)) + self.log.DEBUG("Updating shmem buffer shape / dtype") + self._shmem_buffer.change_shape( + output_shape, self.kernel_runner._output_dtype + ) + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + with self.warning_context( + "processingState", WarningLampType.CORRECTION_RUNNER + ): + corrections, processed_buffer, previews = self.kernel_runner.correct( + image_data, cell_table, *additional_data + ) + data_hash["corrections"] = corrections + for addon in self._enabled_addons: + addon.post_correction( + processed_buffer, cell_table, pulse_table, data_hash ) - self._last_train_id_processed = train_id - self._buffered_status_update.set("trainId", train_id) - self._processing_time_tracker.update( - default_timer() - self._last_processing_started + self.kernel_runner.reshape(processed_buffer, out=buffer_array) + + with self.warning_context( + "processingState", WarningLampType.PREVIEW_SETTINGS + ) as warn: + self.preview_friend.write_outputs( + *previews, + timestamp=timestamp, + cell_table=cell_table, + pulse_table=pulse_table, + warn_fun=warn, + ) + + for addon in self._enabled_addons: + addon.post_reshape( + buffer_array, cell_table, pulse_table, data_hash + ) + + if self.unsafe_get("useShmemHandles"): + data_hash.set(self._image_data_path, buffer_handle) + data_hash.set("calngShmemPaths", [self._image_data_path]) + else: + data_hash.set(self._image_data_path, buffer_array) + data_hash.set("calngShmemPaths", []) + + self._write_output(data_hash, metadata) + self._processing_time_tracker.update( + default_timer() - self._last_processing_started + ) + self._buffered_status_update.set( + "performance.processingTime", self._processing_time_tracker.get() * 1000 + ) + self._rate_tracker.update() + + def _please_send_me_cached_constants(self, constants, callback): + for constant in constants: + if constant in self.calcat_friend.cached_constants: + callback(constant, self.calcat_friend.cached_constants[constant]) + + def _check_train_id_and_time(self, timestamp): + train_id = timestamp.getTrainId() + # check time server connection + with self.warning_context( + "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION + ) as warn: + my_timestamp = self.getActualTimestamp() + my_train_id = my_timestamp.getTrainId() + self._input_delay_tracker.update( + (my_timestamp.toTimestamp() - timestamp.toTimestamp()) * 1000 ) self._buffered_status_update.set( - "performance.processingTime", self._processing_time_tracker.get() * 1000 + "performance.inputDelay", self._input_delay_tracker.get() ) - self._rate_tracker.update() + if my_train_id == 0: + my_train_id = self._last_train_id_processed + 1 + warn( + "Failed to get current train ID, using previously seen train " + "ID for future train thresholding - if this persists, check " + "connection to timeserver." + ) + with self.warning_context( + "inputDataState", WarningLampType.TRAIN_ID + ) as warn: + if train_id > ( + my_train_id + + self.unsafe_get("workarounds.trainFromFutureThreshold") + ): + warn( + f"Suspecting train from the future: 'now' is {my_train_id}, " + f"received train ID {train_id}, dropping data" + ) + raise TrainFromTheFutureException() + try: + self._train_ratio_tracker.update(train_id) + self._buffered_status_update.set( + "performance.ratioOfRecentTrainsReceived", + self._train_ratio_tracker.get( + current_train=my_train_id, + expected_delay=math.ceil( + self.unsafe_get("performance.inputDelay") / 100 + ), + ) + if my_train_id != 0 + else self._train_ratio_tracker.get(), + ) + except trackers.NonMonotonicTrainIdWarning as ex: + warn( + f"Train ratio tracker noticed issue with train ID: {ex}\n" + f"For the record, I think now is: {my_train_id}" + ) + self._train_ratio_tracker.reset() + self._train_ratio_tracker.update(train_id) + + self._last_train_id_processed = train_id + self._buffered_status_update.set("trainId", train_id) def _update_rate_and_state(self): if self.get("state") is State.PROCESSING: @@ -1329,119 +1035,6 @@ class BaseCorrection(PythonDevice): self.signalEndOfStream("dataOutput") -def add_correction_step_schema(schema, field_flag_constants_mapping): - """Using the fields in the provided mapping, will add nodes to schema - - field_flag_mapping is assumed to be iterable of pairs where first entry in each - pair is the name of a correction step as it will appear in device schema (second - entry - typically an enum field - is ignored). For correction step, a node and some - booleans are added to the schema (all tagged as managed). Subclass can customize / - add additional keys under node later. - - This method should be called in expectedParameters of subclass after the same for - BaseCorrection has been called. Would be nice to include in BaseCorrection instead, - but that is tricky: static method of superclass will need _correction_steps - of subclass or device server gets mad. A nice solution with classmethods would be - welcome. - """ - - for field_name, _, used_constants in field_flag_constants_mapping: - node_name = f"corrections.{field_name}" - ( - NODE_ELEMENT(schema).key(node_name).commit(), - - BOOL_ELEMENT(schema) - .key(f"{node_name}.available") - .displayedName("Available") - .description( - "This boolean indicates whether the necessary constants have been " - "loaded for this correction step to be applied. Enabling the " - "correction will have no effect unless this is True." - ) - .readOnly() - # some corrections available without constants - .initialValue(not used_constants or None in used_constants) - .commit(), - - BOOL_ELEMENT(schema) - .key(f"{node_name}.enable") - .tags("managed") - .displayedName("Enable") - .description( - "Controls whether to apply this correction step for main data " - "output - subject to availability." - ) - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(schema) - .key(f"{node_name}.preview") - .tags("managed") - .displayedName("Preview") - .description( - "Whether to apply this correction step for corrected preview " - "output - subject to availability." - ) - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - ) - - -def add_bad_pixel_config_node(schema, prefix="corrections.badPixels"): - ( - STRING_ELEMENT(schema) - .key("corrections.badPixels.maskingValue") - .tags("managed") - .displayedName("Bad pixel masking value") - .description( - "Any pixels masked by the bad pixel mask will have their value replaced " - "with this. Note that this parameter is to be interpreted as a " - "numpy.float32; use 'nan' to get NaN value." - ) - .assignmentOptional() - .defaultValue("nan") - .reconfigurable() - .commit(), - - NODE_ELEMENT(schema) - .key("corrections.badPixels.subsetToUse") - .displayedName("Bad pixel flags to use") - .description( - "The booleans under this node allow for selecting a subset of bad pixel " - "types to take into account when doing bad pixel masking. Upon updating " - "these flags, the map used for bad pixel masking will be ANDed with this " - "selection. Turning disabled flags back on causes reloading of cached " - "constants." - ) - .commit(), - ) - for field in utils.BadPixelValues: - ( - BOOL_ELEMENT(schema) - .key(f"corrections.badPixels.subsetToUse.{field.name}") - .tags("managed") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit() - ) - - -def add_preview_outputs(schema, channels): - for channel in channels: - ( - OUTPUT_CHANNEL(schema) - .key(f"preview.{channel}") - .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True)) - .commit(), - ) - preview_utils.PreviewFriend.add_schema(schema, output_channels=channels) - - def add_addon_nodes(schema, device_class, prefix="addons"): det_name = device_class.__name__[:-len("Correction")].lower() device_class._available_addons = [ @@ -1467,28 +1060,3 @@ def add_addon_nodes(schema, device_class, prefix="addons"): addon_class.extend_device_schema( schema, f"{prefix}.{addon_class.__name__}" ) - - -def get_bad_pixel_field_selection(self): - selection = 0 - for field in utils.BadPixelValues: - if self.get(f"corrections.badPixels.subsetToUse.{field.name}"): - selection |= field - return selection - - -def _parse_frame_filter(config): - filter_type = FramefilterSpecType(config["frameFilter.type"]) - filter_string = config["frameFilter.spec"] - - if filter_type is FramefilterSpecType.NONE or filter_string.strip() == "": - return None - elif filter_type is FramefilterSpecType.RANGE: - # allow exceptions - numbers = tuple(int(part) for part in filter_string.split(",")) - return np.arange(*numbers, dtype=np.uint16) - elif filter_type is FramefilterSpecType.COMMASEPARATED: - # np.fromstring is too lenient I think - return np.array([int(s) for s in filter_string.split(",")], dtype=np.uint16) - else: - raise TypeError(f"Unknown frame filter type {filter_type}") diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py index 5eaaf10155d23e50b73044da62cad62a335f97f6..e04ad559453a2f8ab954ae0b8b60e529733f94d8 100644 --- a/src/calng/base_kernel_runner.py +++ b/src/calng/base_kernel_runner.py @@ -1,19 +1,163 @@ import enum import functools -import operator +import itertools import pathlib -from calngUtils import misc as utils +from karabo.bound import ( + BOOL_ELEMENT, + NODE_ELEMENT, + STRING_ELEMENT, + Hash, +) +from .utils import ( + BadPixelValues, + maybe_get, + subset_of_hash, +) +from calngUtils import misc import jinja2 -import numpy as np class BaseKernelRunner: - _gpu_based = True _xp = None # subclass sets numpy or cupy + num_pixels_ss = None # subclass must set + num_pixels_fs = None # subclass must set + + _bad_pixel_constants = None # should be set by subclasses using bad pixels + bad_pixel_subset = 0xFFFFFFFF # will be set in reconfigure + + @classmethod + def add_schema(cls, schema): + """Using the steps in the provided mapping, will add nodes to schema + + correction_steps is assumed to be iterable of pairs where first entry in each + pair is the name of a correction step as it will appear in device schema (second + entry - typically an enum field - is ignored). For correction step, a node and + some booleans are added to the schema (all tagged as managed). Subclass can + customize / add additional keys under node later. + """ + ( + NODE_ELEMENT(schema) + .key("corrections") + .commit(), + ) + + for field_name, _, used_constants in cls._correction_steps: + node_name = f"corrections.{field_name}" + ( + NODE_ELEMENT(schema).key(node_name).commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.available") + .displayedName("Available") + .description( + "This boolean indicates whether the necessary constants have been " + "loaded for this correction step to be applied. Enabling the " + "correction will have no effect unless this is True." + ) + .readOnly() + # some corrections available without constants + .initialValue(not used_constants or None in used_constants) + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.enable") + .tags("managed") + .displayedName("Enable") + .description( + "Controls whether to apply this correction step for main data " + "output - subject to availability." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.preview") + .tags("managed") + .displayedName("Preview") + .description( + "Whether to apply this correction step for corrected preview " + "output - subject to availability." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) + + @staticmethod + def add_bad_pixel_config(schema): + ( + STRING_ELEMENT(schema) + .key("corrections.badPixels.maskingValue") + .tags("managed") + .displayedName("Bad pixel masking value") + .description( + "Any pixels masked by the bad pixel mask will have their value " + "replaced with this. Note that this parameter is to be interpreted as " + "a numpy.float32; use 'nan' to get NaN value." + ) + .assignmentOptional() + .defaultValue("nan") + .reconfigurable() + .commit(), + + NODE_ELEMENT(schema) + .key("corrections.badPixels.subsetToUse") + .displayedName("Bad pixel flags to use") + .description( + "The booleans under this node allow for selecting a subset of bad " + "pixel types to take into account when doing bad pixel masking. Upon " + "updating these flags, the map used for bad pixel masking will be " + "ANDed with this selection. Turning disabled flags back on causes " + "reloading of cached constants." + ) + .commit(), + ) + for field in BadPixelValues: + ( + BOOL_ELEMENT(schema) + .key(f"corrections.badPixels.subsetToUse.{field.name}") + .tags("managed") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit() + ) + + def __init__(self, device): + self._device = device + # for now, we depend on multiple keys from device hash, so get multiple nodes + config = subset_of_hash( + self._device._parameters, + "corrections", + "dataFormat", + "constantParameters", + ) + self._constant_memory_cells = config["constantParameters.memoryCells"] + self._pre_init() + # note: does not handle one constant enabling multiple correction steps + # (do we need to generalize "decide if this correction step is available"?) + self._constant_to_correction_names = {} + for (name, _, constants) in self._correction_steps: + for constant in constants: + self._constant_to_correction_names.setdefault(constant, set()).add(name) + self._output_dtype = self._xp.float32 + self._setup_constant_buffers() + self.reconfigure(config) + self._post_init() def _pre_init(self): - # can be used, for example, to import cupy and set as _xp at runtime + """Hook used to set up things which need to be in place before __init__ is + called. Note that __init__ will do reconfigure with full configuration, which + means creating constant buffers and such (__init__ already takes care to set + _constant_memory_cells so this can be used in _setup_constant_buffers). + + + See also _setup_constant_buffers + """ pass def _post_init(self): @@ -21,121 +165,230 @@ class BaseKernelRunner: # without overriding __init__ pass - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - ): - self._pre_init() - self.pixels_x = pixels_x - self.pixels_y = pixels_y - self.frames = frames - if constant_memory_cells == 0: - # if not set, guess same as input; may save one recompilation - self.constant_memory_cells = frames - else: - self.constant_memory_cells = constant_memory_cells - self.output_data_dtype = output_data_dtype - self._post_init() + def expected_input_shape(self, num_frames): + """Mostly used to inform the host device about what data shape should be. + Subclass should override as needed (ex. XTDF detectors with extra image/gain + axis).""" + return (num_frames, self.num_pixels_ss, self.num_pixels_fs) - @property - def preview_shape(self): - return (self.pixels_x, self.pixels_y) - - def correct(self, flags): - """Correct (already loaded) image data according to flags - - Detector-specific subclass must define this method. It should assume that image - data, cell table, and other data (including constants) has already been loaded. - It should probably run some {CPU,GPU} kernel and output should go into - self.processed_data{,_gpu}. + def expected_output_shape(self, num_frames): + """Used to inform the host device about what to expect (for reshaping the shmem + buffer). Takes into account transpose setting; override _expected_output_shape + for subclasses which do funky shape things.""" + base_shape = self._expected_output_shape(num_frames) + if self._output_transpose is None: + return base_shape + else: + return tuple(base_shape[i] for i in self._output_transpose) - Keep in mind that user only gets output from compute_preview or reshape - (either of these should come after correct). + def _expected_output_shape(self, num_frames): + return (num_frames, self.num_pixels_ss, self.num_pixels_fs) - The submodules providing subclasses should have some IntFlag enums defining - which flags are available to pass along to the kernel. A zero flag should allow - the kernel to do no actual correction - but still copy the data between buffers - and cast it to desired output type. + def _correct(self, flags, image_data, cell_table, *rest): + """Subclass will implement the core correcton routine with this name. Observe + the parameter list: the "rest" part will always contain at least one output + buffer, but may have more entries such as additional raw input buffers and + multiple output buffers. See how it is called by correct.""" - """ raise NotImplementedError() - @property - def preview_data_views(self): + def _preview_data_views(self, *buffers): """Should provide tuple of views of data for different preview layers - as a minimum, raw and corrected. View is expected to have axis order frame, slow - scan, fast scan with the latter two matching preview_shape (which is X and which - is Y may differ between detectors).""" + scan, fast scan with the latter being the final preview shape (which is X and + which is Y may differ between detectors). The buffers passed vary from subclass + to subclass; will be all inputs and outputs used in _correct. By default, that + is just what gets returned.""" + + return buffers + + def _load_constant(self, constant_type, constant_data): + """Should update the appropriate internal buffers for constant type (maps to + "calibration" type in CalCat) with the data in constant_tdata. This may include + sanitization; data passed here has the axis order found in the stored constant + files.""" + raise NotImplementedError() + def _setup_constant_buffers(self): raise NotImplementedError() + def _make_output_buffers(self, num_frames, flags): + """Should, based on number of frames passed (may in the future be extended to + include other information), create the necessary buffers to store corrected + output(s). By default, will create a single processed data buffer of float32 + with the same shape as input image data, skipping any extra axes (like the + image/gain axis for XTDF). Subclass should override if it wants more outputs. + First buffer must be for corrected image data (passed to addons)""" + return [ + self._xp.empty( + (num_frames, self.num_pixels_ss, self.num_pixels_fs), + dtype=self._xp.float32, + ) + ] + + @functools.cached_property + def _all_relevant_constants(self): + return set( + itertools.chain.from_iterable( + constants for (_, _, constants) in self._correction_steps + ) + ) - {None} + + def reconfigure(self, config): + if config.has("constantParameters.memoryCells"): + self._constant_memory_cells = config.get("constantParameters.memoryCells") + self._setup_constant_buffers() + self._device._please_send_me_cached_constants( + self._all_relevant_constants, self.load_constant + ) + + if config.has("dataFormat.outputImageDtype"): + self._output_dtype = self._xp.dtype( + config.get("dataFormat.outputImageDtype") + ) + + if config.has("dataFormat.outputAxisOrder"): + order = config["dataFormat.outputAxisOrder"] + if order == "f-ss-fs": + self._output_transpose = None + else: + self._output_transpose = misc.transpose_order( + ("f", "ss", "fs"), order.split("-") + ) + + if config.has("corrections.badPixels.maskingValue"): + self.bad_pixel_mask_value = self._xp.float32( + config["corrections.badPixels.maskingValue"] + ) + + if config.has("corrections.badPixels.subsetToUse"): + # note: now just always reloading from cache for convenience + self._update_bad_pixel_subset() + self.flush_buffers(self._bad_pixel_constants) + self._device._please_send_me_cached_constants( + self._bad_pixel_constants, self.load_constant + ) + + # TODO: only trigger flag update if changed (cheap though) + self._update_correction_flags() + + def reshape(self, processed_data, out=None): + """Move axes to desired output order""" + if self._output_transpose is None: + return maybe_get(processed_data, out) + else: + return maybe_get(processed_data.transpose(self._output_transpose), out) + def flush_buffers(self, constants=None): """Optional reset internal buffers (implement in subclasses which need this)""" pass - def compute_previews(self, preview_index): - """Generate single slice or reduction previews for raw and corrected data and - any other layers, determined by self.preview_data_views + def _update_correction_flags(self): + enabled = self._correction_flag_class.NONE + output = Hash() # for informing downstream receivers what was applied + preview = self._correction_flag_class.NONE + for field_name, flag, constants in self._correction_steps: + output[field_name] = False + if self._get(f"corrections.{field_name}.available"): + if self._get(f"corrections.{field_name}.enable"): + enabled |= flag + output[field_name] = True + if self._get(f"corrections.{field_name}.preview"): + preview |= flag + self._correction_flag_enabled = enabled + self._correction_flag_preview = preview + self._correction_applied_hash = output + self._device.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}") + self._device.log.DEBUG(f"Corrections for preview: {str(preview)}") + + def _update_bad_pixel_subset(self): + selection = 0 + for field in BadPixelValues: + if self._get(f"corrections.badPixels.subsetToUse.{field.name}"): + selection |= field + self.bad_pixel_subset = selection - Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and - -4 for stdev (across cells). + def _get(self, key): + """Will look up key in host device's parameters. Note that device calls + reconfigure during preReconfigure, but with new unmerged config overlaid + temporarily, making _get get the "new" state of things.""" + return self._device.unsafe_get(key) - Note that preview_index is taken from data without checking cell table. - Caller has to figure out which index along memory cell dimension they - actually want to preview in case it needs to be a specific cell / pulse. + def _set(self, key, value): + """Helper function to host device's configuration. Note that this does *not* + call reconfigure, so either do so afterwards or make sure to update auxiliary + properties correctly yourself.""" + self._device.set(key, value) - Will typically reuse data from corrected output buffer. Therefore, - correct(...) must have been called with the appropriate flags before - compute_preview(...). + @property + def _map_shape(self): + return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs) + + def load_constant(self, constant, constant_data): + self._load_constant(constant, constant_data) + self._post_load_constant(constant) + + def _post_load_constant(self, constant): + for field_name in self._constant_to_correction_names[constant]: + key = f"corrections.{field_name}.available" + if not self._get(key): + self._set(key, True) + + self._update_correction_flags() + self._device.log_status_info(f"Done loading {constant.name} to runner") + + def correct(self, image_data, cell_table, *additional_data): + """Correct image data, providing both full corrected data and previews + + This method relies on the following auxiliary methods which may need to be + overridden / defined in subclass: + + _make_output_buffers + _preview_data_views + _correct (must be implemented by subclass) + + Look at signature for _correct; it is given image_data, cell_table, + *additional_data (anything the host device class' _get_data_from_hash gives us + in addition to image data, cell table, and pulse table), and *processed_buffers + which is the result of _make_output_buffers. So all these four methods must + agree on the relevant set of buffers for input and output. + + The pulse_table is only used for preview index selection purposes as corrections + do not (yet, at any rate) take the pulse table into account. """ - if preview_index < -4: - raise ValueError(f"No statistic with code {preview_index} defined") - elif preview_index >= self.frames: - raise ValueError(f"Memory cell index {preview_index} out of range") - - if preview_index >= 0: - fun = operator.itemgetter(preview_index) - elif preview_index == -1: - # note: separate from next case because dtype not applicable here - fun = functools.partial(self._xp.nanmax, axis=0) - elif preview_index in (-2, -3, -4): - fun = functools.partial( - { - -2: self._xp.nanmean, - -3: self._xp.nansum, - -4: self._xp.nanstd, - }[preview_index], - axis=0, - dtype=self._xp.float32, + # TODO: check if this is robust enough (also base_correction) + num_frames = image_data.shape[0] + processed_buffers = self._make_output_buffers( + num_frames, self._correction_flag_preview + ) + image_data = self._xp.asarray(image_data) # if cupy, will put on GPU + if cell_table is not None: + cell_table = self._xp.asarray(cell_table) + additional_data = [self._xp.asarray(data) for data in additional_data] + self._correct( + self._correction_flag_preview, + image_data, + cell_table, + *additional_data, + *processed_buffers, + ) + + preview_buffers = self._preview_data_views( + image_data, *additional_data, *processed_buffers + ) + + if self._correction_flag_preview != self._correction_flag_enabled: + processed_buffers = self._make_output_buffers( + num_frames, self._correction_flag_enabled ) - # TODO: reuse output buffers - # TODO: integrate multithreading - res = (fun(in_buffer_view) for in_buffer_view in self.preview_data_views) - if self._gpu_based: - res = (buf.get() for buf in res) - return res - - def reshape(self, output_order, out=None): - """Move axes to desired output order""" - # TODO: avoid copy - if output_order == self._corrected_axis_order: - self.reshaped_data = self.processed_data - else: - self.reshaped_data = self.processed_data.transpose( - utils.transpose_order(self._corrected_axis_order, output_order) + self._correct( + self._correction_flag_enabled, + image_data, + cell_table, + *additional_data, + *processed_buffers, ) - - if self._gpu_based: - return self.reshaped_data.get(out=out) - else: - if out is None: - return self.reshaped_data - else: - out[:] = self.reshaped_data + return self._correction_applied_hash, processed_buffers[0], preview_buffers kernel_dir = pathlib.Path(__file__).absolute().parent / "kernels" diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py index d97f1fc4fc3738ef37bd94fa1bb3b8ee186b3bfb..6632b6ef2da64527d633b637d265125c7050583e 100644 --- a/src/calng/corrections/AgipdCorrection.py +++ b/src/calng/corrections/AgipdCorrection.py @@ -9,30 +9,36 @@ from karabo.bound import ( OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, + UINT32_ELEMENT, ) from .. import ( base_calcat, base_correction, base_kernel_runner, + schemas, utils, ) +from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec from .._version import version as deviceVersion class Constants(enum.Enum): ThresholdsDark = enum.auto() Offset = enum.auto() + SlopesCS = enum.auto() SlopesPC = enum.auto() SlopesFF = enum.auto() BadPixelsDark = enum.auto() + BadPixelsCS = enum.auto() BadPixelsPC = enum.auto() BadPixelsFF = enum.auto() bad_pixel_constants = { Constants.BadPixelsDark, + Constants.BadPixelsCS, Constants.BadPixelsPC, Constants.BadPixelsFF, } @@ -51,146 +57,389 @@ class CorrectionFlags(enum.IntFlag): NONE = 0 THRESHOLD = 1 OFFSET = 2 - BLSHIFT = 4 - REL_GAIN_PC = 8 - GAIN_XRAY = 16 - BPMASK = 32 - FORCE_MG_IF_BELOW = 64 - FORCE_HG_IF_BELOW = 128 + REL_GAIN_PC = 4 + GAIN_XRAY = 8 + BPMASK = 16 + FORCE_MG_IF_BELOW = 32 + FORCE_HG_IF_BELOW = 64 + COMMON_MODE = 128 + + +correction_steps = ( + # step name (used in schema), flag to enable for kernel, constants required + ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}), + ("offset", CorrectionFlags.OFFSET, {Constants.Offset}), + ( # TODO: specify that /both/ constants are needed, not just one or the other + "forceMgIfBelow", + CorrectionFlags.FORCE_MG_IF_BELOW, + {Constants.ThresholdsDark, Constants.Offset}, + ), + ( + "forceHgIfBelow", + CorrectionFlags.FORCE_HG_IF_BELOW, + {Constants.ThresholdsDark, Constants.Offset}, + ), + ("relGain", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesCS, Constants.SlopesPC}), + ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}), + ( + "badPixels", + CorrectionFlags.BPMASK, + bad_pixel_constants | { + None, # means stay available even without constants loaded + }, + ), + ("commonMode", CorrectionFlags.COMMON_MODE, {None}), +) class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner): - _corrected_axis_order = "fxy" + _bad_pixel_constants = bad_pixel_constants + _correction_flag_class = CorrectionFlags + _correction_steps = correction_steps + num_pixels_ss = 512 + num_pixels_fs = 128 - @property - def input_shape(self): - return (self.frames, 2, self.pixels_x, self.pixels_y) + @classmethod + def add_schema(cls, schema): + # will add node, enable / preview toggles + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) - @property - def processed_shape(self): - return (self.frames, self.pixels_x, self.pixels_y) + # additional settings specific to AGIPD correction steps + # turn off some advanced corrections by default + for flag in ("enable", "preview"): + ( + OVERWRITE_ELEMENT(schema) + .key(f"corrections.commonMode.{flag}") + .setNewDefaultValue(False) + .commit(), + ) + for step in ("forceMgIfBelow", "forceHgIfBelow"): + ( + OVERWRITE_ELEMENT(schema) + .key(f"corrections.{step}.{flag}") + .setNewDefaultValue(False) + .commit(), + ) + ( + FLOAT_ELEMENT(schema) + .key("corrections.forceMgIfBelow.hardThreshold") + .tags("managed") + .description( + "If enabled, any pixels assigned to low gain stage which would be " + "below this threshold after having medium gain offset subtracted will " + "be reassigned to medium gain." + ) + .assignmentOptional() + .defaultValue(1000) + .reconfigurable() + .commit(), - @property - def map_shape(self): - return (self.constant_memory_cells, self.pixels_x, self.pixels_y) + FLOAT_ELEMENT(schema) + .key("corrections.forceHgIfBelow.hardThreshold") + .tags("managed") + .description( + "Like forceMgIfBelow, but potentially reassigning from medium gain to " + "high gain based on threshold and pixel value minus low gain offset. " + "Applied after forceMgIfBelow, pixels can theoretically be reassigned " + "twice, from LG to MG and to HG." + ) + .assignmentOptional() + .defaultValue(1000) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key("corrections.relGain.sourceOfSlopes") + .tags("managed") + .displayedName("Kind of slopes") + .description( + "Slopes for relative gain correction can be derived from pulse " + "capacitor (PC) scans or current source scans (CS). Choose one!" + ) + .assignmentOptional() + .defaultValue("PC") + .options("PC,CS") + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key("corrections.relGain.adjustMgBaseline") + .tags("managed") + .displayedName("Adjust MG baseline") + .description( + "If set, an additional offset is applied to pixels in medium gain. " + "The value of the offset is either computed per pixel (and memory " + "cell) based on currently loaded relative gain slopes (PC or CS) or " + "it is set to a specific value when overrideMdAdditionalOffset is set." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key("corrections.relGain.overrideMdAdditionalOffset") + .tags("managed") + .displayedName("Override md_additional_offset") + .description( + "If set, the additional offset applied to medium gain pixels " + "(assuming adjustMgBaseLine is set) will use the value from " + "mdAdditionalOffset globally for. Otherwise, the additional offset " + "is derived from the relevant constants." + ) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(schema) + .key("corrections.relGain.mdAdditionalOffset") + .tags("managed") + .displayedName("Value for md_additional_offset (if overriding)") + .description( + "Normally, md_additional_offset (part of relative gain correction) is " + "computed when loading SlopesPC. In case you want to use a different " + "value (global for all medium gain pixels), you can specify it here " + "and set corrections.overrideMdAdditionalOffset to True." + ) + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(schema) + .key("corrections.gainXray.gGainValue") + .tags("managed") + .displayedName("G_gain_value") + .description( + "Newer X-ray gain correction constants are absolute. The default " + "G_gain_value of 1 means that output is expected to be in keV. If " + "this is not desired, one can here specify the mean X-ray gain value " + "over all modules to get ADU values out - operator must manually " + "find this mean value." + ) + .assignmentOptional() + .defaultValue(1) + .reconfigurable() + .commit(), + + UINT32_ELEMENT(schema) + .key("corrections.commonMode.iterations") + .tags("managed") + .assignmentOptional() + .defaultValue(4) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(schema) + .key("corrections.commonMode.minFrac") + .tags("managed") + .assignmentOptional() + .defaultValue(0.15) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(schema) + .key("corrections.commonMode.noisePeakRange") + .tags("managed") + .assignmentOptional() + .defaultValue(35) + .reconfigurable() + .commit(), + ) + + def reconfigure(self, config): + super().reconfigure(config) + + if config.has("corrections.gainXray.gGainValue"): + self.g_gain_value = self._xp.float32( + config["corrections.gainXray.gGainValue"] + ) + + if config.has("corrections.forceHgIfBelow.hardThreshold"): + self.hg_hard_threshold = self._xp.float32( + config["corrections.forceHgIfBelow.hardThreshold"] + ) + + if config.has("corrections.forceMgIfBelow.hardThreshold"): + self.mg_hard_threshold = self._xp.float32( + config["corrections.forceMgIfBelow.hardThreshold"] + ) + + if ( + config.has("corrections.relGain.overrideMdAdditionalOffset") + or config.has("corrections.relGain.mdAdditionalOffset") + ): + if ( + self._get("corrections.relGain.overrideMdAdditionalOffset") + and self._get("corrections.relGain.adjustMgBaseLine") + ): + self.md_additional_offset.fill( + self._xp.float32( + self._get("corrections.relGain.mdAdditionalOffset") + ) + ) + else: + self._device._please_send_me_cached_constants( + { + Constants.SlopesPC + if self._get("corrections.relGain.sourceOfSlopes") == "PC" + else Constants.SlopesCS + }, + self.load_constant, + ) + + if config.has("corrections.relGain.sourceOfSlopes"): + self.flush_buffers( + { + Constants.SlopesPC, + Constants.SlopesCS, + Constants.BadPixelsCS, + Constants.BadPixelsPC, + } + ) + + # note: will try to reload all bad pixels, but + # _load_constant will reject PC/CS based on which we want + if self._get("corrections.relGain.sourceOfSlopes") == "PC": + to_get = {Constants.SlopesPC} | bad_pixel_constants + else: + to_get = {Constants.SlopesCS} | bad_pixel_constants + + self._device._please_send_me_cached_constants(to_get, self.load_constant) + + if config.has("corrections.commonMode"): + self.cm_min_frac = self._xp.float32( + self._get("corrections.commonMode.minFrac") + ) + self.cm_noise_peak = self._xp.float32( + self._get("corrections.commonMode.noisePeakRange") + ) + self.cm_iter = self._xp.uint16( + self._get("corrections.commonMode.iterations") + ) + + if config.has("constantParameters.gainMode"): + gain_mode = GainModes[config["constantParameters.gainMode"]] + if gain_mode is GainModes.ADAPTIVE_GAIN: + self.default_gain = self._xp.uint8(gain_mode) + else: + self.default_gain = self._xp.uint8(gain_mode - 1) + + def expected_input_shape(self, num_frames): + return ( + num_frames, + 2, + self.num_pixels_ss, + self.num_pixels_fs, + ) @property - def gm_map_shape(self): - return self.map_shape + (3,) # for gain-mapped constants + def _gm_map_shape(self): + return self._map_shape + (3,) # for gain-mapped constants @property def threshold_map_shape(self): - return self.map_shape + (2,) - - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - gain_mode=GainModes.ADAPTIVE_GAIN, - g_gain_value=1, - mg_hard_threshold=2000, - hg_hard_threshold=2000, - override_md_additional_offset=None, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - self.gain_mode = gain_mode - if self.gain_mode is GainModes.ADAPTIVE_GAIN: - self.default_gain = self._xp.uint8(gain_mode) - else: - self.default_gain = self._xp.uint8(gain_mode - 1) - self.gain_map = self._xp.empty(self.processed_shape, dtype=np.float32) + return self._map_shape + (2,) + def _setup_constant_buffers(self): # constants self.gain_thresholds = self._xp.empty( self.threshold_map_shape, dtype=np.float32 ) - self.offset_map = self._xp.empty(self.gm_map_shape, dtype=np.float32) - self.rel_gain_pc_map = self._xp.empty(self.gm_map_shape, dtype=np.float32) - # not gm_map_shape because it only applies to medium gain pixels - self.md_additional_offset = self._xp.empty(self.map_shape, dtype=np.float32) - self.rel_gain_xray_map = self._xp.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = self._xp.empty(self.gm_map_shape, dtype=np.uint32) - self.override_md_additional_offset(override_md_additional_offset) - self.set_bad_pixel_mask_value(bad_pixel_mask_value) - self.set_mg_hard_threshold(mg_hard_threshold) - self.set_hg_hard_threshold(hg_hard_threshold) - self.set_g_gain_value(g_gain_value) + self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + self.rel_gain_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + # not _gm_map_shape because it only applies to medium gain pixels + self.md_additional_offset = self._xp.empty(self._map_shape, dtype=np.float32) + self.xray_gain_map = self._xp.empty(self._map_shape, dtype=np.float32) + self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32) self.flush_buffers(set(Constants)) - self.processed_data = self._xp.empty(self.processed_shape, output_data_dtype) - @property - def preview_data_views(self): + def _make_output_buffers(self, num_frames, flags): + output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs) + return [ + self._xp.empty(output_shape, dtype=self._xp.float32), # image + self._xp.empty(output_shape, dtype=self._xp.float32), # gain + ] + + def _preview_data_views(self, raw_data, processed_data, gain_map): return ( - self.input_data[:, 0], # raw - self.processed_data, # corrected - self.input_data[:, 1], # raw gain - self.gain_map, # digitized gain + raw_data[:, 0], # raw + processed_data, # corrected + raw_data[:, 1], # raw gain + gain_map, # digitized gain ) - def load_constant(self, constant, constant_data): + def _load_constant(self, constant, constant_data): if constant is Constants.ThresholdsDark: # shape: y, x, memory cell, thresholds and gain values # note: the gain values are something like means used to derive thresholds self.gain_thresholds[:] = self._xp.asarray( constant_data[..., :2], dtype=np.float32 - ).transpose((2, 1, 0, 3))[:self.constant_memory_cells] + ).transpose((2, 1, 0, 3))[:self._constant_memory_cells] elif constant is Constants.Offset: # shape: y, x, memory cell, gain stage self.offset_map[:] = self._xp.asarray( constant_data, dtype=np.float32 - ).transpose((2, 1, 0, 3))[:self.constant_memory_cells] + ).transpose((2, 1, 0, 3))[:self._constant_memory_cells] + elif constant is Constants.SlopesCS: + if self._get("corrections.relGain.sourceOfSlopes") == "PC": + return + self.rel_gain_map.fill(1) + self.rel_gain_map[..., 1] = self.rel_gain_map[..., 0] * self._xp.asarray( + constant_data[ ..., :self._constant_memory_cells, 6] + ).T + self.rel_gain_map[..., 2] = self.rel_gain_map[..., 1] * self._xp.asarray( + constant_data[ ..., :self._constant_memory_cells, 7] + ).T + if self._get("corrections.relGain.adjustMgBaseline"): + self.md_additional_offset.fill( + self._get("corrections.relGain.mdAdditionalOffset") + ) + else: + self.md_additional_offset.fill(0) elif constant is Constants.SlopesPC: - # pc has funny shape (11, 352, 128, 512) from file - # this is (fi, memory cell, y, x) - # the following may contain NaNs, though... + if self._get("corrections.relGain.sourceOfSlopes") == "CS": + return + # pc has funny shape (11, 352, 128, 512) from file (fi, memory cell, y, x) + # assume newer variant (TODO: check) which is pre-sanitized hg_slope = constant_data[0] hg_intercept = constant_data[1] mg_slope = constant_data[3] mg_intercept = constant_data[4] - # TODO: remove sanitization (should happen in constant preparation notebook) - # from agipdlib.py: replace NaN with median (per memory cell) - # note: suffixes in agipdlib are "_m" and "_l", should probably be "_I" - for naughty_array in (hg_slope, hg_intercept, mg_slope, mg_intercept): - medians = np.nanmedian(naughty_array, axis=(1, 2)) - nan_bool = np.isnan(naughty_array) - nan_cell, _, _ = np.where(nan_bool) - naughty_array[nan_bool] = medians[nan_cell] - - too_low_bool = naughty_array < 0.8 * medians[:, np.newaxis, np.newaxis] - too_low_cell, _, _ = np.where(too_low_bool) - naughty_array[too_low_bool] = medians[too_low_cell] - - too_high_bool = naughty_array > 1.2 * medians[:, np.newaxis, np.newaxis] - too_high_cell, _, _ = np.where(too_high_bool) - naughty_array[too_high_bool] = medians[too_high_cell] + + hg_slope_median = np.nanmedian(hg_slope, axis=(1, 2)) + mg_slope_median = np.nanmedian(mg_slope, axis=(1, 2)) frac_hg_mg = hg_slope / mg_slope rel_gain_map = np.ones( - (3, self.constant_memory_cells, self.pixels_y, self.pixels_x), + ( + 3, + self._constant_memory_cells, + self.num_pixels_fs, + self.num_pixels_ss, + ), dtype=np.float32, ) rel_gain_map[1] = rel_gain_map[0] * frac_hg_mg rel_gain_map[2] = rel_gain_map[1] * 4.48 - self.rel_gain_pc_map[:] = self._xp.asarray( + self.rel_gain_map[:] = self._xp.asarray( rel_gain_map.transpose((1, 3, 2, 0)), dtype=np.float32 - )[:self.constant_memory_cells] - if self._md_additional_offset_value is None: - md_additional_offset = ( - hg_intercept - mg_intercept * frac_hg_mg - ).astype(np.float32) - self.md_additional_offset[:] = self._xp.asarray( - md_additional_offset.transpose((0, 2, 1)), dtype=np.float32 - )[:self.constant_memory_cells] + )[:self._constant_memory_cells] + if self._get("corrections.relGain.adjustMgBaseline"): + if self._get("corrections.relGain.overrideMdAdditionalOffset"): + self.md_additional_offset.fill( + self._get("corrections.relGain.mdAdditionalOffset") + ) + else: + self.md_additional_offset[:] = self._xp.asarray( + ( + hg_intercept - mg_intercept * frac_hg_mg + ).astype(np.float32).transpose((0, 2, 1)), dtype=np.float32 + )[:self._constant_memory_cells] + else: + self.md_additional_offset.fill(0) elif constant is Constants.SlopesFF: # constant shape: y, x, memory cell if constant_data.shape[2] == 2: @@ -199,35 +448,46 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner): # note: we should not support this in online constant_data = self._xp.broadcast_to( constant_data[..., 0][..., np.newaxis], - (self.pixels_y, self.pixels_x, self.constant_memory_cells), + ( + self.num_pixels_fs, + self.num_pixels_ss, + self._constant_memory_cells, + ), ) - self.rel_gain_xray_map[:] = self._xp.asarray( + self.xray_gain_map[:] = self._xp.asarray( constant_data.transpose(), dtype=np.float32 - )[:self.constant_memory_cells] + )[:self._constant_memory_cells] else: assert constant in bad_pixel_constants + if ( + constant is Constants.BadPixelsCS + and self._get("corrections.relGain.sourceOfSlopes") == "PC" + or constant is Constants.BadPixelsPC + and self._get("corrections.relGain.sourceOfSlopes") == "CS" + ): + return # will simply OR with already loaded, does not take into account which ones constant_data = self._xp.asarray(constant_data, dtype=np.uint32) if len(constant_data.shape) == 3: if constant_data.shape == ( - self.pixels_y, - self.pixels_x, - self.constant_memory_cells, + self.num_pixels_fs, + self.num_pixels_ss, + self._constant_memory_cells, ): # BadPixelsFF is not per gain stage - broadcasting along gain constant_data = self._xp.broadcast_to( constant_data.transpose()[..., np.newaxis], - self.gm_map_shape, + self._gm_map_shape, ) elif constant_data.shape == ( - self.constant_memory_cells, - self.pixels_y, - self.pixels_x, + self._constant_memory_cells, + self.num_pixels_fs, + self.num_pixels_ss, ): # old BadPixelsPC have different axis order constant_data = self._xp.broadcast_to( constant_data.transpose((0, 2, 1))[..., np.newaxis], - self.gm_map_shape, + self._gm_map_shape, ) else: raise ValueError( @@ -236,142 +496,126 @@ class AgipdBaseRunner(base_kernel_runner.BaseKernelRunner): else: # gain mapped constants seem consistent constant_data = constant_data.transpose((2, 1, 0, 3)) - self.bad_pixel_map |= constant_data[:self.constant_memory_cells] - - def override_md_additional_offset(self, override_value): - self._md_additional_offset_value = override_value - if override_value is not None: - self.md_additional_offset.fill(override_value) - - def set_g_gain_value(self, override_value): - self.g_gain_value = self._xp.float32(override_value) - - def set_bad_pixel_mask_value(self, mask_value): - self.bad_pixel_mask_value = self._xp.float32(mask_value) - - def set_mg_hard_threshold(self, value): - self.mg_hard_threshold = self._xp.float32(value) - - def set_hg_hard_threshold(self, value): - self.hg_hard_threshold = self._xp.float32(value) + constant_data &= self._xp.uint32(self.bad_pixel_subset) + self.bad_pixel_map |= constant_data[:self._constant_memory_cells] def flush_buffers(self, constants): if Constants.Offset in constants: self.offset_map.fill(0) - if Constants.SlopesPC in constants: - self.rel_gain_pc_map.fill(1) + if constants & {Constants.SlopesCS, Constants.SlopesPC}: + self.rel_gain_map.fill(1) self.md_additional_offset.fill(0) if Constants.SlopesFF: - self.rel_gain_xray_map.fill(1) + self.xray_gain_map.fill(1) if constants & bad_pixel_constants: self.bad_pixel_map.fill(0) - self.bad_pixel_map[ - :, 64:512:64 - ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value - self.bad_pixel_map[ - :, 63:511:64 - ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE: + self._mask_asic_seams + + def _mask_asic_seams(self): + self.bad_pixel_map[:, 64:512:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map[:, 63:511:64] |= utils.BadPixelValues.NON_STANDARD_SIZE.value class AgipdCpuRunner(AgipdBaseRunner): _xp = np - _gpu_based = False - def correct(self, flags): + def _correct(self, flags, image_data, cell_table, processed_data, gain_map): + if flags & CorrectionFlags.COMMON_MODE: + raise NotImplementedError("Common mode not available on CPU yet") self.correction_kernel( - self.input_data, - self.cell_table, + image_data, + cell_table, flags, self.default_gain, self.gain_thresholds, self.mg_hard_threshold, self.hg_hard_threshold, self.offset_map, - self.rel_gain_pc_map, + self.rel_gain_map, self.md_additional_offset, - self.rel_gain_xray_map, + self.xray_gain_map, self.g_gain_value, self.bad_pixel_map, self.bad_pixel_mask_value, - self.processed_data, + processed_data, ) - def _post_init(self): - self.input_data = None - self.cell_table = None - # NOTE: CPU kernel does not yet support anything other than float32 - self.processed_data = self._xp.empty( - self.processed_shape, dtype=self.output_data_dtype - ) + def _pre_init(self): from ..kernels import agipd_cython self.correction_kernel = agipd_cython.correct - def load_data(self, image_data, cell_table): - self.input_data = image_data - self.cell_table = cell_table - class AgipdGpuRunner(AgipdBaseRunner): - _gpu_based = True - def _pre_init(self): - import cupy as cp - - self._xp = cp + import cupy + self._xp = cupy def _post_init(self): - self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = self._xp.empty(self.frames, dtype=np.uint16) - self.block_shape = (1, 1, 64) - self.grid_shape = utils.grid_to_cover_shape_with_blocks( - self.processed_shape, self.block_shape - ) - self.correction_kernel = self._xp.RawModule( + self.correction_kernel = self._xp.RawKernel( code=base_kernel_runner.get_kernel_template("agipd_gpu.cu").render( - { - "pixels_x": self.pixels_x, - "pixels_y": self.pixels_y, - "frames": self.frames, - "constant_memory_cells": self.constant_memory_cells, - "output_data_dtype": utils.np_dtype_to_c_type( - self.output_data_dtype - ), - "corr_enum": utils.enum_to_c_template(CorrectionFlags), - } - ) - ).get_function("correct") - - def load_data(self, image_data, cell_table): - self.input_data.set(image_data) - self.cell_table.set(cell_table) + pixels_x=self.num_pixels_ss, + pixels_y=self.num_pixels_fs, + output_data_dtype=utils.np_dtype_to_c_type(self._output_dtype), + corr_enum=utils.enum_to_c_template(CorrectionFlags), + ), + name="correct", + ) + self.cm_kernel = self._xp.RawKernel( + code=base_kernel_runner.get_kernel_template("common_gpu.cu").render( + ss_dim=self.num_pixels_ss, + fs_dim=self.num_pixels_fs, + asic_dim=64, + ), + name="common_mode_asic", + ) - def correct(self, flags): - if flags & CorrectionFlags.BLSHIFT: - raise NotImplementedError("Baseline shift not implemented yet") + def _correct(self, flags, image_data, cell_table, processed_data, gain_map): + num_frames = self._xp.uint16(image_data.shape[0]) + block_shape = (1, 1, 64) + grid_shape = utils.grid_to_cover_shape_with_blocks( + processed_data.shape, block_shape + ) self.correction_kernel( - self.grid_shape, - self.block_shape, + grid_shape, + block_shape, ( - self.input_data, - self.cell_table, - self._xp.uint8(flags), + image_data, + cell_table, + flags, + num_frames, + self._constant_memory_cells, self.default_gain, self.gain_thresholds, self.mg_hard_threshold, self.hg_hard_threshold, self.offset_map, - self.rel_gain_pc_map, + self.rel_gain_map, self.md_additional_offset, - self.rel_gain_xray_map, + self.xray_gain_map, self.g_gain_value, self.bad_pixel_map, self.bad_pixel_mask_value, - self.gain_map, - self.processed_data, + gain_map, + processed_data, ), ) + if flags & CorrectionFlags.COMMON_MODE: + # grid: one block per ss asic, fs asic, and frame + self.cm_kernel( + (2, 8, processed_data.shape[0]), + (64, 1, 1), + ( + processed_data, + num_frames, + self.cm_iter, + self.cm_min_frac, + self.cm_noise_peak, + ), + ) + class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = Constants @@ -381,9 +625,11 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): return { Constants.ThresholdsDark: self.dark_condition, Constants.Offset: self.dark_condition, + Constants.SlopesCS: self.dark_condition, Constants.SlopesPC: self.dark_condition, Constants.SlopesFF: self.illuminated_condition, Constants.BadPixelsDark: self.dark_condition, + Constants.BadPixelsCS: self.dark_condition, Constants.BadPixelsPC: self.dark_condition, Constants.BadPixelsFF: self.illuminated_condition, } @@ -489,69 +735,34 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("AgipdCorrection", deviceVersion) class AgipdCorrection(base_correction.BaseCorrection): # subclass *must* set these attributes - _base_output_schema = schemas.xtdf_output_schema() - _correction_flag_class = CorrectionFlags - _correction_steps = ( - # step name (used in schema), flag to enable for kernel, constants required - ("thresholding", CorrectionFlags.THRESHOLD, {Constants.ThresholdsDark}), - ("offset", CorrectionFlags.OFFSET, {Constants.Offset}), - ( # TODO: specify that /both/ constants are needed, not just one or the other - "forceMgIfBelow", - CorrectionFlags.FORCE_MG_IF_BELOW, - {Constants.ThresholdsDark, Constants.Offset}, - ), - ( - "forceHgIfBelow", - CorrectionFlags.FORCE_HG_IF_BELOW, - {Constants.ThresholdsDark, Constants.Offset}, - ), - ("relGainPc", CorrectionFlags.REL_GAIN_PC, {Constants.SlopesPC}), - ("gainXray", CorrectionFlags.GAIN_XRAY, {Constants.SlopesFF}), - ( - "badPixels", - CorrectionFlags.BPMASK, - bad_pixel_constants | { - None, # means stay available even without constants loaded - }, - ), - ) + _base_output_schema = schemas.xtdf_output_schema _calcat_friend_class = AgipdCalcatFriend _constant_enum_class = Constants + _correction_steps = correction_steps + _preview_outputs = [ - "outputRaw", - "outputCorrected", - "outputRawGain", - "outputGainMap", + PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.CELL) + for name in ("raw", "corrected", "rawGain", "gainMap") ] - @staticmethod - def expectedParameters(expected): + @classmethod + def expectedParameters(cls, expected): + # this is not automatically done by superclass for complicated class reasons + cls._calcat_friend_class.add_schema(expected) + AgipdBaseRunner.add_schema(expected) + base_correction.add_addon_nodes(expected, cls) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(AgipdCorrection._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(352) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("cell") + .dataSchema( + AgipdCorrection._base_output_schema( + use_shmem_handles=cls._use_shmem_handles + ) + ) .commit(), ) - # this is not automatically done by superclass for complicated class reasons - AgipdCalcatFriend.add_schema(expected) - base_correction.add_addon_nodes(expected, AgipdCorrection) - base_correction.add_correction_step_schema( - expected, - AgipdCorrection._correction_steps, - ) - base_correction.add_bad_pixel_config_node(expected) - base_correction.add_preview_outputs(expected, AgipdCorrection._preview_outputs) ( # support both CPU and GPU kernels STRING_ELEMENT(expected) @@ -569,130 +780,6 @@ class AgipdCorrection(base_correction.BaseCorrection): .commit(), ) - # turn off the force MG / HG steps by default - for step in ("forceMgIfBelow", "forceHgIfBelow"): - for flag in ("enable", "preview"): - ( - OVERWRITE_ELEMENT(expected) - .key(f"corrections.{step}.{flag}") - .setNewDefaultValue(False) - .commit(), - ) - # additional settings specific to AGIPD correction steps - ( - FLOAT_ELEMENT(expected) - .key("corrections.forceMgIfBelow.hardThreshold") - .tags("managed") - .description( - "If enabled, any pixels assigned to low gain stage which would be " - "below this threshold after having medium gain offset subtracted will " - "be reassigned to medium gain." - ) - .assignmentOptional() - .defaultValue(1000) - .reconfigurable() - .commit(), - - FLOAT_ELEMENT(expected) - .key("corrections.forceHgIfBelow.hardThreshold") - .tags("managed") - .description( - "Like forceMgIfBelow, but potentially reassigning from medium gain to " - "high gain based on threshold and pixel value minus low gain offset. " - "Applied after forceMgIfBelow, pixels can theoretically be reassigned " - "twice, from LG to MG and to HG." - ) - .assignmentOptional() - .defaultValue(1000) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.relGainPc.overrideMdAdditionalOffset") - .tags("managed") - .displayedName("Override md_additional_offset") - .description( - "Toggling this on will use the value in the next field globally for " - "md_additional_offset. Note that the correction map on GPU gets " - "overwritten as long as this boolean is True, so reload constants " - "after turning off." - ) - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - - FLOAT_ELEMENT(expected) - .key("corrections.relGainPc.mdAdditionalOffset") - .tags("managed") - .displayedName("Value for md_additional_offset (if overriding)") - .description( - "Normally, md_additional_offset (part of relative gain correction) is " - "computed when loading SlopesPC. In case you want to use a different " - "value (global for all medium gain pixels), you can specify it here " - "and set corrections.overrideMdAdditionalOffset to True." - ) - .assignmentOptional() - .defaultValue(0) - .reconfigurable() - .commit(), - - FLOAT_ELEMENT(expected) - .key("corrections.gainXray.gGainValue") - .tags("managed") - .displayedName("G_gain_value") - .description( - "Newer X-ray gain correction constants are absolute. The default " - "G_gain_value of 1 means that output is expected to be in keV. If " - "this is not desired, one can here specify the mean X-ray gain value " - "over all modules to get ADU values out - operator must manually " - "find this mean value." - ) - .assignmentOptional() - .defaultValue(1) - .reconfigurable() - .commit(), - ) - - @property - def input_data_shape(self): - return ( - self.unsafe_get("dataFormat.frames"), - 2, - self.unsafe_get("dataFormat.pixelsX"), - self.unsafe_get("dataFormat.pixelsY"), - ) - - def __init__(self, config): - super().__init__(config) - # note: gain mode single sourced from constant retrieval node - try: - np.float32(config.get("corrections.badPixels.maskingValue")) - except ValueError: - config["corrections.badPixels.maskingValue"] = "nan" - - self._has_updated_bad_pixel_selection = False - - @property - def bad_pixel_mask_value(self): - return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) - - _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection) - - @property - def _kernel_runner_init_args(self): - return { - "gain_mode": self.gain_mode, - "bad_pixel_mask_value": self.bad_pixel_mask_value, - "g_gain_value": self.unsafe_get("corrections.gainXray.gGainValue"), - "mg_hard_threshold": self.unsafe_get( - "corrections.forceMgIfBelow.hardThreshold" - ), - "hg_hard_threshold": self.unsafe_get( - "corrections.forceHgIfBelow.hardThreshold" - ), - } - @property def _kernel_runner_class(self): kernel_type = base_kernel_runner.KernelRunnerTypes[ @@ -702,159 +789,3 @@ class AgipdCorrection(base_correction.BaseCorrection): return AgipdCpuRunner else: return AgipdGpuRunner - - @property - def gain_mode(self): - return GainModes[self.unsafe_get("constantParameters.gainMode")] - - @property - def _override_md_additional_offset(self): - if self.unsafe_get("corrections.relGainPc.overrideMdAdditionalOffset"): - return self.unsafe_get("corrections.relGainPc.mdAdditionalOffset") - else: - return None - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, - ): - """Called by input_handler for each data hash. Should correct data, optionally - compute preview, write data output, and optionally write preview outputs.""" - # original shape: frame, data/raw_gain, x, y - if self._frame_filter is not None: - try: - cell_table = cell_table[self._frame_filter] - pulse_table = pulse_table[self._frame_filter] - image_data = image_data[self._frame_filter] - except IndexError: - self.log_status_warn( - "Failed to apply frame filter, please check that it is valid!" - ) - return - - try: - self.kernel_runner.load_data(image_data, cell_table.ravel()) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - self.log_status_warn( - f"For debugging - cell table: {type(cell_table)}, {cell_table}" - ) - self.log_status_warn( - f"For debugging - data: {type(image_data)}, {image_data.shape}" - ) - return - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") - - # first prepare previews (not affected by addons) - self.kernel_runner.correct(self._correction_flag_preview) - with self.warning_context( - "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS - ) as warn: - ( - preview_slice_index, - preview_cell, - preview_pulse, - ), preview_warning = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - pulse_table, - ) - if preview_warning is not None: - warn(preview_warning) - ( - preview_raw, - preview_corrected, - preview_raw_gain, - preview_gain_map, - ) = self.kernel_runner.compute_previews(preview_slice_index) - # TODO: start writing out previews asynchronously in the background - - # then prepare full corrected data - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_enabled) - for addon in self._enabled_addons: - addon.post_correction( - self.kernel_runner.processed_data, cell_table, pulse_table, data_hash - ) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - for addon in self._enabled_addons: - addon.post_reshape( - buffer_array, cell_table, pulse_table, data_hash - ) - - # reusing input data hash for sending - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) - data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis]) - - self._write_output(data_hash, metadata) - self._preview_friend.write_outputs( - metadata, preview_raw, preview_corrected, preview_raw_gain, preview_gain_map - ) - - def _load_constant_to_runner(self, constant, constant_data): - if constant in bad_pixel_constants: - constant_data &= self._override_bad_pixel_flags - self.kernel_runner.load_constant(constant, constant_data) - - def preReconfigure(self, config): - super().preReconfigure(config) - if config.has("corrections.badPixels.maskingValue"): - # only check if it is valid; postReconfigure will use it - np.float32(config.get("corrections.badPixels.maskingValue")) - - def postReconfigure(self): - super().postReconfigure() - - update = self._prereconfigure_update_hash - - if update.has("constantParameters.gainMode"): - self.flush_constants() - self._update_buffers() - - if update.has("corrections.forceMgIfBelow.hardThreshold"): - self.kernel_runner.set_mg_hard_threshold( - self.get("corrections.forceMgIfBelow.hardThreshold") - ) - - if update.has("corrections.forceHgIfBelow.hardThreshold"): - self.kernel_runner.set_hg_hard_threshold( - self.get("corrections.forceHgIfBelow.hardThreshold") - ) - - if update.has("corrections.gainXray.gGainValue"): - self.kernel_runner.set_g_gain_value( - self.get("corrections.gainXray.gGainValue") - ) - - if update.has("corrections.badPixels.maskingValue"): - self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) - - if update.has("corrections.badPixels.subsetToUse"): - self.log_status_info("Updating bad pixel maps based on subset specified") - # note: now just always reloading from cache for convenience - with self.calcat_friend.cached_constants_lock: - self.kernel_runner.flush_buffers(bad_pixel_constants) - for constant in bad_pixel_constants: - if constant in self.calcat_friend.cached_constants: - self._load_constant_to_runner( - constant, self.calcat_friend.cached_constants[constant] - ) diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py index 9a6cc613b0360c93249d47aed841cdce85045869..3611b26721e42ac4f500a7940e3bea9489cba4d8 100644 --- a/src/calng/corrections/DsscCorrection.py +++ b/src/calng/corrections/DsscCorrection.py @@ -8,7 +8,14 @@ from karabo.bound import ( STRING_ELEMENT, ) -from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils +from .. import ( + base_calcat, + base_correction, + base_kernel_runner, + schemas, + utils, +) +from ..preview_utils import FrameSelectionMode, PreviewFriend, PreviewSpec from .._version import version as deviceVersion @@ -21,131 +28,98 @@ class Constants(enum.Enum): Offset = enum.auto() -class DsscBaseRunner(base_kernel_runner.BaseKernelRunner): - _corrected_axis_order = "fyx" - _xp = None +correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),) - @property - def map_shape(self): - return (self.constant_memory_cells, self.pixels_y, self.pixels_x) - @property - def input_shape(self): - return (self.frames, self.pixels_y, self.pixels_x) +class DsscBaseRunner(base_kernel_runner.BaseKernelRunner): + _correction_flag_class = CorrectionFlags + _correction_steps = correction_steps + num_pixels_ss = 128 + num_pixels_fs = 512 - @property - def processed_shape(self): - return self.input_shape - - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32) - self.processed_data = self._xp.empty( - self.processed_shape, dtype=output_data_dtype + def expected_input_shape(self, num_frames): + return ( + num_frames, + 1, + self.num_pixels_ss, + self.num_pixels_fs, ) + def _setup_constant_buffers(self): + self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32) + def flush_buffers(self, constants): if Constants.Offset in constants: self.offset_map.fill(0) - @property - def preview_data_views(self): - return ( - self.input_data, - self.processed_data, - ) - - def load_offset_map(self, offset_map): + def _load_constant(self, constant_type, data): + assert constant_type is Constants.Offset # can have an extra dimension for some reason - if len(offset_map.shape) == 4: # old format (see offsetcorrection_dssc.py)? - offset_map = offset_map[..., 0] + if len(data.shape) == 4: # old format (see offsetcorrection_dssc.py)? + data = data[..., 0] # shape (now): x, y, memory cell - offset_map = self._xp.asarray(offset_map, dtype=np.float32).transpose() - self.offset_map[:] = offset_map[:self.constant_memory_cells] + data = self._xp.asarray(data, dtype=np.float32).transpose() + self.offset_map[:] = data[:self._constant_memory_cells] + + def _preview_data_views(self, raw_data, processed_data): + return ( + raw_data[:, 0], + processed_data, + ) class DsscCpuRunner(DsscBaseRunner): - _gpu_based = False _xp = np def _post_init(self): - self.input_data = None - self.cell_table = None from ..kernels import dssc_cython - self.correction_kernel = dssc_cython.correct - def load_data(self, image_data, cell_table): - self.input_data = image_data.astype(np.uint16, copy=False) - self.cell_table = cell_table.astype(np.uint16, copy=False) - - def correct(self, flags): + def _correct(self, flags, image_data, cell_table, output): self.correction_kernel( - self.input_data, - self.cell_table, + image_data, + cell_table, flags, self.offset_map, - self.processed_data, + output, ) class DsscGpuRunner(DsscBaseRunner): - _gpu_based = True - def _pre_init(self): - import cupy as cp - - self._xp = cp + import cupy + self._xp = cupy def _post_init(self): - self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = self._xp.empty(self.frames, dtype=np.uint16) - self.block_shape = (1, 1, 64) - self.grid_shape = utils.grid_to_cover_shape_with_blocks( - self.input_shape, self.block_shape - ) self.correction_kernel = self._xp.RawModule( code=base_kernel_runner.get_kernel_template("dssc_gpu.cu").render( { - "pixels_x": self.pixels_x, - "pixels_y": self.pixels_y, - "frames": self.frames, - "constant_memory_cells": self.constant_memory_cells, + "pixels_x": self.num_pixels_fs, + "pixels_y": self.num_pixels_ss, "output_data_dtype": utils.np_dtype_to_c_type( - self.output_data_dtype + self._output_dtype ), "corr_enum": utils.enum_to_c_template(CorrectionFlags), } ) ).get_function("correct") - def load_data(self, image_data, cell_table): - self.input_data.set(image_data) - self.cell_table.set(cell_table) - - def correct(self, flags): + def _correct(self, flags, image_data, cell_table, output): + num_frames = self._xp.uint16(image_data.shape[0]) + block_shape = (1, 1, 64) + grid_shape = utils.grid_to_cover_shape_with_blocks(output.shape, block_shape) + image_data = image_data[:, 0] self.correction_kernel( - self.grid_shape, - self.block_shape, + grid_shape, + block_shape, ( - self.input_data, - self.cell_table, - np.uint8(flags), + image_data, + cell_table, + num_frames, + self._constant_memory_cells, + flags, self.offset_map, - self.processed_data, + output, ), ) @@ -185,44 +159,30 @@ class DsscCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("DsscCorrection", deviceVersion) class DsscCorrection(base_correction.BaseCorrection): - # subclass *must* set these attributes - _base_output_schema = schemas.xtdf_output_schema() - _correction_flag_class = CorrectionFlags - _correction_steps = (("offset", CorrectionFlags.OFFSET, {Constants.Offset}),) + _base_output_schema = schemas.xtdf_output_schema + _correction_steps = correction_steps _calcat_friend_class = DsscCalcatFriend _constant_enum_class = Constants - _preview_outputs = ["outputRaw", "outputCorrected"] - @staticmethod - def expectedParameters(expected): + _preview_outputs = [ + PreviewSpec(name, frame_reduction_selection_mode=FrameSelectionMode.PULSE) + for name in ("raw", "corrected") + ] + + @classmethod + def expectedParameters(cls, expected): + cls._calcat_friend_class.add_schema(expected) + DsscBaseRunner.add_schema(expected) + base_correction.add_addon_nodes(expected, cls) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(DsscCorrection._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(400) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.outputAxisOrder") - .setNewDefaultValue("fyx") - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("pulse") + .dataSchema( + DsscCorrection._base_output_schema(cls._use_shmem_handles) + ) .commit(), ) - DsscCalcatFriend.add_schema(expected) - base_correction.add_addon_nodes(expected, DsscCorrection) - base_correction.add_correction_step_schema( - expected, - DsscCorrection._correction_steps, - ) - base_correction.add_preview_outputs(expected, DsscCorrection._preview_outputs) ( # support both CPU and GPU kernels STRING_ELEMENT(expected) @@ -240,15 +200,6 @@ class DsscCorrection(base_correction.BaseCorrection): .commit(), ) - @property - def input_data_shape(self): - return ( - self.get("dataFormat.frames"), - 1, - self.get("dataFormat.pixelsY"), - self.get("dataFormat.pixelsX"), - ) - @property def _kernel_runner_class(self): kernel_type = base_kernel_runner.KernelRunnerTypes[ @@ -258,80 +209,3 @@ class DsscCorrection(base_correction.BaseCorrection): return DsscCpuRunner else: return DsscGpuRunner - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, - ): - if self._frame_filter is not None: - try: - cell_table = cell_table[self._frame_filter] - pulse_table = pulse_table[self._frame_filter] - image_data = image_data[self._frame_filter] - except IndexError: - self.log_status_warn( - "Failed to apply frame filter, please check that it is valid!" - ) - return - - try: - # drop the singleton dimension for kernel runner - self.kernel_runner.load_data(image_data[:, 0], cell_table.ravel()) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - return - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.kernel_runner.correct(self._correction_flag_enabled) - for addon in self._enabled_addons: - addon.post_correction( - self.kernel_runner.processed_data, cell_table, pulse_table, data_hash - ) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - with self.warning_context( - "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS - ) as warn: - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ), preview_warning = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - pulse_table, - ) - if preview_warning is not None: - warn(preview_warning) - preview_raw, preview_corrected = self.kernel_runner.compute_previews( - preview_slice_index, - ) - - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) - data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis]) - - self._write_output(data_hash, metadata) - self._preview_friend.write_outputs(metadata, preview_raw, preview_corrected) - - def _load_constant_to_runner(self, constant, constant_data): - self.kernel_runner.load_offset_map(constant_data) diff --git a/src/calng/corrections/Epix100Correction.py b/src/calng/corrections/Epix100Correction.py index 5bd981bb7839ed41dff0135daccb47d1b2335e74..b541a72af5a96d70cd320140026284e0bfb40ace 100644 --- a/src/calng/corrections/Epix100Correction.py +++ b/src/calng/corrections/Epix100Correction.py @@ -1,6 +1,5 @@ import concurrent.futures import enum -import functools import numpy as np from karabo.bound import ( @@ -19,6 +18,7 @@ from .. import ( utils, ) from .._version import version as deviceVersion +from ..preview_utils import PreviewFriend, PreviewSpec from ..kernels import common_cython @@ -37,6 +37,14 @@ class CorrectionFlags(enum.IntFlag): BPMASK = 2**4 +correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}), + ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}), + ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}), + ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}), +) + + class Epix100CalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = Constants @@ -138,84 +146,119 @@ class Epix100CalcatFriend(base_calcat.BaseCalcatFriend): class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner): - _corrected_axis_order = "xy" + _correction_steps = correction_steps + _correction_flag_class = CorrectionFlags + _bad_pixel_constants = {Constants.BadPixelsDarkEPix100} _xp = np - _gpu_based = False - @property - def input_shape(self): - return (self.pixels_x, self.pixels_y) + num_pixels_ss = 708 + num_pixels_fs = 768 - @property - def preview_shape(self): - return (self.pixels_x, self.pixels_y) + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) + ( + DOUBLE_ELEMENT(schema) + .key("corrections.commonMode.noiseSigma") + .tags("managed") + .assignmentOptional() + .defaultValue(5) + .reconfigurable() + .commit(), - @property - def processed_shape(self): - return self.input_shape + DOUBLE_ELEMENT(schema) + .key("corrections.commonMode.minFrac") + .tags("managed") + .assignmentOptional() + .defaultValue(0.25) + .reconfigurable() + .commit(), - @property - def map_shape(self): - return (self.pixels_x, self.pixels_y) + BOOL_ELEMENT(schema) + .key("corrections.commonMode.enableRow") + .tags("managed") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), - def __init__( - self, - pixels_x, - pixels_y, - frames, # will be 1, will be ignored - constant_memory_cells, - input_data_dtype=np.uint16, - output_data_dtype=np.float32, - ): - assert ( - output_data_dtype == np.float32 - ), "Alternative output types not supported yet" + BOOL_ELEMENT(schema) + .key("corrections.commonMode.enableCol") + .tags("managed") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key("corrections.commonMode.enableBlock") + .tags("managed") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, + def reconfigure(self, config): + super().reconfigure(config) + if config.has("corrections.commonMode.noiseSigma"): + self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"]) + if config.has("corrections.commonMode.minFrac"): + self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"]) + if config.has("corrections.commonMode.enableRow"): + self._cm_row = config["corrections.commonMode.enableRow"] + if config.has("corrections.commonMode.enableCol"): + self._cm_col = config["corrections.commonMode.enableCol"] + if config.has("corrections.commonMode.enableBlock"): + self._cm_block = config["corrections.commonMode.enableBlock"] + + def expected_input_data_shape(self, num_frames): + assert num_frames == 1 + return ( + self.num_pixels_ss, + self.num_pixels_fs, ) - self.input_data = None - self.processed_data = np.empty(self.processed_shape, dtype=np.float32) + def _expected_output_shape(self, num_frames): + return (self.num_pixels_ss, self.num_pixels_fs) - self.offset_map = np.empty(self.map_shape, dtype=np.float32) - self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32) - self.noise_map = np.empty(self.map_shape, dtype=np.float32) + @property + def _map_shape(self): + return (self.num_pixels_ss, self.num_pixels_fs) + + def _make_output_buffers(self, num_frames, flags): + # ignore parameters + return [ + self._xp.empty( + (self.num_pixels_ss, self.num_pixels_fs), + dtype=self._xp.float32, + ) + ] + + def _setup_constant_buffers(self): + assert ( + self._output_dtype == np.float32 + ), "Alternative output types not supported yet" + + self.offset_map = np.empty(self._map_shape, dtype=np.float32) + self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32) + self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32) + self.noise_map = np.empty(self._map_shape, dtype=np.float32) # will do everything by quadrant self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) - self._q_input_data = None - self._q_processed_data = utils.quadrant_views(self.processed_data) self._q_offset_map = utils.quadrant_views(self.offset_map) self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map) self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map) self._q_noise_map = utils.quadrant_views(self.noise_map) - def __del__(self): - self.thread_pool.shutdown() - - @property - def preview_data_views(self): - # NOTE: apparently there are "calibration rows" - # geometry assumes these are cut out already - return (self.input_data[2:-2], self.processed_data[2:-2]) - - def load_data(self, image_data): - # should almost never squeeze, but ePix doesn't do burst mode, right? - self.input_data = image_data.astype(np.uint16, copy=False).squeeze() - self._q_input_data = utils.quadrant_views(self.input_data) - - def load_constant(self, constant_type, data): + def _load_constant(self, constant_type, data): if constant_type is Constants.OffsetEPix100: self.offset_map[:] = data.squeeze().astype(np.float32) elif constant_type is Constants.RelativeGainEPix100: self.rel_gain_map[:] = data.squeeze().astype(np.float32) elif constant_type is Constants.BadPixelsDarkEPix100: - self.bad_pixel_map[:] = data.squeeze() + self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset) elif constant_type is Constants.NoiseEPix100: self.noise_map[:] = data.squeeze() else: @@ -231,127 +274,115 @@ class Epix100CpuRunner(base_kernel_runner.BaseKernelRunner): if Constants.NoiseEPix100 in constants: self.noise_map.fill(np.inf) + def _preview_data_views(self, raw_data, processed_data): + # NOTE: apparently there are "calibration rows" + # geometry assumes these are cut out already + return (raw_data[2:-2], processed_data[2:-2]) + def _correct_quadrant( self, - q, flags, - bad_pixel_mask_value, - cm_noise_sigma, - cm_min_frac, - cm_row, - cm_col, - cm_block, + input_data, + offset_map, + noise_map, + bad_pixel_map, + rel_gain_map, + output, ): - output = self._q_processed_data[q] - output[:] = self._q_input_data[q].astype(np.float32) + output[:] = input_data.astype(np.float32) if flags & CorrectionFlags.OFFSET: - output -= self._q_offset_map[q] + output -= offset_map if flags & CorrectionFlags.COMMONMODE: # per rectangular block that looks like something is going on cm_mask = ( - (self._q_bad_pixel_map[q] != 0) - | (output > self._q_noise_map[q] * cm_noise_sigma) + (bad_pixel_map != 0) + | (output > noise_map * self._cm_sigma) ).astype(np.uint8, copy=False) masked = np.ma.masked_array(data=output, mask=cm_mask) - if cm_block: + if self._cm_block: for block in np.hsplit(masked, 4): - if block.count() < block.size * cm_min_frac: + if block.count() < block.size * self._cm_minfrac: continue block.data[:] -= np.ma.median(block) - if cm_row: - common_cython.cm_fs(output, cm_mask, cm_noise_sigma, cm_min_frac) + if self._cm_row: + common_cython.cm_fs(output, cm_mask, self._cm_sigma, self._cm_minfrac) - if cm_col: - common_cython.cm_ss(output, cm_mask, cm_noise_sigma, cm_min_frac) + if self._cm_col: + common_cython.cm_ss(output, cm_mask, self._cm_sigma, self._cm_minfrac) if flags & CorrectionFlags.RELGAIN: - output *= self._q_rel_gain_map[q] + output *= rel_gain_map if flags & CorrectionFlags.BPMASK: - output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value + output[bad_pixel_map != 0] = self.bad_pixel_mask_value + + def _correct(self, flags, image_data, cell_table, output): + # ignore cell table None + concurrent.futures.wait( + [ + self.thread_pool.submit( + self._correct_quadrant(flags, *parts) + ) + for parts in zip( + utils.quadrant_views(image_data), + self._q_offset_map, + self._q_noise_map, + self._q_bad_pixel_map, + self._q_rel_gain_map, + utils.quadrant_views(output), + ) + ] + ) - def correct( - self, - flags, - bad_pixel_mask_value=np.nan, - cm_noise_sigma=5, - cm_min_frac=0.25, - cm_row=True, - cm_col=True, - cm_block=True, - ): - # NOTE: how to best clean up all these duplicated parameters? - for result in self.thread_pool.map( - functools.partial( - self._correct_quadrant, - flags=flags, - bad_pixel_mask_value=bad_pixel_mask_value, - cm_noise_sigma=cm_noise_sigma, - cm_min_frac=cm_min_frac, - cm_row=cm_row, - cm_col=cm_col, - cm_block=cm_block, - ), - range(4), - ): - pass # just run through to await map + def __del__(self): + self.thread_pool.shutdown() @KARABO_CLASSINFO("Epix100Correction", deviceVersion) class Epix100Correction(base_correction.BaseCorrection): - _base_output_schema = schemas.jf_output_schema() - _correction_flag_class = CorrectionFlags - _correction_steps = ( - ("offset", CorrectionFlags.OFFSET, {Constants.OffsetEPix100}), - ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainEPix100}), - ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseEPix100}), - ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkEPix100}), - ) - _image_data_path = "data.image.pixels" + _base_output_schema = schemas.jf_output_schema _kernel_runner_class = Epix100CpuRunner - _calcat_friend_class = Epix100CalcatFriend + _correction_steps = correction_steps + _correction_flag_class = CorrectionFlags _constant_enum_class = Constants - _preview_outputs = ["outputRaw", "outputCorrected"] + _calcat_friend_class = Epix100CalcatFriend + + _image_data_path = "data.image.pixels" _cell_table_path = None _pulse_table_path = None - _warn_memory_cell_range = False - @staticmethod - def expectedParameters(expected): + _preview_outputs = [ + PreviewSpec( + name, + dimensions=2, + frame_reduction=False, + ) + for name in ("raw", "corrected") + ] + _warn_memory_cell_range = False + _use_shmem_handles = False + + @classmethod + def expectedParameters(cls, expected): + cls._calcat_friend_class.add_schema(expected) + cls._kernel_runner_class.add_schema(expected) + base_correction.add_addon_nodes(expected, cls) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(Epix100Correction._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(708) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(768) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.outputAxisOrder") - .setNewOptions("xy,yx") - .setNewDefaultValue("xy") + .dataSchema( + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) + ) .commit(), # TODO: disable preview selection mode OVERWRITE_ELEMENT(expected) .key("useShmemHandles") - .setNewDefaultValue(False) + .setNewDefaultValue(cls._use_shmem_handles) .commit(), OVERWRITE_ELEMENT(expected) @@ -360,112 +391,11 @@ class Epix100Correction(base_correction.BaseCorrection): .commit(), ) - base_correction.add_preview_outputs( - expected, Epix100Correction._preview_outputs - ) - base_correction.add_correction_step_schema( - expected, - Epix100Correction._correction_steps, - ) - ( - DOUBLE_ELEMENT(expected) - .key("corrections.commonMode.noiseSigma") - .tags("managed") - .assignmentOptional() - .defaultValue(5) - .reconfigurable() - .commit(), - - DOUBLE_ELEMENT(expected) - .key("corrections.commonMode.minFrac") - .tags("managed") - .assignmentOptional() - .defaultValue(0.25) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.commonMode.enableRow") - .tags("managed") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.commonMode.enableCol") - .tags("managed") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.commonMode.enableBlock") - .tags("managed") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - ) - Epix100CalcatFriend.add_schema(expected) - # TODO: bad pixel node? - - @property - def input_data_shape(self): - # TODO: check + def _get_data_from_hash(self, data_hash): + image_data = data_hash.get(self._image_data_path) return ( - self.unsafe_get("dataFormat.pixelsX"), - self.unsafe_get("dataFormat.pixelsY"), - ) - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, # will be None - pulse_table, # ditto - ): - self.kernel_runner.load_data(image_data) - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - args_which_should_be_cached = dict( - cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"), - cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"), - cm_row=self.unsafe_get("corrections.commonMode.enableRow"), - cm_col=self.unsafe_get("corrections.commonMode.enableCol"), - cm_block=self.unsafe_get("corrections.commonMode.enableBlock"), + 1, + image_data, + None, + None, ) - self.kernel_runner.correct( - flags=self._correction_flag_enabled, **args_which_should_be_cached - ) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct( - flags=self._correction_flag_preview, - **args_which_should_be_cached, - ) - - if self._use_shmem_handles: - # TODO: consider better name for key... - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - self._write_output(data_hash, metadata) - - # note: base class preview machinery assumes burst mode, shortcut it - self._preview_friend.write_outputs( - metadata, *self.kernel_runner.preview_data_views - ) - - def _load_constant_to_runner(self, constant, constant_data): - self.kernel_runner.load_constant(constant, constant_data) diff --git a/src/calng/corrections/Gotthard2Correction.py b/src/calng/corrections/Gotthard2Correction.py index aca011ccdbdb16eaad0b1c747061164304968186..f8fa85bb6aa2684a9f33194bf7bdcbc77fd034ca 100644 --- a/src/calng/corrections/Gotthard2Correction.py +++ b/src/calng/corrections/Gotthard2Correction.py @@ -6,17 +6,18 @@ from karabo.bound import ( KARABO_CLASSINFO, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, - Hash, - Timestamp, ) -from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils +from .. import ( + base_calcat, + base_correction, + base_kernel_runner, + schemas, +) +from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec from .._version import version as deviceVersion -_pretend_pulse_table = np.arange(2720, dtype=np.uint8) - - class Constants(enum.Enum): LUTGotthard2 = enum.auto() OffsetGotthard2 = enum.auto() @@ -25,7 +26,7 @@ class Constants(enum.Enum): BadPixelsFFGotthard2 = enum.auto() -bp_constant_types = { +bad_pixel_constants = { Constants.BadPixelsDarkGotthard2, Constants.BadPixelsFFGotthard2, } @@ -39,106 +40,130 @@ class CorrectionFlags(enum.IntFlag): BPMASK = 8 -class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner): - _gpu_based = False - _xp = np +correction_steps = ( + ( + "lut", + CorrectionFlags.LUT, + {Constants.LUTGotthard2}, + ), + ( + "offset", + CorrectionFlags.OFFSET, + {Constants.OffsetGotthard2}, + ), + ( + "gain", + CorrectionFlags.GAIN, + {Constants.RelativeGainGotthard2}, + ), + ( + "badPixels", + CorrectionFlags.BPMASK, + { + Constants.BadPixelsDarkGotthard2, + Constants.BadPixelsFFGotthard2, + } + ), +) - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - bad_pixel_subset=None, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - from ..kernels import gotthard2_cython +def framesums_reduction_fun(data, cell_table, pulse_table, warn_fun): + return np.nansum(data, axis=1) - self.correction_kernel = gotthard2_cython.correct - self.input_shape = (frames, pixels_x) - self.processed_shape = self.input_shape + +class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner): + _bad_pixel_constants = bad_pixel_constants + _correction_steps = correction_steps + _correction_flag_class = CorrectionFlags + _xp = np + + num_pixels_ss = None # 1d detector, should never access this attribute + num_pixels_fs = 1280 + + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) + + def expected_input_shape(self, num_frames): + return (num_frames, self.num_pixels_fs) + + def _expected_output_shape(self, num_frames): + return (num_frames, self.num_pixels_fs) + + def _make_output_buffers(self, num_frames, flags): + return [ + self._xp.empty( + (num_frames, self.num_pixels_fs), dtype=np.float32) + ] + + def _preview_data_views(self, raw_data, raw_gain, processed_data): + return [ + raw_data, # raw 1d + raw_gain, # gain 1d + processed_data, # corrected 1d + processed_data, # frame sums + raw_data, # raw streak (2d) + raw_gain, # gain streak + processed_data, # corrected streak + ] + + def _setup_constant_buffers(self): # model: 2 buffers (corresponding to actual memory cells), 2720 frames # lut maps from uint12 to uint10 values - self.lut_shape = (2, 4096, pixels_x) - self.map_shape = (3, self.constant_memory_cells, self.pixels_x) + # TODO: override superclass properties instead + self.lut_shape = (2, 4096, self.num_pixels_fs) + self.map_shape = (3, self._constant_memory_cells, self.num_pixels_fs) self.lut = np.empty(self.lut_shape, dtype=np.uint16) self.offset_map = np.empty(self.map_shape, dtype=np.float32) self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32) self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32) - self.bad_pixel_mask_value = bad_pixel_mask_value - self.flush_buffers() - self._bad_pixel_subset = bad_pixel_subset + self.flush_buffers(set(Constants)) - self.input_data = None # will just point to data coming in - self.input_gain_stage = None # will just point to data coming in - self.processed_data = None # will just point to buffer we're given - - @property - def preview_data_views(self): - return (self.input_data, self.input_gain_stage, self.processed_data) + def _post_init(self): + from ..kernels import gotthard2_cython + self.correction_kernel = gotthard2_cython.correct - def load_constant(self, constant_type, data): + def _load_constant(self, constant_type, data): if constant_type is Constants.LUTGotthard2: self.lut[:] = np.transpose(data.astype(np.uint16, copy=False), (1, 2, 0)) elif constant_type is Constants.OffsetGotthard2: self.offset_map[:] = np.transpose(data.astype(np.float32, copy=False)) elif constant_type is Constants.RelativeGainGotthard2: self.rel_gain_map[:] = np.transpose(data.astype(np.float32, copy=False)) - elif constant_type in bp_constant_types: + elif constant_type in bad_pixel_constants: # TODO: add the regular bad pixel subset configuration - data = data.astype(np.uint32, copy=False) - if self._bad_pixel_subset is not None: - data = data & self._bad_pixel_subset - self.bad_pixel_map |= np.transpose(data) + self.bad_pixel_map |= (np.transpose(data) & self.bad_pixel_subset) else: raise ValueError(f"What is this constant '{constant_type}'?") - def set_bad_pixel_subset(self, subset, apply_now=True): - self._bad_pixel_subset = subset - if apply_now: - self.bad_pixel_map &= self._bad_pixel_subset - - def load_data(self, image_data, input_gain_stage): - """Experiment: loading both in one function as they are tied""" - self.input_data = image_data.astype(np.uint16, copy=False) - self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False) - - def flush_buffers(self, constants=None): - # for now, always flushing all here... - default_lut = ( - np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12 - ).astype(np.uint16) - self.lut[:] = np.stack([np.stack([default_lut] * 2)] * self.pixels_x, axis=2) - self.offset_map.fill(0) - self.rel_gain_map.fill(1) - self.bad_pixel_map.fill(0) - - def correct(self, flags, out=None): - if out is None: - out = np.empty(self.processed_shape, dtype=np.float32) - + def flush_buffers(self, constants): + if Constants.LUTGotthard2 in constants: + default_lut = ( + np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12 + ).astype(np.uint16) + self.lut[:] = np.stack( + [np.stack([default_lut] * 2)] * self.num_pixels_fs, axis=2 + ) + if Constants.OffsetGotthard2 in constants: + self.offset_map.fill(0) + if Constants.RelativeGainGotthard2 in constants: + self.rel_gain_map.fill(1) + if bad_pixel_constants & constants: + self.bad_pixel_map.fill(0) + + def _correct(self, flags, raw_data, cell_table, gain_stages, processed_data): self.correction_kernel( - self.input_data, - self.input_gain_stage, - np.uint8(flags), + raw_data, + gain_stages, + flags, self.lut, self.offset_map, self.rel_gain_map, self.bad_pixel_map, self.bad_pixel_mask_value, - out, + processed_data, ) - self.processed_data = out - return out class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend): @@ -227,236 +252,109 @@ class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("Gotthard2Correction", deviceVersion) class Gotthard2Correction(base_correction.BaseCorrection): - _base_output_schema = schemas.jf_output_schema(use_shmem_handle=False) - _correction_flag_class = CorrectionFlags - _correction_steps = ( - ( - "lut", - CorrectionFlags.LUT, - {Constants.LUTGotthard2}, - ), - ( - "offset", - CorrectionFlags.OFFSET, - {Constants.OffsetGotthard2}, - ), - ( - "gain", - CorrectionFlags.GAIN, - {Constants.RelativeGainGotthard2}, - ), - ( - "badPixels", - CorrectionFlags.BPMASK, - { - Constants.BadPixelsDarkGotthard2, - Constants.BadPixelsFFGotthard2, - } - ), - ) + _base_output_schema = schemas.jf_output_schema _kernel_runner_class = Gotthard2CpuRunner _calcat_friend_class = Gotthard2CalcatFriend + _correction_steps = correction_steps _constant_enum_class = Constants + _image_data_path = "data.adc" - _cell_table_path = "data.memoryCell" + _cell_table_path = None _pulse_table_path = None + _warn_memory_cell_range = False # for now, receiver always writes 255 - _preview_outputs = ["outputStreak", "outputGainStreak"] _cuda_pin_buffers = False + _use_shmem_handles = False + + _preview_outputs = [ + PreviewSpec( + "raw", + dimensions=1, + frame_reduction=True, + wrap_in_imagedata=False, + valid_frame_selection_modes=[FrameSelectionMode.FRAME], + ), + PreviewSpec( + "gain", + dimensions=1, + frame_reduction=True, + wrap_in_imagedata=False, + valid_frame_selection_modes=[FrameSelectionMode.FRAME], + ), + PreviewSpec( + "corrected", + dimensions=1, + frame_reduction=True, + wrap_in_imagedata=False, + valid_frame_selection_modes=[FrameSelectionMode.FRAME], + ), + PreviewSpec( + "framesums", + dimensions=1, + frame_reduction=False, + frame_reduction_fun=framesums_reduction_fun, + wrap_in_imagedata=False, + ), + PreviewSpec( + "rawStreak", + dimensions=2, + frame_reduction=False, + wrap_in_imagedata=True, + ), + PreviewSpec( + "gainStreak", + dimensions=2, + frame_reduction=False, + wrap_in_imagedata=True, + ), + PreviewSpec( + "correctedStreak", + dimensions=2, + frame_reduction=False, + wrap_in_imagedata=True, + ), + ] - @staticmethod - def expectedParameters(expected): - super(Gotthard2Correction, Gotthard2Correction).expectedParameters(expected) + @classmethod + def expectedParameters(cls, expected): + cls._calcat_friend_class.add_schema(expected) + cls._kernel_runner_class.add_schema(expected) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(Gotthard2Correction._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(1280) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(2720) # note: actually just frames... - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("frame") + .dataSchema( + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) + ) .commit(), OVERWRITE_ELEMENT(expected) .key("outputShmemBufferSize") .setNewDefaultValue(2) .commit(), - ) - - base_correction.add_preview_outputs( - expected, Gotthard2Correction._preview_outputs - ) - for channel in ("outputRaw", "outputGain", "outputCorrected", "outputFrameSums"): - # add these "manually" as the automated bits wrap ImageData - ( - OUTPUT_CHANNEL(expected) - .key(f"preview.{channel}") - .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=False)) - .commit(), - ) - base_correction.add_correction_step_schema( - expected, - Gotthard2Correction._correction_steps, - ) - base_correction.add_bad_pixel_config_node(expected) - Gotthard2CalcatFriend.add_schema(expected) - @property - def input_data_shape(self): - return ( - self.unsafe_get("dataFormat.frames"), - self.unsafe_get("dataFormat.pixelsX"), + OVERWRITE_ELEMENT(expected) + .key("useShmemHandles") + .setNewDefaultValue(cls._use_shmem_handles) + .commit(), ) - @property - def output_data_shape(self): + def _get_data_from_hash(self, data_hash): + image_data = np.asarray( + data_hash.get(self._image_data_path) + ).astype(np.uint16, copy=False) + gain_data = np.asarray( + data_hash.get("data.gain") + ).astype(np.uint8, copy=False) + num_frames = image_data.shape[0] + if self.unsafe_get("workarounds.overrideInputAxisOrder"): + expected_shape = self.kernel_runner.expected_input_shape(num_frames) + image_data.shape = expected_shape + gain_data.shape = expected_shape return ( - self.unsafe_get("dataFormat.frames"), - self.unsafe_get("dataFormat.pixelsX"), + num_frames, + image_data, + None, # no cell table + None, # no pulse table + gain_data, ) - - @property - def _kernel_runner_init_args(self): - return { - "bad_pixel_mask_value": self._bad_pixel_mask_value, - "bad_pixel_subset": self._bad_pixel_subset, - } - - @property - def _bad_pixel_mask_value(self): - return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) - - _bad_pixel_subset = property(base_correction.get_bad_pixel_field_selection) - - def postReconfigure(self): - super().postReconfigure() - update = self._prereconfigure_update_hash - if update.has("corrections.badPixels.subsetToUse"): - if any( - update.get( - f"corrections.badPixels.subsetToUse.{field.name}", default=False - ) - for field in utils.BadPixelValues - ): - self.log_status_info( - "Some fields reenabled, reloading cached bad pixel constants" - ) - self.kernel_runner.set_bad_pixel_subset( - self._bad_pixel_subset, apply_now=False - ) - with self.calcat_friend.cached_constants_lock: - self.kernel_runner.flush_buffers(bp_constant_types) - for constant_type in ( - bp_constant_types & self.calcat_friend.cached_constants.keys() - ): - self._load_constant_to_runner( - constant_type, - self.calcat_friend.cached_constants[constant_type], - ) - else: - # just narrowing the subset - no reload, just AND - self.kernel_runner.set_bad_pixel_subset( - self._bad_pixel_subset, apply_now=True - ) - if update.has("corrections.badPixels.maskingvalue"): - self.kernel_runner.bad_pixel_mask_value = self._bad_pixel_mask_value - - def __init__(self, config): - super().__init__(config) - try: - np.float32(config.get("corrections.badPixels.maskingValue")) - except ValueError: - config["corrections.badPixels.maskingValue"] = "nan" - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, - ): - # cell table currently not used for GOTTHARD2 (assume alternating) - gain_map = np.asarray(data_hash.get("data.gain")) - if self.unsafe_get("workarounds.overrideInputAxisOrder"): - gain_map.shape = self.input_data_shape - try: - self.kernel_runner.load_data(image_data, gain_map) - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data: {e}") - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.kernel_runner.correct(self._correction_flag_enabled, out=buffer_array) - - with self.warning_context( - "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS - ) as warn: - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ), preview_warning = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - _pretend_pulse_table, - ) - if preview_warning is not None: - warn(preview_warning) - ( - preview_raw, - preview_gain, - preview_corrected, - ) = self.kernel_runner.compute_previews(preview_slice_index) - - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - self._write_output(data_hash, metadata) - - frame_sums = np.nansum(buffer_array, axis=1) - timestamp = Timestamp.fromHashAttributes(metadata.getAttributes("timestamp")) - for channel, data in ( - ("outputRaw", preview_raw), - ("outputGain", preview_gain), - ("outputCorrected", preview_corrected), - ("outputFrameSums", frame_sums), - ): - self.writeChannel( - f"preview.{channel}", - Hash( - "image.data", - data, - "image.mask", - (~np.isfinite(data)).astype(np.uint8), - ), - timestamp=timestamp, - ) - self._preview_friend.write_outputs(metadata, buffer_array, gain_map) - - def _load_constant_to_runner(self, constant, constant_data): - self.kernel_runner.load_constant(constant, constant_data) diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py index ab5af5f688fa74c32ab02195fffde302cf305ef9..c51cbe0865be0c9da5034f11785a70f219edf754 100644 --- a/src/calng/corrections/JungfrauCorrection.py +++ b/src/calng/corrections/JungfrauCorrection.py @@ -9,7 +9,6 @@ from karabo.bound import ( OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, - Schema, ) from .. import ( @@ -19,12 +18,10 @@ from .. import ( schemas, utils, ) +from ..preview_utils import PreviewFriend, FrameSelectionMode, PreviewSpec from .._version import version as deviceVersion -_pretend_pulse_table = np.arange(16, dtype=np.uint8) - - class Constants(enum.Enum): Offset10Hz = enum.auto() BadPixelsDark10Hz = enum.auto() @@ -57,66 +54,95 @@ class CorrectionFlags(enum.IntFlag): STRIXEL = 8 -class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): - _corrected_axis_order = "fyx" - _xp = None # subclass sets CuPy (import at runtime) or numpy +correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}), + ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}), + ( + "badPixels", + CorrectionFlags.BPMASK, + { + Constants.BadPixelsDark10Hz, + Constants.BadPixelsFF10Hz, + None, + }, + ), + ("strixel", CorrectionFlags.STRIXEL, set()), +) - @property - def input_shape(self): - return self.frames, self.pixels_y, self.pixels_x - @property - def processed_shape(self): - return self.input_shape +class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): + _bad_pixel_constants = bad_pixel_constants + _correction_flag_class = CorrectionFlags + _correction_steps = correction_steps + num_pixels_ss = 512 + num_pixels_fs = 1024 @property - def map_shape(self): - return (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3) - - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - config, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - ): - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) + def _map_shape(self): + return (self._constant_memory_cells, self.num_pixels_ss, self.num_pixels_fs, 3) + + def _setup_constant_buffers(self): # note: superclass creates cell table with wrong dtype - self.input_gain_stage = self._xp.empty(self.input_shape, dtype=np.uint8) - self.offset_map = self._xp.empty(self.map_shape, dtype=np.float32) - self.rel_gain_map = self._xp.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = self._xp.empty(self.map_shape, dtype=np.uint32) - self.bad_pixel_mask_value = bad_pixel_mask_value - - self.output_dtype = output_data_dtype - self._processed_data_regular = self._xp.empty( - self.processed_shape, dtype=output_data_dtype - ) - self._processed_data_strixel = None + self.offset_map = self._xp.empty(self._map_shape, dtype=np.float32) + self.rel_gain_map = self._xp.empty(self._map_shape, dtype=np.float32) + self.bad_pixel_map = self._xp.empty(self._map_shape, dtype=np.uint32) self.flush_buffers(set(Constants)) - self.correction_kernel_strixel = None - self.reconfigure(config) + + def _make_output_buffers(self, num_frames, flags): + if flags & CorrectionFlags.STRIXEL: + return [ + self._xp.empty( + (num_frames,) + self._strixel_out_shape, dtype=self._output_dtype + ) + ] + else: + return [ + self._xp.empty( + (num_frames, self.num_pixels_ss, self.num_pixels_fs), + dtype=self._output_dtype, + ) + ] + + def _expected_output_shape(self, num_frames): + # TODO: think of how to unify with _make_output_buffers + if self._correction_flag_enabled & CorrectionFlags.STRIXEL: + return (num_frames,) + self._strixel_out_shape + else: + return (num_frames, self.num_pixels_ss, self.num_pixels_fs) + + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) + ( + OVERWRITE_ELEMENT(schema) + .key("corrections.strixel.enable") + .setNewDefaultValue(False) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key("corrections.strixel.preview") + .setNewDefaultValue(False) + .commit(), + + STRING_ELEMENT(schema) + .key("corrections.strixel.type") + .description( + "Which kind of strixel layout is used for this module? cols_A0123 is " + "the first strixel layout deployed at HED and rows_A1256 is the one " + "later deployed at SCS." + ) + .assignmentOptional() + .defaultValue("cols_A0123(HED-type)") + .options("cols_A0123(HED-type),rows_A1256(SCS-type)") + .reconfigurable() + .commit(), + ) def reconfigure(self, config): - # note: regular bad pixel masking uses device property (TODO: unify) - if (mask_value := config.get("badPixels.maskingValue")) is not None: - self._bad_pixel_mask_value = self._xp.float32(mask_value) - # this is a functools.partial, can just update the captured parameter - if self.correction_kernel_strixel is not None: - self.correction_kernel_strixel.keywords[ - "missing" - ] = self._bad_pixel_mask_value - - if (strixel_type := config.get("strixel.type")) is not None: + super().reconfigure(config) + + if (strixel_type := config.get("corrections.strixel.type")) is not None: # drop the friendly parenthesized name strixel_type = strixel_type.partition("(")[0] strixel_package = np.load( @@ -124,18 +150,17 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): / f"strixel_{strixel_type}-lut_mask.npz" ) self._strixel_out_shape = tuple(strixel_package["frame_shape"]) - self._processed_data_strixel = None # TODO: use bad pixel masking config here self.correction_kernel_strixel = functools.partial( utils.apply_partial_lut, lut=self._xp.asarray(strixel_package["lut"]), mask=self._xp.asarray(strixel_package["mask"]), - missing=self._bad_pixel_mask_value, + missing=self.bad_pixel_mask_value, ) # note: always masking unmapped pixels (not respecting NON_STANDARD_SIZE) - def load_constant(self, constant, constant_data): - if constant_data.shape[0] == self.pixels_x: + def _load_constant(self, constant, constant_data): + if constant_data.shape[0] == self.num_pixels_fs: constant_data = np.transpose(constant_data, (2, 1, 0, 3)) else: constant_data = np.transpose(constant_data, (2, 0, 1, 3)) @@ -145,17 +170,14 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): elif constant is Constants.RelativeGain10Hz: self.rel_gain_map[:] = self._xp.asarray(constant_data, dtype=np.float32) elif constant in bad_pixel_constants: - self.bad_pixel_map |= self._xp.asarray(constant_data, dtype=np.uint32) + self.bad_pixel_map |= self._xp.asarray( + constant_data, dtype=np.uint32 + ) & self._xp.uint32(self.bad_pixel_subset) else: raise ValueError(f"Unexpected constant type {constant}") - @property - def preview_data_views(self): - return (self.input_data, self.processed_data, self.input_gain_stage) - - @property - def burst_mode(self): - return self.frames > 1 + def _preview_data_views(self, raw_data, gain_map, processed_data): + return (raw_data, gain_map, processed_data) def flush_buffers(self, constants): if Constants.Offset10Hz in constants: @@ -167,106 +189,80 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): Constants.BadPixelsFF10Hz, }: self.bad_pixel_map.fill(0) - self.bad_pixel_map[ - :, :, 255:1023:256 - ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value - self.bad_pixel_map[ - :, :, 256:1024:256 - ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value - self.bad_pixel_map[ - :, [255, 256] - ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + if self.bad_pixel_subset & utils.BadPixelValues.NON_STANDARD_SIZE: + self._mask_asic_seams() - def override_bad_pixel_flags_to_use(self, override_value): - self.bad_pixel_map &= self._xp.uint32(override_value) + def _mask_asic_seams(self): + self.bad_pixel_map[ + :, :, 255:1023:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map[ + :, :, 256:1024:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map[ + :, [255, 256] + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value class JungfrauGpuRunner(JungfrauBaseRunner): - _gpu_based = True - def _pre_init(self): import cupy as cp self._xp = cp def _post_init(self): - self.input_data = self._xp.empty(self.input_shape, dtype=np.uint16) - self.cell_table = self._xp.empty(self.frames, dtype=np.uint8) - self.block_shape = (1, 1, 64) - self.grid_shape = utils.grid_to_cover_shape_with_blocks( - self.input_shape, self.block_shape - ) - source_module = self._xp.RawModule( code=base_kernel_runner.get_kernel_template("jungfrau_gpu.cu").render( { - "pixels_x": self.pixels_x, - "pixels_y": self.pixels_y, - "frames": self.frames, - "constant_memory_cells": self.constant_memory_cells, "output_data_dtype": utils.np_dtype_to_c_type( - self.output_data_dtype + self._output_dtype ), "corr_enum": utils.enum_to_c_template(CorrectionFlags), - "burst_mode": self.burst_mode, } ) ) self.correction_kernel = source_module.get_function("correct") - def load_data(self, image_data, input_gain_stage, cell_table): - self.input_data.set(image_data) - self.input_gain_stage.set(input_gain_stage) - if self.burst_mode: - self.cell_table.set(cell_table) - - def correct(self, flags): + def _correct(self, flags, image_data, cell_table, gain_map, output): + num_frames = image_data.shape[0] + block_shape = (1, 1, 64) + grid_shape = utils.grid_to_cover_shape_with_blocks( + output.shape, block_shape + ) + if flags & CorrectionFlags.STRIXEL: + # do "regular" correction into "regular" buffer, strixel transform to output + final_output = output + output = self._xp.empty_like(image_data, dtype=np.float32) self.correction_kernel( - self.grid_shape, - self.block_shape, + grid_shape, + block_shape, ( - self.input_data, - self.input_gain_stage, - self.cell_table, + image_data, + gain_map, + cell_table, self._xp.uint8(flags), + self._xp.uint16(num_frames), + self._xp.uint16(self._constant_memory_cells), + self._xp.uint8(num_frames > 1), self.offset_map, self.rel_gain_map, self.bad_pixel_map, self.bad_pixel_mask_value, - self._processed_data_regular, + output, ), ) if flags & CorrectionFlags.STRIXEL: - if ( - self._processed_data_strixel is None - or self.frames != self._processed_data_strixel.shape[0] - ): - self._processed_data_strixel = self._xp.empty( - (self.frames,) + self._strixel_out_shape, - dtype=self.output_dtype, - ) - for pixel_frame, strixel_frame in zip( - self._processed_data_regular, self._processed_data_strixel - ): + for pixel_frame, strixel_frame in zip(output, final_output): self.correction_kernel_strixel( data=pixel_frame, out=strixel_frame, ) - self.processed_data = self._processed_data_strixel - else: - self.processed_data = self._processed_data_regular class JungfrauCpuRunner(JungfrauBaseRunner): - _gpu_based = False _xp = np - def _post_init(self): - # not actually allocating, will just point to incoming data - self.input_data = None - self.input_gain_stage = None - self.input_cell_table = None - + def _pre_init(self): # for computing previews faster self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16) @@ -278,45 +274,39 @@ class JungfrauCpuRunner(JungfrauBaseRunner): def __del__(self): self.thread_pool.shutdown() - def load_data(self, image_data, input_gain_stage, cell_table): - self.input_data = image_data.astype(np.uint16, copy=False) - self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False) - self.input_cell_table = cell_table.astype(np.uint8, copy=False) + def _correct(self, flags, image_data, cell_table, gain_map, output): + if flags & CorrectionFlags.STRIXEL: + # do "regular" correction into "regular" buffer, strixel transform to output + final_output = output + output = self._xp.empty_like(image_data, dtype=np.float32) - def correct(self, flags): - if self.burst_mode: + # burst mode + if image_data.shape[0] > 1: self.correction_kernel_burst( - self.input_data, - self.input_gain_stage, - self.input_cell_table, + image_data, + gain_map, + cell_table, flags, self.offset_map, self.rel_gain_map, self.bad_pixel_map, self.bad_pixel_mask_value, - self._processed_data_regular, + output, ) else: self.correction_kernel_single( - self.input_data, - self.input_gain_stage, + image_data, + gain_map, flags, self.offset_map, self.rel_gain_map, self.bad_pixel_map, self.bad_pixel_mask_value, - self._processed_data_regular, + output, ) if flags & CorrectionFlags.STRIXEL: - if ( - self._processed_data_strixel is None - or self.frames != self._processed_data_strixel.shape[0] - ): - self._processed_data_strixel = self._xp.empty( - (self.frames,) + self._strixel_out_shape, - dtype=self.output_dtype, - ) + # now do strixel transformation into actual output buffer concurrent.futures.wait( [ self.thread_pool.submit( @@ -324,40 +314,9 @@ class JungfrauCpuRunner(JungfrauBaseRunner): data=pixel_frame, out=strixel_frame, ) - for pixel_frame, strixel_frame in zip( - self._processed_data_regular, self._processed_data_strixel - ) + for pixel_frame, strixel_frame in zip(output, final_output) ] ) - self.processed_data = self._processed_data_strixel - else: - self.processed_data = self._processed_data_regular - - def compute_previews(self, preview_index): - if preview_index < -4: - raise ValueError(f"No statistic with code {preview_index} defined") - elif preview_index >= self.frames: - raise ValueError(f"Memory cell index {preview_index} out of range") - - if preview_index >= 0: - - def fun(a): - return a[preview_index] - - elif preview_index == -1: - # note: separate from next case because dtype not applicable here - fun = functools.partial(np.nanmax, axis=0) - elif preview_index in (-2, -3, -4): - fun = functools.partial( - { - -2: np.nanmean, - -3: np.nansum, - -4: np.nanstd, - }[preview_index], - axis=0, - dtype=np.float32, - ) - return self.thread_pool.map(fun, self.preview_data_views) class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend): @@ -475,59 +434,40 @@ class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("JungfrauCorrection", deviceVersion) class JungfrauCorrection(base_correction.BaseCorrection): - _base_output_schema = schemas.jf_output_schema() - _correction_flag_class = CorrectionFlags - _correction_steps = ( - ("offset", CorrectionFlags.OFFSET, {Constants.Offset10Hz}), - ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain10Hz}), - ( - "badPixels", - CorrectionFlags.BPMASK, - { - Constants.BadPixelsDark10Hz, - Constants.BadPixelsFF10Hz, - None, - }, - ), - ("strixel", CorrectionFlags.STRIXEL, set()), - ) + _base_output_schema = schemas.jf_output_schema + _correction_steps = correction_steps _calcat_friend_class = JungfrauCalcatFriend _constant_enum_class = Constants - _preview_outputs = [ - "outputRaw", - "outputCorrected", - "outputGainMap", - ] + _image_data_path = "data.adc" _cell_table_path = "data.memoryCell" _pulse_table_path = None - @staticmethod - def expectedParameters(expected): + _preview_outputs = [ + PreviewSpec( + name, + dimensions=2, + frame_reduction=True, + valid_frame_selection_modes=( + FrameSelectionMode.FRAME, + FrameSelectionMode.CELL, + ), + ) for name in ("raw", "gain", "corrected") + ] + _warn_memory_cell_range = False + _use_shmem_handles = False + + @classmethod + def expectedParameters(cls, expected): + cls._calcat_friend_class.add_schema(expected) + JungfrauBaseRunner.add_schema(expected) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(JungfrauCorrection._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(1024) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(512) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("frame") + .dataSchema( + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) + ) .commit(), # JUNGFRAU data is small, can fit plenty of trains in here @@ -535,6 +475,11 @@ class JungfrauCorrection(base_correction.BaseCorrection): .key("outputShmemBufferSize") .setNewDefaultValue(2) .commit(), + + OVERWRITE_ELEMENT(expected) + .key("useShmemHandles") + .setNewDefaultValue(cls._use_shmem_handles) + .commit(), ) ( # support both CPU and GPU kernels @@ -553,53 +498,30 @@ class JungfrauCorrection(base_correction.BaseCorrection): .commit(), ) - base_correction.add_preview_outputs( - expected, JungfrauCorrection._preview_outputs - ) - base_correction.add_correction_step_schema( - expected, - JungfrauCorrection._correction_steps, - ) - ( - OVERWRITE_ELEMENT(expected) - .key("corrections.strixel.enable") - .setNewDefaultValue(False) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("corrections.strixel.preview") - .setNewDefaultValue(False) - .commit(), - - STRING_ELEMENT(expected) - .key("corrections.strixel.type") - .description( - "Which kind of strixel layout is used for this module? cols_A0123 is " - "the first strixel layout deployed at HED and rows_A1256 is the one " - "later deployed at SCS." + def _get_data_from_hash(self, data_hash): + image_data = data_hash.get(self._image_data_path) + gain_data = np.asarray(data_hash.get("data.gain")).astype(np.uint8, copy=False) + # explicit np.asarray because types are special here + cell_table = np.asarray( + data_hash[self._cell_table_path] + ).astype(np.uint8, copy=False) + num_frames = cell_table.size + if self.unsafe_get("workarounds.overrideInputAxisOrder"): + expected_shape = self.kernel_runner.expected_input_shape( + num_frames ) - .assignmentOptional() - .defaultValue("cols_A0123(HED-type)") - .options("cols_A0123(HED-type),rows_A1256(SCS-type)") - .reconfigurable() - .commit(), - ) - base_correction.add_bad_pixel_config_node(expected) - JungfrauCalcatFriend.add_schema(expected) + if expected_shape != image_data.shape: + image_data.shape = expected_shape + gain_data.shape = expected_shape - @property - def input_data_shape(self): return ( - self.unsafe_get("dataFormat.frames"), - self.unsafe_get("dataFormat.pixelsY"), - self.unsafe_get("dataFormat.pixelsX"), + num_frames, + image_data, # not as ndarray as runner can do that + cell_table, + None, + gain_data, ) - @property - def _warn_memory_cell_range(self): - # disable warning in normal operation as cell 15 is expected - return self.unsafe_get("dataFormat.frames") > 1 - @property def output_data_shape(self): if self._correction_flag_enabled & CorrectionFlags.STRIXEL: @@ -623,134 +545,3 @@ class JungfrauCorrection(base_correction.BaseCorrection): return JungfrauCpuRunner else: return JungfrauGpuRunner - - @property - def _kernel_runner_init_args(self): - return { - "bad_pixel_mask_value": self.bad_pixel_mask_value, - # temporary: will refactor base class to always pass config node - "config": self.get("corrections"), - } - - @property - def bad_pixel_mask_value(self): - return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) - - _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection) - - def __init__(self, config): - super().__init__(config) - try: - np.float32(config.get("corrections.badPixels.maskingValue", default="nan")) - except ValueError: - config["corrections.badPixels.maskingValue"] = "nan" - - if config.get("useShmemHandles", default=True): - schema_override = Schema() - ( - OUTPUT_CHANNEL(schema_override) - .key("dataOutput") - .dataSchema(schemas.jf_output_schema(use_shmem_handle=False)) - .commit(), - ) - self.updateSchema(schema_override) - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, - ): - if len(cell_table.shape) == 0: - cell_table = cell_table[np.newaxis] - try: - gain_map = data_hash["data.gain"] - if self.unsafe_get("workarounds.overrideInputAxisOrder"): - gain_map.shape = self.input_data_shape - self.kernel_runner.load_data(image_data, gain_map, cell_table) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - return - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.kernel_runner.correct(self._correction_flag_enabled) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - - with self.warning_context( - "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS - ) as warn: - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ), preview_warning = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - _pretend_pulse_table, - ) - if preview_warning is not None: - warn(preview_warning) - ( - preview_raw, - preview_corrected, - preview_gain_map, - ) = self.kernel_runner.compute_previews(preview_slice_index) - - # reusing input data hash for sending - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - # TODO: use shmem for data.gain, too - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - self._write_output(data_hash, metadata) - - self._preview_friend.write_outputs( - metadata, preview_raw, preview_corrected, preview_gain_map - ) - - def _load_constant_to_runner(self, constant, constant_data): - if constant in bad_pixel_constants: - constant_data &= self._override_bad_pixel_flags - self.kernel_runner.load_constant(constant, constant_data) - - def preReconfigure(self, config): - super().preReconfigure(config) - if config.has("corrections"): - self.kernel_runner.reconfigure(config["corrections"]) - - def postReconfigure(self): - super().postReconfigure() - - if not hasattr(self, "_prereconfigure_update_hash"): - return - - update = self._prereconfigure_update_hash - - if update.has("corrections.strixel.enable"): - self._lock_and_update(self._update_frame_filter) - - if update.has("corrections.badPixels.subsetToUse"): - self.log_status_info("Updating bad pixel maps based on subset specified") - # note: now just always reloading from cache for convenience - with self.calcat_friend.cached_constants_lock: - self.kernel_runner.flush_buffers(bad_pixel_constants) - for constant in bad_pixel_constants: - if constant in self.calcat_friend.cached_constants: - self._load_constant_to_runner( - constant, self.calcat_friend.cached_constants[constant] - ) diff --git a/src/calng/corrections/LpdCorrection.py b/src/calng/corrections/LpdCorrection.py index 98e81e1e5e97bcc19f0c84dfe9c0f65635815282..325f5b2193de7abe0190b99b8ee26ced09358283 100644 --- a/src/calng/corrections/LpdCorrection.py +++ b/src/calng/corrections/LpdCorrection.py @@ -8,7 +8,6 @@ from karabo.bound import ( OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, - Schema, ) from .. import ( @@ -18,6 +17,7 @@ from .. import ( schemas, utils, ) +from ..preview_utils import PreviewFriend, PreviewSpec from .._version import version as deviceVersion @@ -30,6 +30,12 @@ class Constants(enum.Enum): BadPixelsFF = enum.auto() +bad_pixel_constants = { + Constants.BadPixelsDark, + Constants.BadPixelsFF, +} + + class CorrectionFlags(enum.IntFlag): NONE = 0 OFFSET = 1 @@ -39,130 +45,85 @@ class CorrectionFlags(enum.IntFlag): BPMASK = 16 -class LpdGpuRunner(base_kernel_runner.BaseKernelRunner): - _gpu_based = True - _corrected_axis_order = "fyx" - - @property - def input_shape(self): - return (self.frames, 1, self.pixels_y, self.pixels_x) +correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {Constants.Offset}), + ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}), + ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}), + ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}), + ("badPixels", CorrectionFlags.BPMASK, bad_pixel_constants), +) - @property - def processed_shape(self): - return (self.frames, self.pixels_y, self.pixels_x) - - def __init__( - self, - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype=np.float32, - bad_pixel_mask_value=np.nan, - ): - import cupy as cp - self._xp = cp - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - self.gain_map = cp.empty(self.processed_shape, dtype=np.float32) - self.input_data = cp.empty(self.input_shape, dtype=np.uint16) - self.processed_data = cp.empty(self.processed_shape, dtype=output_data_dtype) - self.cell_table = cp.empty(frames, dtype=np.uint16) - - self.map_shape = (constant_memory_cells, pixels_y, pixels_x, 3) - self.offset_map = cp.empty(self.map_shape, dtype=np.float32) - self.gain_amp_map = cp.empty(self.map_shape, dtype=np.float32) - self.rel_gain_slopes_map = cp.empty(self.map_shape, dtype=np.float32) - self.flatfield_map = cp.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = cp.empty(self.map_shape, dtype=np.uint32) - self.bad_pixel_mask_value = bad_pixel_mask_value - self.flush_buffers(set(Constants)) - self.correction_kernel = cp.RawModule( - code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render( - { - "pixels_x": self.pixels_x, - "pixels_y": self.pixels_y, - "frames": self.frames, - "constant_memory_cells": self.constant_memory_cells, - "output_data_dtype": utils.np_dtype_to_c_type( - self.output_data_dtype - ), - "corr_enum": utils.enum_to_c_template(CorrectionFlags), - } - ) - ).get_function("correct") +class LpdBaseRunner(base_kernel_runner.BaseKernelRunner): + _bad_pixel_constants = bad_pixel_constants + _correction_flag_class = CorrectionFlags + _correction_steps = correction_steps + num_pixels_ss = 256 + num_pixels_fs = 256 - self.block_shape = (1, 1, 64) - self.grid_shape = utils.grid_to_cover_shape_with_blocks( - self.processed_shape, self.block_shape - ) + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) - def load_data(self, image_data, cell_table): - self.input_data.set(image_data) - self.cell_table.set(cell_table) + def expected_input_shape(self, num_frames): + return (num_frames, 1, self.num_pixels_ss, self.num_pixels_fs) @property - def preview_data_views(self): + def _gm_map_shape(self): + return self._map_shape + (3,) # for gain-mapped constants + + def _setup_constant_buffers(self): + self.offset_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + self.gain_amp_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + self.rel_gain_slopes_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + self.flatfield_map = self._xp.empty(self._gm_map_shape, dtype=np.float32) + self.bad_pixel_map = self._xp.empty(self._gm_map_shape, dtype=np.uint32) + self.flush_buffers(set(Constants)) + + def _make_output_buffers(self, num_frames, flags): + output_shape = (num_frames, self.num_pixels_ss, self.num_pixels_fs) + return [ + self._xp.empty(output_shape, dtype=self._xp.float32), # image + self._xp.empty(output_shape, dtype=self._xp.float32), # gain + ] + + def _preview_data_views(self, raw_data, processed_data, gain_map): # TODO: always split off gain from raw to avoid messing up preview? return ( - self.input_data[:, 0], # raw - self.processed_data, # corrected - self.gain_map, # gain (split from raw) + raw_data[:, 0], # raw + processed_data, # corrected + gain_map, # gain (split from raw) ) - def correct(self, flags): - self.correction_kernel( - self.grid_shape, - self.block_shape, - ( - self.input_data, - self.cell_table, - self._xp.uint8(flags), - self.offset_map, - self.gain_amp_map, - self.rel_gain_slopes_map, - self.flatfield_map, - self.bad_pixel_map, - self.bad_pixel_mask_value, - self.gain_map, - self.processed_data, - ), - ) - - def load_constant(self, constant_type, constant_data): - # constant type → transpose order - bad_pixel_loading = { - Constants.BadPixelsDark: (2, 1, 0, 3), - Constants.BadPixelsFF: (2, 0, 1, 3), - } + def _load_constant(self, constant_type, constant_data): # constant type → transpose order, GPU buffer - other_constant_loading = { + transpose_order, my_buffer = { + Constants.BadPixelsDark: ((2, 1, 0, 3), self.bad_pixel_map), + Constants.BadPixelsFF: ((2, 0, 1, 3), self.bad_pixel_map), Constants.Offset: ((2, 1, 0, 3), self.offset_map), Constants.GainAmpMap: ((2, 0, 1, 3), self.gain_amp_map), Constants.FFMap: ((2, 0, 1, 3), self.flatfield_map), Constants.RelativeGain: ((2, 0, 1, 3), self.rel_gain_slopes_map), - } - if constant_type in bad_pixel_loading: - self.bad_pixel_map |= self._xp.asarray( - constant_data.transpose(bad_pixel_loading[constant_type]), - dtype=np.uint32, - )[:self.constant_memory_cells] - elif constant_type in other_constant_loading: - transpose_order, gpu_buffer = other_constant_loading[constant_type] - gpu_buffer.set( - np.transpose( - constant_data.astype(np.float32), - transpose_order, - )[:self.constant_memory_cells] + }[constant_type] + + if constant_type in bad_pixel_constants: + my_buffer |= ( + self._xp.asarray( + constant_data.transpose(transpose_order), + dtype=np.uint32, + )[:self._constant_memory_cells] + & self._xp.uint32(self.bad_pixel_subset) ) else: - raise ValueError(f"Unhandled constant type {constant_type}") + my_buffer[:] = ( + self._xp.asarray( + self._xp.transpose( + constant_data.astype(np.float32), + transpose_order, + )[:self._constant_memory_cells] + ) + ) def flush_buffers(self, constants): if Constants.Offset in constants: @@ -173,10 +134,79 @@ class LpdGpuRunner(base_kernel_runner.BaseKernelRunner): self.rel_gain_slopes_map.fill(1) if Constants.FFMap in constants: self.flatfield_map.fill(1) - if constants & {Constants.BadPixelsDark, Constants.BadPixelsFF}: + if constants & bad_pixel_constants: self.bad_pixel_map.fill(0) +class LpdGpuRunner(LpdBaseRunner): + def _pre_init(self): + import cupy + self._xp = cupy + + def _post_init(self): + self.correction_kernel = self._xp.RawModule( + code=base_kernel_runner.get_kernel_template("lpd_gpu.cu").render( + { + "ss_dim": self.num_pixels_ss, + "fs_dim": self.num_pixels_fs, + "output_data_dtype": utils.np_dtype_to_c_type( + self._output_dtype + ), + "corr_enum": utils.enum_to_c_template(self._correction_flag_class), + } + ) + ).get_function("correct") + + def _correct(self, flags, image_data, cell_table, processed_data, gain_map): + num_frames = self._xp.uint16(image_data.shape[0]) + block_shape = (1, 1, 64) + grid_shape = utils.grid_to_cover_shape_with_blocks( + processed_data.shape, block_shape + ) + self.correction_kernel( + grid_shape, + block_shape, + ( + image_data, + cell_table, + flags, + num_frames, + self._constant_memory_cells, + self.offset_map, + self.gain_amp_map, + self.rel_gain_slopes_map, + self.flatfield_map, + self.bad_pixel_map, + self.bad_pixel_mask_value, + gain_map, + processed_data, + ), + ) + + +class LpdCpuRunner(LpdBaseRunner): + _xp = np + + def _post_init(self): + from ..kernels import lpd_cython + self.correction_kernel = lpd_cython.correct + + def _correct(self, flags, image_data, cell_table, processed_data, gain_map): + self.correction_kernel( + image_data, + cell_table, + flags, + self.offset_map, + self.gain_amp_map, + self.rel_gain_slopes_map, + self.flatfield_map, + self.bad_pixel_map, + self.bad_pixel_mask_value, + gain_map, + processed_data, + ) + + class LpdCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = Constants @@ -294,216 +324,25 @@ class LpdCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("LpdCorrection", deviceVersion) class LpdCorrection(base_correction.BaseCorrection): - _base_output_schema = schemas.xtdf_output_schema() - _correction_flag_class = CorrectionFlags - _correction_steps = ( - ("offset", CorrectionFlags.OFFSET, {Constants.Offset}), - ("gainAmp", CorrectionFlags.GAIN_AMP, {Constants.GainAmpMap}), - ("relGain", CorrectionFlags.REL_GAIN, {Constants.RelativeGain}), - ("flatfield", CorrectionFlags.FF_CORR, {Constants.FFMap}), - ( - "badPixels", - CorrectionFlags.BPMASK, - { - Constants.BadPixelsDark, - Constants.BadPixelsFF, - } - ), - ) + _base_output_schema = schemas.xtdf_output_schema + _correction_steps = correction_steps _kernel_runner_class = LpdGpuRunner _calcat_friend_class = LpdCalcatFriend _constant_enum_class = Constants - _preview_outputs = ["outputRaw", "outputCorrected", "outputGainMap"] - @staticmethod - def expectedParameters(expected): - ( - OUTPUT_CHANNEL(expected) - .key("dataOutput") - .dataSchema(LpdCorrection._base_output_schema) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(256) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(256) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(512) - .commit(), - - # TODO: determine ideal default - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("frame") - .commit(), - ) + _preview_outputs = [PreviewSpec(name) for name in ("raw", "corrected", "gainMap")] - # TODO: add bad pixel config node + @classmethod + def expectedParameters(cls, expected): LpdCalcatFriend.add_schema(expected) - base_correction.add_addon_nodes(expected, LpdCorrection) - base_correction.add_correction_step_schema( - expected, LpdCorrection._correction_steps - ) - base_correction.add_preview_outputs(expected, LpdCorrection._preview_outputs) - - # additional settings for correction steps + LpdBaseRunner.add_schema(expected) + base_correction.add_addon_nodes(expected, cls) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( - STRING_ELEMENT(expected) - .key("corrections.badPixels.maskingValue") - .tags("managed") - .displayedName("Bad pixel masking value") - .description( - "Any pixels masked by the bad pixel mask will have their value " - "replaced with this. Note that this parameter is to be interpreted as " - "a numpy.float32; use 'nan' to get NaN value." + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema( + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) ) - .assignmentOptional() - .defaultValue("nan") - .reconfigurable() .commit(), ) - - @property - def input_data_shape(self): - return ( - self.unsafe_get("dataFormat.frames"), - 1, - self.unsafe_get("dataFormat.pixelsY"), - self.unsafe_get("dataFormat.pixelsX"), - ) - - def __init__(self, config): - super().__init__(config) - try: - np.float32(config.get("corrections.badPixels.maskingValue")) - except ValueError: - config["corrections.badPixels.maskingValue"] = "nan" - - if config.get("useShmemHandles", default=False): - def aux(): - schema_override = Schema() - ( - OUTPUT_CHANNEL(schema_override) - .key("dataOutput") - .dataSchema(schemas.xtdf_output_schema(use_shmem_handle=False)) - .commit(), - ) - - self.registerInitialFunction(aux) - - def _load_constant_to_runner(self, constant, constant_data): - self.kernel_runner.load_constant(constant, constant_data) - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - pulse_table, - ): - with self.warning_context( - "processingState", - base_correction.WarningLampType.CONSTANT_OPERATING_PARAMETERS, - ) as warn: - if ( - cell_table_string := utils.cell_table_to_string(cell_table) - ) != self.unsafe_get( - "constantParameters.memoryCellOrder" - ) and self.unsafe_get( - "constantParameters.useMemoryCellOrder" - ): - warn(f"Cell order does not match input; input: {cell_table_string}") - if self._frame_filter is not None: - try: - cell_table = cell_table[self._frame_filter] - pulse_table = pulse_table[self._frame_filter] - image_data = image_data[self._frame_filter] - except IndexError: - self.log_status_warn( - "Failed to apply frame filter, please check that it is valid!" - ) - return - - try: - # drop the singleton dimension for kernel runner - self.kernel_runner.load_data(image_data, cell_table) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - return - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.kernel_runner.correct(self._correction_flag_enabled) - for addon in self._enabled_addons: - addon.post_correction( - self.kernel_runner.processed_data, cell_table, pulse_table, data_hash - ) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - with self.warning_context( - "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS - ) as warn: - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ), preview_warning = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - pulse_table, - ) - if preview_warning is not None: - warn(preview_warning) - ( - preview_raw, - preview_corrected, - preview_gain_map, - ) = self.kernel_runner.compute_previews(preview_slice_index) - - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) - data_hash.set(self._pulse_table_path, pulse_table[:, np.newaxis]) - - self._write_output(data_hash, metadata) - self._preview_friend.write_outputs( - metadata, preview_raw, preview_corrected, preview_gain_map - ) - - def preReconfigure(self, config): - # TODO: DRY (taken from AGIPD device) - super().preReconfigure(config) - if config.has("corrections.badPixels.maskingValue"): - # only check if it is valid (let raise exception) - # if valid, postReconfigure will use it - np.float32(config.get("corrections.badPixels.maskingValue")) - - def postReconfigure(self): - super().postReconfigure() - if not hasattr(self, "_prereconfigure_update_hash"): - return - - update = self._prereconfigure_update_hash - if update.has("corrections.badPixels.maskingValue"): - self.kernel_runner.bad_pixel_mask_value = self.bad_pixel_mask_value diff --git a/src/calng/corrections/LpdminiCorrection.py b/src/calng/corrections/LpdminiCorrection.py index f43650ad78577b9b9c0be8d85ae62116017f7302..4a5d5c24582427c429b923a73bf69bdf373a1998 100644 --- a/src/calng/corrections/LpdminiCorrection.py +++ b/src/calng/corrections/LpdminiCorrection.py @@ -6,13 +6,13 @@ from karabo.bound import ( ) from .._version import version as deviceVersion -from ..base_correction import add_correction_step_schema from . import LpdCorrection class LpdminiGpuRunner(LpdCorrection.LpdGpuRunner): + num_pixels_ss = 32 + def load_constant(self, constant_type, constant_data): - print(f"Given: {constant_type} with shape {constant_data.shape}") # constant type → transpose order constant_buffer_map = { LpdCorrection.Constants.Offset: self.offset_map, @@ -74,20 +74,3 @@ class LpdminiCalcatFriend(LpdCorrection.LpdCalcatFriend): class LpdminiCorrection(LpdCorrection.LpdCorrection): _calcat_friend_class = LpdminiCalcatFriend _kernel_runner_class = LpdminiGpuRunner - - @classmethod - def expectedParameters(cls, expected): - ( - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(32) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(512) - .commit(), - ) - cls._calcat_friend_class.add_schema(expected) - # warning: this is redundant, but needed for now to get managed keys working - add_correction_step_schema(expected, cls._correction_steps) diff --git a/src/calng/corrections/PnccdCorrection.py b/src/calng/corrections/PnccdCorrection.py index 0b941e2ec9fd244ade05ad2a7cc90963a477e82a..0156e2001508255245b8ca66490f07f01b759ec0 100644 --- a/src/calng/corrections/PnccdCorrection.py +++ b/src/calng/corrections/PnccdCorrection.py @@ -1,6 +1,5 @@ import concurrent.futures import enum -import functools import numpy as np from karabo.bound import ( @@ -9,7 +8,6 @@ from karabo.bound import ( KARABO_CLASSINFO, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, - Schema, ) from .. import ( @@ -19,6 +17,7 @@ from .. import ( schemas, utils, ) +from ..preview_utils import PreviewFriend, PreviewSpec from .._version import version as deviceVersion from ..kernels import common_cython @@ -39,6 +38,14 @@ class CorrectionFlags(enum.IntFlag): BPMASK = 2**4 +correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}), + ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}), + ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}), + ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}), +) + + class PnccdCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = Constants @@ -51,30 +58,30 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend): Constants.RelativeGainCCD: self.illuminated_condition, } - @staticmethod - def add_schema(schema): - super(PnccdCalcatFriend, PnccdCalcatFriend).add_schema(schema, "pnCCD-Type") + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema, "pnCCD-Type") # set some defaults for common parameters ( OVERWRITE_ELEMENT(schema) - .key("constantParameters.memoryCells") - .setNewDefaultValue(1) + .key("constantParameters.pixelsX") + .setNewDefaultValue(1024) .commit(), OVERWRITE_ELEMENT(schema) - .key("constantParameters.biasVoltage") - .setNewDefaultValue(300) + .key("constantParameters.pixelsY") + .setNewDefaultValue(1024) .commit(), OVERWRITE_ELEMENT(schema) - .key("constantParameters.pixelsX") - .setNewDefaultValue(1024) + .key("constantParameters.memoryCells") + .setNewDefaultValue(1) .commit(), OVERWRITE_ELEMENT(schema) - .key("constantParameters.pixelsY") - .setNewDefaultValue(1024) + .key("constantParameters.biasVoltage") + .setNewDefaultValue(300) .commit(), ) @@ -148,82 +155,104 @@ class PnccdCalcatFriend(base_calcat.BaseCalcatFriend): class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner): - _corrected_axis_order = "xy" + _correction_steps = correction_steps + _correction_flag_class = CorrectionFlags + _bad_pixel_constants = {Constants.BadPixelsDarkCCD} _xp = np - _gpu_based = False - @property - def input_shape(self): - return (self.pixels_x, self.pixels_y) + num_pixels_ss = 1024 + num_pixels_fs = 1024 - @property - def preview_shape(self): - return (self.pixels_x, self.pixels_y) + @classmethod + def add_schema(cls, schema): + super(cls, cls).add_schema(schema) + super(cls, cls).add_bad_pixel_config(schema) + ( + DOUBLE_ELEMENT(schema) + .key("corrections.commonMode.noiseSigma") + .assignmentOptional() + .defaultValue(5) + .reconfigurable() + .commit(), - @property - def processed_shape(self): - return self.input_shape + DOUBLE_ELEMENT(schema) + .key("corrections.commonMode.minFrac") + .assignmentOptional() + .defaultValue(0.25) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key("corrections.commonMode.enableRow") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key("corrections.commonMode.enableCol") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) + + def reconfigure(self, config): + super().reconfigure(config) + if config.has("corrections.commonMode.noiseSigma"): + self._cm_sigma = np.float32(config["corrections.commonMode.noiseSigma"]) + if config.has("corrections.commonMode.minFrac"): + self._cm_minfrac = np.float32(config["corrections.commonMode.minFrac"]) + if config.has("corrections.commonMode.enableRow"): + self._cm_row = config["corrections.commonMode.enableRow"] + if config.has("corrections.commonMode.enableCol"): + self._cm_col = config["corrections.commonMode.enableCol"] + + def expected_input_shape(self, num_frames): + assert num_frames == 1, "pnCCD not expected to have multiple frames" + return (self.num_pixels_ss, self.num_pixels_fs) + + def _expected_output_shape(self, num_frames): + return (self.num_pixels_ss, self.num_pixels_fs) @property - def map_shape(self): - return (self.pixels_x, self.pixels_y) + def _map_shape(self): + return (self.num_pixels_ss, self.num_pixels_fs) + + def _make_output_buffers(self, num_frames, flags): + # ignore parameters + return [ + self._xp.empty( + (self.num_pixels_ss, self.num_pixels_fs), + dtype=self._xp.float32, + ) + ] - def __init__( - self, - pixels_x, - pixels_y, - frames, # will be 1, will be ignored - constant_memory_cells, - input_data_dtype=np.uint16, - output_data_dtype=np.float32, - ): + def _post_init(self): + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) + + def _setup_constant_buffers(self): assert ( - output_data_dtype == np.float32 + self._output_dtype == np.float32 ), "Alternative output types not supported yet" - super().__init__( - pixels_x, - pixels_y, - frames, - constant_memory_cells, - output_data_dtype, - ) - - self.input_data = None - self.processed_data = np.empty(self.processed_shape, dtype=np.float32) - - self.offset_map = np.empty(self.map_shape, dtype=np.float32) - self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32) - self.noise_map = np.empty(self.map_shape, dtype=np.float32) + self.offset_map = np.empty(self._map_shape, dtype=np.float32) + self.rel_gain_map = np.empty(self._map_shape, dtype=np.float32) + self.bad_pixel_map = np.empty(self._map_shape, dtype=np.uint32) + self.noise_map = np.empty(self._map_shape, dtype=np.float32) # will do everything by quadrant - self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) - self._q_input_data = None - self._q_processed_data = utils.quadrant_views(self.processed_data) self._q_offset_map = utils.quadrant_views(self.offset_map) self._q_rel_gain_map = utils.quadrant_views(self.rel_gain_map) self._q_bad_pixel_map = utils.quadrant_views(self.bad_pixel_map) self._q_noise_map = utils.quadrant_views(self.noise_map) - def __del__(self): - self.thread_pool.shutdown() - - @property - def preview_data_views(self): - return (self.input_data, self.processed_data) - - def load_data(self, image_data): - # should almost never squeeze, but ePix doesn't do burst mode, right? - self.input_data = image_data.astype(np.uint16, copy=False).squeeze() - self._q_input_data = utils.quadrant_views(self.input_data) - - def load_constant(self, constant_type, data): + def _load_constant(self, constant_type, data): if constant_type is Constants.OffsetCCD: self.offset_map[:] = data.squeeze().astype(np.float32) elif constant_type is Constants.RelativeGainCCD: self.rel_gain_map[:] = data.squeeze().astype(np.float32) elif constant_type is Constants.BadPixelsDarkCCD: - self.bad_pixel_map[:] = data.squeeze() + self.bad_pixel_map[:] = (data.squeeze() & self.bad_pixel_subset) elif constant_type is Constants.NoiseCCD: self.noise_map[:] = data.squeeze() else: @@ -241,123 +270,116 @@ class PnccdCpuRunner(base_kernel_runner.BaseKernelRunner): def _correct_quadrant( self, - q, flags, - bad_pixel_mask_value, - cm_noise_sigma, - cm_min_frac, - cm_row, - cm_col, + input_data, + offset_map, + noise_map, + bad_pixel_map, + rel_gain_map, + output, ): - output = self._q_processed_data[q] - output[:] = self._q_input_data[q].astype(np.float32) + output[:] = input_data.astype(np.float32) if flags & CorrectionFlags.OFFSET: - output -= self._q_offset_map[q] + output -= offset_map if flags & CorrectionFlags.COMMONMODE: cm_mask = ( - (self._q_bad_pixel_map[q] != 0) - | (output > self._q_noise_map[q] * cm_noise_sigma) + (bad_pixel_map != 0) + | (output > noise_map * self._cm_sigma) ).astype(np.uint8, copy=False) - if cm_row: + if self._cm_row: common_cython.cm_fs( output, cm_mask, - cm_noise_sigma, - cm_min_frac, + self._cm_sigma, + self._cm_minfrac, ) - if cm_col: + if self._cm_col: common_cython.cm_ss( output, cm_mask, - cm_noise_sigma, - cm_min_frac, + self._cm_sigma, + self._cm_minfrac, ) if flags & CorrectionFlags.RELGAIN: - output *= self._q_rel_gain_map[q] + output *= rel_gain_map if flags & CorrectionFlags.BPMASK: - output[self._q_bad_pixel_map[q] != 0] = bad_pixel_mask_value + output[bad_pixel_map != 0] = self.bad_pixel_mask_value + + def _correct(self, flags, image_data, cell_table, output): + # cell_table is going to be None for now, just ignore + concurrent.futures.wait( + [ + self.thread_pool.submit( + self._correct_quadrant(flags, *parts) + ) + for parts in zip( + utils.quadrant_views(image_data), + self._q_offset_map, + self._q_noise_map, + self._q_bad_pixel_map, + self._q_rel_gain_map, + utils.quadrant_views(output), + ) + ] + ) - def correct( - self, - flags, - bad_pixel_mask_value=np.nan, - cm_noise_sigma=5, - cm_min_frac=0.25, - cm_row=True, - cm_col=True, - ): - # NOTE: how to best clean up all these duplicated parameters? - for result in self.thread_pool.map( - functools.partial( - self._correct_quadrant, - flags=flags, - bad_pixel_mask_value=bad_pixel_mask_value, - cm_noise_sigma=cm_noise_sigma, - cm_min_frac=cm_min_frac, - cm_row=cm_row, - cm_col=cm_col, - ), - range(4), - ): - pass # just run through to await map + def __del__(self): + self.thread_pool.shutdown() + + def _preview_data_views(self, raw_data, processed_data): + return [ + raw_data[np.newaxis], + processed_data[np.newaxis], + ] @KARABO_CLASSINFO("PnccdCorrection", deviceVersion) class PnccdCorrection(base_correction.BaseCorrection): - _correction_flag_class = CorrectionFlags - _correction_steps = ( - ("offset", CorrectionFlags.OFFSET, {Constants.OffsetCCD}), - ("relGain", CorrectionFlags.RELGAIN, {Constants.RelativeGainCCD}), - ("commonMode", CorrectionFlags.COMMONMODE, {Constants.NoiseCCD}), - ("badPixels", CorrectionFlags.BPMASK, {Constants.BadPixelsDarkCCD}), - ) - _image_data_path = "data.image" + _base_output_schema = schemas.pnccd_output_schema _kernel_runner_class = PnccdCpuRunner _calcat_friend_class = PnccdCalcatFriend + _correction_steps = correction_steps _constant_enum_class = Constants - _preview_outputs = ["outputRaw", "outputCorrected"] + + _image_data_path = "data.image" _cell_table_path = None _pulse_table_path = None - _warn_memory_cell_range = False - @staticmethod - def expectedParameters(expected): + _preview_outputs = [ + PreviewSpec( + name, + dimensions=2, + frame_reduction=False, + ) + for name in ("raw", "corrected") + ] + _warn_memory_cell_range = False + _cuda_pin_buffers = False + _use_shmem_handles = False + + @classmethod + def expectedParameters(cls, expected): + cls._calcat_friend_class.add_schema(expected) + cls._kernel_runner_class.add_schema(expected) + base_correction.add_addon_nodes(expected, cls) + PreviewFriend.add_schema(expected, cls._preview_outputs) ( OUTPUT_CHANNEL(expected) .key("dataOutput") - .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False)) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(1024) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(1024) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.frames") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.outputAxisOrder") - .setNewOptions("xy,yx") - .setNewDefaultValue("xy") + .dataSchema( + cls._base_output_schema(use_shmem_handles=cls._use_shmem_handles) + ) .commit(), # TODO: disable preview selection mode OVERWRITE_ELEMENT(expected) .key("useShmemHandles") - .setNewDefaultValue(False) + .setNewDefaultValue(cls._use_shmem_handles) .commit(), OVERWRITE_ELEMENT(expected) @@ -366,113 +388,11 @@ class PnccdCorrection(base_correction.BaseCorrection): .commit(), ) - base_correction.add_preview_outputs( - expected, PnccdCorrection._preview_outputs - ) - base_correction.add_correction_step_schema( - expected, - PnccdCorrection._correction_steps, - ) - ( - DOUBLE_ELEMENT(expected) - .key("corrections.commonMode.noiseSigma") - .assignmentOptional() - .defaultValue(5) - .reconfigurable() - .commit(), - - DOUBLE_ELEMENT(expected) - .key("corrections.commonMode.minFrac") - .assignmentOptional() - .defaultValue(0.25) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.commonMode.enableRow") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - - BOOL_ELEMENT(expected) - .key("corrections.commonMode.enableCol") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), - ) - PnccdCalcatFriend.add_schema(expected) - # TODO: bad pixel node? - - @property - def input_data_shape(self): - # TODO: check + def _get_data_from_hash(self, data_hash): + image_data = data_hash.get(self._image_data_path) return ( - self.unsafe_get("dataFormat.pixelsX"), - self.unsafe_get("dataFormat.pixelsY"), - ) - - def __init__(self, config): - super().__init__(config) - if config.get("useShmemHandles", default=False): - def aux(): - schema_override = Schema() - ( - OUTPUT_CHANNEL(schema_override) - .key("dataOutput") - .dataSchema(schemas.pnccd_output_schema(use_shmem_handle=False)) - .commit(), - ) - self.updateSchema(schema_override) - - self.registerInitialFunction(aux) - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, # will be None - pulse_table, # ditto - ): - self.kernel_runner.load_data(image_data) - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - args_which_should_be_cached = dict( - cm_noise_sigma=self.unsafe_get("corrections.commonMode.noiseSigma"), - cm_min_frac=self.unsafe_get("corrections.commonMode.minFrac"), - cm_row=self.unsafe_get("corrections.commonMode.enableRow"), - cm_col=self.unsafe_get("corrections.commonMode.enableCol"), + 1, + image_data, + None, + None, ) - self.kernel_runner.correct( - flags=self._correction_flag_enabled, **args_which_should_be_cached - ) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct( - flags=self._correction_flag_preview, - **args_which_should_be_cached, - ) - - if self._use_shmem_handles: - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - else: - data_hash.set(self._image_data_path, buffer_array) - data_hash.set("calngShmemPaths", []) - - self._write_output(data_hash, metadata) - - # note: base class preview machinery assumes burst mode, shortcut it - self._preview_friend.write_outputs( - metadata, *self.kernel_runner.preview_data_views - ) - - def _load_constant_to_runner(self, constant, constant_data): - self.kernel_runner.load_constant(constant, constant_data) diff --git a/src/calng/kernels/agipd_gpu.cu b/src/calng/kernels/agipd_gpu.cu index ac5fc7435f82ae8889a9e6389352b884b5364f58..f638af17e5e8b23463effae43bc89887aeb4096b 100644 --- a/src/calng/kernels/agipd_gpu.cu +++ b/src/calng/kernels/agipd_gpu.cu @@ -13,6 +13,8 @@ extern "C" { __global__ void correct(const unsigned short* data, const unsigned short* cell_table, const unsigned char corr_flags, + const unsigned short input_frames, + const unsigned short map_cells, // default_gain can be 0, 1, or 2, and is relevant for fixed gain mode (no THRESHOLD) const unsigned char default_gain, const float* threshold_map, @@ -27,71 +29,69 @@ extern "C" { const float bad_pixel_mask_value, float* gain_map, // TODO: more compact yet plottable representation {{output_data_dtype}}* output) { - const size_t X = {{pixels_x}}; - const size_t Y = {{pixels_y}}; - const size_t input_frames = {{frames}}; - const size_t map_cells = {{constant_memory_cells}}; + const size_t ss_dim = 512; + const size_t fs_dim = 128; - const size_t cell = blockIdx.x * blockDim.x + threadIdx.x; - const size_t x = blockIdx.y * blockDim.y + threadIdx.y; - const size_t y = blockIdx.z * blockDim.z + threadIdx.z; + const size_t frame = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ss = blockIdx.y * blockDim.y + threadIdx.y; + const size_t fs = blockIdx.z * blockDim.z + threadIdx.z; - if (cell >= input_frames || y >= Y || x >= X) { + if (frame >= input_frames || fs >= fs_dim || ss >= ss_dim) { return; } - // data shape: memory cell, data/raw_gain (dim size 2), x, y - const size_t data_stride_y = 1; - const size_t data_stride_x = Y * data_stride_y; - const size_t data_stride_raw_gain = X * data_stride_x; - const size_t data_stride_cell = 2 * data_stride_raw_gain; - const size_t data_index = cell * data_stride_cell + + // data shape: frame, data/raw_gain (dim size 2), ss, fs + const size_t data_stride_fs = 1; + const size_t data_stride_ss = fs_dim * data_stride_fs; + const size_t data_stride_raw_gain = ss_dim * data_stride_ss; + const size_t data_stride_frame = 2 * data_stride_raw_gain; + const size_t data_index = frame * data_stride_frame + 0 * data_stride_raw_gain + - y * data_stride_y + - x * data_stride_x; - const size_t raw_gain_index = cell * data_stride_cell + + fs * data_stride_fs + + ss * data_stride_ss; + const size_t raw_gain_index = frame * data_stride_frame + 1 * data_stride_raw_gain + - y * data_stride_y + - x * data_stride_x; + fs * data_stride_fs + + ss * data_stride_ss; float res = (float)data[data_index]; const float raw_gain_val = (float)data[raw_gain_index]; - const size_t output_stride_y = 1; - const size_t output_stride_x = output_stride_y * Y; - const size_t output_stride_cell = output_stride_x * X; - const size_t output_index = cell * output_stride_cell + x * output_stride_x + y * output_stride_y; + const size_t output_stride_fs = 1; + const size_t output_stride_ss = output_stride_fs * fs_dim; + const size_t output_stride_frame = output_stride_ss * ss_dim; + const size_t output_index = frame * output_stride_frame + ss * output_stride_ss + fs * output_stride_fs; - // per-pixel only constant: cell, x, y - const size_t map_stride_y = 1; - const size_t map_stride_x = Y * map_stride_y; - const size_t map_stride_cell = X * map_stride_x; + // per-pixel only constant: cell, ss, fs + const size_t map_stride_fs = 1; + const size_t map_stride_ss = fs_dim * map_stride_fs; + const size_t map_stride_cell = ss_dim * map_stride_ss; - // threshold constant shape: cell, x, y, threshold (dim size 2) + // threshold constant shape: cell, ss, fs, threshold (dim size 2) const size_t threshold_map_stride_threshold = 1; - const size_t threshold_map_stride_y = 2 * threshold_map_stride_threshold; - const size_t threshold_map_stride_x = Y * threshold_map_stride_y; - const size_t threshold_map_stride_cell = X * threshold_map_stride_x; + const size_t threshold_map_stride_fs = 2 * threshold_map_stride_threshold; + const size_t threshold_map_stride_ss = fs_dim * threshold_map_stride_fs; + const size_t threshold_map_stride_cell = ss_dim * threshold_map_stride_ss; - // gain mapped constant shape: cell, x, y, gain_level (dim size 3) + // gain mapped constant shape: cell, ss, fs, gain_level (dim size 3) const size_t gm_map_stride_gain = 1; - const size_t gm_map_stride_y = 3 * gm_map_stride_gain; - const size_t gm_map_stride_x = Y * gm_map_stride_y; - const size_t gm_map_stride_cell = X * gm_map_stride_x; - // note: assuming all maps have same shape (in terms of cells / x / y) + const size_t gm_map_stride_fs = 3 * gm_map_stride_gain; + const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs; + const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss; + // note: assuming all maps have same shape (in terms of cells / ss / fs) - const size_t map_cell = cell_table[cell]; + const size_t map_cell = cell_table[frame]; if (map_cell < map_cells) { unsigned char gain = default_gain; if (corr_flags & THRESHOLD) { const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold + map_cell * threshold_map_stride_cell + - y * threshold_map_stride_y + - x * threshold_map_stride_x]; + fs * threshold_map_stride_fs + + ss * threshold_map_stride_ss]; const float threshold_1 = threshold_map[1 * threshold_map_stride_threshold + map_cell * threshold_map_stride_cell + - y * threshold_map_stride_y + - x * threshold_map_stride_x]; + fs * threshold_map_stride_fs + + ss * threshold_map_stride_ss]; // could consider making this const using ternaries / tiny function if (raw_gain_val <= threshold_0) { gain = 0; @@ -103,8 +103,8 @@ extern "C" { } const size_t gm_map_index_without_gain = map_cell * gm_map_stride_cell + - y * gm_map_stride_y + - x * gm_map_stride_x; + fs * gm_map_stride_fs + + ss * gm_map_stride_ss; if ((corr_flags & FORCE_MG_IF_BELOW) && (gain == 2) && (res - offset_map[gm_map_index_without_gain + 1 * gm_map_stride_gain] < mg_hard_threshold)) { gain = 1; @@ -116,8 +116,8 @@ extern "C" { gain_map[output_index] = (float)gain; const size_t map_index = map_cell * map_stride_cell + - y * map_stride_y + - x * map_stride_x; + fs * map_stride_fs + + ss * map_stride_ss; const size_t gm_map_index = gm_map_index_without_gain + gain * gm_map_stride_gain; diff --git a/src/calng/kernels/common_gpu.cu b/src/calng/kernels/common_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..631ab7dd5d1cb966ee736bc39c8cae57ef6fd443 --- /dev/null +++ b/src/calng/kernels/common_gpu.cu @@ -0,0 +1,60 @@ +// https://excalidraw.com/#json=mgBWynet5WUbpgfSd0pzC,TMhzt6L-x5yShNCHRGKLvg +extern "C" __global__ void common_mode_asic( + float* const data, + const unsigned short n_frames, + const unsigned short num_iter, + const float min_dark_fraction, + const float noise_peak_range +) { + const int frame = blockIdx.z; + const int ss_asic = blockIdx.y; + const int fs_asic = blockIdx.x; + const int thread_in_asic = threadIdx.x; + assert(blockDim.x == {{asic_dim}}); + + __shared__ float shared_sum[{{asic_dim}}]; + __shared__ unsigned int shared_count[{{asic_dim}}]; + + const unsigned int min_dark_pixels = static_cast<int>(min_dark_fraction * static_cast<float>({{asic_dim}} * {{asic_dim}})); + float* const asic_start = data + frame * {{ss_dim}} * {{fs_dim}} + ss_asic * {{asic_dim}} * {{fs_dim}} + fs_asic * {{asic_dim}}; + + for (int iter=0; iter<num_iter; ++iter) { + // accumulate (over rows) to shared + unsigned int my_count = 0; + float my_sum = 0; + for (int row=0; row<{{asic_dim}}; ++row) { + const float pixel = asic_start[row * {{fs_dim}} + thread_in_asic]; + if (!isfinite(pixel) || fabs(pixel) > noise_peak_range) { + continue; + } + my_sum += pixel; + ++my_count; + } + shared_sum[thread_in_asic] = my_sum; + shared_count[thread_in_asic] = my_count; + __syncthreads(); + + // reduce (per ASIC) in shared + int active_threads = {{asic_dim}}; + while (active_threads > 1) { + active_threads /= 2; + const int other = thread_in_asic + active_threads; + if (thread_in_asic < active_threads) { + shared_sum[thread_in_asic] += shared_sum[other]; + shared_count[thread_in_asic] += shared_count[other]; + } + __syncthreads(); + } + + // now, each ASIC result should be in first index of shared by ASIC + const unsigned int dark_count = shared_count[0]; + if (dark_count < min_dark_pixels) { + return; + } + const float dark_sum = shared_sum[0]; + const float dark_mean = dark_sum / static_cast<float>(dark_count); + for (int row=0; row<{{asic_dim}}; ++row) { + asic_start[row * {{fs_dim}} + thread_in_asic] -= dark_mean; + } + } +} diff --git a/src/calng/kernels/dssc_cpu.pyx b/src/calng/kernels/dssc_cpu.pyx index 04803d3f0a28a0a36721504c5e97177f609639b8..613ef30755c3b09dae5cac23cc6435e1a5e0b82d 100644 --- a/src/calng/kernels/dssc_cpu.pyx +++ b/src/calng/kernels/dssc_cpu.pyx @@ -8,7 +8,7 @@ cdef unsigned char OFFSET = 1 from cython.parallel import prange def correct( - unsigned short[:, :, :] image_data, + unsigned short[:, :, :, :] image_data, unsigned short[:] cell_table, unsigned char flags, float[:, :, :] offset_map, @@ -19,13 +19,13 @@ def correct( for frame in prange(image_data.shape[0], nogil=True): map_cell = cell_table[frame] if map_cell >= offset_map.shape[0]: - for x in range(image_data.shape[1]): - for y in range(image_data.shape[2]): - output[frame, x, y] = <float>image_data[frame, x, y] + for x in range(image_data.shape[2]): + for y in range(image_data.shape[3]): + output[frame, x, y] = <float>image_data[frame, 0, x, y] continue - for x in range(image_data.shape[1]): - for y in range(image_data.shape[2]): - res = image_data[frame, x, y] + for x in range(image_data.shape[2]): + for y in range(image_data.shape[3]): + res = image_data[frame, 0, x, y] if flags & OFFSET: res = res - offset_map[map_cell, x, y] output[frame, x, y] = res diff --git a/src/calng/kernels/dssc_gpu.cu b/src/calng/kernels/dssc_gpu.cu index 4e21b02b4ade95779e1a31754c0ba8a54be280ae..6ebe7b9ae5f04941a175dc05b4266eabf36220ad 100644 --- a/src/calng/kernels/dssc_gpu.cu +++ b/src/calng/kernels/dssc_gpu.cu @@ -8,41 +8,41 @@ extern "C" { Take cell_table into account when getting correction values Converting to float while correcting Converting to output dtype at the end - Shape of input data: memory cell, 1, y, x - Shape of offset constant: x, y, memory cell + Shape of input data: memory cell, 1, ss, fs + Shape of offset constant: fs, ss, memory cell */ - __global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x + __global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs const unsigned short* cell_table, + const unsigned short input_frames, + const unsigned short map_memory_cells, const unsigned char corr_flags, const float* offset_map, {{output_data_dtype}}* output) { - const size_t X = {{pixels_x}}; - const size_t Y = {{pixels_y}}; - const size_t input_frames = {{frames}}; - const size_t map_memory_cells = {{constant_memory_cells}}; + const size_t ss_dim = 128; + const size_t fs_dim = 512; const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; - const size_t y = blockIdx.y * blockDim.y + threadIdx.y; - const size_t x = blockIdx.z * blockDim.z + threadIdx.z; + const size_t ss = blockIdx.y * blockDim.y + threadIdx.y; + const size_t fs = blockIdx.z * blockDim.z + threadIdx.z; - if (memory_cell >= input_frames || y >= Y || x >= X) { + if (memory_cell >= input_frames || ss >= ss_dim || fs >= fs_dim) { return; } // note: strides differ from numpy strides because unit here is sizeof(...), not byte - const size_t data_stride_x = 1; - const size_t data_stride_y = X * data_stride_x; - const size_t data_stride_cell = Y * data_stride_y; - const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x; + const size_t data_stride_fs = 1; + const size_t data_stride_ss = fs_dim * data_stride_fs; + const size_t data_stride_cell = ss_dim * data_stride_ss; + const size_t data_index = memory_cell * data_stride_cell + ss * data_stride_ss + fs * data_stride_fs; float res = (float)data[data_index]; - const size_t map_stride_x = 1; - const size_t map_stride_y = X * map_stride_x; - const size_t map_stride_cell = Y * map_stride_y; + const size_t map_stride_fs = 1; + const size_t map_stride_ss = fs_dim * map_stride_fs; + const size_t map_stride_cell = ss_dim * map_stride_ss; const size_t map_cell = cell_table[memory_cell]; if (map_cell < map_memory_cells) { - const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x; + const size_t map_index = map_cell * map_stride_cell + ss * map_stride_ss + fs * map_stride_fs; if (corr_flags & OFFSET) { res -= offset_map[map_index]; } diff --git a/src/calng/kernels/jungfrau_cpu.pyx b/src/calng/kernels/jungfrau_cpu.pyx index 2a200063879007b54f92c762378863c0d3e71c22..5ff3148c779dc3395fb0472b444c0f08abad0157 100644 --- a/src/calng/kernels/jungfrau_cpu.pyx +++ b/src/calng/kernels/jungfrau_cpu.pyx @@ -24,20 +24,20 @@ def correct_burst( float badpixel_fill_value, float[:, :, :] output, ): - cdef int frame, map_cell, x, y + cdef int frame, map_cell, fs, ss cdef unsigned char gain cdef float corrected for frame in prange(image_data.shape[0], nogil=True): map_cell = cell_table[frame] if map_cell >= offset_map.shape[0]: - for y in range(image_data.shape[1]): - for x in range(image_data.shape[2]): - output[frame, y, x] = <float>image_data[frame, y, x] + for ss in range(image_data.shape[1]): + for fs in range(image_data.shape[2]): + output[frame, ss, fs] = <float>image_data[frame, ss, fs] continue - for y in range(image_data.shape[1]): - for x in range(image_data.shape[2]): - corrected = image_data[frame, y, x] - gain = gain_stage[frame, y, x] + for ss in range(image_data.shape[1]): + for fs in range(image_data.shape[2]): + corrected = image_data[frame, ss, fs] + gain = gain_stage[frame, ss, fs] # legal values: 0, 1, or 3 if gain == 2: corrected = badpixel_fill_value @@ -45,14 +45,14 @@ def correct_burst( if gain == 3: gain = 2 - if (flags & BPMASK) and badpixel_mask[map_cell, y, x, gain] != 0: + if (flags & BPMASK) and badpixel_mask[map_cell, ss, fs, gain] != 0: corrected = badpixel_fill_value else: if (flags & OFFSET): - corrected = corrected - offset_map[map_cell, y, x, gain] + corrected = corrected - offset_map[map_cell, ss, fs, gain] if (flags & REL_GAIN): - corrected = corrected / relgain_map[map_cell, y, x, gain] - output[frame, y, x] = corrected + corrected = corrected / relgain_map[map_cell, ss, fs, gain] + output[frame, ss, fs] = corrected def correct_single( @@ -66,13 +66,13 @@ def correct_single( float[:, :, :] output, ): # ignore "cell table", constant pretends cell 0 - cdef int x, y + cdef int fs, ss cdef unsigned char gain cdef float corrected - for y in range(image_data.shape[1]): - for x in range(image_data.shape[2]): - corrected = image_data[0, y, x] - gain = gain_stage[0, y, x] + for ss in range(image_data.shape[1]): + for fs in range(image_data.shape[2]): + corrected = image_data[0, ss, fs] + gain = gain_stage[0, ss, fs] # legal values: 0, 1, or 3 if gain == 2: corrected = badpixel_fill_value @@ -80,11 +80,11 @@ def correct_single( if gain == 3: gain = 2 - if (flags & BPMASK) and badpixel_mask[0, y, x, gain] != 0: + if (flags & BPMASK) and badpixel_mask[0, ss, fs, gain] != 0: corrected = badpixel_fill_value else: if (flags & OFFSET): - corrected = corrected - offset_map[0, y, x, gain] + corrected = corrected - offset_map[0, ss, fs, gain] if (flags & REL_GAIN): - corrected = corrected / relgain_map[0, y, x, gain] - output[0, y, x] = corrected + corrected = corrected / relgain_map[0, ss, fs, gain] + output[0, ss, fs] = corrected diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu index a4f148f1c49d172b31f91a6a1a4ae4ec13c61c49..57a2d2a24a6adeaaeb5e4b04032782fb5f83502a 100644 --- a/src/calng/kernels/jungfrau_gpu.cu +++ b/src/calng/kernels/jungfrau_gpu.cu @@ -7,50 +7,45 @@ extern "C" { const unsigned char* gain_stage, // same shape const unsigned char* cell_table, const unsigned char corr_flags, + const unsigned short input_frames, + const unsigned short map_memory_cells, + const unsigned char burst_mode, const float* offset_map, const float* rel_gain_map, const unsigned int* bad_pixel_map, const float bad_pixel_mask_value, {{output_data_dtype}}* output) { - const size_t X = {{pixels_x}}; - const size_t Y = {{pixels_y}}; - const size_t input_frames = {{frames}}; - const size_t map_memory_cells = {{constant_memory_cells}}; + const size_t ss_dim = 512; + const size_t fs_dim = 1024; const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x; - const size_t y = blockIdx.y * blockDim.y + threadIdx.y; - const size_t x = blockIdx.z * blockDim.z + threadIdx.z; + const size_t ss = blockIdx.y * blockDim.y + threadIdx.y; + const size_t fs = blockIdx.z * blockDim.z + threadIdx.z; - if (current_frame >= input_frames || y >= Y || x >= X) { + if (current_frame >= input_frames || ss >= ss_dim || fs >= fs_dim) { return; } - const size_t data_stride_x = 1; - const size_t data_stride_y = X * data_stride_x; - const size_t data_stride_frame = Y * data_stride_y; + const size_t data_stride_fs = 1; + const size_t data_stride_ss = fs_dim * data_stride_fs; + const size_t data_stride_frame = ss_dim * data_stride_ss; const size_t data_index = current_frame * data_stride_frame + - y * data_stride_y + - x * data_stride_x; + ss * data_stride_ss + + fs * data_stride_fs; float res = (float)data[data_index]; // gain mapped constant shape: cell, y, x, gain_level (dim size 3) // note: in fixed gain mode, constant still provides data for three stages const size_t map_stride_gain = 1; - const size_t map_stride_x = 3 * map_stride_gain; - const size_t map_stride_y = X * map_stride_x; - const size_t map_stride_cell = Y * map_stride_y; + const size_t map_stride_fs = 3 * map_stride_gain; + const size_t map_stride_ss = fs_dim * map_stride_fs; + const size_t map_stride_cell = ss_dim * map_stride_ss; // TODO: warn user about cell_table value of 255 in either mode // note: cell table may contain 255 if data didn't arrive - {% if burst_mode %} - // burst mode: "cell 255" will get copied - // TODO: consider masking "cell 255" - const size_t map_cell = cell_table[current_frame]; - {% else %} + // burst mode: "cell 255" will get copied (not corrected) // single cell: "cell 255" will get "corrected" - const size_t map_cell = 0; - {% endif %} - + const size_t map_cell = burst_mode ? cell_table[current_frame] : 0; if (map_cell < map_memory_cells) { unsigned char gain = gain_stage[data_index]; if (gain == 2) { @@ -58,20 +53,24 @@ extern "C" { res = bad_pixel_mask_value; } else { if (gain == 3) { + // but 3 means 2 in the constants gain = 2; } - const size_t map_index = map_cell * map_stride_cell + - y * map_stride_y + - x * map_stride_x + - gain * map_stride_gain; - if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) { - res = bad_pixel_mask_value; - } else { - if (corr_flags & OFFSET) { - res -= offset_map[map_index]; - } - if (corr_flags & REL_GAIN) { - res /= rel_gain_map[map_index]; + if (map_cell < map_memory_cells) { + // only correct if in range of constant data + const size_t map_index = map_cell * map_stride_cell + + ss * map_stride_ss + + fs * map_stride_fs + + gain * map_stride_gain; + if ((corr_flags & BPMASK) && bad_pixel_map[map_index]) { + res = bad_pixel_mask_value; + } else { + if (corr_flags & OFFSET) { + res -= offset_map[map_index]; + } + if (corr_flags & REL_GAIN) { + res /= rel_gain_map[map_index]; + } } } } diff --git a/src/calng/kernels/lpd_cpu.pyx b/src/calng/kernels/lpd_cpu.pyx index 0c3eb82d50c3e23d1414515fd967347ffe4a0876..3cd878fb847e7e1978620c63c3f8899dc3cec55f 100644 --- a/src/calng/kernels/lpd_cpu.pyx +++ b/src/calng/kernels/lpd_cpu.pyx @@ -12,6 +12,7 @@ cdef unsigned char BPMASK = 16 from cython.parallel import prange from libc.math cimport isinf, isnan + def correct( unsigned short[:, :, :, :] image_data, unsigned short[:] cell_table, @@ -22,7 +23,7 @@ def correct( float[:, :, :, :] flatfield_map, unsigned[:, :, :, :] bad_pixel_map, float bad_pixel_mask_value, - # TODO: support spitting out gain map for preview purposes + float[:, :, :] gain_map, float[:, :, :] output ): cdef int frame, map_cell, ss, fs @@ -31,25 +32,41 @@ def correct( cdef unsigned short raw_data_value for frame in prange(image_data.shape[0], nogil=True): map_cell = cell_table[frame] - if map_cell >= offset_map.shape[0]: - for ss in range(image_data.shape[1]): - for fs in range(image_data.shape[2]): - output[frame, ss, fs] = <float>image_data[frame, 0, ss, fs] - continue - for ss in range(image_data.shape[1]): - for fs in range(image_data.shape[3]): - raw_data_value = image_data[frame, 0, ss, fs] - gain_stage = (raw_data_value >> 12) & 0x0003 - res = <float>(raw_data_value & 0x0fff) - if gain_stage > 2 or (flags & BPMASK and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0): - res = bad_pixel_mask_value + if map_cell < offset_map.shape[0]: + for ss in range(image_data.shape[2]): + for fs in range(image_data.shape[3]): + raw_data_value = image_data[frame, 0, ss, fs] + gain_stage = (raw_data_value >> 12) & 0x0003 + res = <float>(raw_data_value & 0x0fff) + if ( + gain_stage > 2 + or ( + flags & BPMASK + and bad_pixel_map[map_cell, ss, fs, gain_stage] != 0 + ) + ): + res = bad_pixel_mask_value + else: + if flags & OFFSET: + res = res - offset_map[map_cell, ss, fs, gain_stage] + if flags & GAIN_AMP: + res = res * gain_amp_map[map_cell, ss, fs, gain_stage] + if flags & FF_CORR: + res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage] + if res < -1e7 or res > 1e7 or isnan(res) or isinf(res): + res = bad_pixel_mask_value + output[frame, ss, fs] = res + gain_map[frame, ss, fs] = <float>gain_stage else: - if flags & OFFSET: - res = res - offset_map[map_cell, ss, fs, gain_stage] - if flags & GAIN_AMP: - res = res * gain_amp_map[map_cell, ss, fs, gain_stage] - if flags & FF_CORR: - res = res * rel_gain_slopes_map[map_cell, ss, fs, gain_stage] - if res < 1e-7 or res > 1e7 or isnan(res) or isinf(res): - res = bad_pixel_mask_value - output[frame, ss, fs] = res + for ss in range(image_data.shape[2]): + for fs in range(image_data.shape[3]): + raw_data_value = image_data[frame, 0, ss, fs] + gain_stage = (raw_data_value >> 12) & 0x0003 + res = <float>(raw_data_value & 0x0fff) + if ( + gain_stage > 2 + or res < -1e7 or res > 1e7 + ): + res = bad_pixel_mask_value + output[frame, ss, fs] = res + gain_map[frame, ss, fs] = <float>gain_stage diff --git a/src/calng/kernels/lpd_gpu.cu b/src/calng/kernels/lpd_gpu.cu index 4e98cc851e836d83edffff1b449087dce8c3f190..7e935a38963af4c27038e7f40ee75ca9e92dcfa8 100644 --- a/src/calng/kernels/lpd_gpu.cu +++ b/src/calng/kernels/lpd_gpu.cu @@ -3,10 +3,12 @@ {{corr_enum}} extern "C" { - __global__ void correct(const unsigned short* data, // shape: memory cell, 1, y, x (I think) + __global__ void correct(const unsigned short* data, // shape: memory cell, 1, ss, fs const unsigned short* cell_table, const unsigned char corr_flags, - const float* offset_map, // shape: cell, y, x, gain + const unsigned short input_frames, + const unsigned short map_memory_cells, + const float* offset_map, // shape: cell, ss, fs, gain const float* gain_amp_map, const float* rel_gain_slopes_map, const float* flatfield_map, @@ -14,39 +16,37 @@ extern "C" { const float bad_pixel_mask_value, float* gain_map, // similar to preview for AGIPD {{output_data_dtype}}* output) { - const size_t X = {{pixels_x}}; - const size_t Y = {{pixels_y}}; - const size_t input_frames = {{frames}}; - const size_t map_memory_cells = {{constant_memory_cells}}; + const size_t fs_dim = {{fs_dim}}; + const size_t ss_dim = {{ss_dim}}; - const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; - const size_t y = blockIdx.y * blockDim.y + threadIdx.y; - const size_t x = blockIdx.z * blockDim.z + threadIdx.z; + const size_t frame = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ss = blockIdx.y * blockDim.y + threadIdx.y; + const size_t fs = blockIdx.z * blockDim.z + threadIdx.z; - if (memory_cell >= input_frames || y >= Y || x >= X) { + if (frame >= input_frames || ss >= ss_dim || fs >= fs_dim) { return; } - const size_t data_stride_x = 1; - const size_t data_stride_y = X * data_stride_x; - const size_t data_stride_cell = Y * data_stride_y; - const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x; + const size_t data_stride_fs = 1; + const size_t data_stride_ss = fs_dim * data_stride_fs; + const size_t data_stride_frame = ss_dim * data_stride_ss; + const size_t data_index = frame * data_stride_frame + ss * data_stride_ss + fs * data_stride_fs; const unsigned short raw_data_value = data[data_index]; const unsigned char gain = (raw_data_value >> 12) & 0x0003; float corrected = (float)(raw_data_value & 0x0fff); float gain_for_preview = (float)gain; const size_t gm_map_stride_gain = 1; - const size_t gm_map_stride_x = 3 * gm_map_stride_gain; - const size_t gm_map_stride_y = X * gm_map_stride_x; - const size_t gm_map_stride_cell = Y * gm_map_stride_y; + const size_t gm_map_stride_fs = 3 * gm_map_stride_gain; + const size_t gm_map_stride_ss = fs_dim * gm_map_stride_fs; + const size_t gm_map_stride_cell = ss_dim * gm_map_stride_ss; - const size_t map_cell = cell_table[memory_cell]; + const size_t map_cell = cell_table[frame]; if (map_cell < map_memory_cells) { const size_t gm_map_index = gain * gm_map_stride_gain + map_cell * gm_map_stride_cell + - y * gm_map_stride_y + - x * gm_map_stride_x; + ss * gm_map_stride_ss + + fs * gm_map_stride_fs; if (gain > 2 || ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index])) { // now also checking for illegal gain value @@ -70,7 +70,13 @@ extern "C" { corrected = bad_pixel_mask_value; } } + } else { + if (gain > 2 || corrected < -1e7 || corrected > 1e7) { + corrected = bad_pixel_mask_value; + gain_for_preview = bad_pixel_mask_value; + } } + gain_map[data_index] = gain_for_preview; {% if output_data_dtype == "half" %} output[data_index] = __float2half(corrected); diff --git a/src/calng/preview_utils.py b/src/calng/preview_utils.py index 55e907423df5526e56a28a00bcf7eb4f88eb5671..12cb7e183a13133824a7964b852b84a524bf9e8b 100644 --- a/src/calng/preview_utils.py +++ b/src/calng/preview_utils.py @@ -1,81 +1,148 @@ +import enum +import functools +from dataclasses import dataclass +from typing import Callable, List, Optional + from karabo.bound import ( BOOL_ELEMENT, FLOAT_ELEMENT, + INT32_ELEMENT, NODE_ELEMENT, OUTPUT_CHANNEL, STRING_ELEMENT, UINT32_ELEMENT, - ChannelMetaData, Dims, Encoding, Hash, ImageData, + Schema, Timestamp, ) import numpy as np -from . import schemas, utils +from . import schemas +from .utils import downsample_1d, downsample_2d, maybe_get + + +class FrameSelectionMode(enum.Enum): + FRAME = "frame" + CELL = "cell" + PULSE = "pulse" + + +@dataclass +class PreviewSpec: + # used for output schema + name: str + channel_name: Optional[str] = None # set after init + dimensions: int = 2 + wrap_in_imagedata: bool = True + + # reconfigurables in node schema + swap_axes: bool = False # only appars in schema if dimensions > 1 + flip_ss: bool = False # only appars in schema if dimensions > 1 + flip_fs: bool = False + nan_replacement: float = 0 + downsampling_factor: int = 1 + downsampling_function: Callable = np.nanmax + # if frame_reduction, frame_reduction_fun is made automatically from index and mode + # otherwise, leave None or provide custom + frame_reduction: bool = True + frame_reduction_selection_mode: FrameSelectionMode = FrameSelectionMode.FRAME + # not all detectors provide cell and / or pulse IDs + valid_frame_selection_modes: List[FrameSelectionMode] = tuple(FrameSelectionMode) + frame_reduction_index: int = 0 + # either set explicitly or automatically with frame selection + frame_reduction_fun: Optional[Callable] = None class PreviewFriend: @staticmethod def add_schema( - schema, node_path="preview", output_channels=None, create_node=False + schema: Schema, + outputs: List[PreviewSpec], + node_name: str = "preview" ): - if output_channels is None: - output_channels = ["output"] + """Add the prerequisite schema configurables for all the outputs.""" + ( + NODE_ELEMENT(schema) + .key(node_name) + .displayedName("Preview") + .description( + "Output specifically intended for preview in Karabo GUI. Includes " + "some options for throttling and adjustments of the output data." + ) + .commit(), + ) + + for spec in outputs: + PreviewFriend._add_subschema(schema, f"{node_name}.{spec.name}", spec) + + @staticmethod + def _add_subschema(schema, node_path, spec): + # each preview gets its own node with config and channel + ( + NODE_ELEMENT(schema) + .key(node_path) + .commit(), + ) - if create_node: + if spec.dimensions > 1: ( - NODE_ELEMENT(schema) - .key(node_path) - .displayedName("Preview") + BOOL_ELEMENT(schema) + .key(f"{node_path}.swapAxes") + .displayedName("Swap slow / fast scan") .description( - "Output specifically intended for preview in Karabo GUI. Includes " - "some options for throttling and adjustments of the output data." + "Swaps the two pixel axes around. Can be combined with " + "flipping to rotate the image" ) + .assignmentOptional() + .defaultValue(spec.swap_axes) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_path}.flipSS") + .displayedName("Flip SS") + .description("Flip image data along slow scan axis.") + .assignmentOptional() + .defaultValue(spec.flip_ss) + .reconfigurable() .commit(), ) - ( - BOOL_ELEMENT(schema) - .key(f"{node_path}.flipSS") - .displayedName("Flip SS") - .description("Flip image data along slow scan axis.") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), + # configurable bits; these correspond to spec parts + ( BOOL_ELEMENT(schema) .key(f"{node_path}.flipFS") .displayedName("Flip FS") .description("Flip image data along fast scan axis.") .assignmentOptional() - .defaultValue(False) + .defaultValue(spec.flip_fs) .reconfigurable() .commit(), UINT32_ELEMENT(schema) .key(f"{node_path}.downsamplingFactor") - .displayedName("Factor") + .displayedName("Downsampling factor") .description( "If greater than 1, the preview image will be downsampled by this " "factor before sending. This is mostly to save bandwidth in case GUI " "updates start lagging." ) .assignmentOptional() - .defaultValue(1) + .defaultValue(spec.downsampling_factor) .options("1,2,4,8") .reconfigurable() .commit(), STRING_ELEMENT(schema) .key(f"{node_path}.downsamplingFunction") - .displayedName("Function") + .displayedName("Downsampling function") .description("Reduction function used during downsampling.") .assignmentOptional() - .defaultValue("nanmax") + .defaultValue(spec.downsampling_function.__name__) .options("nanmax,nanmean,nanmin,nanmedian") .reconfigurable() .commit(), @@ -91,97 +158,255 @@ class PreviewFriend: "from the image data you want to see." ) .assignmentOptional() - .defaultValue(0) + .defaultValue(spec.nan_replacement) .reconfigurable() .commit(), ) - for channel in output_channels: + + if spec.frame_reduction: ( - OUTPUT_CHANNEL(schema) - .key(f"{node_path}.{channel}") - .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True)) - .description("See description of parent node, 'preview'.") + STRING_ELEMENT(schema) + .key(f"{node_path}.selectionMode") + .tags("managed") + .displayedName("Index selection mode") + .description( + "The value of preview.index can be used in multiple ways, " + "controlled by this value. If this is set to 'frame', " + "preview.index is sliced directly from data. If 'cell' " + "(or 'pulse') is selected, I will look at cell (or pulse) table " + "for the requested cell (or pulse ID). Special (stat) index values " + "<0 are not affected by this." + ) + .options( + ",".join( + selectionmode.value + for selectionmode in spec.valid_frame_selection_modes + ) + ) + .assignmentOptional() + .defaultValue("frame") + .reconfigurable() + .commit(), + INT32_ELEMENT(schema) + .key(f"{node_path}.index") + .tags("managed") + .displayedName("Index (or stat) for preview") + .description( + "If this value is ≥ 0, the corresponding index (frame, cell, or " + "pulse) will be sliced for the preview output. If this value is " + "<0, preview will be one of the following stats: -1: max, " + "-2: mean, -3: sum, -4: stdev. These stats are computed across " + "frames, ignoring NaN values." + ) + .assignmentOptional() + .defaultValue(spec.frame_reduction_index) + .minInc(-4) + .reconfigurable() .commit(), ) - def __init__(self, device, node_name="preview", output_channels=None): - if output_channels is None: - output_channels = ["output"] - self.output_channels = output_channels - self.device = device - self.dev_id = self.device.getInstanceId() - self.node_name = node_name - self.outputs = [ - self.device.signalSlotable.getOutputChannel(f"{self.node_name}.{channel}") - for channel in self.output_channels - ] - self.reconfigure(device._parameters) - - def write_outputs(self, timestamp, *datas, inplace=True): + ( + OUTPUT_CHANNEL(schema) + .key(f"{node_path}.output") + .dataSchema( + schemas.preview_schema( + wrap_image_in_imagedata=spec.wrap_in_imagedata + ) + ) + .description("See description of grandparent node, 'preview'.") + .commit(), + ) + + def __init__(self, device, outputs, node_name="preview"): + self._device = device + self.output_specs = outputs + for spec in self.output_specs: + spec.channel_name = f"{node_name}.{spec.name}.output" + self.reconfigure(self._device.get(node_name)) + + def write_outputs( + self, + *datas, + timestamp=None, + inplace=True, + cell_table=None, + pulse_table=None, + warn_fun=None, + ): """Applies GUI-friendly preview settings (replace NaN, downsample, wrap as ImageData) and writes to output channels. Make sure datas length matches number - of channels!""" + of channels! If inplace, masking and such may write to provided buffers (will + otherwise copy).""" if isinstance(timestamp, Hash): timestamp = Timestamp.fromHashAttributes( timestamp.getAttributes("timestamp") ) - for data, output, channel_name in zip( - datas, self.outputs, self.output_channels - ): - if self.downsampling_factor > 1: - data = utils.downsample_2d( - data, - self.downsampling_factor, - reduction_fun=self.downsampling_function, - ) + if warn_fun is None: + warn_fun = self._device.log.WARN + for data, spec in zip(datas, self.output_specs): + if spec.frame_reduction_fun is not None: + data = spec.frame_reduction_fun(data, cell_table, pulse_table, warn_fun) + if spec.downsampling_factor > 1: + if spec.dimensions == 1: + data = downsample_1d( + data, + spec.downsampling_factor, + reduction_fun=spec.downsampling_function, + ) + else: + data = downsample_2d( + data, + spec.downsampling_factor, + reduction_fun=spec.downsampling_function, + ) elif not inplace: data = data.copy() - if self.flip_ss: - data = np.flip(data, 0) - if self.flip_fs: - data = np.flip(data, 1) + if spec.flip_ss: + data = np.flip(data, -2) + if spec.flip_fs: + data = np.flip(data, -1) if isinstance(data, np.ma.MaskedArray): data, mask = data.data, data.mask | ~np.isfinite(data) else: mask = ~np.isfinite(data) - data[mask] = self.nan_replacement + data[mask] = spec.nan_replacement mask = mask.astype(np.uint8) - output_hash = Hash( - "image.data", - ImageData( - data, - Dims(*data.shape), - Encoding.GRAY, - bitsPerPixel=32, - ), - "image.mask", - ImageData( - mask, - Dims(*mask.shape), - Encoding.GRAY, - bitsPerPixel=8, - ), - ) - output.write( - output_hash, - ChannelMetaData( - f"{self.dev_id}:{self.node_name}.{channel_name}", - timestamp, - ), - copyAllData=False, + if spec.wrap_in_imagedata: + output_hash = Hash( + "image.data", + ImageData( + maybe_get(data), + Dims(*data.shape), + Encoding.GRAY, + bitsPerPixel=32, + ), + "image.mask", + ImageData( + maybe_get(mask), + Dims(*mask.shape), + Encoding.GRAY, + bitsPerPixel=8, + ), + ) + else: + output_hash = Hash( + "image.data", maybe_get(data), "image.mask", maybe_get(mask) + ) + self._device.writeChannel( + spec.channel_name, output_hash, timestamp=timestamp, safeNDArray=True ) - output.update(safeNDArray=True) def reconfigure(self, conf): - if conf.has(f"{self.node_name}.downsamplingFunction"): - self.downsampling_function = getattr( - np, conf[f"{self.node_name}.downsamplingFunction"] + for spec in self.output_specs: + if not conf.has(spec.name): + continue + subconf = conf.get(spec.name) + if spec.frame_reduction and ( + subconf.has("selectionMode") or subconf.has("index") + ): + if subconf.has("selectionMode"): + spec.frame_reduction_selection_mode = FrameSelectionMode( + subconf["selectionMode"] + ) + if subconf.has("index"): + spec.frame_reduction_index = subconf["index"] + spec.frame_reduction_fun = _make_frame_reduction_fun(spec) + if subconf.has("downsamplingFunction"): + spec.downsampling_function = getattr( + np, subconf["downsamplingFunction"] + ) + if subconf.has("downsamplingFactor"): + spec.downsampling_factor = subconf["downsamplingFactor"] + if subconf.has("flipSS"): + spec.flip_ss = subconf["flipSS"] + if subconf.has("flipFS"): + spec.flip_fs = subconf["flipFS"] + if subconf.has("swapAxes"): + spec.flip_fs = subconf["flipFS"] + if subconf.has("replaceNanWith"): + spec.nan_replacement = subconf["replaceNanWith"] + + +def _make_frame_reduction_fun(spec) -> Optional[Callable]: + """Will be called during configuration in case frame_reduction is True. Will + use frame_reduction_selection_mode and frame_reduction_index to create a + function to either slice or reduce input data.""" + # note: in case of frame picking, signature includes cell and pulse table + if spec.frame_reduction_index < 0: + # so while those tables are irrelevant here, still include in wrapper + if spec.frame_reduction_index == -1: + # note: separate from next case because dtype not applicable here + reduction_fun = functools.partial(np.nanmax, axis=0) + elif spec.frame_reduction_index == -2: + # note: separate from next case because dtype not applicable here + reduction_fun = functools.partial(np.nanmean, axis=0, dtype=np.float32) + elif spec.frame_reduction_index == -3: + # note: separate from next case because dtype not applicable here + reduction_fun = functools.partial(np.nansum, axis=0, dtype=np.float32) + elif spec.frame_reduction_index == -4: + # note: separate from next case because dtype not applicable here + reduction_fun = functools.partial(np.nanstd, axis=0, dtype=np.float32) + + def fun(data, cell_table, pulse_table, warn_fun): + try: + return reduction_fun(data) + except Exception as ex: + warn_fun(f"Frame reduction error: {ex}") + + return fun + else: + # here, we pick out a frame + if spec.frame_reduction_selection_mode is FrameSelectionMode.FRAME: + return _pick_preview_index_frame(spec.frame_reduction_index) + elif spec.frame_reduction_selection_mode is FrameSelectionMode.CELL: + return _pick_preview_index_cell(spec.frame_reduction_index) + else: + return _pick_preview_index_pulse(spec.frame_reduction_index) + + +def _pick_preview_index_frame(index): + def aux(data, cell_table, pulse_table, warn_fun): + if index >= data.shape[0]: + warn_fun( + f"Frame index {index} out of bounds for data with " + f"{data.shape[0]} frames, previewing frame 0" + ) + return data[0] + return data[index] + return aux + + +def _pick_preview_index_cell(index): + def aux(data, cell_table, pulse_table, warn_fun): + found = np.nonzero(cell_table == index)[0] + if found.size == 0: + warn_fun( + f"Cell ID {index} not found in cell mapping, " + "previewing frame 0 instead" + ) + return data[0] + if found.size > 1: + warn_fun( + f"Cell ID {index} not unique in cell mapping, " + f"previewing first occurrence (frame {found[0]}" + ) + return data[found[0]] + return aux + + +def _pick_preview_index_pulse(index): + def aux(data, cell_table, pulse_table, warn_fun): + found = np.nonzero(pulse_table == index)[0] + if found.size == 0: + warn_fun( + f"Pulse ID {index} not found in pulse table, " + "previewing frame 0 instead" + ) + return data[0] + if found.size > 1: + warn_fun( + f"Pulse ID {index} not unique in pulse table, " + f"previewing first occurrence (frame {found[0]})" ) - if conf.has(f"{self.node_name}.downsamplingFactor"): - self.downsampling_factor = conf[f"{self.node_name}.downsamplingFactor"] - if conf.has(f"{self.node_name}.flipSS"): - self.flip_ss = conf[f"{self.node_name}.flipSS"] - if conf.has(f"{self.node_name}.flipFS"): - self.flip_fs = conf[f"{self.node_name}.flipFS"] - if conf.has(f"{self.node_name}.replaceNanWith"): - self.nan_replacement = conf[f"{self.node_name}.replaceNanWith"] + return data[found[0]] + return aux diff --git a/src/calng/scenes.py b/src/calng/scenes.py index 56ae3bfd2f1098efce69a38f8f677e0f3ddd8fa7..3eb66b0fdc9ba8e5ce507023ddaa1c7caededafe 100644 --- a/src/calng/scenes.py +++ b/src/calng/scenes.py @@ -34,6 +34,8 @@ from karabo.common.scenemodel.api import ( WebLinkModel, ) +docs_url = "https://rtd.xfel.eu/docs/calng/en/latest" + @titled("Found constants", width=6 * NARROW_INC) @boxed @@ -140,7 +142,7 @@ class ManagerDeviceStatus(VerticalLayout): text="Documentation", width=7 * BASE_INC, height=BASE_INC, - target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#calibration-manager", + target=f"{docs_url}/devices/#calibration-manager", ) self.children.extend( [ @@ -446,63 +448,77 @@ class RoiBox(VerticalLayout): @titled("Preview settings") @boxed class PreviewSettings(HorizontalLayout): - def __init__(self, device_id, schema_hash, node_name="preview", extras=None): + def __init__(self, device_id, schema_hash, node_name): super().__init__() - self.children.extend( - [ - VerticalLayout( - EditableRow( - device_id, - schema_hash, - f"{node_name}.replaceNanWith", - 6, - 4, - ), - HorizontalLayout( - EditableRow( - device_id, - schema_hash, - f"{node_name}.flipSS", - 3, - 2, - ), - EditableRow( - device_id, - schema_hash, - f"{node_name}.flipFS", - 3, - 2, - ), - padding=0, - ), - ), - Vline(height=3 * BASE_INC), - VerticalLayout( - LabelModel( - text="Image downsampling", - width=8 * BASE_INC, - height=BASE_INC, - ), + tweaks = VerticalLayout( + EditableRow( + device_id, + schema_hash, + f"{node_name}.replaceNanWith", + 6, + 4, + ), + ) + if schema_hash.has(f"{node_name}.flipSS"): + tweaks.children.append( + HorizontalLayout( EditableRow( device_id, schema_hash, - f"{node_name}.downsamplingFactor", - 5, - 5, + f"{node_name}.flipSS", + 3, + 2, ), EditableRow( device_id, schema_hash, - f"{node_name}.downsamplingFunction", - 5, - 5, + f"{node_name}.flipFS", + 3, + 2, ), - ), + padding=0, + ) + ) + downsampling = VerticalLayout( + LabelModel( + text="Downsampling", + width=8 * BASE_INC, + height=BASE_INC, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.downsamplingFactor", + 5, + 5, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.downsamplingFunction", + 5, + 5, + ), + ) + self.children.extend( + [ + tweaks, + Vline(height=3 * BASE_INC), + downsampling, ] ) - if extras is not None: - self.children.append(Vline(height=3 * BASE_INC)) - self.children.extend(extras) + if schema_hash.has(f"{node_name}.index"): + self.children.extend( + [ + Vline(height=3 * BASE_INC), + VerticalLayout( + EditableRow(device_id, schema_hash, f"{node_name}.index", 8, 4), + EditableRow( + device_id, schema_hash, f"{node_name}.selectionMode", 8, 4 + ), + ), + ] + ) @titled("Histogram settings") @@ -633,26 +649,23 @@ def correction_device_overview(device_id, schema): ) return VerticalLayout( main_overview, - LabelModel( - text="Preview (corrected):", - width=20 * BASE_INC, - height=BASE_INC, - ), - PreviewDisplayArea(device_id, schema_hash, "preview.outputCorrected"), *( DeviceSceneLinkModel( - text=f"Preview: {channel}", + text=f"Preview: {preview_name}", keys=[f"{device_id}.availableScenes"], - target=f"preview:preview.{channel}", + target=f"preview:{preview_name}", target_window=SceneTargetWindow.Dialog, width=16 * BASE_INC, height=BASE_INC, ) - for channel in schema_hash.get("preview").getKeys() - if schema_hash.hasAttribute(f"preview.{channel}", "classId") - and schema_hash.getAttribute(f"preview.{channel}", "classId") - == "OutputChannel" + for preview_name in schema_hash.get("preview").getKeys() + ), + LabelModel( + text="Preview (corrected):", + width=20 * BASE_INC, + height=BASE_INC, ), + PreviewDisplayArea(device_id, schema_hash, "preview.corrected.output"), ) @@ -666,25 +679,20 @@ def lpdmini_splitter_overview(device_id, schema): @scene_generator -def correction_device_preview(device_id, schema, preview_channel): +def correction_device_preview(device_id, schema, name): schema_hash = schema_to_hash(schema) return VerticalLayout( LabelModel( - text=f"Preview: {preview_channel}", + text=f"Preview: {name}", width=10 * BASE_INC, height=BASE_INC, ), PreviewSettings( device_id, schema_hash, - extras=[ - VerticalLayout( - EditableRow(device_id, schema_hash, "preview.index", 8, 4), - EditableRow(device_id, schema_hash, "preview.selectionMode", 8, 4), - ) - ], + f"preview.{name}", ), - PreviewDisplayArea(device_id, schema_hash, preview_channel), + PreviewDisplayArea(device_id, schema_hash, f"preview.{name}.output"), ) @@ -758,7 +766,8 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan ), DisplayCommandModel( keys=[ - f"{device_id}.{prefix}.{constant}.overrideConstantFromVersion" + f"{device_id}.{prefix}.{constant}" + ".overrideConstantFromVersion" ], width=8 * BASE_INC, height=BASE_INC, @@ -829,67 +838,47 @@ def manager_device_overview( mds_hash = schema_to_hash(manager_device_schema) cds_hash = schema_to_hash(correction_device_schema) - data_throttling_children = [ - LabelModel( - text="Frame filter", - width=11 * BASE_INC, - height=BASE_INC, - ), - EditableRow( + config_column = [ + recursive_editable( manager_device_id, mds_hash, - "managedKeys.frameFilter.type", - 7, - 4, + "managedKeys.constantParameters", ), - EditableRow( - manager_device_id, - mds_hash, - "managedKeys.frameFilter.spec", - 7, - 4, + DisplayCommandModel( + keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"], + width=10 * BASE_INC, + height=BASE_INC, ), ] - if "managedKeys.daqTrainStride" in mds_hash: - # Only add DAQ train stride if present on the schema, may be - # disabled on the manager. - data_throttling_children.insert( - 0, - EditableRow( + if "managedKeys.preview" in mds_hash: + config_column.append( + recursive_editable( manager_device_id, mds_hash, - "managedKeys.daqTrainStride", - 7, - 4, + "managedKeys.preview", + max_depth=2, ), ) - return VerticalLayout( - HorizontalLayout( - ManagerDeviceStatus(manager_device_id), - VerticalLayout( - recursive_editable( - manager_device_id, - mds_hash, - "managedKeys.constantParameters", - ), - DisplayCommandModel( - keys=[f"{manager_device_id}.managedKeys.loadMostRecentConstants"], - width=10 * BASE_INC, - height=BASE_INC, - ), - recursive_editable( + if "managedKeys.daqTrainStride" in mds_hash: + config_column.append( + titled("Data throttling")(boxed(VerticalLayout))( + EditableRow( manager_device_id, mds_hash, - "managedKeys.preview", - max_depth=2, - ), - titled("Data throttling")(boxed(VerticalLayout))( - children=data_throttling_children, - padding=0, + "managedKeys.daqTrainStride", + 7, + 4, ), + padding=0, ), + ) + + return VerticalLayout( + HorizontalLayout( + ManagerDeviceStatus(manager_device_id), + VerticalLayout(*config_column), recursive_editable( manager_device_id, mds_hash, @@ -932,7 +921,7 @@ def manager_device_overview( text="Documentation", width=6 * BASE_INC, height=BASE_INC, - target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#correction-devices", + target=f"{docs_url}/devices/#correction-devices", ), padding=0, ) @@ -1043,12 +1032,12 @@ def detector_assembler_overview(device_id, schema): return VerticalLayout( HorizontalLayout( AssemblerDeviceStatus(device_id), - PreviewSettings(device_id, schema_hash), + PreviewSettings(device_id, schema_hash, "preview.assembled"), ), PreviewDisplayArea( device_id, schema_hash, - "preview.output", + "preview.assembled.output", data_width=40 * BASE_INC, data_height=40 * BASE_INC, mask_width=40 * BASE_INC, diff --git a/src/calng/schemas.py b/src/calng/schemas.py index 79524daade964745896b32a143491b8e8587a6d7..3b4820b0c7b90239bd70807a7d90cd250404d675 100644 --- a/src/calng/schemas.py +++ b/src/calng/schemas.py @@ -44,7 +44,7 @@ def preview_schema(wrap_image_in_imagedata=False): return res -def xtdf_output_schema(use_shmem_handle=True): +def xtdf_output_schema(use_shmem_handles): # TODO: trim output schema / adapt to specific detectors # currently: based on snapshot of actual output reusing AGIPD hash res = Schema() @@ -178,7 +178,7 @@ def xtdf_output_schema(use_shmem_handle=True): .commit(), ) - if use_shmem_handle: + if use_shmem_handles: ( STRING_ELEMENT(res) .key("image.data") @@ -194,7 +194,7 @@ def xtdf_output_schema(use_shmem_handle=True): return res -def jf_output_schema(use_shmem_handle=True): +def jf_output_schema(use_shmem_handles): res = Schema() ( NODE_ELEMENT(res) @@ -227,7 +227,7 @@ def jf_output_schema(use_shmem_handle=True): .readOnly() .commit(), ) - if use_shmem_handle: + if use_shmem_handles: ( STRING_ELEMENT(res) .key("data.adc") @@ -243,7 +243,7 @@ def jf_output_schema(use_shmem_handle=True): return res -def pnccd_output_schema(use_shmem_handle=True): +def pnccd_output_schema(use_shmem_handles): res = Schema() ( NODE_ELEMENT(res) @@ -255,7 +255,7 @@ def pnccd_output_schema(use_shmem_handle=True): .readOnly() .commit(), ) - if use_shmem_handle: + if use_shmem_handles: ( STRING_ELEMENT(res) .key("data.image") diff --git a/src/calng/stacking_utils.py b/src/calng/stacking_utils.py index 4d6e7008afa8a68ae44deabfef3d651cf2f7f910..8c81d1fe3c1a498cd2a7774e71fd31f04ae41818 100644 --- a/src/calng/stacking_utils.py +++ b/src/calng/stacking_utils.py @@ -166,8 +166,6 @@ class StackingFriend: self.reconfigure(source_config, merge_config) def reconfigure(self, source_config, merge_config): - print("merge_config", type(merge_config)) - print("source_config", type(source_config)) if source_config is not None: self._source_config = source_config if merge_config is not None: diff --git a/src/calng/utils.py b/src/calng/utils.py index fd7c6c284b4409e33c5ea9816eac8cfc17d6744c..c955d80b38d4b9f01cfce1c6e799fa691695512a 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -8,6 +8,21 @@ import numpy as np from calngUtils import misc +class WarningLampType(enum.Enum): + FRAME_FILTER = enum.auto() + MEMORY_CELL_RANGE = enum.auto() + CONSTANT_OPERATING_PARAMETERS = enum.auto() + PREVIEW_SETTINGS = enum.auto() + CORRECTION_RUNNER = enum.auto() + OUTPUT_BUFFER = enum.auto() + GPU_MEMORY = enum.auto() + CALCAT_CONNECTION = enum.auto() + EMPTY_HASH = enum.auto() + MISC_INPUT_DATA = enum.auto() + TRAIN_ID = enum.auto() + TIMESERVER_CONNECTION = enum.auto() + + class WarningContextSystem: """Helper object to trigger warning lamps based on different warning types in contexts. Intended use: something that is checked multiple times and is good or bad @@ -90,63 +105,6 @@ class WarningContextSystem: self.device.set("status", message) -class PreviewIndexSelectionMode(enum.Enum): - FRAME = "frame" - CELL = "cell" - PULSE = "pulse" - - -def pick_frame_index(selection_mode, index, cell_table, pulse_table): - """When selecting a single frame to preview, an obvious question is whether the - number the operator provides is a frame index, a cell ID, or a pulse ID. This - function allows any of the three, translating into frame index. - - Indices below zero are special values and thus returned directly. - - Returns: (frame index, cell ID, pulse ID), optional warning""" - - if index < 0: - return (index, index, index), None - - warning = None - selection_mode = PreviewIndexSelectionMode(selection_mode) - - if selection_mode is PreviewIndexSelectionMode.FRAME: - if index < cell_table.size: - frame_index = index - else: - warning = ( - f"Frame index {index} out of range for cell table of length " - f"{len(cell_table)}. Will use frame index 0 instead." - ) - frame_index = 0 - else: - if selection_mode is PreviewIndexSelectionMode.CELL: - index_table = cell_table - else: - index_table = pulse_table - found = np.where(index_table == index)[0] - if len(found) == 1: - frame_index = found[0] - elif len(found) == 0: - frame_index = 0 - warning = ( - f"{selection_mode.value.capitalize()} ID {index} not found in " - f"{selection_mode.name} table. Will use frame {frame_index}, which " - f"corresponds to {selection_mode.name} {index_table[0]} instead." - ) - elif len(found) > 1: - warning = ( - f"{selection_mode.value.capitalize()} ID {index} was not unique in " - f"{selection_mode} table. Will use first occurrence out of " - f"{len(found)} occurrences (frame {frame_index})." - ) - - cell_id = cell_table[frame_index] - pulse_id = pulse_table[frame_index] - return (frame_index, cell_id, pulse_id), warning - - _np_typechar_to_c_typestring = { "?": "bool", "B": "unsigned char", @@ -234,22 +192,37 @@ class BadPixelValues(enum.IntFlag): def downsample_2d(arr, factor, reduction_fun=np.nanmax): """Generalization of downsampling from FemDataAssembler - Expects first two dimensions of arr to be multiple of 2 ** factor - Useful if you're sitting at home and ssh connection is slow to get full-resolution - previews.""" + Expects last two dimensions of arr to be multiple of 2 ** factor. Useful for + looking at detector previews over a slow connection (say, ssh).""" for i in range(factor // 2): + # downsample slow scan arr = reduction_fun( ( - arr[:-1:2], - arr[1::2], + arr[..., :-1:2, :], + arr[..., 1::2, :], ), axis=0, ) + # downsample fast scan arr = reduction_fun( ( - arr[:, :-1:2], - arr[:, 1::2], + arr[..., :-1:2], + arr[..., 1::2], + ), + axis=0, + ) + return arr + + +def downsample_1d(arr, factor, reduction_fun=np.nanmax): + """Same as downsample_2d, but only 1d (applied to last axis of input)""" + + for i in range(factor // 2): + arr = reduction_fun( + ( + arr[..., :-1:2], + arr[..., 1::2], ), axis=0, ) @@ -289,3 +262,25 @@ def apply_partial_lut(data, lut, mask, out, missing=np.nan): tmp = out.ravel() tmp[~mask] = data.ravel()[lut] tmp[mask] = missing + + +def subset_of_hash(input_hash, *keys): + # don't bother importing here, just construct empty + res = input_hash.__class__() + for key in keys: + if input_hash.has(key): + res.set(key, input_hash.get(key)) + return res + + +def maybe_get(a, out=None): + """For getting CuPy ndarray data to system memory - tries to behave like + cupy.ndarray.get with respect to the 'out' parameter""" + if hasattr(a, "get"): + return a.get(out=out) + else: + if out is None: + return a + else: + np.copyto(out, a) + return out diff --git a/tests/common_setup.py b/tests/common_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..fafe2cb7078a4a20b6ef4beb7d17b41d0c05d1b9 --- /dev/null +++ b/tests/common_setup.py @@ -0,0 +1,100 @@ +import contextlib +import pathlib +import threading + +import h5py +import numpy as np +from calngUtils import device as device_utils +from calngUtils.misc import ChainHash +from karabo.bound import Hash, Schema + + +def maybe_get(b): + if hasattr(b, "get"): + return b.get() + else: + return b + + +class DummyLogger: + DEBUG = print + INFO = print + WARN = print + + +@device_utils.with_config_overlay +@device_utils.with_unsafe_get +class DummyCorrectionDevice: + log = DummyLogger() + + def log_status_info(self, msg): + self.log.INFO(msg) + + def log_status_warn(self, msg): + self.log.WARN(msg) + + def get(self, key): + return self._parameters.get(key) + + def set(self, key, value): + self._parameters[key] = value + + def reconfigure(self, config): + with self.new_config_context(config): + self.runner.reconfigure(config) + self._parameters.merge(config) + + def __init__(self, *friends): + self._parameters = Hash() + s = Schema() + self._temporary_config = [] + self._temporary_config_lock = threading.RLock() + for friend in friends: + if isinstance(friend, dict): + friend_class = friend.pop("friend_class") + extra_args = friend + else: + friend_class = friend + extra_args = {} + friend_class.add_schema(s, **extra_args) + h = s.getParameterHash() + for path in h.getPaths(): + if h.hasAttribute(path, "defaultValue"): + if "output" in path.split("."): + continue + self._parameters[path] = h.getAttribute(path, "defaultValue") + + self.warnings = {} + + def _please_send_me_cached_constants(self, constants, callback): + pass + + @contextlib.contextmanager + def warning_context(self, key, warn_type): + def aux(message): + print(message) + self.warnings[(key, warn_type)] = message + + yield aux + + @contextlib.contextmanager + def new_config_context(self, config): + with self._temporary_config_lock: + old_config = self._parameters + if isinstance(self._parameters, ChainHash): + # in case we need to support recursive case / configuration stacking + self._parameters = ChainHash(config, *old_config._hashes) + else: + self._parameters = ChainHash(config, old_config) + yield + self._parameters = old_config + + +caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store") + + +def get_constant_from_file(path_to_file, file_name, data_set_name): + """Intended to work with copy-pasting paths from CalCat. Specifially, the "Path to + File", "File Name", and "Data Set Name" fields displayed for a given CCV.""" + with h5py.File(caldb_store / path_to_file / file_name, "r") as fd: + return np.array(fd[f"{data_set_name}/data"]) diff --git a/tests/test_agipd_kernels.py b/tests/test_agipd_kernels.py index afbb2ca057dd3ec4b95aca4b2c16f235f59e9519..25021ae92bf2c221acec0f0ecf1a919a6659310a 100644 --- a/tests/test_agipd_kernels.py +++ b/tests/test_agipd_kernels.py @@ -1,19 +1,23 @@ -import h5py import numpy as np -import pathlib - +import pytest from calng.corrections.AgipdCorrection import ( + AgipdBaseRunner, + AgipdCalcatFriend, AgipdCpuRunner, AgipdGpuRunner, Constants, CorrectionFlags, ) +from karabo.bound import Hash + +from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get output_dtype = np.float32 corr_dtype = np.float32 pixels_x = 512 pixels_y = 128 memory_cells = 352 +num_frames = memory_cells raw_data = np.random.randint( low=0, high=2000, size=(memory_cells, 2, pixels_x, pixels_y), dtype=np.uint16 @@ -23,29 +27,26 @@ raw_gain = raw_data[:, 1] cell_table = np.arange(memory_cells, dtype=np.uint16) np.random.shuffle(cell_table) -caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store/xfel/cal") -caldb_prefix = caldb_store / "agipd-type/agipd_siv1_agipdv11_m305" - -with h5py.File(caldb_prefix / "cal.1619543695.4679213.h5", "r") as fd: - thresholds = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0/data"]) -with h5py.File(caldb_prefix / "cal.1619543664.1545036.h5", "r") as fd: - offset_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/Offset/0/data"]) -with h5py.File(caldb_prefix / "cal.1615377705.8904035.h5", "r") as fd: - slopes_pc_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0/data"]) - -kernel_runners = [ - runner_class( - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells=memory_cells, - output_data_dtype=output_dtype, - ) - for runner_class in (AgipdCpuRunner, AgipdGpuRunner) -] +thresholds = get_constant_from_file( + "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305", + "cal.1619543695.4679213.h5", + "/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0", +) +offset_map = get_constant_from_file( + "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305", + "cal.1619543664.1545036.h5", + "/AGIPD_SIV1_AGIPDV11_M305/Offset/0", +) +slopes_pc_map = get_constant_from_file( + "xfel/cal/agipd-type/agipd_siv1_agipdv11_m305", + "cal.1615377705.8904035.h5", + "/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0", +) +# first compute things naively -def thresholding_cpu(data, cell_table, thresholds): + +def thresholding_naive(data, cell_table, thresholds): # get to memory_cell, x, y raw_gain = data[:, 1, ...].astype(corr_dtype, copy=False) # get to threshold, memory_cell, x, y @@ -56,16 +57,13 @@ def thresholding_cpu(data, cell_table, thresholds): return res -gain_map_cpu = thresholding_cpu(raw_data, cell_table, thresholds) - - -def corr_offset_cpu(data, cell_table, gain_map, offset): +def corr_offset_naive(data, cell_table, gain_map, offset): image_data = data[:, 0].astype(corr_dtype, copy=False) offset = np.transpose(offset)[:, cell_table] return (image_data - np.choose(gain_map, offset)).astype(output_dtype) -def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc): +def corr_rel_gain_pc_naive(data, cell_table, gain_map, slopes_pc): slopes_pc = slopes_pc.astype(np.float32, copy=False) pc_high_m = slopes_pc[0] pc_high_I = slopes_pc[1] @@ -77,7 +75,7 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc): rel_gain_map[0] = 1 # rel xray gain can come after rel_gain_map[1] = rel_gain_map[0] * np.transpose(frac_high_med) rel_gain_map[2] = rel_gain_map[1] * 4.48 - res = data[:, 0].astype(corr_dtype, copy=True) + res = data.astype(corr_dtype, copy=True) res *= np.choose(gain_map, np.transpose(rel_gain_map, (0, 3, 1, 2))) pixels_in_medium_gain = gain_map == 1 res[pixels_in_medium_gain] += np.transpose(md_additional_offset, (0, 2, 1))[ @@ -86,35 +84,90 @@ def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc): return res -def test_thresholding(): - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - runner.load_constant(Constants.ThresholdsDark, thresholds) - runner.correct(np.uint8(CorrectionFlags.THRESHOLD)) - assert np.allclose(kernel_runners[0].gain_map, gain_map_cpu) - assert np.allclose(kernel_runners[1].gain_map.get(), gain_map_cpu) - - -def test_offset(): - reference = corr_offset_cpu(raw_data, cell_table, gain_map_cpu, offset_map) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - runner.load_constant(Constants.ThresholdsDark, thresholds) - runner.load_constant(Constants.Offset, offset_map) - # have to do thresholding, otherwise all is treated as high gain - runner.correct(np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET)) - res = runner.reshape(runner._corrected_axis_order) - assert np.allclose(res, reference), f"{runner.__class__}" - - -def test_rel_gain_pc(): - reference = corr_rel_gain_pc_cpu(raw_data, cell_table, gain_map_cpu, slopes_pc_map) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - runner.load_constant(Constants.ThresholdsDark, thresholds) - runner.load_constant(Constants.SlopesPC, slopes_pc_map) - runner.correct( - np.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.REL_GAIN_PC) +gain_map_naive = thresholding_naive(raw_data, cell_table, thresholds) +offset_corrected_naive = corr_offset_naive( + raw_data, cell_table, gain_map_naive, offset_map +) +gain_corrected_naive = corr_rel_gain_pc_naive( + offset_corrected_naive, cell_table, gain_map_naive, slopes_pc_map +) + + +@pytest.fixture(params=[AgipdCpuRunner, AgipdGpuRunner]) +def kernel_runner(request): + device = DummyCorrectionDevice(AgipdBaseRunner, AgipdCalcatFriend) + runner = request.param(device) + # TODO: also test ASIC seam masking + runner.reconfigure( + Hash( + "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE", + False, ) - res = runner.reshape(runner._corrected_axis_order) - assert np.allclose(res, reference) + ) + return runner + + +# then start testing individually + + +def test_thresholding(kernel_runner): + kernel_runner.load_constant(Constants.ThresholdsDark, thresholds) + # AgipdBaseRunner's buffers don't depend on flags + output, output_gain = kernel_runner._make_output_buffers(num_frames, None) + kernel_runner._correct( + kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD), + kernel_runner._xp.asarray(raw_data), + kernel_runner._xp.asarray(cell_table), + output, + output_gain, + ) + assert np.allclose(maybe_get(output_gain), gain_map_naive.astype(np.float32)) + + +def test_offset(kernel_runner): + # have to also do thresholding, otherwise all is treated as high gain + kernel_runner.load_constant(Constants.ThresholdsDark, thresholds) + kernel_runner.load_constant(Constants.Offset, offset_map) + output, output_gain = kernel_runner._make_output_buffers(num_frames, None) + kernel_runner._correct( + kernel_runner._xp.uint8(CorrectionFlags.THRESHOLD | CorrectionFlags.OFFSET), + kernel_runner._xp.asarray(raw_data), + kernel_runner._xp.asarray(cell_table), + output, + output_gain, + ) + assert np.allclose(maybe_get(output), offset_corrected_naive) + + +def test_rel_gain_pc(kernel_runner): + # similarly, do previous steps first + kernel_runner.load_constant(Constants.ThresholdsDark, thresholds) + kernel_runner.load_constant(Constants.Offset, offset_map) + kernel_runner.load_constant(Constants.SlopesPC, slopes_pc_map) + output, output_gain = kernel_runner._make_output_buffers(num_frames, None) + kernel_runner._correct( + kernel_runner._xp.uint8( + CorrectionFlags.THRESHOLD + | CorrectionFlags.OFFSET + | CorrectionFlags.REL_GAIN_PC + ), + kernel_runner._xp.asarray(raw_data), + kernel_runner._xp.asarray(cell_table), + output, + output_gain, + ) + assert np.allclose(maybe_get(output), gain_corrected_naive) + + +def test_with_preview(kernel_runner): + kernel_runner.load_constant(Constants.ThresholdsDark, thresholds) + kernel_runner.load_constant(Constants.Offset, offset_map) + ( + _, + processed_data, + (preview_raw, preview_corrected, preview_raw_gain, preview_gain_stage), + ) = kernel_runner.correct(raw_data, cell_table) + assert np.allclose(maybe_get(processed_data), offset_corrected_naive) + assert np.allclose(raw_data[:, 0], maybe_get(preview_raw)) + assert np.allclose(offset_corrected_naive, maybe_get(preview_corrected)) + assert np.allclose(raw_data[:, 1], maybe_get(preview_raw_gain)) diff --git a/tests/test_dssc_kernels.py b/tests/test_dssc_kernels.py index 6ab37a8c0cadf40175b8a6acfede720d9651477a..f7df8ef9470460a28a57c2f85b60c3db1df3f3c5 100644 --- a/tests/test_dssc_kernels.py +++ b/tests/test_dssc_kernels.py @@ -1,29 +1,34 @@ import numpy as np import pytest - from calng.corrections.DsscCorrection import ( - CorrectionFlags, + Constants, + DsscBaseRunner, + DsscCalcatFriend, DsscCpuRunner, DsscGpuRunner, ) +from karabo.bound import Hash + +from common_setup import DummyCorrectionDevice, maybe_get output_dtype = np.float32 corr_dtype = np.float32 pixels_x = 512 pixels_y = 128 memory_cells = 400 +num_frames = memory_cells offset_map = ( np.random.random(size=(pixels_x, pixels_y, memory_cells)).astype(corr_dtype) * 20 ) cell_table = np.arange(memory_cells, dtype=np.uint16) np.random.shuffle(cell_table) raw_data = np.random.randint( - low=0, high=2000, size=(memory_cells, pixels_y, pixels_x), dtype=np.uint16 + low=0, high=2000, size=(memory_cells, 1, pixels_y, pixels_x), dtype=np.uint16 ) # TODO: gather CPU implementations elsewhere -def correct_cpu(data, cell_table, offset_map): +def correct_naive(data, cell_table, offset_map): corr = np.squeeze(data).astype(corr_dtype, copy=True) safe_cell_bool = cell_table < offset_map.shape[-1] safe_cell_index = cell_table[safe_cell_bool] @@ -31,127 +36,76 @@ def correct_cpu(data, cell_table, offset_map): return corr.astype(output_dtype, copy=False) -corrected_data = correct_cpu(raw_data, cell_table, offset_map) +corrected_data = correct_naive(raw_data, cell_table, offset_map) only_cast_data = np.squeeze(raw_data).astype(output_dtype) -kernel_runners = [ - runner_class( - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells=memory_cells, - output_data_dtype=output_dtype, - ) - for runner_class in (DsscCpuRunner, DsscGpuRunner) -] - - -def test_only_cast(): - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - runner.correct(CorrectionFlags.NONE) - assert np.allclose(runner.reshape(runner._corrected_axis_order), only_cast_data) - - -def test_correct(): - for runner in kernel_runners: - runner.load_offset_map(offset_map) - runner.load_data(raw_data, cell_table) - runner.correct(CorrectionFlags.OFFSET) - assert np.allclose(runner.reshape(runner._corrected_axis_order), corrected_data) - - -def test_correct_oob_cells(): - wild_cell_table = cell_table * 2 - reference = correct_cpu(raw_data, wild_cell_table, offset_map) - for runner in kernel_runners: - runner.load_offset_map(offset_map) - # here, half the cell IDs will be out of bounds - runner.load_data(raw_data, wild_cell_table) - # should not crash - runner.correct(CorrectionFlags.OFFSET) - # should correct as much as possible - assert np.allclose(runner.reshape(runner._corrected_axis_order), reference) - - -def test_reshape(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - assert np.allclose( - runner.reshape(output_order="xyf"), corrected_data.transpose() - ) - - -def test_preview_slice(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - preview_raw, preview_corrected = runner.compute_previews(42) - assert np.allclose( - preview_raw, - raw_data[42].astype(np.float32), - ) - assert np.allclose( - preview_corrected, - corrected_data[42].astype(np.float32), - ) +@pytest.fixture(params=[DsscCpuRunner, DsscGpuRunner]) +def kernel_runner(request): + device = DummyCorrectionDevice(DsscBaseRunner, DsscCalcatFriend) + runner = request.param(device) + device.runner = runner + return runner -def test_preview_max(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - # note: in case correction failed, still test this separately - runner.load_data(raw_data, cell_table) - preview_raw, preview_corrected = runner.compute_previews(-1) - assert np.allclose(preview_raw, np.max(raw_data, axis=0).astype(np.float32)) - assert np.allclose( - preview_corrected, np.max(corrected_data, axis=0).astype(np.float32) - ) - - -def test_preview_mean(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - preview_raw, preview_corrected = runner.compute_previews(-2) - assert np.allclose(preview_raw, np.nanmean(raw_data, axis=0, dtype=np.float32)) - assert np.allclose( - preview_corrected, np.nanmean(corrected_data, axis=0, dtype=np.float32) - ) - - -def test_preview_sum(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - preview_raw, preview_corrected = runner.compute_previews(-3) - assert np.allclose(preview_raw, np.nansum(raw_data, axis=0, dtype=np.float32)) - assert np.allclose( - preview_corrected, np.nansum(corrected_data, axis=0, dtype=np.float32) +def test_only_cast(kernel_runner): + assert raw_data.shape == kernel_runner.expected_input_shape(num_frames) + # note: previews are just buffers, not yet reduced + _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct( + raw_data, cell_table + ) + assert np.allclose(maybe_get(processed_data), only_cast_data) + assert np.allclose(maybe_get(preview_raw), only_cast_data) + assert np.allclose(maybe_get(preview_corrected), only_cast_data) + + +def test_correct(kernel_runner): + for preview in ("raw", "corrected"): + kernel_runner._device.reconfigure( + Hash( + f"preview.{preview}.index", + 0, + f"preview.{preview}.selectionMode", + "frame", + ) ) + kernel_runner.load_constant(Constants.Offset, offset_map) + _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct( + raw_data, cell_table + ) + assert np.allclose(maybe_get(processed_data), corrected_data) + assert np.allclose(maybe_get(preview_raw), raw_data[:, 0]) + assert np.allclose(maybe_get(preview_corrected), corrected_data) - -def test_preview_std(): - kernel_runners[0].processed_data = corrected_data - kernel_runners[1].processed_data.set(corrected_data) - for runner in kernel_runners: - runner.load_data(raw_data, cell_table) - preview_raw, preview_corrected = runner.compute_previews(-4) - assert np.allclose(preview_raw, np.nanstd(raw_data, axis=0, dtype=np.float32)) - assert np.allclose( - preview_corrected, np.nanstd(corrected_data, axis=0, dtype=np.float32) + kernel_runner._device.reconfigure(Hash("corrections.offset.enable", False)) + _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct( + raw_data, cell_table + ) + assert np.allclose(maybe_get(processed_data), np.squeeze(raw_data)) + assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data)) + assert np.allclose(maybe_get(preview_corrected), corrected_data) + + kernel_runner._device.reconfigure( + Hash( + "corrections.offset.enable", + True, + "corrections.offset.preview", + False, ) + ) + _, processed_data, (preview_raw, preview_corrected) = kernel_runner.correct( + raw_data, cell_table + ) + assert np.allclose(maybe_get(processed_data), corrected_data) + assert np.allclose(maybe_get(preview_raw), np.squeeze(raw_data)) + assert np.allclose(maybe_get(preview_corrected), np.squeeze(raw_data)) -def test_preview_valid_index(): - for runner in kernel_runners: - with pytest.raises(ValueError): - runner.compute_previews(-5) - with pytest.raises(ValueError): - runner.compute_previews(memory_cells) +def test_correct_oob_cells(kernel_runner): + wild_cell_table = cell_table * 2 + reference = correct_naive(raw_data, wild_cell_table, offset_map) + kernel_runner.load_constant(Constants.Offset, offset_map) + _, processed_data, previews = kernel_runner.correct( + raw_data, wild_cell_table + ) + assert np.allclose(maybe_get(processed_data), reference) diff --git a/tests/test_jungfrau_kernels.py b/tests/test_jungfrau_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..94f18794ce17fa0c4028e681256a09af49f5ec40 --- /dev/null +++ b/tests/test_jungfrau_kernels.py @@ -0,0 +1,108 @@ +import numpy as np +import pytest +from calng.corrections.JungfrauCorrection import ( + Constants, + JungfrauBaseRunner, + JungfrauCalcatFriend, + JungfrauCpuRunner, + JungfrauGpuRunner, +) +from calng.utils import BadPixelValues +from karabo.bound import Hash + +from common_setup import DummyCorrectionDevice, get_constant_from_file, maybe_get + + +raw_data = np.rint(np.random.random((16, 512, 1024)) * 5000).astype(np.uint16) +gain = np.random.choice( + np.uint8([0, 1, 3]), size=raw_data.shape +) # not testing gain value 2 (wrong) for now +gain_as_index = gain.copy() +gain_as_index[gain_as_index == 3] = 2 +constants = { + Constants.Offset10Hz: get_constant_from_file( + "xfel/cal/jungfrau-type/jungfrau_m572/", + "cal.1714364366.1478188.h5", + "/Jungfrau_M572/Offset10Hz/0", + ), + Constants.BadPixelsDark10Hz: get_constant_from_file( + "xfel/cal/jungfrau-type/jungfrau_m572/", + "cal.1714364369.7054155.h5", + "/Jungfrau_M572/BadPixelsDark10Hz/0", + ), + Constants.BadPixelsFF10Hz: get_constant_from_file( + "xfel/cal/jungfrau-type/jungfrau_m572/", + "cal.1711037632.1430287.h5", + "/Jungfrau_M572/BadPixelsFF10Hz/0", + ), + Constants.RelativeGain10Hz: get_constant_from_file( + "xfel/cal/jungfrau-type/jungfrau_m572/", + "cal.1711036758.5906427.h5", + "/Jungfrau_M572/RelativeGain10Hz/0", + ), +} +cell_table = np.arange(16, dtype=np.uint8) +np.random.shuffle(cell_table) +cast_data = raw_data.astype(np.float32) +offset_corrected = cast_data - np.choose( + gain_as_index, + np.transpose(constants[Constants.Offset10Hz][:, :, cell_table], (3, 2, 0, 1)), +) +gain_corrected = offset_corrected / np.choose( + gain_as_index, constants[Constants.RelativeGain10Hz][:, :, cell_table].T +) +masked_dark = gain_corrected.copy() +masked_dark[ + np.choose( + gain_as_index, + np.transpose(constants[Constants.BadPixelsDark10Hz], (3, 2, 0, 1))[ + :, cell_table + ], + ) + != 0 +] = np.nan +masked_ff = masked_dark.copy() +masked_ff[ + np.choose( + gain_as_index, + np.transpose(constants[Constants.BadPixelsFF10Hz], (3, 2, 1, 0))[:, cell_table], + ) + != 0 +] = np.nan + + +@pytest.fixture(params=[JungfrauCpuRunner, JungfrauGpuRunner]) +def kernel_runner(request): + device = DummyCorrectionDevice(JungfrauBaseRunner, JungfrauCalcatFriend) + runner = request.param(device) + device.runner = runner + # TODO: test masking aroudn ASIC seams + # TODO: test single-frame mode + device.reconfigure( + Hash( + "constantParameters.memoryCells", + 16, + ) + ) + device.reconfigure( + Hash( + "corrections.badPixels.subsetToUse.NON_STANDARD_SIZE", + False, + ) + ) + return runner + + +def test_cast(kernel_runner): + _, result, previews = kernel_runner.correct(raw_data, cell_table, gain) + result = maybe_get(result) + assert np.allclose(result, cast_data, equal_nan=True) + + +def test_correct(kernel_runner): + for constant, constant_data in constants.items(): + kernel_runner.load_constant(constant, constant_data) + + _, result, previews = kernel_runner.correct(raw_data, cell_table, gain) + result = maybe_get(result) + assert np.allclose(result, masked_ff, equal_nan=True) diff --git a/tests/test_lpd_kernels.py b/tests/test_lpd_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..cf51fb209a62d32bd7929c59af30464f7e2a8d2c --- /dev/null +++ b/tests/test_lpd_kernels.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from calng.corrections.LpdCorrection import ( + Constants, + LpdBaseRunner, + LpdCalcatFriend, + LpdCpuRunner, + LpdGpuRunner, +) + +from common_setup import DummyCorrectionDevice, maybe_get + + +@pytest.fixture(params=[LpdCpuRunner, LpdGpuRunner]) +def kernel_runner(request): + device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend) + return request.param(device) + + +@pytest.fixture +def cpu_runner(): + device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend) + return LpdCpuRunner(device) + + +@pytest.fixture +def gpu_runner(): + device = DummyCorrectionDevice(LpdBaseRunner, LpdCalcatFriend) + return LpdGpuRunner(device) + + +# generate some "raw data" +raw_data_shape = (100, 1, 256, 256) +# the image data part +raw_image = (np.arange(np.product(raw_data_shape)).reshape(raw_data_shape)).astype( + np.uint16 +) +# keeping in mind 12-bit ADC +raw_image = np.bitwise_and(raw_image, 0x0FFF) +raw_image_32 = raw_image.astype(np.float32) +# the gain values (TODO: test with some being illegal) +gain_data = (np.random.random(raw_data_shape) * 3).astype(np.uint16) +assert np.array_equal(gain_data, np.bitwise_and(gain_data, 0x0003)) +gain_data_32 = gain_data.astype(np.float32) +raw_data = np.bitwise_or(raw_image, np.left_shift(gain_data, 12)) +# just checking that we constructed something reasonable +assert np.array_equal(gain_data, np.bitwise_and(np.right_shift(raw_data, 12), 0x0003)) +assert np.array_equal(raw_image, np.bitwise_and(raw_data, 0x0FFF)) + +# generate some defects +wrong_gain_value_indices = (np.random.random(raw_data_shape) < 0.005).astype(np.bool_) +raw_data_with_some_wrong_gain = raw_data.copy() +raw_data_with_some_wrong_gain[wrong_gain_value_indices] = np.bitwise_or( + raw_data_with_some_wrong_gain[wrong_gain_value_indices], 0x3000 +) + +# generate some constants +dark_constant_shape = (256, 256, 512, 3) +# okay, as slow and fast scan have same size, can't tell the difference here... +funky_constant_shape = (256, 256, 512, 3) +offset_constant = np.random.random(dark_constant_shape).astype(np.float32) +gain_amp_constant = np.random.random(funky_constant_shape).astype(np.float32) +bp_dark_constant = (np.random.random(dark_constant_shape) < 0.01).astype(np.uint32) + +cell_table = np.linspace(0, 512, 100).astype(np.uint16) +np.random.shuffle(cell_table) + + +def test_only_cast(kernel_runner): + _, processed, previews = kernel_runner.correct(raw_data, cell_table) + assert np.allclose(maybe_get(processed), raw_image_32[:, 0]) + # TODO: make raw preview show image data (remove gain bits) + assert np.allclose(previews[0], raw_data[:, 0]) # raw: raw + assert np.allclose(previews[1], raw_image_32[:, 0]) # processed: raw + assert np.allclose(previews[2], gain_data_32[:, 0]) # gain: unchanged + + +def test_only_cast_with_some_wrong_gain(kernel_runner): + ( + _, + processed, + _, + ) = kernel_runner.correct(raw_data_with_some_wrong_gain, cell_table) + assert np.all(np.isnan(processed[wrong_gain_value_indices[:, 0]])) + + +def test_correct(cpu_runner, gpu_runner): + # naive numpy version way too slow, just compare the two runners + # the CPU one uses kernel almost identical to pycalibration + cpu_runner.load_constant(Constants.Offset, offset_constant) + cpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant) + cpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant) + gpu_runner.load_constant(Constants.Offset, offset_constant) + gpu_runner.load_constant(Constants.GainAmpMap, gain_amp_constant) + gpu_runner.load_constant(Constants.BadPixelsDark, bp_dark_constant) + _, processed_cpu, _ = cpu_runner.correct(raw_data_with_some_wrong_gain, cell_table) + _, processed_gpu, _ = gpu_runner.correct(raw_data_with_some_wrong_gain, cell_table) + assert np.allclose(processed_cpu, processed_gpu.get(), equal_nan=True) diff --git a/tests/test_pnccd_kernels.py b/tests/test_pnccd_kernels.py index 18f438fa7de2ce7cd77bb45245726689d2508563..35da5f926acec7e3a3bf80298025729f2b4922f2 100644 --- a/tests/test_pnccd_kernels.py +++ b/tests/test_pnccd_kernels.py @@ -1,50 +1,78 @@ +import itertools + import numpy as np +import pytest from calng import utils -from calng.corrections.PnccdCorrection import Constants, CorrectionFlags, PnccdCpuRunner +from calng.corrections.PnccdCorrection import ( + Constants, + PnccdCalcatFriend, + PnccdCpuRunner, +) +from karabo.bound import Hash +from common_setup import DummyCorrectionDevice + + +@pytest.fixture +def kernel_runner(): + device = DummyCorrectionDevice(PnccdCpuRunner, PnccdCalcatFriend) + runner = PnccdCpuRunner(device) + device.runner = runner + return runner -def test_common_mode(): - def slow_common_mode( - data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs - ): - masked = np.ma.masked_array( - data=data, - mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma), - ) - if cm_fs: - subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac - data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True) +def slow_common_mode( + data, bad_pixel_map, noise_map, cm_min_frac, cm_noise_sigma, cm_ss, cm_fs +): + masked = np.ma.masked_array( + data=data, + mask=(bad_pixel_map != 0) | (data > noise_map * cm_noise_sigma), + ) - if cm_ss: - subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac - data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0) + if cm_fs: + subset_fs = masked.count(axis=1) >= masked.shape[1] * cm_min_frac + data[subset_fs] -= np.ma.median(masked[subset_fs], axis=1, keepdims=True) + if cm_ss: + subset_ss = masked.count(axis=0) >= masked.shape[0] * cm_min_frac + data[:, subset_ss] -= np.ma.median(masked[:, subset_ss], axis=0) + + +def test_common_mode(kernel_runner): data_shape = (1024, 1024) data = (np.random.random(data_shape) * 1000).astype(np.uint16) - runner = PnccdCpuRunner(1024, 1024, 1, 1) - for noise_level in (0, 5, 100, 1000): - noise = (np.random.random(data_shape) * noise_level).astype(np.float32) - runner.load_constant(Constants.NoiseCCD, noise) - for bad_pixel_percentage in (0, 1, 20, 80, 100): - bad_pixels = ( - np.random.random(data_shape) < (bad_pixel_percentage / 100) - ).astype(np.uint32) - runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixels) - runner.load_data(data.copy()) - runner.correct( - CorrectionFlags.COMMONMODE, - cm_min_frac=0.25, - cm_noise_sigma=10, - cm_row=True, - cm_col=True, - ) - - slow_version = data.astype(np.float32, copy=True) - for q_data, q_noise, q_bad_pixels in zip( - utils.quadrant_views(slow_version), - utils.quadrant_views(noise), - utils.quadrant_views(bad_pixels), - ): - slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True) - assert np.allclose(runner.processed_data, slow_version, equal_nan=True) + noise_maps = [ + (np.random.random(data_shape) * noise_level).astype(np.float32) + for noise_level in (0, 5, 100, 1000) + ] + bad_pixel_maps = [ + (np.random.random(data_shape) < (bad_pixel_percentage / 100)).astype(np.uint32) + for bad_pixel_percentage in (0, 1, 20, 80, 100) + ] + kernel_runner._device.reconfigure( + Hash( + "corrections.commonMode.noiseSigma", + 10, + "corrections.commonMode.minFrac", + 0.25, + "corrections.commonMode.enableRow", + True, + "corrections.commonMode.enableCol", + True, + "corrections.badPixels.enable", + False, + ) + ) + for noise_map, bad_pixel_map in itertools.product(noise_maps, bad_pixel_maps): + kernel_runner.load_constant(Constants.NoiseCCD, noise_map) + kernel_runner.load_constant(Constants.BadPixelsDarkCCD, bad_pixel_map) + _, processed_data, _ = kernel_runner.correct(data, None) + + slow_version = data.astype(np.float32, copy=True) + for q_data, q_noise, q_bad_pixels in zip( + utils.quadrant_views(slow_version), + utils.quadrant_views(noise_map), + utils.quadrant_views(bad_pixel_map), + ): + slow_common_mode(q_data, q_bad_pixels, q_noise, 0.25, 10, True, True) + assert np.allclose(processed_data, slow_version, equal_nan=True) diff --git a/tests/test_preview_utils.py b/tests/test_preview_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..043d7ac6704b4384fb18a2152201984ae64d0969 --- /dev/null +++ b/tests/test_preview_utils.py @@ -0,0 +1,160 @@ +import numpy as np +import pytest +from karabo.bound import Hash +from calng.preview_utils import PreviewFriend, PreviewSpec + +from common_setup import DummyCorrectionDevice + + +cell_table = np.array([0, 2, 1], dtype=np.uint16) +pulse_table = np.array([1, 2, 0], dtype=np.uint16) +raw_data = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, np.nan, 5.0], + [6.0, 7.0, 8.0], + ], + dtype=np.float32, +) +raw_without_nan = raw_data.copy() +raw_without_nan[~np.isfinite(raw_data)] = 0 + + +@pytest.fixture +def device(): + output_specs = [ + PreviewSpec( + "dummyPreview", + dimensions=1, + wrap_in_imagedata=False, + frame_reduction=True, + ) + ] + device = DummyCorrectionDevice( + {"friend_class": PreviewFriend, "outputs": output_specs} + ) + + def writeChannel(name, hsh, **kwargs): + if not hasattr(device, "_written"): + device._written = [] + device._written.append(("name", hsh)) + + device.writeChannel = writeChannel + device.preview_friend = PreviewFriend(device, output_specs) + + return device + + +def test_preview_frame(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + 1, + "dummyPreview.selectionMode", + "frame", + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + assert hasattr(device, "_written") + assert len(device._written) == 1 + preview = device._written[0][1] + assert np.allclose(preview["image.data"], raw_without_nan[1]) + + +def test_preview_cell(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + 1, + "dummyPreview.selectionMode", + "cell", + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose(preview["image.data"], raw_without_nan[2]) + + +def test_preview_pulse(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + 2, + "dummyPreview.selectionMode", + "pulse", + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose(preview["image.data"], raw_without_nan[1]) + + +def test_preview_max(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + -1, + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose(preview["image.data"], np.nanmax(raw_data, axis=0)) + + +def test_preview_mean(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + -2, + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose( + preview["image.data"], np.nanmean(raw_data, axis=0, dtype=np.float32) + ) + + +def test_preview_sum(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + -3, + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose( + preview["image.data"], np.nansum(raw_data, axis=0, dtype=np.float32) + ) + + +def test_preview_std(device): + device.preview_friend.reconfigure( + Hash( + "dummyPreview.index", + -4, + ) + ) + device.preview_friend.write_outputs( + raw_data, cell_table=cell_table, pulse_table=pulse_table + ) + preview = device._written[0][1] + assert np.allclose( + preview["image.data"], np.nanstd(raw_data, axis=0, dtype=np.float32) + ) + + +# TODO: also test alternative preview index selection modes