diff --git a/DEPENDS b/DEPENDS index 225ed6a034814485fb2f298679177547fe7c8a2e..a08c346474e4810330b3d7f5fd70eade68630827 100644 --- a/DEPENDS +++ b/DEPENDS @@ -1,3 +1,3 @@ -TrainMatcher, 2.2.0-2.14.1 +TrainMatcher, 2.3.0-2.14.2 calngDeps, 0.2.0-2.14.0 -calibrationClient, 9.1.1 +calibrationClient, 11.0.0 diff --git a/setup.py b/setup.py index 5555d8dc6832f0807be98fa16b4ba8cb524ad9ab..04c4f70ab985ee570d35aa939434aa4fe9b0397d 100644 --- a/setup.py +++ b/setup.py @@ -25,11 +25,11 @@ setup(name='calng', packages=find_packages('src'), entry_points={ 'karabo.bound_device': [ - 'AgipdCorrection = calng.AgipdCorrection:AgipdCorrection', - 'DsscCorrection = calng.DsscCorrection:DsscCorrection', - 'Gotthard2Correction = calng.Gotthard2Correction:Gotthard2Correction', - 'JungfrauCorrection = calng.JungfrauCorrection:JungfrauCorrection', - 'LpdCorrection = calng.LpdCorrection:LpdCorrection', + 'AgipdCorrection = calng.corrections.AgipdCorrection:AgipdCorrection', + 'DsscCorrection = calng.corrections.DsscCorrection:DsscCorrection', + 'Gotthard2Correction = calng.corrections.Gotthard2Correction:Gotthard2Correction', + 'JungfrauCorrection = calng.corrections.JungfrauCorrection:JungfrauCorrection', + 'LpdCorrection = calng.corrections.LpdCorrection:LpdCorrection', 'ShmemToZMQ = calng.ShmemToZMQ:ShmemToZMQ', 'ShmemTrainMatcher = calng.ShmemTrainMatcher:ShmemTrainMatcher', 'DetectorAssembler = calng.DetectorAssembler:DetectorAssembler', @@ -37,10 +37,13 @@ setup(name='calng', 'karabo.middlelayer_device': [ 'CalibrationManager = calng.CalibrationManager:CalibrationManager', + 'AgipdCondition = calng.conditions:AgipdCondition.AgipdCondition', + 'JungfrauCondition = calng.conditions.JungfrauCondition:JungfrauCondition', 'Agipd1MGeometry = calng.geometries.Agipd1MGeometry:Agipd1MGeometry', 'Dssc1MGeometry = calng.geometries:Dssc1MGeometry.Dssc1MGeometry', 'Lpd1MGeometry = calng.geometries:Lpd1MGeometry.Lpd1MGeometry', 'JungfrauGeometry = calng.geometries:JungfrauGeometry.JungfrauGeometry', + 'RoiTool = calng.RoiTool:RoiTool', ], }, package_data={'': ['kernels/*']}, @@ -50,12 +53,12 @@ setup(name='calng', Extension( 'calng.kernels.gotthard2_cython', ['src/calng/kernels/gotthard2_cpu.pyx'], - extra_compile_args = ['-O3', '-march=native'], + extra_compile_args=['-O3', '-march=native'], ), Extension( 'calng.kernels.jungfrau_cython', ['src/calng/kernels/jungfrau_cpu.pyx'], - extra_compile_args = ['-O3', '-march=native', '-fopenmp' ], + extra_compile_args=['-O3', '-march=native', '-fopenmp'], extra_link_args=['-fopenmp'], ), ], diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py index c90e64ab1912b2f88d54e7a5bf53535035177c7f..8af755ee38e21967a0c16614938b7eba134e1706 100644 --- a/src/calng/CalibrationManager.py +++ b/src/calng/CalibrationManager.py @@ -150,7 +150,7 @@ class ModuleGroupRow(Configurable): bridgePattern = String( displayedName='Bridge pattern', - options=['PUSH', 'REP', 'PUBLISH'], + options=['PUSH', 'REP', 'PUB'], defaultValue='PUSH') @@ -1300,7 +1300,6 @@ class CalibrationManager(DeviceClientBase, Device): layer=layer) config = Hash() - # TODO: put _image_data_path in corr dev schema, get from there config['sources'] = [ Hash('select', True, 'source', @@ -1310,6 +1309,7 @@ class CalibrationManager(DeviceClientBase, Device): in correct_device_id_by_module.items()] config['geometryDevice'] = self.geometryDevice.value config['maxIdle'] = self.maxIdle.value + # TODO: enable live reconfiguration of maxIdle via manager awaitables.append(self._instantiate_device( server, class_ids['assembler'], assembler_device_id, config)) @@ -1323,12 +1323,12 @@ class CalibrationManager(DeviceClientBase, Device): async def _apply_managed_values(self): """Apply all managed keys to local values.""" - for daq_key, local_key in ManagedKeysNode.DAQ_KEYS.items(): + for daq_key, local_key, _ in self._get_managed_daq_keys(): await self._set_on_daq( - daq_key, get_property(self, f'managed.{local_key}')) + daq_key, get_property(self, f'managedKeys.{local_key}')) for key in self._managed_keys: - value = get_property(self, f'managed.{key}') + value = get_property(self, f'managedKeys.{key}') if not ismethod(value): await self._set_on_corrections(key, value) diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py index 0259846f4f7a09e16457bbbf26233b7cccbdad09..eee5cfeb090732ec1fd0d19244481aa5983315cd 100644 --- a/src/calng/DetectorAssembler.py +++ b/src/calng/DetectorAssembler.py @@ -5,52 +5,22 @@ import re import numpy as np from karabo.bound import ( DOUBLE_ELEMENT, - FLOAT_ELEMENT, - IMAGEDATA_ELEMENT, - NDARRAY_ELEMENT, - NODE_ELEMENT, KARABO_CLASSINFO, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, - UINT32_ELEMENT, - UINT64_ELEMENT, ChannelMetaData, - Dims, - Encoding, - Epochstamp, Hash, - ImageData, MetricPrefix, - Schema, Timestamp, - Trainstamp, Unit, ) from TrainMatcher import TrainMatcher -from . import geom_utils, scenes, utils +from . import geom_utils, scenes, schemas, preview_utils from ._version import version as deviceVersion -assembled_schema = Schema() -( - NODE_ELEMENT(assembled_schema).key("image").commit(), - - NDARRAY_ELEMENT(assembled_schema).key("image.data").commit(), - - UINT64_ELEMENT(assembled_schema).key("trainId").readOnly().commit(), -) - -preview_schema = Schema() -( - NODE_ELEMENT(preview_schema).key("image").commit(), - - IMAGEDATA_ELEMENT(preview_schema).key("image.data").commit(), - - UINT64_ELEMENT(preview_schema).key("trainId").readOnly().commit(), -) - xtdf_source_re = re.compile(r".*\/DET\/(\d+)CH0:xtdf") daq_source_re = re.compile(r".*\/DET\/.*?(\d+):daqOutput") @@ -61,7 +31,6 @@ class BridgeOutputOptions(enum.Enum): PREVIEW = "preview" -# TODO: merge scene with TrainMatcher's nice overview @KARABO_CLASSINFO("DetectorAssembler", deviceVersion) class DetectorAssembler(TrainMatcher.TrainMatcher): @staticmethod @@ -92,72 +61,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): OUTPUT_CHANNEL(expected) .key("assembledOutput") - .dataSchema(assembled_schema) - .commit(), - - NODE_ELEMENT(expected) - .key("preview") - .description( - "The preview output is intended for Karabo GUI previews. It differs " - "from the main output in that it is rate throttled, can be " - "downsampled, and is given the ImageData type for use within Karabo." - ) - .commit(), - - OUTPUT_CHANNEL(expected) - .key("preview.output") - .dataSchema(preview_schema) - .description("See description of parent node, 'preview'.") - .commit(), - - UINT32_ELEMENT(expected) - .key("preview.downsamplingFactor") - .description( - "If greater than 1, the assembled image will be downsampled by this " - "factor in x and y dimensions before sending. This is only to save " - "bandwidth in case GUI updates start lagging." - ) - .assignmentOptional() - .defaultValue(1) - .options("1,2,4,8") - .reconfigurable() - .commit(), - - STRING_ELEMENT(expected) - .key("preview.downsamplingFunction") - .description("Reduction function used during downsampling.") - .assignmentOptional() - .defaultValue("nanmax") - .options("nanmax,nanmean,nanmin,nanmedian") - .reconfigurable() - .commit(), - - FLOAT_ELEMENT(expected) - .key("preview.replaceNanWith") - .description( - "Displaying images in KaraboGUI seems to not go well when there are " - "NaN values in data. And there will be with bad pixel masking or just " - "geometry space between modules. NaN values get replaced with this " - "value to get around this; choose a value which clearly stands out " - "from the image data you want to see." - ) - .assignmentOptional() - .defaultValue(-1000) - .reconfigurable() - .commit(), - - DOUBLE_ELEMENT(expected) - .key("preview.maxRate") - .displayedName("Max rate") - .description( - "Preview output is throttled to (at most) this speed. New trains " - "matched 'too soon' get dropped here (instead of sending to be dropped " - "by GUI)." - ) - .unit(Unit.HERTZ) - .assignmentOptional() - .defaultValue(2) - .reconfigurable() + .dataSchema(schemas.preview_schema()) .commit(), STRING_ELEMENT(expected) @@ -186,6 +90,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): .assignmentMandatory() .commit(), ) + preview_utils.PreviewFriend.add_schema(expected, "preview") def __init__(self, conf): super().__init__(conf) @@ -196,7 +101,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): # TODO: match inside device, fill multiple independent buffers - self._throttler = utils.SkippingThrottler(1 / self.get("preview.maxRate")) + self._preview_friend = preview_utils.PreviewFriend(self) self._path_to_stack = self.get("pathToStack") self._geometry = None self._stack_input_buffer = None @@ -221,16 +126,15 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self.remote().registerDeviceMonitor(geometry_device, self._receive_geometry) self.assembled_output = self.signalSlotable.getOutputChannel("assembledOutput") - self.preview_output = self.signalSlotable.getOutputChannel("preview.output") self.start() def requestScene(self, params): - # TODO: unify with TrainMatcher overview scene_name = params.get("name", default="") if scene_name == "overview": payload = Hash("name", scene_name, "success", True) payload["data"] = scenes.detector_assembler_overview( - device_id=self.getInstanceId() + device_id=self.getInstanceId(), + schema=self.getFullSchema(), ) self.reply( Hash( @@ -265,8 +169,8 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self.log.WARN("Have not received a geometry yet, will not send anything") return - my_timestamp = Timestamp(Epochstamp(), Trainstamp(train_id)) - my_source = self.getInstanceId() + my_timestamp = self.getActualTimestamp() + my_device_id = self.getInstanceId() bridge_output_choice = BridgeOutputOptions( self.unsafe_get("outputForBridgeOutput") ) @@ -275,7 +179,9 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): earliest_source_timestamp = float("inf") for source, (data, source_timestamp) in sources.items(): # regular TrainMatcher output - self.output.write(data, ChannelMetaData(source, source_timestamp)) + self.output.write( + data, ChannelMetaData(source, source_timestamp), copyAllData=False + ) if bridge_output_choice is BridgeOutputOptions.MATCHED: self.zmq_output.write(source, data, source_timestamp) @@ -295,8 +201,8 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self.zmq_output.update() for module_index in module_indices_unfilled: - self._stack_input_buffer[module_index].fill(0) - # TODO: configurable treatment of missing modules + self._stack_input_buffer[module_index].fill(np.nan) + # consider configurable treatment of missing modules # TODO: reusable output buffer to save on allocation assembled, _ = self._geometry.position_modules_fast(self._stack_input_buffer) @@ -305,51 +211,30 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): output_hash = Hash( "image.data", assembled, - "trainId", - train_id, ) - output_metadata = ChannelMetaData(my_source, my_timestamp) - self.assembled_output.write(output_hash, output_metadata) + self.assembled_output.write( + output_hash, + ChannelMetaData(f"{my_device_id}:assembledOutput", my_timestamp), + copyAllData=False, + ) self.assembled_output.update() if bridge_output_choice is BridgeOutputOptions.ASSEMBLED: - self.zmq_output.write(my_source, output_hash, my_timestamp) + self.zmq_output.write( + f"{my_device_id}:assembledOutput", output_hash, my_timestamp + ) self.zmq_output.update() - if self._throttler.test_and_set(): - downsampling_factor = self.unsafe_get("preview.downsamplingFactor") - if downsampling_factor > 1: - assembled = downsample_2d( - assembled, - downsampling_factor, - reduction_fun=getattr( - np, self.unsafe_get("preview.downsamplingFunction") - ), - ) - assembled[np.isnan(assembled)] = self.unsafe_get("preview.replaceNanWith") - output_hash = Hash( - "image.data", - ImageData( - # TODO: get around this being mirrored... - assembled[::-1, ::-1], - Dims(*assembled.shape), - Encoding.GRAY, - bitsPerPixel=32, - ), - "trainId", - train_id, - ) - self.preview_output.write( - output_hash, - output_metadata, + preview_hash_sent = self._preview_friend.maybe_write([assembled]) + if ( + bridge_output_choice is BridgeOutputOptions.PREVIEW + and preview_hash_sent is not None + ): + self.zmq_output.write( + f"{my_device_id}:preview.output", + preview_hash_sent, + my_timestamp, ) - self.preview_output.update() - if bridge_output_choice is BridgeOutputOptions.PREVIEW: - self.zmq_output.write( - my_source, - output_hash, - my_timestamp, - ) - self.zmq_output.update() + self.zmq_output.update() self.info["timeOfFlight"] = ( Timestamp().toTimestamp() - earliest_source_timestamp @@ -381,35 +266,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): def preReconfigure(self, conf): super().preReconfigure(conf) - if conf.has("preview.maxRate"): - self._throttler = utils.SkippingThrottler( - 1 / conf["preview.maxRate"] - ) - - -def downsample_2d(arr, factor, reduction_fun=np.nanmax): - """Generalization of downsampling from FemDataAssembler - - Expects first two dimensions of arr to be multiple of 2 ** factor - Useful if you're sitting at home and ssh connection is slow to get full-resolution - previews.""" - - for i in range(factor // 2): - arr = reduction_fun( - ( - arr[:-1:2], - arr[1::2], - ), - axis=0, - ) - arr = reduction_fun( - ( - arr[:, :-1:2], - arr[:, 1::2], - ), - axis=0, - ) - return arr + self._preview_friend.reconfigure(conf) # forward-compatible unsafe_get proposed by @haufs diff --git a/src/calng/JungfrauCorrection.py b/src/calng/JungfrauCorrection.py deleted file mode 100644 index 788da429a28dfada638c3f0401e86c486374b9ce..0000000000000000000000000000000000000000 --- a/src/calng/JungfrauCorrection.py +++ /dev/null @@ -1,553 +0,0 @@ -import enum - -import cupy -import numpy as np -from karabo.bound import ( - DOUBLE_ELEMENT, - KARABO_CLASSINFO, - OUTPUT_CHANNEL, - OVERWRITE_ELEMENT, - STRING_ELEMENT, - VECTOR_STRING_ELEMENT, -) - -from . import base_calcat, utils -from ._version import version as deviceVersion -from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema -from .base_kernel_runner import BaseGpuRunner, BaseKernelRunner - - -_pretend_pulse_table = np.arange(16, dtype=np.uint8) - - -class JungfrauConstants(enum.Enum): - Offset10Hz = enum.auto() - BadPixelsDark10Hz = enum.auto() - BadPixelsFF10Hz = enum.auto() - RelativeGain10Hz = enum.auto() - - -class CorrectionFlags(enum.IntFlag): - NONE = 0 - OFFSET = 1 - REL_GAIN = 2 - BPMASK = 4 - - -class KernelRunnerVersions(enum.Enum): - CPU = enum.auto() - GPU = enum.auto() - - -class JungfrauGpuRunner(BaseGpuRunner): - _kernel_source_filename = "jungfrau_gpu.cu" - _corrected_axis_order = "cyx" - - def __init__( - self, - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells, - input_data_dtype=cupy.uint16, - output_data_dtype=cupy.float32, - bad_pixel_mask_value=cupy.nan, - ): - self.input_shape = (memory_cells, pixels_y, pixels_x) - self.processed_shape = self.input_shape - super().__init__( - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells, - input_data_dtype, - output_data_dtype, - ) - # TODO: avoid superclass creating cell table with wrong dtype first - self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=cupy.uint8) - self.input_gain_stage_gpu = cupy.empty(self.input_shape, dtype=cupy.uint8) - self.preview_buffer_getters.append(self._get_gain_stage_for_preview) - self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3) - self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) - self.rel_gain_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) - self.bad_pixel_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.uint32) - self.bad_pixel_mask_value = bad_pixel_mask_value - - self.update_block_size((1, 1, 64)) - - def _init_kernels(self): - kernel_source = self._kernel_template.render( - { - "pixels_x": self.pixels_x, - "pixels_y": self.pixels_y, - "data_memory_cells": self.memory_cells, - "constant_memory_cells": self.constant_memory_cells, - "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), - "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), - "corr_enum": utils.enum_to_c_template(CorrectionFlags), - "burst_mode": self.burst_mode, - } - ) - self.source_module = cupy.RawModule(code=kernel_source) - self.correction_kernel = self.source_module.get_function("correct") - - @property - def burst_mode(self): - return self.memory_cells > 1 - - def _get_raw_for_preview(self): - return self.input_data_gpu - - def _get_corrected_for_preview(self): - return self.processed_data_gpu - - def _get_gain_stage_for_preview(self): - return self.input_gain_stage_gpu - - def load_data(self, image_data, input_gain_stage, cell_table): - """Experiment: loading all three in one function as they are tied""" - self.input_data_gpu.set(image_data) - self.input_gain_stage_gpu.set(input_gain_stage) - if self.burst_mode: - self.cell_table_gpu.set(cell_table) - - def flush_buffers(self): - self.offset_map_gpu.fill(0) - self.rel_gain_map_gpu.fill(1) - self.bad_pixel_map_gpu.fill(0) - - def correct(self, flags): - self.correction_kernel( - self.full_grid, - self.full_block, - ( - self.input_data_gpu, - self.input_gain_stage_gpu, - self.cell_table_gpu, - cupy.uint8(flags), - self.offset_map_gpu, - self.rel_gain_map_gpu, - self.bad_pixel_map_gpu, - self.bad_pixel_mask_value, - self.processed_data_gpu, - ) - ) - - -class JungfrauCpuRunner(BaseKernelRunner): - _corrected_axis_order = "cyx" - - def __init__( - self, - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells, - input_data_dtype=np.uint16, - output_data_dtype=np.float32, # TODO: configurable - bad_pixel_mask_value=np.nan, - ): - super().__init__( - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells, - input_data_dtype, - output_data_dtype, - ) - - from .kernels import jungfrau_cython - self.correction_kernel_single = jungfrau_cython.correct_single - self.correction_kernel_burst = jungfrau_cython.correct_burst - self.input_shape = (memory_cells, pixels_y, pixels_x) - self.preview_buffer_getters.append(self._get_gain_stage_for_preview) - self.processed_shape = self.input_shape - self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3) - self.offset_map = np.empty(self.map_shape, dtype=np.float32) - self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32) - self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32) - self.bad_pixel_mask_value = bad_pixel_mask_value - self.flush_buffers() - - self.input_data = None # will just point to data coming in - self.input_gain_stage = None # will just point to data coming in - self.processed_data = None # will just point to buffer we're given - - def _get_raw_for_preview(self): - return self.input_data - - def _get_corrected_for_preview(self): - return self.processed_data - - def _get_gain_stage_for_preview(self): - return self.input_gain_stage - - @property - def burst_mode(self): - return self.memory_cells > 1 - - def load_data(self, image_data, input_gain_stage, cell_table): - self.input_data = image_data.astype(np.uint16, copy=False) - self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False) - self.input_cell_table = cell_table.astype(np.uint8, copy=False) - - def flush_buffers(self): - self.offset_map.fill(0) - self.rel_gain_map.fill(1) - self.bad_pixel_map.fill(0) - - def correct(self, flags, out=None): - if out is None: - out = np.empty(self.processed_shape, dtype=np.float32) - - if self.burst_mode: - self.correction_kernel_burst( - self.input_data, - self.input_gain_stage, - self.input_cell_table, - flags, - self.offset_map, - self.rel_gain_map, - self.bad_pixel_map, - self.bad_pixel_mask_value, - out, - ) - else: - self.correction_kernel_single( - self.input_data, - self.input_gain_stage, - flags, - self.offset_map, - self.rel_gain_map, - self.bad_pixel_map, - self.bad_pixel_mask_value, - out, - ) - self.processed_data = out - - -class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend): - _constant_enum_class = JungfrauConstants - - def __init__(self, device, *args, **kwargs): - super().__init__(device, *args, **kwargs) - self._constants_need_conditions = { - JungfrauConstants.Offset10Hz: self.dark_condition, - JungfrauConstants.BadPixelsDark10Hz: self.dark_condition, - JungfrauConstants.BadPixelsFF10Hz: self.dark_condition, - JungfrauConstants.RelativeGain10Hz: self.dark_condition, - } - - @staticmethod - def add_schema( - schema, - managed_keys, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): - super(JungfrauCalcatFriend, JungfrauCalcatFriend).add_schema( - schema, managed_keys, "jungfrau-Type", param_prefix, status_prefix - ) - - # set some defaults for common parameters - ( - OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsX") - .setNewDefaultValue(1024) - .commit(), - - OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsY") - .setNewDefaultValue(512) - .commit(), - - OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") - .setNewDefaultValue(90) - .commit(), - ) - - # add extra parameters - ( - DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.integrationTime") - .displayedName("Integration time") - .description("Integration time in ms") - .assignmentOptional() - .defaultValue(350) - .reconfigurable() - .commit(), - - DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.sensorTemperature") - .displayedName("Sensor temperature") - .description("Sensor temperature in K") - .assignmentOptional() - .defaultValue(291) - .reconfigurable() - .commit(), - - DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.gainSetting") - .displayedName("Gain setting") - .description("Feedback capacitor setting; 0 is default, 1 is HG0") - .assignmentOptional() - .defaultValue(0) - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.gainMode") - .displayedName("Gain mode") - .description( - "Detector may be operating in one of several gain modes. For this " - "device to query appropriate constants, it is sufficient to know " - "whether gain mode is dynamic or fixed." - ) - .assignmentOptional() - .defaultValue("dynamicgain") - .options("dynamicgain,fixedgain") - .commit(), - ) - managed_keys.add(f"{param_prefix}.integrationTime") - managed_keys.add(f"{param_prefix}.sensorTemperature") - managed_keys.add(f"{param_prefix}.gainSetting") - managed_keys.add(f"{param_prefix}.gainMode") - - base_calcat.add_status_schema_from_enum( - schema, status_prefix, JungfrauConstants - ) - - def dark_condition(self): - res = base_calcat.OperatingConditions() - res["Memory cells"] = self._get_param("memoryCells") - res["Sensor Bias Voltage"] = self._get_param("biasVoltage") - res["Pixels X"] = self._get_param("pixelsX") - res["Pixels Y"] = self._get_param("pixelsY") - res["Integration Time"] = self._get_param("integrationTime") - res["Sensor Temperature"] = self._get_param("sensorTemperature") - res["Gain Setting"] = self._get_param("gainSetting") - gain_mode = self._get_param("gainMode") - if gain_mode != "dynamicgain": - # NOTE: always include if CalCat is updated for this - res["Gain mode"] = 1 - return res - - -@KARABO_CLASSINFO("JungfrauCorrection", deviceVersion) -class JungfrauCorrection(BaseCorrection): - _correction_flag_class = CorrectionFlags - _correction_field_names = ( - ("offset", CorrectionFlags.OFFSET), - ("relGain", CorrectionFlags.REL_GAIN), - ("badPixels", CorrectionFlags.BPMASK), - ) - _kernel_runner_class = None # note: set in __init__ based on config - _calcat_friend_class = JungfrauCalcatFriend - _constant_enum_class = JungfrauConstants - _managed_keys = BaseCorrection._managed_keys.copy() - _image_data_path = "data.adc" - _cell_table_path = "data.memoryCell" - - @staticmethod - def expectedParameters(expected): - ( - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsX") - .setNewDefaultValue(1024) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.pixelsY") - .setNewDefaultValue(512) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("dataFormat.memoryCells") - .setNewDefaultValue(1) - .commit(), - - OVERWRITE_ELEMENT(expected) - .key("preview.selectionMode") - .setNewDefaultValue("frame") - .commit(), - - # JUNGFRAU data is small, can fit plenty of trains in here - OVERWRITE_ELEMENT(expected) - .key("outputShmemBufferSize") - .setNewDefaultValue(2) - .commit(), - ) - - # support both CPU and GPU kernels - ( - STRING_ELEMENT(expected) - .key("kernelType") - .assignmentOptional() - .defaultValue(KernelRunnerVersions.CPU.name) - .options(",".join(kernel_type.name for kernel_type in KernelRunnerVersions)) - .reconfigurable() - .commit(), - ) - JungfrauCorrection._managed_keys.add("kernelType") - - ( - OUTPUT_CHANNEL(expected) - .key("preview.outputGainMap") - .dataSchema(preview_schema) - .commit(), - ) - - JungfrauCalcatFriend.add_schema(expected, JungfrauCorrection._managed_keys) - add_correction_step_schema( - expected, - JungfrauCorrection._managed_keys, - JungfrauCorrection._correction_field_names, - ) - - # mandatory: manager needs this in schema - ( - VECTOR_STRING_ELEMENT(expected) - .key("managedKeys") - .assignmentOptional() - .defaultValue(list(JungfrauCorrection._managed_keys)) - .commit() - ) - - @property - def input_data_shape(self): - return ( - self.unsafe_get("dataFormat.memoryCells"), - self.unsafe_get("dataFormat.pixelsY"), - self.unsafe_get("dataFormat.pixelsX"), - ) - - def __init__(self, config): - super().__init__(config) - kernel_type = KernelRunnerVersions[config["kernelType"]] - if kernel_type is KernelRunnerVersions.CPU: - self._kernel_runner_class = JungfrauCpuRunner - else: - self._kernel_runner_class = JungfrauGpuRunner - # TODO: gain mode as constant parameter and / or device configuration - - try: - self.bad_pixel_mask_value = np.float32( - config.get("corrections.badPixels.maskingValue") - ) - except ValueError: - self.bad_pixel_mask_value = np.float32("nan") - - self._kernel_runner_init_args = { - "bad_pixel_mask_value": self.bad_pixel_mask_value, - } - - def process_data( - self, - data_hash, - metadata, - source, - train_id, - image_data, - cell_table, - do_generate_preview, - ): - if len(cell_table.shape) == 0: - cell_table = cell_table[np.newaxis] - try: - gain_map = data_hash.get("data.gain") - if self.unsafe_get("dataFormat.overrideInputAxisOrder"): - gain_map.shape = self.input_data_shape - self.kernel_runner.load_data( - image_data, gain_map, cell_table - ) - except ValueError as e: - self.log_status_warn(f"Failed to load data: {e}") - return - except Exception as e: - self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") - - buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.kernel_runner.correct(self._correction_flag_enabled) - self.kernel_runner.reshape( - output_order=self.unsafe_get("dataFormat.outputAxisOrder"), - out=buffer_array, - ) - - if do_generate_preview: - if self._correction_flag_enabled != self._correction_flag_preview: - self.kernel_runner.correct(self._correction_flag_preview) - ( - preview_slice_index, - preview_cell, - preview_pulse, - ) = utils.pick_frame_index( - self.unsafe_get("preview.selectionMode"), - self.unsafe_get("preview.index"), - cell_table, - _pretend_pulse_table, - warn_func=self.log_status_warn, - ) - ( - preview_raw, - preview_corrected, - preview_gain_map - ) = self.kernel_runner.compute_previews(preview_slice_index) - - # reusing input data hash for sending - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set("calngShmemPaths", [self._image_data_path]) - - self._write_output(data_hash, metadata) - - if do_generate_preview: - self._write_preview_outputs( - ( - ("preview.outputRaw", preview_raw), - ("preview.outputCorrected", preview_corrected), - ("preview.outputGainMap", preview_gain_map), - ), - metadata, - ) - - def _load_constant_to_runner(self, constant, constant_data): - if constant_data.shape[0] == self.get("dataFormat.pixelsX"): - constant_data = np.transpose(constant_data, (2, 1, 0, 3)) - else: - constant_data = np.transpose(constant_data, (2, 0, 1, 3)) - - kernel_type = KernelRunnerVersions[self.get("kernelType")] - if constant is JungfrauConstants.Offset10Hz: - if kernel_type is KernelRunnerVersions.CPU: - self.kernel_runner.offset_map[:] = constant_data.astype(np.float32) - else: - self.kernel_runner.offset_map_gpu.set(constant_data.astype(np.float32)) - if not self.get("corrections.offset.available"): - self.set("corrections.offset.available", True) - elif constant is JungfrauConstants.RelativeGain10Hz: - if kernel_type is KernelRunnerVersions.CPU: - self.kernel_runner.rel_gain_map[:] = constant_data.astype(np.float32) - else: - self.kernel_runner.rel_gain_map_gpu.set( - constant_data.astype(np.float32) - ) - if not self.get("corrections.relGain.available"): - self.set("corrections.relGain.available", True) - elif constant in ( - JungfrauConstants.BadPixelsDark10Hz, - JungfrauConstants.BadPixelsFF10Hz, - ): - if kernel_type is KernelRunnerVersions.CPU: - self.kernel_runner.bad_pixel_map |= constant_data - else: - self.kernel_runner.bad_pixel_map_gpu |= cupy.asarray(constant_data) - if not self.get("corrections.badPixels.available"): - self.set("corrections.badPixels.available", True) - - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name} to runner") diff --git a/src/calng/ModuleStacker.py b/src/calng/ModuleStacker.py deleted file mode 100644 index 95abe95ee79f05007fffe98bb4c73b1fad6c6f08..0000000000000000000000000000000000000000 --- a/src/calng/ModuleStacker.py +++ /dev/null @@ -1,142 +0,0 @@ -import numpy as np -from karabo.bound import ( - FLOAT_ELEMENT, - KARABO_CLASSINFO, - NODE_ELEMENT, - STRING_ELEMENT, - ChannelMetaData, - Epochstamp, - Hash, - MetricPrefix, - Schema, - Timestamp, - Trainstamp, - Unit, -) -from TrainMatcher import TrainMatcher - -from ._version import version as deviceVersion - - -@KARABO_CLASSINFO("ModuleStacker", deviceVersion) -class ModuleStacker(TrainMatcher.TrainMatcher): - """This will be deprecated: now just equivalent to ModuleMatcher except this - stacks `image.data` instead of `data.adc` - - """ - - def __init__(self, conf): - super().__init__(conf) - self.info.merge(Hash("timeOfFlight", 0)) - - @staticmethod - def expectedParameters(expected): - ( - FLOAT_ELEMENT(expected) - .key("timeOfFlight") - .displayedName("Time of flight") - .description( - "Time elapsed from DAQ sent data until train was matched and ready to " - "send from here. Measured for latest train matched. Maximum over all " - "sources included in said train." - ) - .unit(Unit.SECOND) - .metricPrefix(MetricPrefix.MILLI) - .readOnly() - .commit(), - - STRING_ELEMENT(expected) - .key("pathToStack") - .displayedName("Data path to stack") - .description( - "Typically, image.data will be used for full data going through the " - "pipeline. Set this when input is dataOutput from a correction device " - "and output goes to a bridge. For previews, data.adc is used as part " - "of the combiner format. Set to data.adc when input is a preview " - "output and output goes to a femDataAssembler." - ) - .options("image.data,data.adc") - .assignmentOptional() - .defaultValue("image.data") - .commit(), - ) - - def initialization(self): - """Apply configuration and automatically start. - - Upon instantiation, the device will automatically start matching data - from the specified data sources. - """ - super().initialization() - - # Disable the start and stop slots. - # It's not possible to pop items from the full schema. The slots are - # therefore converted to unused nodes. - desc = ( - "Disable slots from the parent class. The acquisition start automatically " - "on instantiation. Check that the `fastSources` table is populated and " - "its booleans set to True the project configuration (not instantiated). " - "These nodes are not used." - ) - schema = Schema() - ( - NODE_ELEMENT(schema).key("start").description(desc).commit(), - - NODE_ELEMENT(schema).key("stop").description(desc).commit(), - ) - self.path_to_stack = self.get("pathToStack") - self.updateSchema(schema) - - super().start() - - def _send(self, tid, sources): - # Add control data - timestamp = Timestamp(Epochstamp(), Trainstamp(tid)) - - # Reuse arbitrary hash from existing ones (among the ones to stack) - try: - out_hash = next( - data - for (data, _) in iter(sources.values()) - if data.has(self.path_to_stack) - ) - except StopIteration: - out_hash = Hash() - - # TODO: handle missing modules properly (track all sources) - stacked_data = [] - stacked_sources = [] - stacked_present = [] - # TODO: should this be threaded? - time_of_flight = 0 - for source, (data, metadata) in sources.items(): - if not data.has(self.path_to_stack): - # may add sources not for stacking - # TODO: make stack or no part of source configuration - out_hash[f"unstacked.{source}"] = data - continue - old_ts = metadata.getTimestamp() - elapsed = timestamp.toTimestamp() - old_ts.toTimestamp() - time_of_flight = max(time_of_flight, elapsed) - image_data = data.get(self.path_to_stack) - stacked_data.append(image_data) - stacked_sources.append(source) - stacked_present.append(True) - - for source, data in self.ctrlmon.get(tid): - out_hash[f"unstacked.{source}"] = data - - if stacked_data: - if not isinstance(stacked_data[0], str): - # strings (like shmem handles) should stay list - stacked_data = np.stack(stacked_data, axis=0) - # TODO: merge with super().update_info (throttled updates) - self.info["timeOfFlight"] = time_of_flight * 1000 - - out_hash[self.path_to_stack] = stacked_data - out_hash["sources"] = stacked_sources - out_hash["modulesPresent"] = stacked_present - channel = self.signalSlotable.getOutputChannel("output") - channel.write(out_hash, ChannelMetaData(self.getInstanceId(), timestamp)) - channel.update() - self.rate_out.update() diff --git a/src/calng/RoiTool.py b/src/calng/RoiTool.py new file mode 100644 index 0000000000000000000000000000000000000000..af66fffcfe6282f1a09a7a110e760a15d930b0dd --- /dev/null +++ b/src/calng/RoiTool.py @@ -0,0 +1,312 @@ +import threading + + +import numpy as np +from karabo.middlelayer import ( + AccessLevel, + AccessMode, + Bool, + Configurable, + DaqPolicy, + Device, + Double, + EncodingType, + Hash, + Image, + ImageData, + InputChannel, + Node, + OutputChannel, + Overwrite, + Slot, + State, + String, + UInt32, + Unit, + VectorDouble, + VectorInt32, + VectorString, + slot, +) + +from . import scenes, utils + + +def image_data_node(): + return Image( + ImageData(np.zeros((100, 100), dtype=np.float32), encoding=EncodingType.GRAY), + ) + + +class DownsamplingNode(Configurable): + method = String(options=["nanmax", "nanmean", "nanmin", "nanmedian"]) + factor = UInt32(options=[1, 2, 4, 8]) + + +class PreviewSettingsNode(Configurable): + downsampling = Node(DownsamplingNode) + flipX = Bool(defaultValue=True, accessMode=AccessMode.RECONFIGURABLE) + flipY = Bool(defaultValue=True, accessMode=AccessMode.RECONFIGURABLE) + + @Double(accessMode=AccessMode.RECONFIGURABLE, defaultValue=2, unitSymbol=Unit.HERTZ) + def maxPreviewRate(self, value): + self.maxPreviewRate = value + parent = self.get_root() + parent._throttler = utils.SkippingThrottler(1 / self.maxPreviewRate.value) + + +class HistogramSettingsNode(Configurable): + @Slot(displayedName="Reset bins") + async def resetBins(self): + parent = self.get_root() + with parent._bin_lock: + parent._bins = None + parent._means = None + + resetBinsOnRoiChange = Bool( + defaultValue=True, + accessMode=AccessMode.RECONFIGURABLE, + displayedName="Reset on ROI change", + ) + automaticallyExpandRange = Bool( + defaultValue=True, + accessMode=AccessMode.RECONFIGURABLE, + displayedName="Auto-expand", + ) + + @Double(defaultValue=0, accessMode=AccessMode.RECONFIGURABLE, displayedName="Min") + async def rangeMin(self, value): + self.rangeMin = value + await self.resetBins() + + @Double( + defaultValue=True, accessMode=AccessMode.RECONFIGURABLE, displayedName="Max" + ) + async def rangeMax(self, value): + self.rangeMax = value + await self.resetBins() + + @UInt32( + defaultValue=10, + maxInc=100_000, + accessMode=AccessMode.RECONFIGURABLE, + displayedName="#Bins", + ) + async def numBins(self, value): + self.numBins = value + await self.resetBins() + + @UInt32( + defaultValue=10, + maxInc=100_000, + accessMode=AccessMode.RECONFIGURABLE, + displayedName="Window size", + ) + async def rollingWindowSize(self, value): + self.rollingWindowSize = value + await self.resetBins() + + +class RedundantImageNode(Configurable): + image = image_data_node() + + +class RectRoiNode(Configurable): + # same as in gaussian.py + displayType = "WidgetNode|RectRoiGraph" + + @VectorInt32(defaultValue=[0, 100, 0, 100]) + async def roi(self, value): + self.roi = value + parent = self.get_root() + if parent.histogram.resetBinsOnRoiChange: + await parent.histogram.resetBins() + + data = Node(RedundantImageNode) + + +class ManualHistogramDisplayableNode(Configurable): + x = VectorDouble() + y = VectorDouble() + yMean = VectorDouble() + + +class OutputNode(Configurable): + roiImage = Node(RectRoiNode) + zoomImage = image_data_node() + manualHistogram = Node(ManualHistogramDisplayableNode) + + +class RoiTool(Device): + state = Overwrite(defaultValue=State.INIT) + output = OutputChannel(OutputNode) + imageDataPath = String( + defaultValue="image.data", accessMode=AccessMode.RECONFIGURABLE + ) + histogram = Node(HistogramSettingsNode) + preview = Node(PreviewSettingsNode) + rate = Double(defaultValue=0, accessMode=AccessMode.READONLY) + numPixelsIncluded = UInt32( + defaultValue=0, accessMode=AccessMode.READONLY, displayedName="#Pixels counted" + ) + + def __init__(self, config): + super().__init__(config) + self._bins = None + self._means = None + self._bin_lock = threading.RLock() + self._rate_tracker = utils.WindowRateTracker() + + async def onInitialization(self): + self._throttler = utils.SkippingThrottler(1 / self.preview.maxPreviewRate.value) + self.state = State.ON + + @InputChannel() + async def imageInput(self, data, meta): + image_data = rec_getattr(data, self.imageDataPath.value).astype( + np.float32, copy=False + ) + if self.preview.flipX.value: + image_data = image_data[:, ::-1] + if self.preview.flipY.value: + image_data = image_data[::-1] + + # TODO: make handling of extra dimension(s) configurable + if len(image_data.shape) == 3: + image_data = image_data[0] + image_data = np.ascontiguousarray(image_data) + x_min, x_max, y_min, y_max = self.output.schema.roiImage.roi.value + x_min = max(x_min, 0) + y_min = max(y_min, 0) + x_max = min(x_max, image_data.shape[0]) + y_max = min(y_max, image_data.shape[1]) + if (downsampling_factor := self.preview.downsampling.factor.value) > 1: + x_min *= downsampling_factor + y_min *= downsampling_factor + x_max *= downsampling_factor + y_max *= downsampling_factor + + # data for analysis may contain NaNs, we should filter them out + zoomed = image_data[y_min:y_max, x_min:x_max] + zoomed = zoomed[np.isfinite(zoomed)] + finite_pixel_count = zoomed.size + update_histogram = finite_pixel_count > 0 + + if update_histogram: + with self._bin_lock: + if self.histogram.automaticallyExpandRange.value: + # device itself updating value doesn't triggers change function + ranges_changed = False + if (min_val := np.min(zoomed)) < self.histogram.rangeMin.value: + self.histogram.rangeMin = min_val + ranges_changed = True + if (max_val := np.max(zoomed)) > self.histogram.rangeMax.value: + self.histogram.rangeMax = max_val + ranges_changed = True + if ranges_changed: + await self.histogram.resetBins() + if self._bins is None: + counts, bin_edges = np.histogram( + zoomed, + bins=self.histogram.numBins.value, + range=( + self.histogram.rangeMin.value, + self.histogram.rangeMax.value, + ), + ) + self._bins = bin_edges.astype(np.float32) + self._window_counts = np.zeros( + ( + self.histogram.rollingWindowSize.value, + self.histogram.numBins.value, + ), + dtype=np.float32, + ) + self._window_count_index = 0 + else: + counts, bin_edges = np.histogram( + zoomed[np.isfinite(zoomed)], bins=self._bins + ) + counts = counts.astype(np.float32) + self._window_counts[self._window_count_index] = counts + self._window_count_index = ( + self._window_count_index + 1 + ) % self._window_counts.shape[0] + + self._rate_tracker.update() + + if self._throttler.test_and_set(): + if update_histogram: + ( + self.output.schema.manualHistogram.x, + self.output.schema.manualHistogram.y, + ) = _histogram_plot_helper(counts / np.sum(counts), bin_edges) + _, self.output.schema.manualHistogram.yMean = _histogram_plot_helper( + np.sum(self._window_counts, axis=0) / np.sum(self._window_counts), + bin_edges, + ) + + # data for preview should not contain NaNs + np.nan_to_num(image_data, copy=False) + zoomed = image_data[y_min:y_max, x_min:x_max] + self.output.schema.roiImage.data.image = ImageData( + utils.downsample_2d( + image_data, + downsampling_factor, + getattr(np, self.preview.downsampling.method.value), + ), + encoding=EncodingType.GRAY, + bitsPerPixel=32, + ) + self.output.schema.zoomImage = ImageData( + np.ascontiguousarray(zoomed), + encoding=EncodingType.GRAY, + bitsPerPixel=32, + ) + await self.output.writeData() + self.rate = self._rate_tracker.get() + if finite_pixel_count != self.numPixelsIncluded: + self.numPixelsIncluded = finite_pixel_count + + availableScenes = VectorString( + displayedName="Available scenes", + displayType="Scenes", + requiredAccessLevel=AccessLevel.OBSERVER, + accessMode=AccessMode.READONLY, + defaultValue=[ + "overview", + ], + daqPolicy=DaqPolicy.OMIT, + ) + + @slot + def requestScene(self, params): + name = params.get("name", default="overview") + if name == "overview": + scene_data = scenes.histogram_overview( + self.deviceId, + schema=self.getDeviceSchema(), + ) + payload = Hash("success", True, "name", name, "data", scene_data) + + return Hash("type", "deviceScene", "origin", self.deviceId, "payload", payload) + + +def rec_getattr(obj, path): + res = obj + for part in path.split("."): + res = getattr(res, part) + return res + + +def _histogram_plot_helper(counts, bin_edges): + x = np.zeros(bin_edges.size * 2) + y = np.zeros_like(x) + + x[::2] = bin_edges + x[1::2] = bin_edges + + y[1:-1:2] = counts + y[2:-1:2] = counts + + return x, y diff --git a/src/calng/ShmemTrainMatcher.py b/src/calng/ShmemTrainMatcher.py index a3f919ee197c1313e2e8f1b9bca835860f0e1d2c..223a7c51cdd073a1d27d689255e3e5fae85b524e 100644 --- a/src/calng/ShmemTrainMatcher.py +++ b/src/calng/ShmemTrainMatcher.py @@ -7,6 +7,7 @@ from karabo.bound import ( BOOL_ELEMENT, INT32_ELEMENT, KARABO_CLASSINFO, + OVERWRITE_ELEMENT, STRING_ELEMENT, TABLE_ELEMENT, ChannelMetaData, @@ -116,6 +117,13 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .reconfigurable() .commit(), + # order is important for stacking, disable sorting + OVERWRITE_ELEMENT(expected) + .key("sortSources") + .setNowReadOnly() + .setNewDefaultValue(False) + .commit(), + BOOL_ELEMENT(expected) .key("useThreadPool") .displayedName("Use thread pool") @@ -133,8 +141,25 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): .defaultValue(True) .reconfigurable() .commit(), + + BOOL_ELEMENT(expected) + .key("useInfiniband") + .description( + "If enabled, device will during initialization try to bind its data " + "output channel (output) to its node's infiniband interface. Default " + "interface is used if no 'ib0' interface is found." + ) + .assignmentOptional() + .defaultValue(True) + .commit(), ) + def __init__(self, config): + if config.get("useInfiniband", default=True): + from PipeToZeroMQ.utils import find_infiniband_ip + config["output.hostname"] = find_infiniband_ip() + super().__init__(config) + def initialization(self): self._stacking_buffers = {} self._source_stacking_indices = {} @@ -332,7 +357,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): # karabo output if self.output is not None: for source, (data, timestamp) in sources.items(): - self.output.write(data, ChannelMetaData(source, timestamp)) + self.output.write( + data, ChannelMetaData(source, timestamp), copyAllData=False + ) self.output.update() # karabo bridge output @@ -343,36 +370,3 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher): self.info["sent"] += 1 self.info["trainId"] = train_id self.rate_out.update() - - def _maybe_connect_data(self, conf, update=False, state=None): - """Temporary override on _maybe_connect_data to avoid sorting sources list (we - need it for stacking order)""" - if self["state"] not in (State.CHANGING, State.ACTIVE): - return - - last_state = self["state"] - self.updateState(State.CHANGING) - - # unwatch removed sources - def src_names(c): - # do not assign a lambda expression, use a def - return {s["source"] for s in c["sources"]} - - for source in src_names(self).difference(src_names(conf)): - self.monitor.unwatch_source(source) - - new_conf = VectorHash() - for src in conf["sources"]: - source = src["source"] - if src["select"]: - src["status"] = self.monitor.watch_source(source, src["offset"]) - else: - self.monitor.unwatch_source(source) - src["status"] = "" - new_conf.append(src) - - if update: - self.set("sources", new_conf) - else: - conf["sources"] = new_conf - self.updateState(state or last_state) diff --git a/src/calng/base_calcat.py b/src/calng/base_calcat.py index 655300436dc2695587e24ae1627f25c2c6447ca6..40965e9cdec7744bee23270bb7baa6b62fc1218d 100644 --- a/src/calng/base_calcat.py +++ b/src/calng/base_calcat.py @@ -1,6 +1,7 @@ import copy import functools import json +import multiprocessing import pathlib import threading @@ -16,7 +17,6 @@ from calibration_client.modules import ( PhysicalDetectorUnit, ) from karabo.bound import ( - BOOL_ELEMENT, DOUBLE_ELEMENT, NODE_ELEMENT, SLOT_ELEMENT, @@ -48,17 +48,22 @@ class CalibrationClientConfigError(Exception): pass -def add_status_schema_from_enum(schema, prefix, enum_class): +def add_status_schema_from_enum(schema, enum_class): for constant in enum_class: - constant_node = f"{prefix}.{constant.name}" + constant_node = f"foundConstants.{constant.name}" ( NODE_ELEMENT(schema).key(constant_node).commit(), - BOOL_ELEMENT(schema) - .key(f"{constant_node}.found") - .displayedName("Found and loaded") + STRING_ELEMENT(schema) + .key(f"{constant_node}.state") + .setSpecialDisplayType("State") .readOnly() - .initialValue(False) + .initialValue("OFF") + .commit(), + + SLOT_ELEMENT(schema) + .key(f"{constant_node}.loadMostRecent") + .displayedName("Load most recent") .commit(), STRING_ELEMENT(schema) @@ -153,7 +158,7 @@ def add_status_schema_from_enum(schema, prefix, enum_class): .commit(), SLOT_ELEMENT(schema) - .key(f"{constant_node}.overrideConstantVersion") + .key(f"{constant_node}.overrideConstantFromVersion") .displayedName("Override constant version") .description("See description of constant version id.") .commit(), @@ -213,8 +218,8 @@ class BaseCalcatFriend: """Base class for CalCat friends - handles interacting with CalCat for the device A CalCat friend uses the device schema to build up parameters for CalCat queries. - It focuses on two nodes (added by static method add_schema): param_prefix and - status_prefix. The former is primarily used to get parameters which are (via + It focuses on two nodes (added by static method add_schema): constantParameters and + foundConstants. The former is primarily used to get parameters which are (via condition methods - see for example dark_condition of DsscCalcatFriend) used to look for constants. The latter is primarily used to give user information about what was found. @@ -228,8 +233,6 @@ class BaseCalcatFriend: schema, managed_keys, detector_type, - param_prefix="constantParameters", - status_prefix="foundConstants", ): """Add elements needed by this object to device's schema (expectedSchema) @@ -237,24 +240,25 @@ class BaseCalcatFriend: node which does not exist yet. To change default values and add more fields, extend this method in subclass. - The param_prefix node will hold all the parameters needed to build constant - condition dicts for querying CalCat. These values are set either directly on - the device or via manager and this class gets them from the device using helper - function _get_param. See for example AgipdCalcatFriend.dark_condition. + The constantParameters node will hold all the parameters needed to build + constant condition dicts for querying CalCat. These values are set either + directly on the device or via manager and this class gets them from the device + using helper function _get_param. See for example + AgipdCalcatFriend.dark_condition. - The status_prefix node is used to report information about what was found in + The foundConstants node is used to report information about what was found in CalCat. This class will update the values on the device using the helper function _set_status. This should not need to happen in subclass methods. """ ( NODE_ELEMENT(schema) - .key(param_prefix) + .key("constantParameters") .displayedName("Constant retrieval parameters") .commit(), NODE_ELEMENT(schema) - .key(status_prefix) + .key("foundConstants") .displayedName("Constants retrieved") .commit(), ) @@ -263,7 +267,7 @@ class BaseCalcatFriend: # TODO: probably switch to floating point for everything, including mem cells ( STRING_ELEMENT(schema) - .key(f"{param_prefix}.deviceMappingSnapshotAt") + .key("constantParameters.deviceMappingSnapshotAt") .displayedName("Snapshot timestamp (for device mapping)") .description( "CalCat supports querying with a specific snapshot of the database. " @@ -279,7 +283,7 @@ class BaseCalcatFriend: .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.constantVersionEventAt") + .key("constantParameters.constantVersionEventAt") .displayedName("Event at timestamp (for constant version)") .description("TODO") .assignmentOptional() @@ -288,7 +292,7 @@ class BaseCalcatFriend: .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorType") + .key("constantParameters.detectorType") .displayedName("Detector type name") .description( "Name of detector type in CalCat; typically has suffix '-Type'" @@ -298,37 +302,37 @@ class BaseCalcatFriend: .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorTypeId") + .key("constantParameters.detectorTypeId") .readOnly() .initialValue("") .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorName") + .key("constantParameters.detectorName") .assignmentOptional() .defaultValue("") .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorId") + .key("constantParameters.detectorId") .readOnly() .initialValue("") .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.karaboDa") + .key("constantParameters.karaboDa") .assignmentOptional() .defaultValue("") .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.moduleId") + .key("constantParameters.moduleId") .readOnly() .initialValue("") .commit(), UINT32_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") + .key("constantParameters.memoryCells") .displayedName("Memory cells") .description( "Number of memory cells / frames per train. Relevant for burst mode." @@ -339,21 +343,21 @@ class BaseCalcatFriend: .commit(), UINT32_ELEMENT(schema) - .key(f"{param_prefix}.pixelsX") + .key("constantParameters.pixelsX") .displayedName("Pixels X") .assignmentOptional() .defaultValue(512) .commit(), UINT32_ELEMENT(schema) - .key(f"{param_prefix}.pixelsY") + .key("constantParameters.pixelsY") .displayedName("Pixels Y") .assignmentOptional() .defaultValue(128) .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") + .key("constantParameters.biasVoltage") .displayedName("Bias voltage") .description("Sensor bias voltage") .assignmentOptional() @@ -361,27 +365,23 @@ class BaseCalcatFriend: .reconfigurable() .commit(), ) - managed_keys.add(f"{param_prefix}.deviceMappingSnapshotAt") - managed_keys.add(f"{param_prefix}.constantVersionEventAt") - managed_keys.add(f"{param_prefix}.memoryCells") - managed_keys.add(f"{param_prefix}.pixelsX") - managed_keys.add(f"{param_prefix}.pixelsY") - managed_keys.add(f"{param_prefix}.biasVoltage") + managed_keys.add("constantParameters.deviceMappingSnapshotAt") + managed_keys.add("constantParameters.constantVersionEventAt") + managed_keys.add("constantParameters.memoryCells") + managed_keys.add("constantParameters.pixelsX") + managed_keys.add("constantParameters.pixelsY") + managed_keys.add("constantParameters.biasVoltage") def __init__( self, device, secrets_fn: pathlib.Path, - param_prefix="constantParameters", - status_prefix="foundConstants", ): self.device = device - self.param_prefix = param_prefix - self.status_prefix = status_prefix self.cached_constants = {} - self.cached_constants_lock = threading.Lock() + self.cached_constants_lock = threading.RLock() # api lock used to force queries to be sequential (SSL issue on ONC) - self.api_lock = threading.Lock() + self.api_lock = threading.RLock() if not secrets_fn.is_file(): self.device.log_status_warn( @@ -411,22 +411,23 @@ class BaseCalcatFriend: def _get_param(self, key): """Helper to get value from attached device schema""" - return self.device.get(f"{self.param_prefix}.{key}") + return self.device.get(f"constantParameters.{key}") def _set_param(self, key, value): - self.device.set(f"{self.param_prefix}.{key}", value) + self.device.set(f"constantParameters.{key}", value) def _get_status(self, constant, key): - return self.device.get(f"{self.status_prefix}.{constant.name}.{key}") + return self.device.get(f"foundConstants.{constant.name}.{key}") def _set_status(self, constant, key, value): """Helper to update information about found constants on device""" - self.device.set(f"{self.status_prefix}.{constant.name}.{key}", value) + self.device.set(f"foundConstants.{constant.name}.{key}", value) @functools.cached_property def detector_id(self): detector_name = self._get_param("detectorName") - resp = Detector.get_by_identifier(self.client, detector_name) + with self.api_lock: + resp = Detector.get_by_identifier(self.client, detector_name) self._check_resp(resp, DetectorNotFound, f"Detector {detector_name} not found") res = resp["data"]["id"] self._set_param("detectorId", str(res)) @@ -435,7 +436,8 @@ class BaseCalcatFriend: @functools.cached_property def detector_type_id(self): detector_type = self._get_param("detectorType") - resp = DetectorType.get_by_name(self.client, detector_type) + with self.api_lock: + resp = DetectorType.get_by_name(self.client, detector_type) self._check_resp( resp, DetectorNotFound, f"Detector type {detector_type} not found" ) @@ -445,9 +447,12 @@ class BaseCalcatFriend: @functools.cached_property def pdus(self): - resp = PhysicalDetectorUnit.get_all_by_detector( - self.client, self.detector_id, self._get_param("deviceMappingSnapshotAt") - ) + with self.api_lock: + resp = PhysicalDetectorUnit.get_all_by_detector( + self.client, + self.detector_id, + self._get_param("deviceMappingSnapshotAt"), + ) self._check_resp(resp, warning="Failed to retrieve module mapping") for irrelevant_key in ("detector", "detector_type", "flg_available"): for pdu in resp["data"]: @@ -470,20 +475,24 @@ class BaseCalcatFriend: @utils.threadsafe_cache def calibration_id(self, calibration_name: str): - resp = Calibration.get_by_name(self.client, calibration_name) + with self.api_lock: + resp = Calibration.get_by_name(self.client, calibration_name) self._check_resp( - resp, CalibrationNotFound, f"Calibration type {calibration_name} not found!" + resp, + CalibrationNotFound, + f"Calibration type {calibration_name} not found!" ) return resp["data"]["id"] def condition_ids(self, pdu, condition): - # modifying condition parameter messes with cache + # note: do not cache, let CalCat search new condition IDs as they may differ condition_with_detector = copy.copy(condition) condition_with_detector["Detector UUID"] = pdu self.device.log.DEBUG(f"Look for condition: {condition_with_detector}") - resp = self.client.search_possible_conditions_from_dict( - "", condition_with_detector.encode() - ) + with self.api_lock: + resp = self.client.search_possible_conditions_from_dict( + "", condition_with_detector.encode() + ) self._check_resp( resp, ConditionNotFound, @@ -492,17 +501,17 @@ class BaseCalcatFriend: return [d["id"] for d in resp["data"]] def constant_ids(self, calibration_id, condition_ids): - resp = CalibrationConstant.get_all_by_conditions( - self.client, - calibration_id=calibration_id, - detector_type_id=self.detector_type_id, - condition_ids=condition_ids, - ) + with self.api_lock: + resp = CalibrationConstant.get_all_by_conditions( + self.client, + calibration_id=calibration_id, + detector_type_id=self.detector_type_id, + condition_ids=condition_ids, + ) self._check_resp(resp, warning="Failed to retrieve constant ID") return [d["id"] for d in resp["data"]] def get_constant_version(self, constant): - # TODO: catch exceptions, give warnings appropriately karabo_da = self._get_param("karaboDa") self.device.log_status_info(f"Attempting to find {constant} for {karabo_da}") @@ -530,15 +539,16 @@ class BaseCalcatFriend: ) self._set_status(constant, "constantIds", constant_ids) - resp = CalibrationConstantVersion.get_closest_by_time( - self.client, - calibration_constant_ids=constant_ids, - physical_detector_unit_id=self._karabo_da_to_id[karabo_da], - event_at=self._get_param("constantVersionEventAt"), - snapshot_at=None, - ) + with self.api_lock: + resp = CalibrationConstantVersion.get_closest_by_time( + self.client, + calibration_constant_ids=constant_ids, + physical_detector_unit_id=self._karabo_da_to_id[karabo_da], + event_at=self._get_param("constantVersionEventAt"), + snapshot_at=None, + ) self._check_resp(resp, warning="Failed to find calibration constant version") - # TODO: replace with start date and end date + # note: could consider adding end date timestamp = resp["data"]["begin_validity_at"] self._set_status(constant, "beginValidityAt", timestamp) self._set_status(constant, "constantVersionId", resp["data"]["id"]) @@ -556,43 +566,37 @@ class BaseCalcatFriend: ) self._set_status(constant, "dataFilePath", str(file_path)) self._set_status(constant, "dataSetName", resp["data"]["data_set_name"]) - # TODO: handle FileNotFoundError if we are led astray - with h5py.File(file_path, "r") as fd: - constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) + + constant_data = _read_dataset_externally( + file_path, resp["data"]["data_set_name"] + ) + with self.cached_constants_lock: self.cached_constants[constant] = constant_data - self._set_status(constant, "found", True) - self.device.log_status_info(f"Done finding {constant} for {karabo_da}") + self.device.log_status_info(f"Done finding {constant.name} for {karabo_da}") return constant_data - def get_constant_version_and_call_me_back(self, constant, callback): - """Runs get_constant_version in thread, will call callback on completion""" - # TODO: do we want to use asyncio / "modern" async? - # TODO: consider moving out of this class, closer to correction device - def aux(): - with self.api_lock: - data = self.get_constant_version(constant) - callback(constant, data) - - thread = threading.Thread(target=aux) - thread.start() - return thread - - def get_overridden_constant_version(self, constant): + def get_constant_from_constant_version_id(self, constant): # TODO: warn if PDU or constant type does not match # TODO: warn if result is list (happens for empty version ID) constant_version_id = self.device.get( - f"{self.status_prefix}.{constant.name}.constantVersionId" + f"foundConstants.{constant.name}.constantVersionId" ) - resp = CalibrationConstantVersion.get_by_id(self.client, constant_version_id) + with self.api_lock: + resp = CalibrationConstantVersion.get_by_id( + self.client, constant_version_id + ) self._check_resp(resp, warning="Failed to find calibration constant version") file_path = ( self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] ) - with h5py.File(file_path, "r") as fd: - constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) + self._set_status(constant, "dataFilePath", str(file_path)) + self._set_status(constant, "dataSetName", resp["data"]["data_set_name"]) + constant_data = _read_dataset_externally( + file_path, resp["data"]["data_set_name"] + ) with self.cached_constants_lock: self.cached_constants[constant] = constant_data self._set_status(constant, "beginValidityAt", resp["data"]["begin_at"]) @@ -600,51 +604,46 @@ class BaseCalcatFriend: self._set_status(constant, "usedConditionId", "manual override") self._set_status(constant, "usedConstantId", "manual override") self._set_status(constant, "constantVersionId", constant_version_id) - self._set_status(constant, "found", True) - return constant_data - def get_overridden_constant_version_and_call_me_back(self, constant, callback): - """Blindly load whatever CalCat points to for CCV - user must be confident that - this CCV corresponds to correct kind of constant.""" + return constant_data - # TODO: warn user about all the things that go wrong - def aux(): - with self.api_lock: - data = self.get_overridden_constant_version(constant) - callback(constant, data) + def get_constant_from_file(self, constant): + constant_data = _read_dataset_externally( + self.device.get(f"foundConstants.{constant.name}.dataFilePath"), + self.device.get(f"foundConstants.{constant.name}.dataSetName"), + ) + with self.cached_constants_lock: + self.cached_constants[constant] = constant_data + self._set_status(constant, "beginValidityAt", "manual override") + self._set_status(constant, "calibrationId", "manual override") + self._set_status(constant, "usedConditionId", "manual override") + self._set_status(constant, "usedConstantId", "manual override") + self._set_status(constant, "constantVersionId", "manual override") - thread = threading.Thread(target=aux) - thread.start() - return thread + return constant_data - def get_overridden_constant_from_file_and_call_me_back(self, constant, callback): - def aux(): - file_path = self.device.get( - f"{self.status_prefix}.{constant.name}.dataFilePath" - ) - data_set_name = self.device.get( - f"{self.status_prefix}.{constant.name}.dataSetName" - ) - with h5py.File(file_path, "r") as fd: - constant_data = np.array(fd[data_set_name]["data"]) - with self.cached_constants_lock: - self.cached_constants[constant] = constant_data - self._set_status(constant, "beginValidityAt", "manual override") - self._set_status(constant, "calibrationId", "manual override") - self._set_status(constant, "usedConditionId", "manual override") - self._set_status(constant, "usedConstantId", "manual override") - self._set_status(constant, "constantVersionId", "manual override") - self._set_status(constant, "found", True) - callback(constant, constant_data) - - thread = threading.Thread(target=aux) - thread.start() - return thread - - def flush_constants(self): - for constant in self._constant_enum_class: - self._set_status(constant, "beginValidityAt", "") - self._set_status(constant, "found", False) + def flush_constants(self, constants=None, preserve_fields=None): + if preserve_fields is None: + preserve_fields = set() + reset_fields = { + "beginValidityAt", + "calibrationId", + "usedConditionId", + "usedConstantId", + "constantVersionId", + "dataFilePath", + "dataSetName", + } - set(preserve_fields) + if constants is None: + constants = self._constant_enum_class + with self.cached_constants_lock: + for constant in constants: + for field in reset_fields: + self._set_status(constant, field, "") + if "state" not in preserve_fields: + self._set_status(constant, "state", "OFF") + if constant in self.cached_constants: + del self.cached_constants[constant] def _check_resp(self, resp, exception=Exception, warning=None): # TODO: probably verify using "info" that exception is the right one @@ -663,3 +662,24 @@ class BaseCalcatFriend: if warning is not None: self.device.log_status_warn(warning) raise to_raise + + +def _read_dataset_externally(file_path, data_set_name, append_data=True): + def aux(queue): + try: + with h5py.File(file_path, "r") as fd: + if append_data: + res = fd[data_set_name]["data"] + else: + res = fd[data_set_name] + queue.put(np.array(res)) + except Exception as ex: + queue.put(ex) + + res_queue = multiprocessing.Queue() + process = multiprocessing.Process(target=aux, args=(res_queue,)) + process.start() + constant_data = res_queue.get() + if isinstance(constant_data, Exception): + raise constant_data + return constant_data diff --git a/src/calng/base_condition.py b/src/calng/base_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..2583ab7bdf63a46f5160b1fd5faee9373392b401 --- /dev/null +++ b/src/calng/base_condition.py @@ -0,0 +1,330 @@ +import enum + +from karabo.middlelayer import ( + Assignment, + AccessMode, + AccessLevel, + Bool, + DaqPolicy, + Configurable, + Device, + Hash, + Proxy, + Slot, + State, + String, + VectorHash, + VectorString, + background, + connectDevice, + disconnectDevice, + getConfiguration, + slot, + waitUntilNew, +) +from . import scenes, utils + + +class PipelineOperationMode(enum.Enum): + MANAGED = enum.auto() + STANDALONE = enum.auto() + + +class KeyMappingRow(Configurable): + managerKey = String(defaultValue="") + controlValue = String(defaultValue="") + idealValue = String(defaultValue="") + managerValue = String(defaultValue="") + matches = String(defaultValue="NORMAL", displayType="State") + + +class ConditionBase(Device): + managerDeviceId = String( + displayedName="Manager device ID", + assignment=Assignment.MANDATORY, + accessMode=AccessMode.INITONLY, + ) + + keyMapping = VectorHash( + displayedName="Key mapping", + description="Read-only table used to show which properties go where.", + rows=KeyMappingRow, + accessMode=AccessMode.READONLY, + ) + + conditionsMatch = String(displayType="State") + + @property + def keys_to_get(self): + """Must return dict mapping device IDs to tuples of (device ID key, correction + constant condition parameter key, [optional, can be None] function to + translate control device value into condition parameter value. + """ + raise NotImplementedError("Subclass must implement!") + + async def onInitialization(self): + self.state = State.INIT + self._manager_dev = await connectDevice(self.managerDeviceId.value) + if self._manager_dev.classId.value == "CalibrationManager": + self._operation_mode = PipelineOperationMode.MANAGED + else: + self._operation_mode = PipelineOperationMode.STANDALONE + self.state = State.ON + + def _trigger_constant_loading(self): + if self._operation_mode is PipelineOperationMode.MANAGED: + background(self._manager_dev.managedKeys.loadMostRecentConstants()) + else: + background(self._manager_dev.loadMostRecentConstants()) + + @property + def _manager_parameters_node(self): + if self._operation_mode is PipelineOperationMode.MANAGED: + return self._manager_dev.managedKeys.constantParameters + else: + return self._manager_dev.constantParameters + + updateManagerOnMonitor = Bool( + displayedName="Monitor: update manager", + description="Whenever parameter changes on monitored control devices, " + "automatically reflect changes in operating conditions set on manager.", + defaultValue=True, + assignment=Assignment.OPTIONAL, + accessMode=AccessMode.RECONFIGURABLE, + ) + + loadConstantsOnMonitor = Bool( + displayedName="Monitor: load constants", + description="After automatically reflecting control device parameter settings " + "on manager, automatically trigger loading of most recent constants matching " + "these parameters. Should probably only be used in conjunction with " + "updateManagerOnMonitor.", + defaultValue=False, + assignment=Assignment.OPTIONAL, + accessMode=AccessMode.RECONFIGURABLE, + ) + + @Slot( + allowedStates=[State.ON], + displayedName="Start monitoring", + description="Will keep settings on manager in sync with control device(s) " + "automagically by monitoring for updates.", + ) + async def startMonitoring(self): + self.state = State.CHANGING + control_devs = { + control_id: await connectDevice(control_id) + for control_id in self.keys_to_get.keys() + } + + # first bring in line + if ( + await self._check_or_update( + update_manager=self.updateManagerOnMonitor.value, + control_devs=control_devs, + ) + and self.loadConstantsOnMonitor.value + ): + self._trigger_constant_loading() + + # then start monitoring + async def aux(): + while True: + await waitUntilNew( + *( + utils.rec_getattr(control_devs[control_id], control_key) + for control_id, v in self.keys_to_get.items() + for control_key, *_ in v # "v" for "variable naming is hard" + ) + ) + await self._check_or_update( + self.updateManagerOnMonitor.value, control_devs=control_devs + ) + # TODO: maybe debounce + if self.loadConstantsOnMonitor.value: + self._trigger_constant_loading() + + self._monitor = background(aux) + self._control_devs = control_devs + self.state = State.MONITORING + + @Slot( + allowedStates=[State.MONITORING], + displayedName="Stop monitoring", + ) + async def stopMonitoring(self): + self.state = State.CHANGING + self._monitor.cancel() + for dev in self._control_devs.values(): + await disconnectDevice(dev) + self.state = State.ON + + async def _check_or_update(self, update_manager, control_devs=None): + """Will compare values of keys_to_get on control device(s) to actual values on + manager. Optionally changing manager settings accordingly. Optionally provide + through control_devs a dict mapping control device ID to control device + configurations or proxies. + """ + manager_params_checked = set() + key_mapping = [] + happy = True + updated_manager = False + for control_id, to_get in self.keys_to_get.items(): + if control_devs is None: + control_dev = await getConfiguration(control_id) + else: + control_dev = control_devs[control_id] + + for control_key, manager_key, translator in to_get: + ( + control_value, + ideal_value, + manager_value, + ), could_look_up = self._look_up( + control_dev, control_key, manager_key, translator + ) + row_state = State.ON + if not could_look_up: + happy = False + row_state = State.UNKNOWN + elif ideal_value != manager_value: + if update_manager: + try: + setattr( + self._manager_parameters_node, + manager_key, + ideal_value, + ) + except Exception as ex: + manager_value = f"Failed to set {ideal_value}: {ex}" + row_state = State.ERROR + else: + manager_value = ideal_value + updated_manager = True + else: + happy = False + row_state = State.ERROR + key_mapping.append( + ( + manager_key, + control_value, + ideal_value, + manager_value, + row_state.value, + ) + ) + manager_params_checked.add(manager_key) + + for unchecked_parameter in ( + set(dir(self._manager_parameters_node)) - manager_params_checked + ): + key_mapping.append( + ( + unchecked_parameter, + "", + "", + utils.rec_getattr( + self._manager_parameters_node, unchecked_parameter + ).value, + State.IGNORING.value, + ) + ) + + self.keyMapping = key_mapping + if happy: + self.conditionsMatch = State.ON.value + else: + self.conditionsMatch = State.ERROR.value + + return updated_manager + + def _look_up( + self, + control_dev, + control_key, + manager_key, + translator, + ): + could_look_up = True + if isinstance(control_dev, Proxy): + # device proxy via connectDevice + try: + control_value = utils.rec_getattr(control_dev, control_key).value + except AttributeError: + control_value = "key not found" + could_look_up = False + else: + # device configuration hash via getConfiguration + if control_dev.has(control_key): + control_value = control_dev[control_key] + else: + control_value = "key not found" + could_look_up = False + + if could_look_up: + try: + ideal_value = ( + control_value if translator is None else translator(control_value) + ) + except Exception as ex: + warning = f"Failed to process control value {control_value}; {ex}" + ideal_value = warning + self.log.WARN(warning) + could_look_up = False + else: + ideal_value = "control key not found" + + try: + manager_value = utils.rec_getattr( + self._manager_parameters_node, manager_key + ).value + except AttributeError: + manager_value = "key not found" + could_look_up = False + + return (control_value, ideal_value, manager_value), could_look_up + + @Slot( + allowedStates=[State.ON], + displayedName="Check conditions", + description="Compares current settings on manager to control device(s), but " + "does not change anything on manager.", + ) + async def checkConditions(self): + self.state = State.CHANGING + await self._check_or_update(False) + self.state = State.ON + + @Slot( + allowedStates=[State.ON], + displayedName="Copy conditions", + description="Copies current settings from control device(s) to manager. See " + "also 'Start monitoring' for automatic version of this.", + ) + async def updateConditions(self): + self.state = State.CHANGING + await self._check_or_update(True) + self.state = State.ON + + availableScenes = VectorString( + displayedName="Available scenes", + displayType="Scenes", + requiredAccessLevel=AccessLevel.OBSERVER, + accessMode=AccessMode.READONLY, + defaultValue=[ + "overview", + ], + daqPolicy=DaqPolicy.OMIT, + ) + + @slot + def requestScene(self, params): + name = params.get("name", default="overview") + if name == "overview": + scene_data = scenes.condition_checker_overview( + self.deviceId, + self.getDeviceSchema(), + ) + payload = Hash("success", True, "name", name, "data", scene_data) + + return Hash("type", "deviceScene", "origin", self.deviceId, "payload", payload) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 3f58ca800c5e2a5c1837a2fbb0bd51711db4b43d..bd88ad7df4d59ae3602358ecd6045dcd3ab95c75 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -1,4 +1,7 @@ +import concurrent.futures +import contextlib import enum +import functools import gc import pathlib import threading @@ -11,9 +14,7 @@ from karabo.bound import ( DOUBLE_ELEMENT, INPUT_CHANNEL, INT32_ELEMENT, - INT64_ELEMENT, KARABO_CLASSINFO, - NDARRAY_ELEMENT, NODE_ELEMENT, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, @@ -27,14 +28,13 @@ from karabo.bound import ( Hash, MetricPrefix, PythonDevice, - Schema, State, Timestamp, Unit, ) from karabo.common.api import KARABO_SCHEMA_DISPLAY_TYPE_SCENES as DT_SCENES -from . import scenes, shmem_utils, utils +from . import scenes, shmem_utils, schemas, utils from ._version import version as deviceVersion PROCESSING_STATE_TIMEOUT = 10 @@ -46,164 +46,59 @@ class FramefilterSpecType(enum.Enum): COMMASEPARATED = "commaseparated" -preview_schema = Schema() -( - NODE_ELEMENT(preview_schema).key("image").commit(), - - NDARRAY_ELEMENT(preview_schema).key("image.data").dtype("FLOAT").commit(), - - UINT64_ELEMENT(preview_schema) - .key("image.trainId") - .displayedName("Train ID") - .assignmentOptional() - .defaultValue(0) - .commit(), -) - -# TODO: trim output schema / adapt to specific detectors -# currently: based on snapshot of actual output reusing AGIPD hash -output_schema = Schema() -( - NODE_ELEMENT(output_schema).key("image").commit(), - - STRING_ELEMENT(output_schema) - .key("image.data") - .assignmentOptional() - .defaultValue("") - .commit(), - - NDARRAY_ELEMENT(output_schema).key("image.length").dtype("UINT32").commit(), - - NDARRAY_ELEMENT(output_schema).key("image.cellId").dtype("UINT16").commit(), - - NDARRAY_ELEMENT(output_schema).key("image.pulseId").dtype("UINT64").commit(), - - NDARRAY_ELEMENT(output_schema).key("image.status").commit(), - - NDARRAY_ELEMENT(output_schema).key("image.trainId").dtype("UINT64").commit(), - - VECTOR_STRING_ELEMENT(output_schema) - .key("calngShmemPaths") - .assignmentOptional() - .defaultValue(["image.data"]) - .commit(), - - NODE_ELEMENT(output_schema).key("metadata").commit(), - - STRING_ELEMENT(output_schema) - .key("metadata.source") - .assignmentOptional() - .defaultValue("") - .commit(), - - NODE_ELEMENT(output_schema).key("metadata.timestamp").commit(), - - INT32_ELEMENT(output_schema) - .key("metadata.timestamp.tid") - .assignmentOptional() - .defaultValue(0) - .commit(), - - NODE_ELEMENT(output_schema).key("header").commit(), - - INT32_ELEMENT(output_schema) - .key("header.minorTrainFormatVersion") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT32_ELEMENT(output_schema) - .key("header.majorTrainFormatVersion") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT32_ELEMENT(output_schema) - .key("header.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT64_ELEMENT(output_schema) - .key("header.linkId") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT64_ELEMENT(output_schema) - .key("header.dataId") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT64_ELEMENT(output_schema) - .key("header.pulseCount") - .assignmentOptional() - .defaultValue(0) - .commit(), - - NDARRAY_ELEMENT(output_schema).key("header.reserved").commit(), - - NDARRAY_ELEMENT(output_schema).key("header.magicNumberBegin").commit(), - - NODE_ELEMENT(output_schema).key("detector").commit(), - - INT32_ELEMENT(output_schema) - .key("detector.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), - - NDARRAY_ELEMENT(output_schema).key("detector.data").commit(), - - NODE_ELEMENT(output_schema).key("trailer").commit(), - - NDARRAY_ELEMENT(output_schema).key("trailer.checksum").commit(), - - NDARRAY_ELEMENT(output_schema).key("trailer.magicNumberEnd").commit(), - - INT32_ELEMENT(output_schema) - .key("trailer.status") - .assignmentOptional() - .defaultValue(0) - .commit(), - - INT32_ELEMENT(output_schema) - .key("trailer.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), -) +class WarningLampType(enum.Enum): + FRAME_FILTER = enum.auto() + MEMORY_CELL_RANGE = enum.auto() + PREVIEW_SETTINGS = enum.auto() + CORRECTION_RUNNER = enum.auto() + OUTPUT_BUFFER = enum.auto() + GPU_MEMORY = enum.auto() + CALCAT_CONNECTION = enum.auto() + EMPTY_HASH = enum.auto() + MISC_INPUT_DATA = enum.auto() + TRAIN_ID = enum.auto() + TIMESERVER_CONNECTION = enum.auto() @KARABO_CLASSINFO("BaseCorrection", deviceVersion) class BaseCorrection(PythonDevice): + _constant_enum_class = None # subclass must set _correction_flag_class = None # subclass must set (ex.: dssc_gpu.CorrectionFlags) + _correction_steps = None # subclass must set _kernel_runner_class = None # subclass must set (ex.: dssc_gpu.DsscGpuRunner) _kernel_runner_init_args = {} # optional extra args for runner _managed_keys = { - "outputShmemBufferSize", "dataFormat.outputAxisOrder", "dataFormat.outputImageDtype", "dataFormat.overrideInputAxisOrder", - "frameFilter.type", "frameFilter.spec", - "preview.enable", + "frameFilter.type", + "loadMostRecentConstants", + "outputShmemBufferSize", "preview.index", "preview.selectionMode", - "preview.trainIdModulo", - "loadMostRecentConstants", + "runAsStandaloneModule", + "useInfiniband", } # subclass can extend this, /must/ put it in schema as managedKeys _image_data_path = "image.data" # customize for *some* subclasses _cell_table_path = "image.cellId" _warn_memory_cell_range = True # can be disabled for some detectors - _cuda_pin_buffers = True + _cuda_pin_buffers = False - def _load_constant_to_runner(self, constant_name, constant_data): + def _load_constant_to_runner(self, constant, constant_data): """Subclass must define how to process constants into correction maps and store into appropriate buffers in (GPU or main) memory.""" raise NotImplementedError() + def _successfully_loaded_constant_to_runner(self, constant): + field_name = self._constant_to_correction_name[constant] + key = f"corrections.{field_name}.available" + if not self.unsafe_get(key): + self.set(key, True) + + self._update_correction_flags() + self.log_status_info(f"Done loading {constant.name} to GPU") + @property def input_data_shape(self): """Subclass must define expected input data shape in terms of dataFormat.{ @@ -213,15 +108,14 @@ class BaseCorrection(PythonDevice): @property def output_data_shape(self): """Shape of corrected image data sent on dataOutput. Depends on data format - parameters pixels x / y, and number of cells (optionally after frame filter).""" + parameters (optionally including frame filter).""" axis_lengths = { "x": self.unsafe_get("dataFormat.pixelsX"), "y": self.unsafe_get("dataFormat.pixelsY"), - "c": self.unsafe_get("dataFormat.filteredFrames"), + "f": self.unsafe_get("dataFormat.filteredFrames"), } return tuple( - axis_lengths[axis] - for axis in self.unsafe_get("dataFormat.outputAxisOrder") + axis_lengths[axis] for axis in self.unsafe_get("dataFormat.outputAxisOrder") ) def process_data( @@ -232,7 +126,6 @@ class BaseCorrection(PythonDevice): train_id, image_data, cell_table, - do_generate_preview, ): """Subclass must define data processing (presumably using the kernel runner). Will be called by input_handler, which will take care of some common checks and @@ -249,11 +142,6 @@ class BaseCorrection(PythonDevice): INPUT_CHANNEL(expected).key("dataInput").commit(), - OUTPUT_CHANNEL(expected) - .key("dataOutput") - .dataSchema(output_schema) - .commit(), - VECTOR_STRING_ELEMENT(expected) .key("fastSources") .displayedName("Fast data sources") @@ -282,7 +170,7 @@ class BaseCorrection(PythonDevice): STRING_ELEMENT(expected) .key("frameFilter.type") - .displayedName("Filter definition type") + .displayedName("Type") .description( "Controls how frameFilter.spec is used. The default value of 'none' " "means that no filter is set (regardless of frameFilter.spec). " @@ -298,6 +186,7 @@ class BaseCorrection(PythonDevice): STRING_ELEMENT(expected) .key("frameFilter.spec") + .displayedName("Specification") .assignmentOptional() .defaultValue("") .reconfigurable() @@ -334,6 +223,31 @@ class BaseCorrection(PythonDevice): .readOnly() .initialValue(["overview", "constant_overrides"]) .commit(), + + BOOL_ELEMENT(expected) + .key("runAsStandaloneModule") + .displayedName("Standalone mode") + .description( + "If enabled, full corrected data (not using shared memory handles) " + "will be sent on main output and preview outputs will be configured " + "to be suitable for use directly into Karabo GUI rather than through " + "an assembler." + ) + .assignmentOptional() + .defaultValue(False) + .commit(), + + BOOL_ELEMENT(expected) + .key("useInfiniband") + .displayedName("Use infiniband") + .description( + "If enabled, device will during initialization try to bind its main " + "data output channel (dataOutput) to its node's infiniband interface. " + "Default interface is used if no 'ib0' interface is found." + ) + .assignmentOptional() + .defaultValue(True) + .commit(), ) ( @@ -357,6 +271,21 @@ class BaseCorrection(PythonDevice): .reconfigurable() .commit(), + UINT64_ELEMENT(expected) + .key("dataFormat.trainFromFutureThreshold") + .displayedName("Spurious future train ID threshold") + .description( + "Some detectors occasionally send a train with incorrect and much too " + "large train ID. To avoid these 'future trains' from interfering with " + "train matching, use this threshold to discard them immediately. If a " + "train arrives with an ID which exceeds the current train ID from the " + "time server by more than this threshold, the train is ignored." + ) + .assignmentOptional() + .defaultValue(10000) + .reconfigurable() + .commit(), + STRING_ELEMENT(expected) .key("dataFormat.inputImageDtype") .displayedName("Input image data dtype") @@ -400,9 +329,9 @@ class BaseCorrection(PythonDevice): .commit(), UINT32_ELEMENT(expected) - .key("dataFormat.memoryCells") - .displayedName("Memory cells") - .description("Full number of memory cells in incoming data") + .key("dataFormat.frames") + .displayedName("Frames") + .description("Number of image frames per train in incoming data") .assignmentOptional() .defaultValue(1) # subclass will want to set a default value .commit(), @@ -420,13 +349,13 @@ class BaseCorrection(PythonDevice): .displayedName("Output axis order") .description( "Axes of main data output can be reordered after correction. Axis " - "order is specified as string consisting of 'x', 'y', and 'c', with " - "the latter indicating the memory cell axis. The default value of " - "'cxy' puts pixels on the fast axes." + "order is specified as string consisting of 'x', 'y', and 'f', with " + "the latter indicating the image frame. The default value of 'fxy' " + "puts pixels on the fast axes." ) - .options("cxy,cyx,xcy,xyc,ycx,yxc") + .options("fxy,fyx,xfy,xyf,yfx,yxf") .assignmentOptional() - .defaultValue("cxy") + .defaultValue("fxy") .reconfigurable() .commit(), @@ -435,7 +364,7 @@ class BaseCorrection(PythonDevice): .displayedName("Input data shape") .description( "Image data shape in incoming data (from reader / DAQ). This value is " - "computed from pixelsX, pixelsY, and memoryCells - this field just " + "computed from pixelsX, pixelsY, and frames - this field just " "shows what is currently expected." ) .readOnly() @@ -461,9 +390,9 @@ class BaseCorrection(PythonDevice): .displayedName("Load most recent constants") .description( "Calling this slot will flush all constant buffers and cause the " - "device to start querying CalCat for the most recent constants - all " - "constants applicable for this device - available with the currently " - "set constant parameters. This is typically called after " + "device to start querying CalCat for the most recent constants (all " + "constants applicable for this device) available with the currently " + "set constant parameters. This should typically be called after " "instantiating pipeline, after changing parameters, or after " "generating new constants." ) @@ -471,24 +400,42 @@ class BaseCorrection(PythonDevice): ) ( - NODE_ELEMENT(expected).key("preview").displayedName("Preview").commit(), + STRING_ELEMENT(expected) + .key("inputDataState") + .setSpecialDisplayType("State") + .readOnly() + .initialValue("NORMAL") + .commit(), + + STRING_ELEMENT(expected) + .key("deviceInternalsState") + .setSpecialDisplayType("State") + .readOnly() + .initialValue("NORMAL") + .commit(), + + STRING_ELEMENT(expected) + .key("processingState") + .setSpecialDisplayType("State") + .readOnly() + .initialValue("NORMAL") + .commit(), + ) + + ( + NODE_ELEMENT(expected) + .key("preview") + .displayedName("Preview") + .commit(), OUTPUT_CHANNEL(expected) .key("preview.outputRaw") - .dataSchema(preview_schema) + .dataSchema(schemas.preview_schema()) .commit(), OUTPUT_CHANNEL(expected) .key("preview.outputCorrected") - .dataSchema(preview_schema) - .commit(), - - BOOL_ELEMENT(expected) - .key("preview.enable") - .displayedName("Enable preview") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() + .dataSchema(schemas.preview_schema()) .commit(), INT32_ELEMENT(expected) @@ -498,7 +445,7 @@ class BaseCorrection(PythonDevice): "If this value is ≥ 0, the corresponding index (frame, cell, or pulse) " "will be sliced for the preview output. If this value is < 0, preview " "will be one of the following stats: -1: max, -2: mean, -3: sum, -4: " - "stdev. These stats are computed across memory cells." + "stdev. These stats are computed across frames, ignoring NaN values." ) .assignmentOptional() .defaultValue(0) @@ -516,24 +463,13 @@ class BaseCorrection(PythonDevice): "at cell (or pulse) table for the requested cell (or pulse ID). " "Special (stat) index values <0 are not affected by this." ) - .options("frame,cell,pulse") - .assignmentOptional() - .defaultValue("frame") - .reconfigurable() - .commit(), - - UINT32_ELEMENT(expected) - .key("preview.trainIdModulo") - .displayedName("Preview train stride") - .description( - "Preview will only be generated for trains whose ID modulo this " - "number is zero. Higher values means less frequent preview updates. " - "Keep in mind that the GUI has limited refresh rate. Extra care must " - "be taken if DAQ train stride is >1." + .options( + ",".join( + spectype.value for spectype in utils.PreviewIndexSelectionMode + ) ) - .unit(Unit.COUNT) .assignmentOptional() - .defaultValue(1) + .defaultValue("frame") .reconfigurable() .commit(), ) @@ -598,8 +534,31 @@ class BaseCorrection(PythonDevice): .displayedName("Correction steps") .commit(), ) + ( + VECTOR_STRING_ELEMENT(expected) + .key("warningLamps") + .assignmentOptional() + .defaultValue( + [ + "deviceInternalsState", + "inputDataState", + "processingState", + ] + ) + .commit() + ) def __init__(self, config): + if config.get("useInfiniband", default=True): + from PipeToZeroMQ.utils import find_infiniband_ip + + ib_ip = find_infiniband_ip() + config["dataOutput.hostname"] = ib_ip + if not config.get("runAsStandaloneModule", default=False): + for key in config.get("preview").getKeys(): + path = f"preview.{key}" + if key.startswith("output") and config.has(f"{path}.hostname"): + config[f"{path}.hostname"] = ib_ip super().__init__(config) self.input_data_dtype = np.dtype(config["dataFormat.inputImageDtype"]) @@ -609,43 +568,72 @@ class BaseCorrection(PythonDevice): self.kernel_runner = None # must call _update_buffers to initialize self._shmem_buffer = None # ditto + self._preview_friend = None # used in standalone mode (see JungfrauCorrection) self._correction_flag_enabled = self._correction_flag_class.NONE self._correction_flag_preview = self._correction_flag_class.NONE + self._correction_applied_hash = Hash() + # note: does not handle one constant enabling multiple correction steps + # (do we need to generalize "decide if this correction step is available"?) + self._constant_to_correction_name = { + constant: name + for (name, _, constants) in self._correction_steps + for constant in constants + } self._buffer_lock = threading.Lock() self._last_processing_started = 0 # used for processing time and timeout # register slots - # TODO: the CalCatFriend could add these for us - # note: overly complicated for closure to work - def make_wrapper_1(constant): - def aux(): - self.calcat_friend.get_overridden_constant_version_and_call_me_back( - constant, self._load_constant_to_runner - ) - - return aux - - def make_wrapper_2(constant): + def constant_override_fun(friend_fun, constant, preserve_fields): def aux(): - self.calcat_friend.get_overridden_constant_from_file_and_call_me_back( - constant, self._load_constant_to_runner + self.flush_constants( + constants={constant}, preserve_fields=preserve_fields ) - - return aux + with self.warning_context( + f"foundConstants.{constant.name}.state", on_success="ON" + ) as warn: + try: + constant_data = getattr(self.calcat_friend, friend_fun)( + constant + ) + self._load_constant_to_runner(constant, constant_data) + except Exception as e: + warn(f"Failed to load {constant}: {e}") + else: + self._successfully_loaded_constant_to_runner(constant) + # note: consider if multi-constant buffers need special care + self._lock_and_update(aux) for constant in self._constant_enum_class: - slot_name = f"foundConstants.{constant.name}.overrideConstantVersion" - meth_name = slot_name.replace(".", "_") self.KARABO_SLOT( - make_wrapper_1(constant), - slotName=meth_name, + functools.partial( + constant_override_fun, + friend_fun="get_constant_version", + constant=constant, + preserve_fields=set(), + ), + slotName=f"foundConstants_{constant.name}_loadMostRecent", + numArgs=0, ) - slot_name = f"foundConstants.{constant.name}.overrideConstantFromFile" - meth_name = slot_name.replace(".", "_") self.KARABO_SLOT( - make_wrapper_2(constant), - slotName=meth_name, + functools.partial( + constant_override_fun, + friend_fun="get_constant_from_constant_version_id", + constant=constant, + preserve_fields={"constantVersionId"}, + ), + slotName=f"foundConstants_{constant.name}_overrideConstantFromVersion", + numArgs=0, + ) + self.KARABO_SLOT( + functools.partial( + constant_override_fun, + friend_fun="get_constant_from_file", + constant=constant, + preserve_fields={"dataFilePath", "dataSetName"}, + ), + slotName=f"foundConstants_{constant.name}_overrideConstantFromFile", + numArgs=0, ) self.KARABO_SLOT(self.loadMostRecentConstants) @@ -653,15 +641,69 @@ class BaseCorrection(PythonDevice): self.registerInitialFunction(self._initialization) - def _initialization(self): - self.calcat_friend = self._calcat_friend_class( - self, pathlib.Path.cwd() / "calibration-client-secrets.json" - ) + @contextlib.contextmanager + def warning_context( + self, + schema_key, + warn_type=None, + on_success="NORMAL", + on_error="ERROR", + reraise=True, + only_print_once=False, + ): + tracker = self._warning_trackers[schema_key] + warn_fun = tracker.new_context(warn_type, only_print_once) try: - self._frame_filter = _parse_frame_filter(self._parameters) - except (ValueError, TypeError): - self.log_status_warn("Failed to parse initial frame filter, will not use") - self._frame_filter = None + yield warn_fun + except Exception as e: + warn_fun(f"Exception happened for {schema_key}, {warn_type}: {e}") + if reraise: + raise e + finally: + tracker.update_state(on_success=on_success, on_error=on_error) + + def _initialization(self): + self._warning_trackers = { + key: utils.ContextWarningLamp(self, key) + for key in self.getFullSchema().getDefaultValue("warningLamps") + } + for constant in self._constant_enum_class: + key = f"foundConstants.{constant.name}.state" + self._warning_trackers[key] = utils.ContextWarningLamp(self, key) + + with self.warning_context( + "deviceInternalsState", WarningLampType.CALCAT_CONNECTION + ) as warn: + try: + self.calcat_friend = self._calcat_friend_class( + self, pathlib.Path.cwd() / "calibration-client-secrets.json" + ) + except Exception as e: + warn(f"Failed to connect to CalCat: {e}") + # TODO: add raw fallback mode if CalCat fails (raw data still useful) + return + + # check time server connection + with self.warning_context( + "deviceInternalsState", WarningLampType.TIMESERVER_CONNECTION + ) as warn: + if self.getActualTimestamp().getTrainId() == 0: + warn( + "Warning: likely missing connection to time server, " + "cannot threshold against future train IDs" + ) + self.set("dataFormat.trainFromFutureThreshold", np.iinfo(np.uint64).max) + + with self.warning_context( + "processingState", WarningLampType.FRAME_FILTER + ) as warn: + try: + self._frame_filter = _parse_frame_filter(self._parameters) + except (ValueError, TypeError): + warn("Failed to parse initial frame filter, will not use") + self._frame_filter = None + # TODO: decide how persistent initial warning should be + self._update_correction_flags() self._update_frame_filter() self._buffered_status_update = Hash( @@ -680,9 +722,7 @@ class BaseCorrection(PythonDevice): interval=1, callback=self._update_rate_and_state, ) - self._train_ratio_tracker = utils.TrainRatioTracker( - warn_callback=self.log_status_warn - ) + self._train_ratio_tracker = utils.TrainRatioTracker() self.KARABO_ON_INPUT("dataInput", self.input_handler) self.KARABO_ON_EOS("dataInput", self.handle_eos) @@ -727,7 +767,7 @@ class BaseCorrection(PythonDevice): update = self._prereconfigure_update_hash if update.has("frameFilter"): - self._lock_and_update_in_background(self._update_frame_filter) + self._lock_and_update(self._update_frame_filter) elif any( update.has(shape_param) for shape_param in ( @@ -735,35 +775,86 @@ class BaseCorrection(PythonDevice): "dataFormat.pixelsY", "dataFormat.outputImageDtype", "dataFormat.outputAxisOrder", - "dataFormat.memoryCells", + "dataFormat.frames", "constantParameters.memoryCells", "frameFilter", ) ): - self._lock_and_update_in_background(self._update_buffers) - # TODO: only call this if they are changed (is cheap, though) - self._update_correction_flags() + self._lock_and_update(self._update_buffers) + + if any( + ( + update.has(f"corrections.{field_name}.enable") + or update.has(f"corrections.{field_name}.preview") + ) + for field_name, *_ in self._correction_steps + ): + self._update_correction_flags() - def _lock_and_update_in_background(self, method): + def _lock_and_update(self, method, background=True): # TODO: securely handle errors (postReconfigure may succeed, device state not) def runner(): with self._buffer_lock, utils.StateContext(self, State.CHANGING): method() - threading.Thread(target=runner, daemon=True).start() - def loadMostRecentConstants(self): - self.flush_constants() - self.calcat_friend.flush_constants() - for constant in self._constant_enum_class: - self.calcat_friend.get_constant_version_and_call_me_back( - constant, self._load_constant_to_runner - ) + if background: + threading.Thread(target=runner, daemon=True).start() + else: + runner() - def flush_constants(self): - """Reset constant buffers and disable corresponding correction steps""" - for correction_step, _ in self._correction_field_names: - self.set(f"corrections.{correction_step}.available", False) - self.kernel_runner.flush_buffers() + def loadMostRecentConstants(self): + def aux(): + self.flush_constants() + # TODO: ignore irrelevant constants (like agpid thresholding in fixed gain) + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_constant = { + executor.submit( + self.calcat_friend.get_constant_version, + constant, + ): constant + for constant in self._constant_enum_class + } + for future in concurrent.futures.as_completed(future_to_constant): + constant = future_to_constant[future] + with self.warning_context( + f"foundConstants.{constant.name}.state", on_success="ON" + ) as warn: + try: + constant_data = future.result() + self._load_constant_to_runner(constant, constant_data) + except Exception as e: + warn(f"Failed to load {constant}: {e}") + else: + self._successfully_loaded_constant_to_runner(constant) + self._lock_and_update(aux) + + def flush_constants(self, constants=None, preserve_fields=None): + """Reset constant buffers; the correction maps (often buffers on GPU) should be + set to identity values and constant caches cleared. By default do this for all + constants - a subset can instead be specified via the constants parameter + (should be iterable of constant enum values). + + After flushing, the kernel runner will be instructed to reload any remaining + cached constants after flushing. This can take a second. + + The parameter preserve_fields is passed to the CalCat friend; this is used for + overrides; in those cases, we want to get value from an overridden field after + flushing everything else, so don't wipe it out. + """ + if constants is None: + constants = set(self._constant_enum_class) + else: + constants = set(constants) + self.kernel_runner.flush_buffers(constants) + for field_name, _, used_constants in self._correction_steps: + if ( + used_constants + and None not in used_constants + and constants >= used_constants + ): + self.set(f"corrections.{field_name}.available", False) + self.calcat_friend.flush_constants(constants, preserve_fields) + self._reload_constants_from_cache_to_runner(constants) self._update_correction_flags() def log_status_info(self, msg): @@ -774,11 +865,6 @@ class BaseCorrection(PythonDevice): self.log.WARN(msg) self.set("status", msg) - def log_status_error(self, msg): - self.set("status", msg) - self.log.ERROR(msg) - self.updateState(State.ERROR) - def requestScene(self, params): payload = Hash() name = params.get("name", default="") @@ -788,15 +874,23 @@ class BaseCorrection(PythonDevice): payload["data"] = scenes.correction_device_overview( device_id=self.getInstanceId(), schema=self.getFullSchema(), + direct_preview=self.get("runAsStandaloneModule"), ) elif name == "constant_overrides": payload["data"] = scenes.correction_device_constant_overrides( device_id=self.getInstanceId(), schema=self.getFullSchema(), ) + elif name.startswith("preview:"): + channel_name = name[len("preview:") :] + payload["data"] = scenes.correction_device_preview( + device_id=self.getInstanceId(), + schema=self.getFullSchema(), + preview_channel=channel_name, + ) elif name.startswith("browse_schema"): if ":" in name: - prefix = name[len("browse_schema:"):] + prefix = name[len("browse_schema:") :] else: prefix = "managed" payload["data"] = scenes.recursive_subschema_scene( @@ -819,13 +913,14 @@ class BaseCorrection(PythonDevice): old_metadata.get("source"), Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), ) + data["corrections"] = self._correction_applied_hash channel = self.signalSlotable.getOutputChannel("dataOutput") - channel.write(data, metadata, False) + channel.write(data, metadata, copyAllData=False) channel.update() def _write_preview_outputs(self, channel_data_pairs, old_metadata): - # TODO: allow sending *all* frames for commissioning (request: Jola) + # consider allowing sending *all* frames for commissioning (request: Jola) timestamp = Timestamp.fromHashAttributes( old_metadata.getAttributes("timestamp") ) @@ -835,25 +930,25 @@ class BaseCorrection(PythonDevice): for channel_name, data in channel_data_pairs: preview_hash.set("image.data", data) channel = self.signalSlotable.getOutputChannel(channel_name) - channel.write(preview_hash, metadata, False) + channel.write(preview_hash, metadata, copyAllData=False) channel.update() def _update_correction_flags(self): """Based on constants loaded and settings, update bit mask flags for kernel""" - available = self._correction_flag_class.NONE enabled = self._correction_flag_class.NONE + output = Hash() # for informing downstream receivers what was applied preview = self._correction_flag_class.NONE - for field_name, flag in self._correction_field_names: + for field_name, flag, constants in self._correction_steps: + output[field_name] = False if self.get(f"corrections.{field_name}.available"): - available |= flag - if self.get(f"corrections.{field_name}.enable"): - enabled |= flag - if self.get(f"corrections.{field_name}.preview"): - preview |= flag - enabled &= available - preview &= available + if self.get(f"corrections.{field_name}.enable"): + enabled |= flag + output[field_name] = True + if self.get(f"corrections.{field_name}.preview"): + preview |= flag self._correction_flag_enabled = enabled self._correction_flag_preview = preview + self._correction_applied_hash = output self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}") self.log.DEBUG(f"Corrections for preview: {str(preview)}") @@ -865,17 +960,22 @@ class BaseCorrection(PythonDevice): self.log.DEBUG("Updating frame filter") if self._frame_filter is None: - self.set("dataFormat.filteredFrames", self.get("dataFormat.memoryCells")) + self.set("dataFormat.filteredFrames", self.get("dataFormat.frames")) self.set("frameFilter.current", []) else: self.set("dataFormat.filteredFrames", self._frame_filter.size) self.set("frameFilter.current", list(map(int, self._frame_filter))) - if self._frame_filter is not None and ( - self._frame_filter.min() < 0 - or self._frame_filter.max() >= self.get("dataFormat.memoryCells") - ): - self.log_status_warn("Invalid frame filter set, expect exceptions!") + with self.warning_context( + "deviceInternalsState", WarningLampType.FRAME_FILTER + ) as warn: + if self._frame_filter is not None and ( + self._frame_filter.min() < 0 + or self._frame_filter.max() >= self.get("dataFormat.frames") + ): + warn("Invalid frame filter set, expect exceptions!") + # note: dataFormat.frames could change from detector and + # that could make this warning invalid if update_buffers: self._update_buffers() @@ -891,7 +991,7 @@ class BaseCorrection(PythonDevice): if self._shmem_buffer is None: shmem_buffer_name = self.getInstanceId() + ":dataOutput" - memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 + memory_budget = self.get("outputShmemBufferSize") * 2**30 self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") self._shmem_buffer = shmem_utils.ShmemCircularBuffer( memory_budget, @@ -922,15 +1022,32 @@ class BaseCorrection(PythonDevice): **self._kernel_runner_init_args, ) + self._reload_constants_from_cache_to_runner() + + def _reload_constants_from_cache_to_runner(self, constants=None): # TODO: gracefully handle change in constantParameters.memoryCells - # TODO: offload to thread to avoid reconfigure of framefilter timing out + if constants is None: + constants = set(self._constant_enum_class) + else: + constants = set(constants) + with self.calcat_friend.cached_constants_lock: for ( constant, data, ) in self.calcat_friend.cached_constants.items(): - self.log_status_info(f"Reload constant {constant}") - self._load_constant_to_runner(constant, data) + if constant not in constants: + continue + with self.warning_context( + f"foundConstants.{constant.name}.state", on_success="ON" + ) as warn: + try: + self.log_status_info(f"Reload constant {constant.name}") + self._load_constant_to_runner(constant, data) + except Exception as e: + warn(f"Failed to reload {constant.name}: {e}") + else: + self._successfully_loaded_constant_to_runner(constant) def input_handler(self, input_channel): """Main handler for data input: Do a few simple checks to determine whether to @@ -942,32 +1059,47 @@ class BaseCorrection(PythonDevice): if state is State.ERROR: # in this case, we should have already issued warning return - elif self.kernel_runner is None: - self.log_status_warn("Received data, but have not initialized kernels yet") - return + + with self.warning_context( + "deviceInternalsState", WarningLampType.CORRECTION_RUNNER + ) as warn: + if self.kernel_runner is None: + warn("Received data, but have not initialized kernels yet") + return all_metadata = input_channel.getMetaData() for input_index in range(input_channel.size()): - self._last_processing_started = default_timer() - data_hash = input_channel.read(input_index) - metadata = all_metadata[input_index] - source = metadata.get("source") - - if source not in self.sources: - self.log_status_info(f"Ignoring hash with unknown source {source}") - continue - elif not data_hash.has(self._image_data_path): - self.log_status_info("Ignoring hash without image node") - continue - - try: - image_data = np.asarray(data_hash.get(self._image_data_path)) - cell_table = np.asarray(data_hash.get(self._cell_table_path)).ravel() - except RuntimeError as err: - self.log_status_info( - f"Failed to load image data; probably empty hash from DAQ: {err}" - ) - continue + with self.warning_context( + "inputDataState", WarningLampType.MISC_INPUT_DATA + ) as warn: + self._last_processing_started = default_timer() + data_hash = input_channel.read(input_index) + metadata = all_metadata[input_index] + source = metadata.get("source") + + if source not in self.sources: + warn(f"Ignoring hash with unknown source {source}") + continue + elif not data_hash.has(self._image_data_path): + warn("Ignoring hash without image node") + continue + + with self.warning_context( + "inputDataState", + WarningLampType.EMPTY_HASH, + only_print_once=True, + ) as warn: + try: + image_data = np.asarray(data_hash.get(self._image_data_path)) + cell_table = np.asarray( + data_hash.get(self._cell_table_path) + ).ravel() + except RuntimeError as err: + warn( + "Failed to load image data; " + f"probably empty hash from DAQ: {err}" + ) + continue # no more common reasons to skip input, so go to processing if state is State.ON: @@ -975,21 +1107,48 @@ class BaseCorrection(PythonDevice): self.log_status_info("Processing data") train_id = metadata.getAttribute("timestamp", "tid") - self._train_ratio_tracker.update(train_id) - - if ( - self._warn_memory_cell_range - and self.unsafe_get("constantParameters.memoryCells") - <= cell_table.max() - ): - self.log_status_warn("Input cell IDs out of range of constants") - - if cell_table.size != self.unsafe_get("dataFormat.memoryCells"): + my_tid = self.getActualTimestamp().getTrainId() + with self.warning_context( + "inputDataState", WarningLampType.TRAIN_ID + ) as warn: + if train_id > ( + my_tid + self.unsafe_get("dataFormat.trainFromFutureThreshold") + ): + warn( + f"Suspecting train from the future: now is {my_tid}, " + f"received train ID {train_id}, dropping data" + ) + continue + try: + self._train_ratio_tracker.update(train_id) + self._buffered_status_update.set( + "performance.ratioOfRecentTrainsReceived", + self._train_ratio_tracker.get(), + ) + except utils.NonMonotonicTrainIdWarning as ex: + warn( + f"Train ratio tracker noticed issue with train ID: {ex}\n" + f"For the record, I think now is: {my_tid}" + ) + self._train_ratio_tracker.reset() + self._train_ratio_tracker.update(train_id) + + with self.warning_context( + "processingState", WarningLampType.MEMORY_CELL_RANGE + ) as warn: + if ( + self._warn_memory_cell_range + and self.unsafe_get("constantParameters.memoryCells") + <= cell_table.max() + ): + warn("Input cell IDs out of range of constants") + + if cell_table.size != self.unsafe_get("dataFormat.frames"): self.log_status_info( f"Updating new input shape to account for {cell_table.size} cells" ) - self.set("dataFormat.memoryCells", cell_table.size) - self._lock_and_update_in_background(self._update_frame_filter) + self.set("dataFormat.frames", cell_table.size) + self._lock_and_update(self._update_frame_filter, background=False) # DataAggregator typically tells us the wrong axis order if self.unsafe_get("dataFormat.overrideInputAxisOrder"): @@ -997,11 +1156,6 @@ class BaseCorrection(PythonDevice): if expected_shape != image_data.shape: image_data.shape = expected_shape - do_generate_preview = ( - train_id % self.unsafe_get("preview.trainIdModulo") == 0 - and self.unsafe_get("preview.enable") - ) - with self._buffer_lock: self.process_data( data_hash, @@ -1010,28 +1164,25 @@ class BaseCorrection(PythonDevice): train_id, image_data, cell_table, - do_generate_preview, ) self._buffered_status_update.set("trainId", train_id) self._processing_time_ema.update( default_timer() - self._last_processing_started ) + self._buffered_status_update.set( + "performance.processingTime", self._processing_time_ema.get() * 1000 + ) self._rate_tracker.update() def _update_rate_and_state(self): if self.get("state") is State.PROCESSING: + # always update rate: estimate depends on query time self._buffered_status_update.set( "performance.rate", self._rate_tracker.get() ) - self._buffered_status_update.set( - "performance.processingTime", self._processing_time_ema.get() * 1000 - ) - self._buffered_status_update.set( - "performance.ratioOfRecentTrainsReceived", - self._train_ratio_tracker.get(), - ) - # trainId in _buffered_status_update should be updated in input handler + # remaining stats are set _buffered_status_update in input handler self.set(self._buffered_status_update) + self._buffered_status_update.clear() if ( default_timer() - self._last_processing_started > PROCESSING_STATE_TIMEOUT @@ -1063,7 +1214,7 @@ if not hasattr(BaseCorrection, "unsafe_get"): setattr(BaseCorrection, "unsafe_get", unsafe_get) -def add_correction_step_schema(schema, managed_keys, field_flag_mapping): +def add_correction_step_schema(schema, managed_keys, field_flag_constants_mapping): """Using the fields in the provided mapping, will add nodes to schema field_flag_mapping is assumed to be iterable of pairs where first entry in each @@ -1074,12 +1225,12 @@ def add_correction_step_schema(schema, managed_keys, field_flag_mapping): This method should be called in expectedParameters of subclass after the same for BaseCorrection has been called. Would be nice to include in BaseCorrection instead, - but that is tricky: static method of superclass will need _correction_field_names + but that is tricky: static method of superclass will need _correction_steps of subclass or device server gets mad. A nice solution with classmethods would be welcome. """ - for field_name, _ in field_flag_mapping: + for field_name, _, used_constants in field_flag_constants_mapping: node_name = f"corrections.{field_name}" ( NODE_ELEMENT(schema).key(node_name).commit(), @@ -1093,7 +1244,8 @@ def add_correction_step_schema(schema, managed_keys, field_flag_mapping): "correction will have no effect unless this is True." ) .readOnly() - .initialValue(False) + # some corrections available without constants + .initialValue(not used_constants or None in used_constants) .commit(), BOOL_ELEMENT(schema) @@ -1124,6 +1276,53 @@ def add_correction_step_schema(schema, managed_keys, field_flag_mapping): managed_keys.add(f"{node_name}.preview") +def add_bad_pixel_config_node(schema, managed_keys, prefix="corrections.badPixels"): + ( + STRING_ELEMENT(schema) + .key("corrections.badPixels.maskingValue") + .displayedName("Bad pixel masking value") + .description( + "Any pixels masked by the bad pixel mask will have their value replaced " + "with this. Note that this parameter is to be interpreted as a " + "numpy.float32; use 'nan' to get NaN value." + ) + .assignmentOptional() + .defaultValue("nan") + .reconfigurable() + .commit(), + NODE_ELEMENT(schema) + .key("corrections.badPixels.subsetToUse") + .displayedName("Bad pixel flags to use") + .description( + "The booleans under this node allow for selecting a subset of bad pixel " + "types to take into account when doing bad pixel masking. Upon updating " + "these flags, the map used for bad pixel masking will be ANDed with this " + "selection. Turning disabled flags back on causes reloading of cached " + "constants." + ) + .commit(), + ) + managed_keys.add("corrections.badPixels.maskingValue") + for field in utils.BadPixelValues: + ( + BOOL_ELEMENT(schema) + .key(f"corrections.badPixels.subsetToUse.{field.name}") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit() + ) + managed_keys.add(f"corrections.badPixels.subsetToUse.{field.name}") + + +def get_bad_pixel_field_selection(self): + selection = 0 + for field in utils.BadPixelValues: + if self.get(f"corrections.badPixels.subsetToUse.{field.name}"): + selection |= field + return selection + + def _parse_frame_filter(config): filter_type = FramefilterSpecType(config["frameFilter.type"]) filter_string = config["frameFilter.spec"] diff --git a/src/calng/base_geometry.py b/src/calng/base_geometry.py index 0b0485f55d9da88cc4a6e07c7b2ceed2ec83f095..c4964c787dcb1f24f2db251ff8ed305884422d97 100644 --- a/src/calng/base_geometry.py +++ b/src/calng/base_geometry.py @@ -370,6 +370,7 @@ class ManualQuadrantsGeometryBase(ManualGeometryBase): if name == "overview": scene_data = scenes.quadrant_geometry_overview( self.deviceId, + self.getDeviceSchema(), ) payload = Hash("success", True, "name", name, "data", scene_data) @@ -436,6 +437,7 @@ class ManualModuleListGeometryBase(ManualGeometryBase): # Assumes there are correction devices known to manager scene_data = scenes.modules_geometry_overview( self.deviceId, + self.getDeviceSchema(), ) payload = Hash("success", True, "name", name, "data", scene_data) diff --git a/src/calng/base_gpu.py b/src/calng/base_gpu.py deleted file mode 100644 index c0619d0ae253b7cb736dbc2fe7275092c7493a3c..0000000000000000000000000000000000000000 --- a/src/calng/base_gpu.py +++ /dev/null @@ -1,195 +0,0 @@ -import pathlib - -import cupy -import jinja2 -import numpy as np - -from . import utils - - -class BaseGpuRunner: - """Class to handle GPU buffers and execution of CUDA kernels on image data - - All GPU buffers are kept within this class and it is intentionally very stateful. - This generally means that you will want to load data into it and then do something. - Typical usage in correct order: - - 1. instantiate - 2. load constants - 3. load_data - 4. load_cell_table - 5. correct - 6a. reshape (only here does data transfer back to host) - 6b. compute_preview (optional) - - repeat from 2. or 3. - - In case no constants are available / correction is not desired, can skip 3 and 4 and - pass CorrectionFlags.NONE to correct(...). Generally, user must handle which - correction steps are appropriate given the constants loaded so far. - """ - - # These must be set by subclass - _kernel_source_filename = None - _corrected_axis_order = None - - def __init__( - self, - pixels_x, - pixels_y, - memory_cells, - constant_memory_cells, - input_data_dtype=np.uint16, - output_data_dtype=np.float32, - ): - _src_dir = pathlib.Path(__file__).absolute().parent - # subclass must define _kernel_source_filename - with (_src_dir / "kernels" / self._kernel_source_filename).open("r") as fd: - self._kernel_template = jinja2.Template(fd.read()) - - self.pixels_x = pixels_x - self.pixels_y = pixels_y - self.memory_cells = memory_cells - if constant_memory_cells == 0: - # if not set, guess same as input; may save one recompilation - self.constant_memory_cells = memory_cells - else: - self.constant_memory_cells = constant_memory_cells - # preview will only be single memory cell - self.preview_shape = (self.pixels_x, self.pixels_y) - self.input_data_dtype = input_data_dtype - self.output_data_dtype = output_data_dtype - - self._init_kernels() - - # reuse buffers for input / output - self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=np.uint16) - self.input_data_gpu = cupy.empty(self.input_shape, dtype=input_data_dtype) - self.processed_data_gpu = cupy.empty( - self.processed_shape, dtype=output_data_dtype - ) - self.reshaped_data_gpu = None # currently not reusing buffer - - # default preview layers: raw and corrected (subclass can extend) - self.preview_buffer_getters = [ - self._get_raw_for_preview, - self._get_corrected_for_preview, - ] - - # to get data from respective buffers to cell, x, y shape for preview computation - def _get_raw_for_preview(self): - """Should return view of self.input_data_gpu with shape (cell, x/y, x/y)""" - raise NotImplementedError() - - def _get_corrected_for_preview(self): - """Should return view of self.processed_data_gpu with shape (cell, x/y, x/y)""" - raise NotImplementedError() - - def flush_buffers(self): - """Optional reset GPU buffers (implement in subclasses which need this)""" - pass - - def correct(self, flags): - """Correct (already loaded) image data according to flags - - Subclass must define this method. It should assume that image data, cell table, - and other data (including constants) has already been loaded. It should - probably run some GPU kernel and output should go into self.processed_data_gpu. - - Keep in mind that user only gets output from compute_preview or reshape - (either of these should come after correct). - - The submodules providing subclasses should have some IntFlag enums defining - which flags are available to pass along to the kernel. A zero flag should allow - the kernel to do no actual correction - but still copy the data between buffers - and cast it to desired output type. - - """ - raise NotImplementedError() - - def reshape(self, output_order, out=None): - """Move axes to desired output order and copy to host memory - - The out parameter is passed directly to the get function of GPU array: if - None, then a new ndarray (in host memory) is returned. If not None, then data - will be loaded into the provided array, which must match shape / dtype. - """ - # TODO: avoid copy - if output_order == self._corrected_axis_order: - self.reshaped_data_gpu = self.processed_data_gpu - else: - self.reshaped_data_gpu = cupy.transpose( - self.processed_data_gpu, - utils.transpose_order(self._corrected_axis_order, output_order), - ) - - return self.reshaped_data_gpu.get(out=out) - - def load_data(self, raw_data): - self.input_data_gpu.set(np.squeeze(raw_data)) - - def load_cell_table(self, cell_table): - self.cell_table_gpu.set(cell_table) - - def compute_previews(self, preview_index): - """Generate single slice or reduction preview of raw and corrected data - - Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and - -4 for stdev (across cells). - - Note that preview_index is taken from data without checking cell table. - Caller has to figure out which index along memory cell dimension they - actually want to preview in case it needs to be a specific pulse. - - Will reuse data from corrected output buffer. Therefore, correct(...) must have - been called with the appropriate flags before compute_preview(...). - """ - - if preview_index < -4: - raise ValueError(f"No statistic with code {preview_index} defined") - elif preview_index >= self.memory_cells: - raise ValueError(f"Memory cell index {preview_index} out of range") - - # TODO: enum around reduction type - return tuple( - self._compute_a_preview(image_data=getter(), preview_index=preview_index) - for getter in self.preview_buffer_getters - ) - - def _compute_a_preview(self, image_data, preview_index): - """image_data must have cells on first axis; X and Y order is not important - here for now (and can differ between AGIPD and DSSC)""" - if preview_index >= 0: - # TODO: reuse pinned buffers for this - return image_data[preview_index].astype(np.float32).get() - elif preview_index == -1: - # TODO: confirm that max is pixel and not integrated intensity - # separate from next case because dtype not applicable here - return cupy.nanmax(image_data, axis=0).astype(cupy.float32).get() - elif preview_index in (-2, -3, -4): - stat_fun = { - -2: cupy.nanmean, - -3: cupy.nansum, - -4: cupy.nanstd, - }[preview_index] - return stat_fun(image_data, axis=0, dtype=cupy.float32).get() - - def update_block_size(self, full_block): - """Set execution grid such that it covers processed_shape with full_blocks - - Execution is scheduled with 3d "blocks" of CUDA threads. Tuning can affect - performance. Correction kernels are "monolithic" for simplicity (i.e. each - logical thread handles one entry in output data), so in each dimension we - parallelize, grid * block >= length to cover all entries. - - Note that individual kernels must themselves check whether they go out of - bounds; grid dimensions get rounded up in case ndarray size is not multiple of - block size. - - """ - assert len(full_block) == 3 - self.full_block = tuple(full_block) - self.full_grid = tuple( - utils.ceil_div(a_length, block_length) - for (a_length, block_length) in zip(self.processed_shape, full_block) - ) diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py index 1d0d81233725b25ae1cfe7acf89c5d1086b0afb7..c24192b5b82982b3e312b10a63171e3a082263bd 100644 --- a/src/calng/base_kernel_runner.py +++ b/src/calng/base_kernel_runner.py @@ -1,6 +1,6 @@ +import functools import pathlib -import cupy import jinja2 import numpy as np @@ -12,17 +12,17 @@ class BaseKernelRunner: self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): self.pixels_x = pixels_x self.pixels_y = pixels_y - self.memory_cells = memory_cells + self.frames = frames if constant_memory_cells == 0: # if not set, guess same as input; may save one recompilation - self.constant_memory_cells = memory_cells + self.constant_memory_cells = frames else: self.constant_memory_cells = constant_memory_cells # preview will only be single memory cell @@ -30,20 +30,13 @@ class BaseKernelRunner: self.input_data_dtype = input_data_dtype self.output_data_dtype = output_data_dtype - # default preview layers: raw and corrected (subclass can extend) - self.preview_buffer_getters = [ - self._get_raw_for_preview, - self._get_corrected_for_preview, - ] - - def correct(self, flags): """Correct (already loaded) image data according to flags Detector-specific subclass must define this method. It should assume that image data, cell table, and other data (including constants) has already been loaded. - It should probably run some GPU kernel and output should go into - self.processed_data(_gpu). + It should probably run some {CPU,GPU} kernel and output should go into + self.processed_data{,_gpu}. Keep in mind that user only gets output from compute_preview or reshape (either of these should come after correct). @@ -56,61 +49,58 @@ class BaseKernelRunner: """ raise NotImplementedError() - # to get data from respective buffers to cell, x, y shape for preview computation - def _get_raw_for_preview(self): - """Should return view of self.input_data_gpu with shape (cell, x/y, x/y)""" - raise NotImplementedError() + @property + def preview_data_views(self): + """Should provide tuple of views of data for different preview layers - as a + minimum, raw and corrected. View is expected to have axis order frame, slow + scan, fast scan with the latter two matching preview_shape (which is X and which + is Y may differ between detectors).""" - def _get_corrected_for_preview(self): - """Should return view of self.processed_data(_gpu) with shape (cell, x/y, x/y)""" raise NotImplementedError() - def flush_buffers(self): - """Optional reset GPU buffers (implement in subclasses which need this)""" + def flush_buffers(self, constants=None): + """Optional reset internal buffers (implement in subclasses which need this)""" pass def compute_previews(self, preview_index): - """Generate single slice or reduction preview of raw and corrected data + """Generate single slice or reduction previews for raw and corrected data and + any other layers, determined by self.preview_data_views Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and -4 for stdev (across cells). Note that preview_index is taken from data without checking cell table. Caller has to figure out which index along memory cell dimension they - actually want to preview in case it needs to be a specific pulse. + actually want to preview in case it needs to be a specific cell / pulse. - Will reuse data from corrected output buffer. Therefore, correct(...) must have - been called with the appropriate flags before compute_preview(...). + Will typically reuse data from corrected output buffer. Therefore, + correct(...) must have been called with the appropriate flags before + compute_preview(...). """ - if preview_index < -4: raise ValueError(f"No statistic with code {preview_index} defined") - elif preview_index >= self.memory_cells: + elif preview_index >= self.frames: raise ValueError(f"Memory cell index {preview_index} out of range") - # TODO: enum around reduction type - return tuple( - self._compute_a_preview(image_data=getter(), preview_index=preview_index) - for getter in self.preview_buffer_getters - ) - - def _compute_a_preview(self, image_data, preview_index): - """image_data must have cells on first axis; X and Y order is not important - here for now (and can differ between AGIPD and DSSC)""" if preview_index >= 0: - # TODO: reuse pinned buffers for this - return image_data[preview_index].astype(np.float32) + def fun(a): + return a[preview_index] elif preview_index == -1: - # TODO: confirm that max is pixel and not integrated intensity - # separate from next case because dtype not applicable here - return np.nanmax(image_data, axis=0).astype(np.float32) + # note: separate from next case because dtype not applicable here + fun = functools.partial(np.nanmax, axis=0) elif preview_index in (-2, -3, -4): - stat_fun = { - -2: np.nanmean, - -3: np.nansum, - -4: np.nanstd, - }[preview_index] - return stat_fun(image_data, axis=0, dtype=np.float32) + fun = functools.partial( + { + -2: np.nanmean, + -3: np.nansum, + -4: np.nanstd, + }[preview_index], + axis=0, + dtype=np.float32, + ) + # TODO: reuse output buffers + # TODO: integrate multithreading + return (fun(in_buffer_view) for in_buffer_view in self.preview_data_views) def reshape(self, output_order, out=None): """Move axes to desired output order""" @@ -138,11 +128,10 @@ class BaseGpuRunner(BaseKernelRunner): 1. instantiate 2. load constants - 3. load_data - 4. load_cell_table - 5. correct - 6a. reshape (only here does data transfer back to host) - 6b. compute_preview (optional) + 3. load data and cell table + 4. correct + 5a. reshape (only here does data transfer back to host) + 5b. compute previews (optional) repeat from 2. or 3. @@ -159,7 +148,7 @@ class BaseGpuRunner(BaseKernelRunner): self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype=np.uint16, output_data_dtype=np.float32, @@ -167,11 +156,13 @@ class BaseGpuRunner(BaseKernelRunner): super().__init__( pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype, output_data_dtype, ) + global cupy + import cupy _src_dir = pathlib.Path(__file__).absolute().parent # subclass must define _kernel_source_filename with (_src_dir / "kernels" / self._kernel_source_filename).open("r") as fd: @@ -180,7 +171,7 @@ class BaseGpuRunner(BaseKernelRunner): self._init_kernels() # reuse buffers for input / output - self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=np.uint16) + self.cell_table_gpu = cupy.empty(self.frames, dtype=np.uint16) self.input_data_gpu = cupy.empty(self.input_shape, dtype=input_data_dtype) self.processed_data_gpu = cupy.empty( self.processed_shape, dtype=output_data_dtype @@ -231,5 +222,8 @@ class BaseGpuRunner(BaseKernelRunner): for (a_length, block_length) in zip(self.processed_shape, full_block) ) - def _compute_a_preview(self, image_data, preview_index): - return super()._compute_a_preview(image_data, preview_index).get() + def compute_previews(self, preview_index): + """See BaseKernelRunner.compute_previews; GPU version uses the same logic (numpy + calls dispatched via CuPy) and just .get()s results.""" + + return (res.get() for res in super().compute_previews(preview_index)) diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py deleted file mode 100644 index 76b27adf0d18e3129daf5e331bd5492caef59ce1..0000000000000000000000000000000000000000 --- a/src/calng/calcat_utils.py +++ /dev/null @@ -1,558 +0,0 @@ -import copy -import functools -import json -import pathlib -import threading - -import calibration_client -import h5py -import numpy as np -from calibration_client.modules import ( - Calibration, - CalibrationConstant, - CalibrationConstantVersion, - Detector, - DetectorType, - PhysicalDetectorUnit, -) -from karabo.bound import ( - BOOL_ELEMENT, - DOUBLE_ELEMENT, - NODE_ELEMENT, - SLOT_ELEMENT, - STRING_ELEMENT, - UINT32_ELEMENT, - VECTOR_UINT32_ELEMENT, -) -from karabo import version as karaboVersion -from pkg_resources import parse_version - -from . import utils - - -class ConditionNotFound(Exception): - pass - - -class DetectorNotFound(Exception): - pass - - -class ModuleNotFound(Exception): - pass - - -class CalibrationNotFound(Exception): - pass - - -class CalibrationClientConfigError(Exception): - pass - - -def add_status_schema_from_enum(schema, prefix, enum_class): - for constant in enum_class: - constant_node = f"{prefix}.{constant.name}" - ( - NODE_ELEMENT(schema).key(constant_node).commit(), - - BOOL_ELEMENT(schema) - .key(f"{constant_node}.found") - .readOnly() - .initialValue(False) - .commit(), - - STRING_ELEMENT(schema) - .key(f"{constant_node}.validFrom") - .readOnly() - .initialValue("") - .commit(), - - STRING_ELEMENT(schema) - .key(f"{constant_node}.calibrationId") - .readOnly() - .initialValue("") - .commit(), - - VECTOR_UINT32_ELEMENT(schema) - .key(f"{constant_node}.conditionIds") - .readOnly() - .initialValue([]) - .commit(), - - VECTOR_UINT32_ELEMENT(schema) - .key(f"{constant_node}.constantIds") - .readOnly() - .initialValue([]) - .commit(), - - STRING_ELEMENT(schema) - .key(f"{constant_node}.constantVersionId") - .description( - "This field is editable - if for any reason a specific constant " - "version is desired, the constant version ID (as used in CalCat) can " - "be set here and the slot below can be called to load this particular " - "version, overriding the automatic loading of latest constants." - ) - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - ) - if parse_version(karaboVersion) >= parse_version("2.11"): - ( - SLOT_ELEMENT(schema) - .key(f"{constant_node}.overrideConstantVersion") - .displayedName("Override constant version") - .commit(), - ) - - -class OperatingConditions(dict): - # TODO: support deviation? - def encode(self): - return { - "parameters": [ - { - "parameter_name": key, - "lower_deviation_value": 0.0, - "upper_deviation_value": 0.0, - "flg_available": False, - "value": value, - } - for (key, value) in self.items() - ] - } - - def __hash__(self): - # this takes me back to pre-screening interview time... - return hash(tuple(sorted(self.items()))) - - -class BaseCalcatFriend: - """Base class for CalCat friends - handles interacting with CalCat for the device - - A CalCat friend uses the device schema to build up parameters for CalCat queries. - It focuses on two nodes (added by static method add_schema): param_prefix and - status_prefix. The former is primarily used to get parameters which are (via - condition methods - see for example dark_condition of DsscCalcatFriend) used - to look for constants. The latter is primarily used to give user information - about what was found. - """ - - _constant_enum_class = None # subclass should set - _constants_need_conditions = None # subclass should set - - @staticmethod - def add_schema( - schema, - managed_keys, - detector_type, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): - """Add elements needed by this object to device's schema (expectedSchema) - - All elements added to schema go under prefixes which should end with name of - node which does not exist yet. To change default values and add more fields, - extend this method in subclass. - - The param_prefix node will hold all the parameters needed to build constant - condition dicts for querying CalCat. These values are set either directly on - the device or via manager and this class gets them from the device using helper - function _get_param. See for example AgipdCalcatFriend.dark_condition. - - The status_prefix node is used to report information about what was found in - CalCat. This class will update the values on the device using the helper - function _set_status. This should not need to happen in subclass methods. - """ - - ( - NODE_ELEMENT(schema) - .key(param_prefix) - .displayedName("Constant retrieval parameters") - .commit(), - - NODE_ELEMENT(schema) - .key(status_prefix) - .displayedName("Constants retrieved") - .commit(), - ) - - # Parameters which any detector would probably have (extend this in subclass) - # TODO: probably switch to floating point for everything, including mem cells - ( - STRING_ELEMENT(schema) - .key(f"{param_prefix}.deviceMappingSnapshotAt") - .displayedName("Snapshot timestamp (for device mapping)") - .description( - "CalCat supports querying with a specific snapshot of the database. " - "When playing back a run from the file system, this feature is useful " - "to look up the device mapping at the time of the run. If this field " - "is left empty, the latest device mapping is used. Date format should " - "be 'YYYY-MM-DD' with optional time of day starting with 'T' followed " - "by 'hh:mm:ss.mil+02:00'." - ) - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.constantVersionEventAt") - .displayedName("Event at timestamp (for constant version)") - .description("TODO") - .assignmentOptional() - .defaultValue("") - .reconfigurable() - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorType") - .displayedName("Detector type name") - .description( - "Name of detector type in CalCat; typically has suffix '-Type'" - ) - .readOnly() - .initialValue(detector_type) - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorTypeId") - .readOnly() - .initialValue("") - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorName") - .assignmentOptional() - .defaultValue("") - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.detectorId") - .readOnly() - .initialValue("") - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.karaboDa") - .assignmentOptional() - .defaultValue("") - .commit(), - - STRING_ELEMENT(schema) - .key(f"{param_prefix}.moduleId") - .readOnly() - .initialValue("") - .commit(), - - UINT32_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") - .displayedName("Memory cells") - .description( - "Number of memory cells / frames per train. Relevant for burst mode." - ) - .assignmentOptional() - .defaultValue(1) - .reconfigurable() - .commit(), - - UINT32_ELEMENT(schema) - .key(f"{param_prefix}.pixelsX") - .displayedName("Pixels X") - .assignmentOptional() - .defaultValue(512) - .commit(), - - UINT32_ELEMENT(schema) - .key(f"{param_prefix}.pixelsY") - .displayedName("Pixels Y") - .assignmentOptional() - .defaultValue(128) - .commit(), - - DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") - .displayedName("Bias voltage") - .description("Sensor bias voltage") - .assignmentOptional() - .defaultValue(300) - .reconfigurable() - .commit(), - ) - managed_keys.add(f"{param_prefix}.deviceMappingSnapshotAt") - managed_keys.add(f"{param_prefix}.constantVersionEventAt") - managed_keys.add(f"{param_prefix}.memoryCells") - managed_keys.add(f"{param_prefix}.pixelsX") - managed_keys.add(f"{param_prefix}.pixelsY") - managed_keys.add(f"{param_prefix}.biasVoltage") - - def __init__( - self, - device, - secrets_fn: pathlib.Path, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): - self.device = device - self.param_prefix = param_prefix - self.status_prefix = status_prefix - self.cached_constants = {} - self.cached_constants_lock = threading.Lock() - # api lock used to force queries to be sequential (SSL issue on ONC) - self.api_lock = threading.Lock() - - if not secrets_fn.is_file(): - self.device.log_status_warn( - f"Missing CalCat secrets file (expected {secrets_fn})" - ) - with secrets_fn.open("r") as fd: - calcat_secrets = json.load(fd) - - self.caldb_store = pathlib.Path(calcat_secrets["caldb_store_path"]) - if not self.caldb_store.is_dir(): - raise ValueError(f"caldb_store location '{self.caldb_store}' is not dir") - - self.device.log.INFO(f"Connecting to CalCat at {calcat_secrets['base_url']}") - base_url = calcat_secrets["base_url"] - self.client = calibration_client.CalibrationClient( - client_id=calcat_secrets["client_id"], - client_secret=calcat_secrets["client_secret"], - user_email=calcat_secrets["user_email"], - base_api_url=f"{base_url}/api/", - token_url=f"{base_url}/oauth/token", - refresh_url=f"{base_url}/oauth/token", - auth_url=f"{base_url}/oauth/authorize", - scope="public", - session_token=None, - ) - self.device.log_status_info("CalCat connection established") - - def _get_param(self, key): - """Helper to get value from attached device schema""" - return self.device.get(f"{self.param_prefix}.{key}") - - def _set_param(self, key, value): - self.device.set(f"{self.param_prefix}.{key}", value) - - def _get_status(self, constant, key): - return self.device.get(f"{self.status_prefix}.{constant.name}.{key}") - - def _set_status(self, constant, key, value): - """Helper to update information about found constants on device""" - self.device.set(f"{self.status_prefix}.{constant.name}.{key}", value) - - # Python 3.6 does not have functools.cached_property or even functools.cache - @property - @functools.lru_cache() - def detector_id(self): - detector_name = self._get_param("detectorName") - resp = Detector.get_by_identifier(self.client, detector_name) - self._check_resp(resp, DetectorNotFound, f"Detector {detector_name} not found") - res = resp["data"]["id"] - self._set_param("detectorId", str(res)) - return res - - @property - @functools.lru_cache() - def detector_type_id(self): - detector_type = self._get_param("detectorType") - resp = DetectorType.get_by_name(self.client, detector_type) - self._check_resp( - resp, DetectorNotFound, f"Detector type {detector_type} not found" - ) - res = resp["data"]["id"] - self._set_param("detectorTypeId", str(res)) - return res - - @property - @functools.lru_cache() - def pdus(self): - resp = PhysicalDetectorUnit.get_all_by_detector( - self.client, self.detector_id, self._get_param("deviceMappingSnapshotAt") - ) - self._check_resp(resp, warning="Failed to retrieve module mapping") - for irrelevant_key in ("detector", "detector_type", "flg_available"): - for pdu in resp["data"]: - del pdu[irrelevant_key] - return resp["data"] - - @property - @functools.lru_cache() - def _karabo_da_to_float_uuid(self): - return {pdu["karabo_da"]: pdu["float_uuid"] for pdu in self.pdus} - - @property - @functools.lru_cache() - def _karabo_da_to_id(self): - return {pdu["karabo_da"]: pdu["id"] for pdu in self.pdus} - - def flush_pdu_mapping(self): - for attr in ("pdus", "_karabo_da_to_float_uuid", "_karabo_da_to_id"): - if hasattr(self, attr): - delattr(self, attr) - self._set_param("moduleId", "") - - @utils.threadsafe_cache - def calibration_id(self, calibration_name: str): - resp = Calibration.get_by_name(self.client, calibration_name) - self._check_resp( - resp, CalibrationNotFound, f"Calibration type {calibration_name} not found!" - ) - return resp["data"]["id"] - - @utils.threadsafe_cache - def condition_ids(self, pdu, condition): - # modifying condition parameter messes with cache - condition_with_detector = copy.copy(condition) - condition_with_detector["Detector UUID"] = pdu - self.device.log.DEBUG(f"Look for condition: {condition_with_detector}") - resp = self.client.search_possible_conditions_from_dict( - "", condition_with_detector.encode() - ) - self._check_resp( - resp, - ConditionNotFound, - f"Failed to find condition {condition} for pdu {pdu}", - ) - return [d["id"] for d in resp["data"]] - - def constant_ids(self, calibration_id, condition_ids): - resp = CalibrationConstant.get_all_by_conditions( - self.client, - calibration_id=calibration_id, - detector_type_id=self.detector_type_id, - condition_ids=condition_ids, - ) - self._check_resp(resp, warning="Failed to retrieve constant ID") - return [d["id"] for d in resp["data"]] - - def get_constant_version(self, constant): - # TODO: catch exceptions, give warnings appropriately - karabo_da = self._get_param("karaboDa") - self.device.log_status_info(f"Attempting to find {constant} for {karabo_da}") - - if karabo_da not in self._karabo_da_to_float_uuid: - self.device.log_status_warn( - f"Module {karabo_da} not found in mapping, check configuration!" - ) - raise ModuleNotFound(f"Module map did not include {karabo_da}") - self._set_param("moduleId", str(self._karabo_da_to_id[karabo_da])) - - if isinstance(constant, str): - constant = self._constant_enum_class[constant] - - calibration_id = self.calibration_id(constant.name) - self._set_status(constant, "calibrationId", calibration_id) - - condition = self._constants_need_conditions[constant]() - condition_ids = self.condition_ids( - self._karabo_da_to_float_uuid[karabo_da], condition - ) - self._set_status(constant, "conditionIds", condition_ids) - - constant_ids = self.constant_ids( - calibration_id=calibration_id, condition_ids=condition_ids - ) - self._set_status(constant, "constantIds", constant_ids) - - resp = CalibrationConstantVersion.get_closest_by_time( - self.client, - calibration_constant_ids=constant_ids, - physical_detector_unit_id=self._karabo_da_to_id[karabo_da], - event_at=self._get_param("constantVersionEventAt"), - snapshot_at=None, - ) - self._check_resp(resp, warning="Failed to find calibration constant version") - # TODO: replace with start date and end date - timestamp = resp["data"]["begin_validity_at"] - self._set_status(constant, "validFrom", timestamp) - self._set_status(constant, "constantVersionId", resp["data"]["id"]) - - file_path = ( - self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] - ) - # TODO: handle FileNotFoundError if we are led astray - with h5py.File(file_path, "r") as fd: - constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) - with self.cached_constants_lock: - self.cached_constants[constant] = constant_data - self._set_status(constant, "found", True) - self.device.log_status_info(f"Done finding {constant} for {karabo_da}") - - return constant_data - - def get_specific_constant_version(self, constant): - # TODO: warn if PDU or constant type does not match - # TODO: warn if result is list (happens for empty version ID) - constant_version_id = self.device.get( - f"{self.status_prefix}.{constant.name}.constantVersionId" - ) - - resp = CalibrationConstantVersion.get_by_id(self.client, constant_version_id) - self._check_resp(resp, warning="Failed to find calibration constant version") - file_path = ( - self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] - ) - with h5py.File(file_path, "r") as fd: - constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) - with self.cached_constants_lock: - self.cached_constants[constant] = constant_data - self._set_status(constant, "validFrom", resp["data"]["begin_at"]) - self._set_status(constant, "calibrationId", "manual override") - self._set_status(constant, "conditionId", "manual override") - self._set_status(constant, "constantId", "manual override") - self._set_status(constant, "constantVersionId", constant_version_id) - self._set_status(constant, "found", True) - return constant_data - - def get_constant_version_and_call_me_back(self, constant, callback): - """Runs get_constant_version in thread, will call callback on completion""" - # TODO: do we want to use asyncio / "modern" async? - # TODO: consider moving out of this class, closer to correction device - def aux(): - with self.api_lock: - data = self.get_constant_version(constant) - callback(constant, data) - - thread = threading.Thread(target=aux) - thread.start() - return thread - - def get_specific_constant_version_and_call_me_back(self, constant, callback): - """Blindly load whatever CalCat points to for CCV - user must be confident that - this CCV corresponds to correct kind of constant.""" - - # TODO: warn user about all the things that go wrong - def aux(): - with self.api_lock: - data = self.get_specific_constant_version(constant) - callback(constant, data) - - thread = threading.Thread(target=aux) - thread.start() - return thread - - def flush_constants(self): - for constant in self._constant_enum_class: - self._set_status(constant, "validFrom", "") - self._set_status(constant, "found", False) - - def _check_resp(self, resp, exception=Exception, warning=None): - # TODO: probably verify using "info" that exception is the right one - to_raise = None - if not resp["success"]: - # TODO: probably more types of app_info errors? - if resp["app_info"]: - if "not found" in resp["info"]: - # this was likely the exception exception - to_raise = exception(resp["info"]) - else: - # but could also be authorization or similar issue - to_raise = CalibrationClientConfigError(resp["app_info"]) - to_raise = exception(resp["info"]) - if to_raise is not None: - if warning is not None: - self.device.log_status_warn(warning) - raise to_raise diff --git a/src/calng/conditions/AgipdCondition.py b/src/calng/conditions/AgipdCondition.py new file mode 100644 index 0000000000000000000000000000000000000000..92200ce5a2f73064428e515038612fae21d9bb27 --- /dev/null +++ b/src/calng/conditions/AgipdCondition.py @@ -0,0 +1,48 @@ +from karabo.middlelayer import AccessMode, Assignment, String +from .. import base_condition +from ..corrections.AgipdCorrection import AgipdGainMode + + +class AgipdCondition(base_condition.ConditionBase): + controlDeviceId = String( + displayedName="Control device ID", + assignment=Assignment.MANDATORY, + accessMode=AccessMode.INITONLY, + ) + + voltageSourceDevice = String() + voltageSourceKey = String() + + @property + def keys_to_get(self): + return { + self.controlDeviceId.value: [ + # rep rate: 1.1, 2.2, 4.5 + ("bunchStructure.repetitionRate", "acquisitionRate", None), + # TODO: check if appropriate (agipdlib looks at image.cellId) + ("bunchStructure.nPulses", "memoryCells", None), + ("gain", "gainSetting", None), + ( + "gainModeIndex", + "gainMode", + lambda i: AgipdGainMode(i).name, + ), + ("integrationTime", "integrationTime", None), + ], + # observed voltages in calcat: + # 59.0 + # 64.0 + # 74.0 + # 100.0 + # 199.0 + # 200.0 + # 300.0 + # 500.0 + self.voltageSourceDevice.value: [ + ( + self.voltageSourceKey.value, + "biasVoltage", + lambda n: round(n, ndigits=0), + ) + ], + } diff --git a/src/calng/conditions/JungfrauCondition.py b/src/calng/conditions/JungfrauCondition.py new file mode 100644 index 0000000000000000000000000000000000000000..ca71a8524b2460ca019abe503a506ead6f84cf94 --- /dev/null +++ b/src/calng/conditions/JungfrauCondition.py @@ -0,0 +1,59 @@ +from karabo.middlelayer import AccessMode, Assignment, String +from .. import base_condition +from ..corrections.JungfrauCorrection import JungfrauGainMode + + +def settings_to_gain_mode(setting): + gain_mode = JungfrauGainMode(setting) + if gain_mode in (JungfrauGainMode.FIX_GAIN_1, JungfrauGainMode.FIX_GAIN_2): + return 1 + else: + return 0 + + +def settings_to_gain_setting(setting): + gain_mode = JungfrauGainMode(setting) + if gain_mode is JungfrauGainMode.DYNAMIC_GAIN_HG0: + return 1 + else: + return 0 + + +class JungfrauCondition(base_condition.ConditionBase): + controlDeviceId = String( + displayedName="Control device ID", + assignment=Assignment.MANDATORY, + accessMode=AccessMode.INITONLY, + ) + + @property + def keys_to_get(self): + return { + self.controlDeviceId.value: [ + # cells: 1.0 or 16.0 + ("storageCells", "memoryCells", lambda n: n + 1), + # observed voltages in parameter conditions: + # 90.0 + # 180.0 + # 200.0 + # note: control device parameter is a vector + ("vHighVoltage", "biasVoltage", lambda arr: round(arr[0], ndigits=0)), + # observed integration times: + # 9.999999747378752 + # 12.999999853491317 + # 49.99999873689376 + # 300.0000142492354 + # 349.9999875202775 + # 399.99998989515007 + # 500.00002374872565 + ( + "exposureTime", + "integrationTime", + lambda n: round(n * 1e6, ndigits=0), + ), + # gain mode: omitted or 1.0 + ("settings", "gainMode", settings_to_gain_mode), + # gain setting: 0.0 or 1.0 (derived from gain mode on device) + ("settings", "gainSetting", settings_to_gain_setting), + ] + } diff --git a/src/calng/conditions/__init__.py b/src/calng/conditions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c1b78ef083174014fc9b1c58ed9df2f04451c54 --- /dev/null +++ b/src/calng/conditions/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa: F401 +from . import AgipdCondition, JungfrauCondition diff --git a/src/calng/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py similarity index 71% rename from src/calng/AgipdCorrection.py rename to src/calng/corrections/AgipdCorrection.py index 7d2946113c9f124432bfa8b5fecc34a01cb61380..1eaf7aa9fa0336626baacf4ec9be5c894b175a2f 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/corrections/AgipdCorrection.py @@ -1,29 +1,26 @@ import enum -import cupy import numpy as np from karabo.bound import ( BOOL_ELEMENT, DOUBLE_ELEMENT, FLOAT_ELEMENT, KARABO_CLASSINFO, - NODE_ELEMENT, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, VECTOR_STRING_ELEMENT, ) -from . import base_calcat, base_kernel_runner, utils -from ._version import version as deviceVersion -from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema +from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils +from .._version import version as deviceVersion class AgipdConstants(enum.Enum): - SlopesFF = enum.auto() ThresholdsDark = enum.auto() Offset = enum.auto() SlopesPC = enum.auto() + SlopesFF = enum.auto() BadPixelsDark = enum.auto() BadPixelsPC = enum.auto() BadPixelsFF = enum.auto() @@ -49,20 +46,22 @@ class CorrectionFlags(enum.IntFlag): class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): _kernel_source_filename = "agipd_gpu.cu" - _corrected_axis_order = "cxy" + _corrected_axis_order = "fxy" def __init__( self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, - input_data_dtype=cupy.uint16, - output_data_dtype=cupy.float32, - bad_pixel_mask_value=cupy.nan, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + bad_pixel_mask_value=np.nan, gain_mode=AgipdGainMode.ADAPTIVE_GAIN, g_gain_value=1, ): + global cupy + import cupy self.gain_mode = gain_mode # default gain only matters when not thresholding (missing constant or fixed) # note: gain stage (result of thresholding) is 0, 1, or 2 @@ -70,51 +69,45 @@ class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): self.default_gain = cupy.uint8(gain_mode) else: self.default_gain = cupy.uint8(gain_mode - 1) - self.input_shape = (memory_cells, 2, pixels_x, pixels_y) - self.processed_shape = (memory_cells, pixels_x, pixels_y) + self.input_shape = (frames, 2, pixels_x, pixels_y) + self.processed_shape = (frames, pixels_x, pixels_y) super().__init__( pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype, output_data_dtype, ) - self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=cupy.float32) - self.preview_buffer_getters.extend( - [self._get_raw_gain_for_preview, self._get_gain_map_for_preview] - ) + self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=np.float32) self.map_shape = (self.constant_memory_cells, self.pixels_x, self.pixels_y) self.gm_map_shape = self.map_shape + (3,) # for gain-mapped constants self.threshold_map_shape = self.map_shape + (2,) # constants self.gain_thresholds_gpu = cupy.empty( - self.threshold_map_shape, dtype=cupy.float32 + self.threshold_map_shape, dtype=np.float32 ) - self.offset_map_gpu = cupy.zeros(self.gm_map_shape, dtype=cupy.float32) - self.rel_gain_pc_map_gpu = cupy.ones(self.gm_map_shape, dtype=cupy.float32) + self.offset_map_gpu = cupy.empty(self.gm_map_shape, dtype=np.float32) + self.rel_gain_pc_map_gpu = cupy.empty(self.gm_map_shape, dtype=np.float32) # not gm_map_shape because it only applies to medium gain pixels - self.md_additional_offset_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) - self.rel_gain_xray_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) - self.bad_pixel_map_gpu = cupy.zeros(self.gm_map_shape, dtype=cupy.uint32) + self.md_additional_offset_gpu = cupy.empty(self.map_shape, dtype=np.float32) + self.rel_gain_xray_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) + self.bad_pixel_map_gpu = cupy.empty(self.gm_map_shape, dtype=np.uint32) self.set_bad_pixel_mask_value(bad_pixel_mask_value) self.set_g_gain_value(g_gain_value) + self.flush_buffers(set(AgipdConstants)) self.update_block_size((1, 1, 64)) - def _get_raw_for_preview(self): - return self.input_data_gpu[:, 0] - - def _get_corrected_for_preview(self): - return self.processed_data_gpu - - # special to AGIPD - def _get_raw_gain_for_preview(self): - return self.input_data_gpu[:, 1] - - def _get_gain_map_for_preview(self): - return self.gain_map_gpu + @property + def preview_data_views(self): + return ( + self.input_data_gpu[:, 0], # raw + self.processed_data_gpu, # corrected + self.input_data_gpu[:, 1], # raw gain + self.gain_map_gpu, # digitized gain + ) def load_thresholds(self, threshold_map): # shape: y, x, memory cell, thresholds and gain values @@ -193,7 +186,6 @@ class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): def load_bad_pixels_map(self, bad_pixels_map, override_flags_to_use=None): # will simply OR with already loaded, does not take into account which ones - # TODO: inquire what "mask for double size pixels" means if len(bad_pixels_map.shape) == 3: if bad_pixels_map.shape == ( self.pixels_y, @@ -239,12 +231,26 @@ class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): def set_bad_pixel_mask_value(self, mask_value): self.bad_pixel_mask_value = cupy.float32(mask_value) - def flush_buffers(self): - self.offset_map_gpu.fill(0) - self.rel_gain_pc_map_gpu.fill(1) - self.md_additional_offset_gpu.fill(0) - self.rel_gain_xray_map_gpu.fill(1) - self.bad_pixel_map_gpu.fill(0) + def flush_buffers(self, constants): + if AgipdConstants.Offset in constants: + self.offset_map_gpu.fill(0) + if AgipdConstants.SlopesPC in constants: + self.rel_gain_pc_map_gpu.fill(1) + self.md_additional_offset_gpu.fill(0) + if AgipdConstants.SlopesFF: + self.rel_gain_xray_map_gpu.fill(1) + if constants & { + AgipdConstants.BadPixelsDark, + AgipdConstants.BadPixelsPC, + AgipdConstants.BadPixelsFF, + }: + self.bad_pixel_map_gpu.fill(0) + self.bad_pixel_map_gpu[ + :, 64:512:64 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map_gpu[ + :, 63:511:64 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value # TODO: baseline shift @@ -278,7 +284,7 @@ class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): { "pixels_x": self.pixels_x, "pixels_y": self.pixels_y, - "data_memory_cells": self.memory_cells, + "frames": self.frames, "constant_memory_cells": self.constant_memory_cells, "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), @@ -292,9 +298,9 @@ class AgipdGpuRunner(base_kernel_runner.BaseGpuRunner): class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = AgipdConstants - def __init__(self, device, *args, **kwargs): - super().__init__(device, *args, **kwargs) - self._constants_need_conditions = { + @property + def _constants_need_conditions(self): + return { AgipdConstants.ThresholdsDark: self.dark_condition, AgipdConstants.Offset: self.dark_condition, AgipdConstants.SlopesPC: self.dark_condition, @@ -305,52 +311,47 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): } @staticmethod - def add_schema( - schema, - managed_keys, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): + def add_schema(schema, managed_keys): super(AgipdCalcatFriend, AgipdCalcatFriend).add_schema( - schema, managed_keys, "AGIPD-Type", param_prefix, status_prefix + schema, managed_keys, "AGIPD-Type" ) ( OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") + .key("constantParameters.memoryCells") .setNewDefaultValue(352) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") + .key("constantParameters.biasVoltage") .setNewDefaultValue(300) .commit() ) ( DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.acquisitionRate") + .key("constantParameters.acquisitionRate") .assignmentOptional() .defaultValue(1.1) .reconfigurable() .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.gainSetting") + .key("constantParameters.gainSetting") .assignmentOptional() .defaultValue(0) .reconfigurable() .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.photonEnergy") + .key("constantParameters.photonEnergy") .assignmentOptional() .defaultValue(9.2) .reconfigurable() .commit(), STRING_ELEMENT(schema) - .key(f"{param_prefix}.gainMode") + .key("constantParameters.gainMode") .assignmentOptional() .defaultValue("ADAPTIVE_GAIN") .options(",".join(gain_mode.name for gain_mode in AgipdGainMode)) @@ -358,19 +359,19 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.integrationTime") + .key("constantParameters.integrationTime") .assignmentOptional() .defaultValue(12) .reconfigurable() .commit(), ) - managed_keys.add(f"{param_prefix}.acquisitionRate") - managed_keys.add(f"{param_prefix}.gainSetting") - managed_keys.add(f"{param_prefix}.photonEnergy") - managed_keys.add(f"{param_prefix}.gainMode") - managed_keys.add(f"{param_prefix}.integrationTime") + managed_keys.add("constantParameters.acquisitionRate") + managed_keys.add("constantParameters.gainSetting") + managed_keys.add("constantParameters.photonEnergy") + managed_keys.add("constantParameters.gainMode") + managed_keys.add("constantParameters.integrationTime") - base_calcat.add_status_schema_from_enum(schema, status_prefix, AgipdConstants) + base_calcat.add_status_schema_from_enum(schema, AgipdConstants) def dark_condition(self): res = base_calcat.OperatingConditions() @@ -410,26 +411,41 @@ class AgipdCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("AgipdCorrection", deviceVersion) -class AgipdCorrection(BaseCorrection): +class AgipdCorrection(base_correction.BaseCorrection): # subclass *must* set these attributes _correction_flag_class = CorrectionFlags - _correction_field_names = ( - ("thresholding", CorrectionFlags.THRESHOLD), - ("offset", CorrectionFlags.OFFSET), - ("relGainPc", CorrectionFlags.REL_GAIN_PC), - ("gainXray", CorrectionFlags.GAIN_XRAY), - ("badPixels", CorrectionFlags.BPMASK), + _correction_steps = ( + # step name (used in schema), flag to enable for kernel, constants required + ("thresholding", CorrectionFlags.THRESHOLD, {AgipdConstants.ThresholdsDark}), + ("offset", CorrectionFlags.OFFSET, {AgipdConstants.Offset}), + ("relGainPc", CorrectionFlags.REL_GAIN_PC, {AgipdConstants.SlopesPC}), + ("gainXray", CorrectionFlags.GAIN_XRAY, {AgipdConstants.SlopesFF}), + ( + "badPixels", + CorrectionFlags.BPMASK, + { + AgipdConstants.BadPixelsDark, + AgipdConstants.BadPixelsPC, + AgipdConstants.BadPixelsFF, + None, # means stay available even without constants loaded + }, + ), ) _kernel_runner_class = AgipdGpuRunner _calcat_friend_class = AgipdCalcatFriend _constant_enum_class = AgipdConstants - _managed_keys = BaseCorrection._managed_keys.copy() + _managed_keys = base_correction.BaseCorrection._managed_keys.copy() @staticmethod def expectedParameters(expected): ( + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(schemas.xtdf_output_schema()) + .commit(), + OVERWRITE_ELEMENT(expected) - .key("dataFormat.memoryCells") + .key("dataFormat.frames") .setNewDefaultValue(352) .commit(), @@ -442,21 +458,24 @@ class AgipdCorrection(BaseCorrection): ( OUTPUT_CHANNEL(expected) .key("preview.outputRawGain") - .dataSchema(preview_schema) + .dataSchema(schemas.preview_schema()) .commit(), OUTPUT_CHANNEL(expected) .key("preview.outputGainMap") - .dataSchema(preview_schema) + .dataSchema(schemas.preview_schema()) .commit(), ) AgipdCalcatFriend.add_schema(expected, AgipdCorrection._managed_keys) # this is not automatically done by superclass for complicated class reasons - add_correction_step_schema( + base_correction.add_correction_step_schema( expected, AgipdCorrection._managed_keys, - AgipdCorrection._correction_field_names, + AgipdCorrection._correction_steps, + ) + base_correction.add_bad_pixel_config_node( + expected, AgipdCorrection._managed_keys ) # additional settings specific to AGIPD correction steps @@ -503,51 +522,12 @@ class AgipdCorrection(BaseCorrection): .defaultValue(1) .reconfigurable() .commit(), - - STRING_ELEMENT(expected) - .key("corrections.badPixels.maskingValue") - .displayedName("Bad pixel masking value") - .description( - "Any pixels masked by the bad pixel mask will have their value " - "replaced with this. Note that this parameter is to be interpreted as " - "a numpy.float32; use 'nan' to get NaN value." - ) - .assignmentOptional() - .defaultValue("nan") - .reconfigurable() - .commit(), - - NODE_ELEMENT(expected) - .key("corrections.badPixels.subsetToUse") - .displayedName("Bad pixel flags to use") - .description( - "The booleans under this node allow for selecting a subset of bad " - "pixel types to take into account when doing bad pixel masking. " - "Upon updating these flags, the map used for bad pixel masking will " - "be ANDed with this selection. Turning disabled flags back on causes " - "reloading of cached constants." - ) - .commit(), ) AgipdCorrection._managed_keys.add( "corrections.relGainPc.overrideMdAdditionalOffset" ) AgipdCorrection._managed_keys.add("corrections.relGainPc.mdAdditionalOffset") AgipdCorrection._managed_keys.add("corrections.gainXray.gGainValue") - AgipdCorrection._managed_keys.add("corrections.badPixels.maskingValue") - # TODO: DRY / encapsulate - for field in utils.BadPixelValues: - ( - BOOL_ELEMENT(expected) - .key(f"corrections.badPixels.subsetToUse.{field.name}") - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit() - ) - AgipdCorrection._managed_keys.add( - f"corrections.badPixels.subsetToUse.{field.name}" - ) # mandatory: manager needs this in schema ( @@ -561,7 +541,7 @@ class AgipdCorrection(BaseCorrection): @property def input_data_shape(self): return ( - self.unsafe_get("dataFormat.memoryCells"), + self.unsafe_get("dataFormat.frames"), 2, self.unsafe_get("dataFormat.pixelsX"), self.unsafe_get("dataFormat.pixelsY"), @@ -570,34 +550,37 @@ class AgipdCorrection(BaseCorrection): def __init__(self, config): super().__init__(config) # note: gain mode single sourced from constant retrieval node - self.gain_mode = AgipdGainMode[config.get("constantParameters.gainMode")] - try: - self.bad_pixel_mask_value = np.float32( - config.get("corrections.badPixels.maskingValue") - ) + np.float32(config.get("corrections.badPixels.maskingValue")) except ValueError: - self.bad_pixel_mask_value = np.float32("nan") + config["corrections.badPixels.maskingValue"] = "nan" + + self._has_updated_bad_pixel_selection = False - self._kernel_runner_init_args = { + @property + def bad_pixel_mask_value(self): + return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) + + _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection) + + @property + def _kernel_runner_init_args(self): + return { "gain_mode": self.gain_mode, "bad_pixel_mask_value": self.bad_pixel_mask_value, - "g_gain_value": config.get("corrections.gainXray.gGainValue"), + "g_gain_value": self.unsafe_get("corrections.gainXray.gGainValue"), } - # configurability: overriding md_additional_offset - if config.get("corrections.relGainPc.overrideMdAdditionalOffset"): - self._override_md_additional_offset = config.get( - "corrections.relGainPc.mdAdditionalOffset" - ) - else: - self._override_md_additional_offset = None - - self._has_updated_bad_pixel_selection = False + @property + def gain_mode(self): + return AgipdGainMode[self.unsafe_get("constantParameters.gainMode")] - def _initialization(self): - self._update_bad_pixel_selection() - super()._initialization() + @property + def _override_md_additional_offset(self): + if self.unsafe_get("corrections.relGainPc.overrideMdAdditionalOffset"): + return self.unsafe_get("corrections.relGainPc.mdAdditionalOffset") + else: + return None def process_data( self, @@ -607,11 +590,10 @@ class AgipdCorrection(BaseCorrection): train_id, image_data, cell_table, - do_generate_preview, ): """Called by input_handler for each data hash. Should correct data, optionally compute preview, write data output, and optionally write preview outputs.""" - # original shape: memory_cell, data/raw_gain, x, y + # original shape: frame, data/raw_gain, x, y pulse_table = np.squeeze(data_hash.get("image.pulseId")) if self._frame_filter is not None: @@ -647,26 +629,29 @@ class AgipdCorrection(BaseCorrection): out=buffer_array, ) # after reshape, data for dataOutput is now safe in its own buffer - if do_generate_preview: + with self.warning_context( + "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS + ) as warn: if self._correction_flag_enabled != self._correction_flag_preview: self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, preview_pulse, - ) = utils.pick_frame_index( + ), preview_warning = utils.pick_frame_index( self.unsafe_get("preview.selectionMode"), self.unsafe_get("preview.index"), cell_table, pulse_table, - warn_func=self.log_status_warn, ) - ( - preview_raw, - preview_corrected, - preview_raw_gain, - preview_gain_map, - ) = self.kernel_runner.compute_previews(preview_slice_index) + if preview_warning is not None: + warn(preview_warning) + ( + preview_raw, + preview_corrected, + preview_raw_gain, + preview_gain_map, + ) = self.kernel_runner.compute_previews(preview_slice_index) # reusing input data hash for sending data_hash.set("image.data", buffer_handle) @@ -676,82 +661,52 @@ class AgipdCorrection(BaseCorrection): data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) self._write_output(data_hash, metadata) - if do_generate_preview: - self._write_preview_outputs( - ( - ("preview.outputRaw", preview_raw), - ("preview.outputCorrected", preview_corrected), - ("preview.outputRawGain", preview_raw_gain), - ("preview.outputGainMap", preview_gain_map), - ), - metadata, - ) + self._write_preview_outputs( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), + ("preview.outputRawGain", preview_raw_gain), + ("preview.outputGainMap", preview_gain_map), + ), + metadata, + ) def _load_constant_to_runner(self, constant, constant_data): - # TODO: encode correction / constant dependencies in a clever way if constant is AgipdConstants.ThresholdsDark: - field_name = "thresholding" # TODO: (reverse) mapping, DRY if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode") + # TODO: set constant status to ignoring return self.kernel_runner.load_thresholds(constant_data) elif constant is AgipdConstants.Offset: - field_name = "offset" self.kernel_runner.load_offset_map(constant_data) elif constant is AgipdConstants.SlopesPC: - field_name = "relGainPc" self.kernel_runner.load_rel_gain_pc_map(constant_data) if self._override_md_additional_offset is not None: self.kernel_runner.md_additional_offset_gpu.fill( self._override_md_additional_offset ) elif constant is AgipdConstants.SlopesFF: - field_name = "gainXray" self.kernel_runner.load_rel_gain_ff_map(constant_data) elif "BadPixels" in constant.name: - field_name = "badPixels" self.kernel_runner.load_bad_pixels_map( constant_data, override_flags_to_use=self._override_bad_pixel_flags ) - # switch relevant correction on if it just now became available - if not self.get(f"corrections.{field_name}.available"): - # TODO: turn off again when flushing - self.set(f"corrections.{field_name}.available", True) - - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name} to GPU") - - def _update_bad_pixel_selection(self): - selection = 0 - for field in utils.BadPixelValues: - if self.get(f"corrections.badPixels.subsetToUse.{field.name}"): - selection |= field - self._override_bad_pixel_flags = selection - def preReconfigure(self, config): super().preReconfigure(config) if config.has("corrections.badPixels.maskingValue"): # only check if it is valid; postReconfigure will use it - try: - np.float32(config.get("corrections.badPixels.maskingValue")) - except ValueError: - self.log_status_warn("Invalid masking value, ignoring.") - config.erase("corrections.badPixels.maskingValue") + np.float32(config.get("corrections.badPixels.maskingValue")) def postReconfigure(self): super().postReconfigure() # TODO: move after getting cached update, check if necessary - if self.get("corrections.relGainPc.overrideMdAdditionalOffset"): - self._override_md_additional_offset = self.get( - "corrections.relGainPc.mdAdditionalOffset" - ) + if self._override_md_additional_offset is not None: self.kernel_runner.override_md_additional_offset( self._override_md_additional_offset ) - else: - self._override_md_additional_offset = None if not hasattr(self, "_prereconfigure_update_hash"): return @@ -759,25 +714,16 @@ class AgipdCorrection(BaseCorrection): update = self._prereconfigure_update_hash if update.has("constantParameters.gainMode"): - self.gain_mode = AgipdGainMode[update["constantParameters.gainMode"]] + self.flush_constants() self._update_buffers() if update.has("corrections.gainXray.gGainValue"): self.kernel_runner.set_g_gain_value( self.get("corrections.gainXray.gGainValue") ) - self._kernel_runner_init_args["g_gain_value"] = self.get( - "corrections.gainXray.gGainValue" - ) if update.has("corrections.badPixels.maskingValue"): - self.bad_pixel_mask_value = np.float32( - self.get("corrections.badPixels.maskingValue") - ) self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) - self._kernel_runner_init_args[ - "bad_pixel_mask_value" - ] = self.bad_pixel_mask_value if any( path.startswith("corrections.badPixels.subsetToUse") @@ -794,13 +740,19 @@ class AgipdCorrection(BaseCorrection): "Some fields reenabled, reloading cached bad pixel constants" ) with self.calcat_friend.cached_constants_lock: + self.kernel_runner.flush_buffers( + { + AgipdConstants.BadPixelsDark, + AgipdConstants.BadPixelsPC, + AgipdConstants.BadPixelsFF, + } + ) for ( constant, data, ) in self.calcat_friend.cached_constants.items(): if "BadPixels" in constant.name: self._load_constant_to_runner(constant, data) - self._update_bad_pixel_selection() self.kernel_runner.override_bad_pixel_flags_to_use( self._override_bad_pixel_flags ) diff --git a/src/calng/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py similarity index 74% rename from src/calng/DsscCorrection.py rename to src/calng/corrections/DsscCorrection.py index c6605dcf535417f1fa415c50ee10243c4b0da4be..6c1b2d0ed7246bf61c17d5781244e775e36064e0 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/corrections/DsscCorrection.py @@ -1,17 +1,15 @@ import enum -import cupy import numpy as np from karabo.bound import ( - DOUBLE_ELEMENT, KARABO_CLASSINFO, + OUTPUT_CHANNEL, OVERWRITE_ELEMENT, VECTOR_STRING_ELEMENT, ) -from . import base_calcat, base_kernel_runner, utils -from ._version import version as deviceVersion -from .base_correction import BaseCorrection, add_correction_step_schema +from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils +from .._version import version as deviceVersion class CorrectionFlags(enum.IntFlag): @@ -25,23 +23,25 @@ class DsscConstants(enum.Enum): class DsscGpuRunner(base_kernel_runner.BaseGpuRunner): _kernel_source_filename = "dssc_gpu.cu" - _corrected_axis_order = "cyx" + _corrected_axis_order = "fyx" def __init__( self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype=np.uint16, output_data_dtype=np.float32, ): - self.input_shape = (memory_cells, pixels_y, pixels_x) + global cupy + import cupy + self.input_shape = (frames, pixels_y, pixels_x) self.processed_shape = self.input_shape super().__init__( pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype, output_data_dtype, @@ -49,18 +49,15 @@ class DsscGpuRunner(base_kernel_runner.BaseGpuRunner): self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x) self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) - self._init_kernels() - - self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) - self.update_block_size((1, 1, 64)) - def _get_raw_for_preview(self): - return self.input_data_gpu - - def _get_corrected_for_preview(self): - return self.processed_data_gpu + @property + def preview_data_views(self): + return ( + self.input_data_gpu, + self.processed_data_gpu, + ) def load_offset_map(self, offset_map): # can have an extra dimension for some reason @@ -88,7 +85,7 @@ class DsscGpuRunner(base_kernel_runner.BaseGpuRunner): { "pixels_x": self.pixels_x, "pixels_y": self.pixels_y, - "data_memory_cells": self.memory_cells, + "frames": self.frames, "constant_memory_cells": self.constant_memory_cells, "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), @@ -102,35 +99,28 @@ class DsscGpuRunner(base_kernel_runner.BaseGpuRunner): class DsscCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = DsscConstants - def __init__(self, device, *args, **kwargs): - super().__init__(device, *args, **kwargs) - self._constants_need_conditions = { - DsscConstants.Offset: self.dark_condition, - } + @property + def _constants_need_conditions(self): + return {DsscConstants.Offset: self.dark_condition} @staticmethod - def add_schema( - schema, - managed_keys, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): + def add_schema(schema, managed_keys): super(DsscCalcatFriend, DsscCalcatFriend).add_schema( - schema, managed_keys, "DSSC-Type", param_prefix, status_prefix + schema, managed_keys, "DSSC-Type" ) ( OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") + .key("constantParameters.memoryCells") .setNewDefaultValue(400) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") + .key("constantParameters.biasVoltage") .setNewDefaultValue(100) # TODO: proper .commit() ) - base_calcat.add_status_schema_from_enum(schema, status_prefix, DsscConstants) + base_calcat.add_status_schema_from_enum(schema, DsscConstants) def dark_condition(self): res = base_calcat.OperatingConditions() @@ -142,26 +132,31 @@ class DsscCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("DsscCorrection", deviceVersion) -class DsscCorrection(BaseCorrection): +class DsscCorrection(base_correction.BaseCorrection): # subclass *must* set these attributes _correction_flag_class = CorrectionFlags - _correction_field_names = (("offset", CorrectionFlags.OFFSET),) + _correction_steps = (("offset", CorrectionFlags.OFFSET, {DsscConstants.Offset}),) _kernel_runner_class = DsscGpuRunner _calcat_friend_class = DsscCalcatFriend _constant_enum_class = DsscConstants - _managed_keys = BaseCorrection._managed_keys.copy() + _managed_keys = base_correction.BaseCorrection._managed_keys.copy() @staticmethod def expectedParameters(expected): ( + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(schemas.xtdf_output_schema()) + .commit(), + OVERWRITE_ELEMENT(expected) - .key("dataFormat.memoryCells") + .key("dataFormat.frames") .setNewDefaultValue(400) .commit(), OVERWRITE_ELEMENT(expected) .key("dataFormat.outputAxisOrder") - .setNewDefaultValue("cyx") + .setNewDefaultValue("fyx") .commit(), OVERWRITE_ELEMENT(expected) @@ -170,10 +165,10 @@ class DsscCorrection(BaseCorrection): .commit(), ) DsscCalcatFriend.add_schema(expected, DsscCorrection._managed_keys) - add_correction_step_schema( + base_correction.add_correction_step_schema( expected, DsscCorrection._managed_keys, - DsscCorrection._correction_field_names, + DsscCorrection._correction_steps, ) ( VECTOR_STRING_ELEMENT(expected) @@ -186,7 +181,7 @@ class DsscCorrection(BaseCorrection): @property def input_data_shape(self): return ( - self.get("dataFormat.memoryCells"), + self.get("dataFormat.frames"), 1, self.get("dataFormat.pixelsY"), self.get("dataFormat.pixelsX"), @@ -200,7 +195,6 @@ class DsscCorrection(BaseCorrection): train_id, image_data, cell_table, - do_generate_preview, ): pulse_table = np.ravel(data_hash.get("image.pulseId")) if self._frame_filter is not None: @@ -230,20 +224,23 @@ class DsscCorrection(BaseCorrection): output_order=self.unsafe_get("dataFormat.outputAxisOrder"), out=buffer_array, ) - if do_generate_preview: + with self.warning_context( + "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS + ) as warn: if self._correction_flag_enabled != self._correction_flag_preview: self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, preview_pulse, - ) = utils.pick_frame_index( + ), preview_warning = utils.pick_frame_index( self.unsafe_get("preview.selectionMode"), self.unsafe_get("preview.index"), cell_table, pulse_table, - warn_func=self.log_status_warn, ) + if preview_warning is not None: + warn(preview_warning) preview_raw, preview_corrected = self.kernel_runner.compute_previews( preview_slice_index, ) @@ -253,20 +250,13 @@ class DsscCorrection(BaseCorrection): data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) data_hash.set("calngShmemPaths", [self._image_data_path]) self._write_output(data_hash, metadata) - if do_generate_preview: - self._write_preview_outputs( - ( - ("preview.outputRaw", preview_raw), - ("preview.outputCorrected", preview_corrected), - ), - metadata, - ) + self._write_preview_outputs( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), + ), + metadata, + ) def _load_constant_to_runner(self, constant, constant_data): - assert constant is DsscConstants.Offset self.kernel_runner.load_offset_map(constant_data) - if not self.get("corrections.offset.available"): - self.set("corrections.offset.available", True) - - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name} to GPU") diff --git a/src/calng/Gotthard2Correction.py b/src/calng/corrections/Gotthard2Correction.py similarity index 66% rename from src/calng/Gotthard2Correction.py rename to src/calng/corrections/Gotthard2Correction.py index f561ec1732d390c95c60a7988797933af0b5c76b..5db3b05919e7e58034c0fc10eba6ae9fd90ed0d0 100644 --- a/src/calng/Gotthard2Correction.py +++ b/src/calng/corrections/Gotthard2Correction.py @@ -3,43 +3,22 @@ import enum import numpy as np from karabo.bound import ( FLOAT_ELEMENT, - IMAGEDATA_ELEMENT, KARABO_CLASSINFO, - NODE_ELEMENT, OUTPUT_CHANNEL, OVERWRITE_ELEMENT, - UINT64_ELEMENT, VECTOR_STRING_ELEMENT, Dims, Encoding, - Hash, ImageData, - Schema, ) -from . import base_calcat, utils -from ._version import version as deviceVersion -from .base_correction import ( - BaseCorrection, - add_correction_step_schema, - preview_schema, -) -from .base_kernel_runner import BaseKernelRunner +from .. import base_calcat, base_correction, base_kernel_runner, schemas, utils +from .._version import version as deviceVersion _pretend_pulse_table = np.arange(2720, dtype=np.uint8) -streak_preview_schema = Schema() -( - NODE_ELEMENT(streak_preview_schema).key("image").commit(), - - IMAGEDATA_ELEMENT(streak_preview_schema).key("image.data").commit(), - - UINT64_ELEMENT(streak_preview_schema).key("trainId").readOnly().commit(), -) - - class Gotthard2Constants(enum.Enum): Lut = enum.auto() Offset = enum.auto() @@ -53,12 +32,12 @@ class CorrectionFlags(enum.IntFlag): GAIN = 4 -class Gotthard2CpuRunner(BaseKernelRunner): +class Gotthard2CpuRunner(base_kernel_runner.BaseKernelRunner): def __init__( self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype=np.uint16, output_data_dtype=np.float32, @@ -67,16 +46,16 @@ class Gotthard2CpuRunner(BaseKernelRunner): super().__init__( pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype, output_data_dtype, ) - from .kernels import gotthard2_cython + from ..kernels import gotthard2_cython self.correction_kernel = gotthard2_cython.correct - self.input_shape = (memory_cells, pixels_x) + self.input_shape = (frames, pixels_x) self.processed_shape = self.input_shape # model: 2 buffers (corresponding to actual memory cells), 2720 frames # lut maps from uint12 to uint10 values @@ -91,18 +70,17 @@ class Gotthard2CpuRunner(BaseKernelRunner): self.input_gain_stage = None # will just point to data coming in self.processed_data = None # will just point to buffer we're given - def _get_raw_for_preview(self): - return self.input_data - - def _get_corrected_for_preview(self): - return self.processed_data + @property + def input_data_views(self): + return (self.input_data, self.processed_data) def load_data(self, image_data, input_gain_stage): """Experiment: loading both in one function as they are tied""" self.input_data = image_data.astype(np.uint16, copy=False) self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False) - def flush_buffers(self): + def flush_buffers(self, constants=None): + # for now, always flushing all here... default_lut = ( np.arange(2 ** 12).astype(np.float64) * 2 ** 10 / 2 ** 12 ).astype(np.uint16) @@ -135,51 +113,58 @@ class Gotthard2CalcatFriend(base_calcat.BaseCalcatFriend): self._constants_need_conditions = {} # TODO @staticmethod - def add_schema( - schema, - managed_keys, - param_prefix="constantParameters", - status_prefix="foundConstants", - ): + def add_schema(schema, managed_keys): super(Gotthard2CalcatFriend, Gotthard2CalcatFriend).add_schema( - schema, managed_keys, "gotthard-Type", param_prefix, status_prefix + schema, managed_keys, "gotthard-Type" ) # set some defaults for common parameters ( OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsX") + .key("constantParameters.pixelsX") .setNewDefaultValue(1280) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsY") + .key("constantParameters.pixelsY") .setNewDefaultValue(1) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") + .key("constantParameters.memoryCells") .setNewDefaultValue(2) .commit(), ) base_calcat.add_status_schema_from_enum( - schema, status_prefix, Gotthard2Constants + schema, Gotthard2Constants ) @KARABO_CLASSINFO("Gotthard2Correction", deviceVersion) -class Gotthard2Correction(BaseCorrection): +class Gotthard2Correction(base_correction.BaseCorrection): _correction_flag_class = CorrectionFlags - _correction_field_names = ( - ("lut", CorrectionFlags.LUT), - ("offset", CorrectionFlags.OFFSET), - ("gain", CorrectionFlags.GAIN), + _correction_steps = ( + ( + "lut", + CorrectionFlags.LUT, + {Gotthard2Constants.Lut}, + ), + ( + "offset", + CorrectionFlags.OFFSET, + {Gotthard2Constants.Offset}, + ), + ( + "gain", + CorrectionFlags.GAIN, + {Gotthard2Constants.Gain}, + ), ) _kernel_runner_class = Gotthard2CpuRunner _calcat_friend_class = Gotthard2CalcatFriend _constant_enum_class = Gotthard2Constants - _managed_keys = BaseCorrection._managed_keys.copy() + _managed_keys = base_correction.BaseCorrection._managed_keys.copy() _image_data_path = "data.adc" _cell_table_path = "data.memoryCell" _warn_memory_cell_range = False # for now, receiver always writes 255 @@ -189,6 +174,11 @@ class Gotthard2Correction(BaseCorrection): def expectedParameters(expected): super(Gotthard2Correction, Gotthard2Correction).expectedParameters(expected) ( + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(schemas.jf_output_schema(use_shmem_handle=False)) + .commit(), + OVERWRITE_ELEMENT(expected) .key("dataFormat.pixelsX") .setNewDefaultValue(1280) @@ -200,7 +190,7 @@ class Gotthard2Correction(BaseCorrection): .commit(), OVERWRITE_ELEMENT(expected) - .key("dataFormat.memoryCells") + .key("dataFormat.frames") .setNewDefaultValue(2720) # note: actually just frames... .commit(), @@ -218,12 +208,12 @@ class Gotthard2Correction(BaseCorrection): ( OUTPUT_CHANNEL(expected) .key("preview.outputStreak") - .dataSchema(streak_preview_schema) + .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True)) .commit(), OUTPUT_CHANNEL(expected) .key("preview.outputFrameSums") - .dataSchema(preview_schema) + .dataSchema(schemas.preview_schema()) .commit(), FLOAT_ELEMENT(expected) @@ -236,10 +226,10 @@ class Gotthard2Correction(BaseCorrection): ) Gotthard2CalcatFriend.add_schema(expected, Gotthard2Correction._managed_keys) - add_correction_step_schema( + base_correction.add_correction_step_schema( expected, Gotthard2Correction._managed_keys, - Gotthard2Correction._correction_field_names, + Gotthard2Correction._correction_steps, ) # mandatory: manager needs this in schema @@ -254,31 +244,31 @@ class Gotthard2Correction(BaseCorrection): @property def input_data_shape(self): return ( - self.unsafe_get("dataFormat.memoryCells"), + self.unsafe_get("dataFormat.frames"), self.unsafe_get("dataFormat.pixelsX"), ) @property def output_data_shape(self): return ( - self.unsafe_get("dataFormat.memoryCells"), + self.unsafe_get("dataFormat.frames"), self.unsafe_get("dataFormat.pixelsX"), ) + @property + def _kernel_runner_init_args(self): + return {"bad_pixel_mask_value": self.bad_pixel_mask_value} + + @property + def bad_pixel_mask_value(self): + return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) + def __init__(self, config): super().__init__(config) - # TODO: gain mode as constant parameter and / or device configuration - try: - self.bad_pixel_mask_value = np.float32( - config.get("corrections.badPixels.maskingValue") - ) + np.float32(config.get("corrections.badPixels.maskingValue")) except ValueError: - self.bad_pixel_mask_value = np.float32("nan") - - self._kernel_runner_init_args = { - "bad_pixel_mask_value": self.bad_pixel_mask_value, - } + config["corrections.badPixels.maskingValue"] = "nan" def process_data( self, @@ -288,40 +278,40 @@ class Gotthard2Correction(BaseCorrection): train_id, image_data, cell_table, - do_generate_preview, ): # cell table currently not used for GOTTHARD2 (assume alternating) gain_map = np.asarray(data_hash.get("data.gain")) if self.unsafe_get("dataFormat.overrideInputAxisOrder"): gain_map.shape = self.input_data_shape try: - self.kernel_runner.load_data( - image_data, gain_map - ) + self.kernel_runner.load_data(image_data, gain_map) except Exception as e: self.log_status_warn(f"Unknown exception when loading data: {e}") buffer_handle, buffer_array = self._shmem_buffer.next_slot() self.kernel_runner.correct(self._correction_flag_enabled, out=buffer_array) - if do_generate_preview: + with self.warning_context( + "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS + ) as warn: if self._correction_flag_enabled != self._correction_flag_preview: self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, preview_pulse, - ) = utils.pick_frame_index( + ), preview_warning = utils.pick_frame_index( self.unsafe_get("preview.selectionMode"), self.unsafe_get("preview.index"), cell_table, _pretend_pulse_table, - warn_func=self.log_status_warn, ) - ( - preview_raw, - preview_corrected, - ) = self.kernel_runner.compute_previews(preview_slice_index) + if preview_warning is not None: + warn(preview_warning) + ( + preview_raw, + preview_corrected, + ) = self.kernel_runner.compute_previews(preview_slice_index) # reusing input data hash for sending data_hash.set(self._image_data_path, buffer_handle) @@ -329,47 +319,39 @@ class Gotthard2Correction(BaseCorrection): self._write_output(data_hash, metadata) - if do_generate_preview: - streak_preview = buffer_array.copy() - replacement = self.unsafe_get("preview.replaceNanWith") - streak_preview[np.isnan(streak_preview)] = replacement - streak_preview[np.isinf(streak_preview)] = replacement - frame_sums = np.sum(streak_preview, axis=1) - self._write_preview_outputs( + streak_preview = buffer_array.copy() + streak_preview = np.nan_to_num( + streak_preview, + copy=False, + nan=self.unsafe_get("preview.replaceNanWith"), + ) + frame_sums = np.sum(streak_preview, axis=1) + self._write_preview_outputs( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), ( - ("preview.outputRaw", preview_raw), - ("preview.outputCorrected", preview_corrected), - ( - "preview.outputStreak", - ImageData( - streak_preview, - Dims(*streak_preview.shape), - Encoding.GRAY, - bitsPerPixel=32, - ), + "preview.outputStreak", + ImageData( + streak_preview, + Dims(*streak_preview.shape), + Encoding.GRAY, + bitsPerPixel=32, ), - ("preview.outputFrameSums", frame_sums), ), - metadata, - ) + ("preview.outputFrameSums", frame_sums), + ), + metadata, + ) def _load_constant_to_runner(self, constant, constant_data): if constant is Gotthard2Constants.Lut: self.kernel_runner.lut[:] = constant_data.astype(np.uint16, copy=False) - if not self.get("corrections.lut.available"): - self.set("corrections.lut.available", True) elif constant is Gotthard2Constants.Offset: self.kernel_runner.offset_map[:] = constant_data.astype( np.float32, copy=False ) - if not self.get("corrections.offset.available"): - self.set("corrections.offset.available", True) elif constant is Gotthard2Constants.Gain: self.kernel_runner.rel_gain_map[:] = constant_data.astype( np.float32, copy=False ) - if not self.get("corrections.gain.available"): - self.set("corrections.gain.available", True) - - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name}") diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py new file mode 100644 index 0000000000000000000000000000000000000000..77eeabcd73cbc0cde41dd664127e470abf0c4522 --- /dev/null +++ b/src/calng/corrections/JungfrauCorrection.py @@ -0,0 +1,813 @@ +import concurrent.futures +import enum +import functools + +import numpy as np +from karabo.bound import ( + DOUBLE_ELEMENT, + KARABO_CLASSINFO, + OUTPUT_CHANNEL, + OVERWRITE_ELEMENT, + STRING_ELEMENT, + VECTOR_STRING_ELEMENT, + Schema, +) + +from .. import ( + base_calcat, + base_correction, + base_kernel_runner, + schemas, + preview_utils, + utils, +) +from .._version import version as deviceVersion + + +_pretend_pulse_table = np.arange(16, dtype=np.uint8) + + +class JungfrauConstants(enum.Enum): + Offset10Hz = enum.auto() + BadPixelsDark10Hz = enum.auto() + BadPixelsFF10Hz = enum.auto() + RelativeGain10Hz = enum.auto() + + +# from pycalibration (TOOD: move to common shared lib) +class JungfrauGainMode(enum.Enum): + DYNAMIC_GAIN = "dynamicgain" + DYNAMIC_GAIN_HG0 = "dynamichg0" + FIX_GAIN_1 = "fixgain1" + FIX_GAIN_2 = "fixgain2" + FORCE_SWITCH_HG1 = "forceswitchg1" + FORCE_SWITCH_HG2 = "forceswitchg2" + + +class CorrectionFlags(enum.IntFlag): + NONE = 0 + OFFSET = 1 + REL_GAIN = 2 + BPMASK = 4 + STRIXEL = 8 + + +class KernelRunnerVersions(enum.Enum): + GPU = enum.auto() + CPU = enum.auto() + + +class JungfrauGpuRunner(base_kernel_runner.BaseGpuRunner): + _kernel_source_filename = "jungfrau_gpu.cu" + _corrected_axis_order = "fyx" + + def __init__( + self, + pixels_x, + pixels_y, + frames, + constant_memory_cells, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + bad_pixel_mask_value=np.nan, + ): + global cupy + import cupy + self.input_shape = (frames, pixels_y, pixels_x) + self.processed_shape = self.input_shape + self.update_block_size((1, 1, 64)) + + super().__init__( + pixels_x, + pixels_y, + frames, + constant_memory_cells, + input_data_dtype, + output_data_dtype, + ) + # note: superclass creates cell table with wrong dtype + self.cell_table_gpu = cupy.empty(self.frames, dtype=np.uint8) + self.input_gain_stage_gpu = cupy.empty(self.input_shape, dtype=np.uint8) + self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3) + self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=np.float32) + self.rel_gain_map_gpu = cupy.ones(self.map_shape, dtype=np.float32) + self.bad_pixel_map_gpu = cupy.zeros(self.map_shape, dtype=np.uint32) + self.bad_pixel_mask_value = bad_pixel_mask_value + + # strixel support + self._strixel_out_shape = (frames, 86, 3090) + self._strixel_block = ((1, 1, 64)) + # note: only executing kernel on lower half of y range, hence 256 + self._strixel_grid = tuple( + utils.ceil_div(a_length, block_length) + for (a_length, block_length) in zip( + (frames, 256, 1024), self._strixel_block + ) + ) + self._processed_data_regular_gpu = self.processed_data_gpu + self._processed_data_strixel_gpu = cupy.empty( + self._strixel_out_shape, dtype=output_data_dtype + ) + + def _init_kernels(self): + kernel_source = self._kernel_template.render( + { + "pixels_x": self.pixels_x, + "pixels_y": self.pixels_y, + "frames": self.frames, + "constant_memory_cells": self.constant_memory_cells, + "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), + "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), + "corr_enum": utils.enum_to_c_template(CorrectionFlags), + "burst_mode": self.burst_mode, + } + ) + self.source_module = cupy.RawModule(code=kernel_source) + self.correction_kernel = self.source_module.get_function("correct") + self.strixel_transform_kernel = self.source_module.get_function( + "strixel_transform" + ) + + @property + def burst_mode(self): + return self.frames > 1 + + @property + def preview_data_views(self): + return (self.input_data_gpu, self.processed_data_gpu, self.input_gain_stage_gpu) + + def load_data(self, image_data, input_gain_stage, cell_table): + """Experiment: loading all three in one function as they are tied""" + self.input_data_gpu.set(image_data) + self.input_gain_stage_gpu.set(input_gain_stage) + if self.burst_mode: + self.cell_table_gpu.set(cell_table) + + def flush_buffers(self, constants): + if JungfrauConstants.Offset10Hz in constants: + self.offset_map_gpu.fill(0) + if JungfrauConstants.RelativeGain10Hz in constants: + self.rel_gain_map_gpu.fill(1) + if constants & { + JungfrauConstants.BadPixelsDark10Hz, + JungfrauConstants.BadPixelsFF10Hz, + }: + self.bad_pixel_map_gpu.fill(0) + self.bad_pixel_map_gpu[ + :, :, 255:1023:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map_gpu[ + :, :, 256:1024:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map_gpu[ + :, [255, 256] + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + + def override_bad_pixel_flags_to_use(self, override_value): + self.bad_pixel_map_gpu &= cupy.uint32(override_value) + + def correct(self, flags): + self.correction_kernel( + self.full_grid, + self.full_block, + ( + self.input_data_gpu, + self.input_gain_stage_gpu, + self.cell_table_gpu, + cupy.uint8(flags), + self.offset_map_gpu, + self.rel_gain_map_gpu, + self.bad_pixel_map_gpu, + self.bad_pixel_mask_value, + self._processed_data_regular_gpu, + ), + ) + if flags & CorrectionFlags.STRIXEL: + self.strixel_transform_kernel( + self._strixel_grid, + self._strixel_block, + ( + self._processed_data_regular_gpu, + self._processed_data_strixel_gpu, + ), + ) + self.processed_data_gpu = self._processed_data_strixel_gpu + else: + self.processed_data_gpu = self._processed_data_regular_gpu + + +class JungfrauCpuRunner(base_kernel_runner.BaseKernelRunner): + _corrected_axis_order = "fyx" + + def __init__( + self, + pixels_x, + pixels_y, + frames, + constant_memory_cells, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, # TODO: configurable + bad_pixel_mask_value=np.nan, + ): + self.input_shape = (frames, pixels_y, pixels_x) + self.preview_shape = (pixels_y, pixels_x) + self.processed_shape = self.input_shape + self._strixel_out_shape = (frames, 86, 3090) + + super().__init__( + pixels_x, + pixels_y, + frames, + constant_memory_cells, + input_data_dtype, + output_data_dtype, + ) + + # not actually allocating, will just point to incoming data + self.input_data = None + self.input_gain_stage = None + self.processed_data = None + self._processed_data_regular = np.empty( + self.processed_shape, dtype=output_data_dtype + ) + self._processed_data_strixel = np.empty( + self._strixel_out_shape, dtype=output_data_dtype + ) + + # for computing previews faster + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=3) + + from ..kernels import jungfrau_cython + self.correction_kernel_single = jungfrau_cython.correct_single + self.correction_kernel_burst = jungfrau_cython.correct_burst + self.correction_kernel_strixel = jungfrau_cython.strixel_transform + + self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x, 3) + self.offset_map = np.empty(self.map_shape, dtype=np.float32) + self.rel_gain_map = np.empty(self.map_shape, dtype=np.float32) + self.bad_pixel_map = np.empty(self.map_shape, dtype=np.uint32) + self.bad_pixel_mask_value = bad_pixel_mask_value + self.flush_buffers(set(JungfrauConstants)) + + def __del__(self): + self.thread_pool.shutdown() + + @property + def preview_data_views(self): + return (self.input_data, self.processed_data, self.input_gain_stage) + + @property + def burst_mode(self): + return self.frames > 1 + + def load_data(self, image_data, input_gain_stage, cell_table): + self.input_data = image_data.astype(np.uint16, copy=False) + self.input_gain_stage = input_gain_stage.astype(np.uint8, copy=False) + self.input_cell_table = cell_table.astype(np.uint8, copy=False) + + def flush_buffers(self, constants): + if JungfrauConstants.Offset10Hz in constants: + self.offset_map.fill(0) + if JungfrauConstants.RelativeGain10Hz in constants: + self.rel_gain_map.fill(1) + if constants & { + JungfrauConstants.BadPixelsDark10Hz, + JungfrauConstants.BadPixelsFF10Hz, + }: + self.bad_pixel_map.fill(0) + self.bad_pixel_map[ + :, :, 255:1023:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map[ + :, :, 256:1024:256 + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + self.bad_pixel_map[ + :, [255, 256] + ] |= utils.BadPixelValues.NON_STANDARD_SIZE.value + + def override_bad_pixel_flags_to_use(self, override_value): + self.bad_pixel_map &= np.uint32(override_value) + + def correct(self, flags): + if self.burst_mode: + self.correction_kernel_burst( + self.input_data, + self.input_gain_stage, + self.input_cell_table, + flags, + self.offset_map, + self.rel_gain_map, + self.bad_pixel_map, + self.bad_pixel_mask_value, + self._processed_data_regular, + ) + else: + self.correction_kernel_single( + self.input_data, + self.input_gain_stage, + flags, + self.offset_map, + self.rel_gain_map, + self.bad_pixel_map, + self.bad_pixel_mask_value, + self._processed_data_regular, + ) + + if flags & CorrectionFlags.STRIXEL: + self.correction_kernel_strixel( + self._processed_data_regular, self._processed_data_strixel + ) + self.processed_data = self._processed_data_strixel + else: + self.processed_data = self._processed_data_regular + + def compute_previews(self, preview_index): + if preview_index < -4: + raise ValueError(f"No statistic with code {preview_index} defined") + elif preview_index >= self.frames: + raise ValueError(f"Memory cell index {preview_index} out of range") + + if preview_index >= 0: + def fun(a): + return a[preview_index] + elif preview_index == -1: + # note: separate from next case because dtype not applicable here + fun = functools.partial(np.nanmax, axis=0) + elif preview_index in (-2, -3, -4): + fun = functools.partial( + { + -2: np.nanmean, + -3: np.nansum, + -4: np.nanstd, + }[preview_index], + axis=0, + dtype=np.float32, + ) + return self.thread_pool.map(fun, self.preview_data_views) + + +class JungfrauCalcatFriend(base_calcat.BaseCalcatFriend): + _constant_enum_class = JungfrauConstants + + @property + def _constants_need_conditions(self): + return { + JungfrauConstants.Offset10Hz: self.dark_condition, + JungfrauConstants.BadPixelsDark10Hz: self.dark_condition, + JungfrauConstants.BadPixelsFF10Hz: self.dark_condition, + JungfrauConstants.RelativeGain10Hz: self.dark_condition, + } + + @staticmethod + def add_schema(schema, managed_keys): + super(JungfrauCalcatFriend, JungfrauCalcatFriend).add_schema( + schema, managed_keys, "jungfrau-Type" + ) + + # set some defaults for common parameters + ( + OVERWRITE_ELEMENT(schema) + .key("constantParameters.pixelsX") + .setNewDefaultValue(1024) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key("constantParameters.pixelsY") + .setNewDefaultValue(512) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key("constantParameters.memoryCells") + .setNewDefaultValue(1) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key("constantParameters.biasVoltage") + .setNewDefaultValue(90) + .commit(), + ) + + # add extra parameters + ( + DOUBLE_ELEMENT(schema) + .key("constantParameters.integrationTime") + .displayedName("Integration time") + .description("Integration time in ms") + .assignmentOptional() + .defaultValue(350) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key("constantParameters.sensorTemperature") + .displayedName("Sensor temperature") + .description("Sensor temperature in K") + .assignmentOptional() + .defaultValue(291) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key("constantParameters.gainMode") + .displayedName("Gain mode") + .description( + "Detector may be operating in one of several gain modes. In CalCat, " + "these map to two CalCat parameters: 'Gain Setting' which is 1 for " + "dynamic gain with HG0 (0 otherwise) and 'Gain mode' which is 1 " + "for fixed gain (omitted otherwise)." + ) + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key("constantParameters.gainSetting") + .displayedName("Gain setting") + .description("See description of gainMode") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + ) + managed_keys.add("constantParameters.integrationTime") + managed_keys.add("constantParameters.sensorTemperature") + managed_keys.add("constantParameters.gainMode") + managed_keys.add("constantParameters.gainSetting") + + base_calcat.add_status_schema_from_enum( + schema, JungfrauConstants + ) + + def dark_condition(self): + res = base_calcat.OperatingConditions() + res["Memory cells"] = self._get_param("memoryCells") + res["Sensor Bias Voltage"] = self._get_param("biasVoltage") + res["Pixels X"] = self._get_param("pixelsX") + res["Pixels Y"] = self._get_param("pixelsY") + res["Integration Time"] = self._get_param("integrationTime") + res["Sensor Temperature"] = self._get_param("sensorTemperature") + + if self._get_param("gainMode") != 0: + # NOTE: always include if CalCat is updated for this + res["Gain mode"] = self._get_param("gainMode") + res["Gain Setting"] = self._get_param("gainSetting") + + return res + + +@KARABO_CLASSINFO("JungfrauCorrection", deviceVersion) +class JungfrauCorrection(base_correction.BaseCorrection): + _correction_flag_class = CorrectionFlags + _correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {JungfrauConstants.Offset10Hz}), + ("relGain", CorrectionFlags.REL_GAIN, {JungfrauConstants.RelativeGain10Hz}), + ( + "badPixels", + CorrectionFlags.BPMASK, + { + JungfrauConstants.BadPixelsDark10Hz, + JungfrauConstants.BadPixelsFF10Hz, + None, + } + ), + ( + "strixel", + CorrectionFlags.STRIXEL, set() + ) + ) + _kernel_runner_class = None # note: set in __init__ based on config + _calcat_friend_class = JungfrauCalcatFriend + _constant_enum_class = JungfrauConstants + _managed_keys = base_correction.BaseCorrection._managed_keys.copy() + _image_data_path = "data.adc" + _cell_table_path = "data.memoryCell" + + @staticmethod + def expectedParameters(expected): + ( + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(schemas.jf_output_schema()) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("dataFormat.pixelsX") + .setNewDefaultValue(1024) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("dataFormat.pixelsY") + .setNewDefaultValue(512) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("dataFormat.frames") + .setNewDefaultValue(1) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("preview.selectionMode") + .setNewDefaultValue("frame") + .commit(), + + # JUNGFRAU data is small, can fit plenty of trains in here + OVERWRITE_ELEMENT(expected) + .key("outputShmemBufferSize") + .setNewDefaultValue(2) + .commit(), + ) + ( + # support both CPU and GPU kernels + STRING_ELEMENT(expected) + .key("kernelType") + .assignmentOptional() + .defaultValue(KernelRunnerVersions.GPU.name) + .options(",".join(kernel_type.name for kernel_type in KernelRunnerVersions)) + .reconfigurable() + .commit(), + ) + JungfrauCorrection._managed_keys.add("kernelType") + base_correction.add_correction_step_schema( + expected, + JungfrauCorrection._managed_keys, + JungfrauCorrection._correction_steps, + ) + ( + OVERWRITE_ELEMENT(expected) + .key("corrections.strixel.enable") + .setNewDefaultValue(False) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("corrections.strixel.preview") + .setNewDefaultValue(False) + .commit(), + ) + JungfrauCalcatFriend.add_schema(expected, JungfrauCorrection._managed_keys) + base_correction.add_bad_pixel_config_node( + expected, JungfrauCorrection._managed_keys + ) + ( + OUTPUT_CHANNEL(expected) + .key("preview.outputGainMap") + .dataSchema(schemas.preview_schema()) + .commit(), + ) + + # mandatory: manager needs this in schema + ( + VECTOR_STRING_ELEMENT(expected) + .key("managedKeys") + .assignmentOptional() + .defaultValue(list(JungfrauCorrection._managed_keys)) + .commit() + ) + + @property + def input_data_shape(self): + return ( + self.unsafe_get("dataFormat.frames"), + self.unsafe_get("dataFormat.pixelsY"), + self.unsafe_get("dataFormat.pixelsX"), + ) + + @property + def _warn_memory_cell_range(self): + # disable warning in normal operation as cell 15 is expected + return self.unsafe_get("dataFormat.frames") > 1 + + @property + def output_data_shape(self): + if self._correction_flag_enabled & CorrectionFlags.STRIXEL: + axis_lengths = { + "x": 3090, + "y": 86, + "f": self.unsafe_get("dataFormat.filteredFrames"), + } + return tuple( + axis_lengths[axis] + for axis in self.unsafe_get("dataFormat.outputAxisOrder") + ) + return super().output_data_shape + + @property + def _kernel_runner_class(self): + kernel_type = KernelRunnerVersions[self.unsafe_get("kernelType")] + if kernel_type is KernelRunnerVersions.CPU: + return JungfrauCpuRunner + else: + return JungfrauGpuRunner + + @property + def _kernel_runner_init_args(self): + return { + "bad_pixel_mask_value": self.bad_pixel_mask_value, + } + + @property + def bad_pixel_mask_value(self): + return np.float32(self.unsafe_get("corrections.badPixels.maskingValue")) + + _override_bad_pixel_flags = property(base_correction.get_bad_pixel_field_selection) + + def __init__(self, config): + super().__init__(config) + try: + np.float32(config.get("corrections.badPixels.maskingValue", default="nan")) + except ValueError: + config["corrections.badPixels.maskingValue"] = "nan" + + if config.get("runAsStandaloneModule", default=False): + schema_override = Schema() + ( + OUTPUT_CHANNEL(schema_override) + .key("dataOutput") + .dataSchema(schemas.jf_output_schema(use_shmem_handle=False)) + .commit(), + ) + preview_utils.PreviewFriend.add_schema( + schema_override, + output_channels=["outputRaw", "outputCorrected", "outputGainMap"], + ) + self.updateSchema(schema_override) + + def aux(): + self._preview_friend = preview_utils.PreviewFriend( + self, + output_channels=["outputRaw", "outputCorrected", "outputGainMap"], + ) + self["availableScenes"] = self["availableScenes"] + [ + "preview:outputRaw", + "preview:outputAssembled", + "preview:outputGainMap", + ] + + self.registerInitialFunction(aux) + + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + ): + if len(cell_table.shape) == 0: + cell_table = cell_table[np.newaxis] + try: + gain_map = data_hash["data.gain"] + if self.unsafe_get("dataFormat.overrideInputAxisOrder"): + gain_map.shape = self.input_data_shape + self.kernel_runner.load_data( + image_data, gain_map, cell_table + ) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") + return + except Exception as e: + self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") + + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + self.kernel_runner.correct(self._correction_flag_enabled) + self.kernel_runner.reshape( + output_order=self.unsafe_get("dataFormat.outputAxisOrder"), + out=buffer_array, + ) + + with self.warning_context( + "processingState", base_correction.WarningLampType.PREVIEW_SETTINGS + ) as warn: + if self._correction_flag_enabled != self._correction_flag_preview: + self.kernel_runner.correct(self._correction_flag_preview) + ( + preview_slice_index, + preview_cell, + preview_pulse, + ), preview_warning = utils.pick_frame_index( + self.unsafe_get("preview.selectionMode"), + self.unsafe_get("preview.index"), + cell_table, + _pretend_pulse_table, + ) + if preview_warning is not None: + warn(preview_warning) + ( + preview_raw, + preview_corrected, + preview_gain_map, + ) = self.kernel_runner.compute_previews(preview_slice_index) + + # reusing input data hash for sending + if self.unsafe_get("runAsStandaloneModule"): + # TODO: use shmem for data.gain, too + data_hash.set(self._image_data_path, buffer_array) + data_hash.set("calngShmemPaths", []) + else: + data_hash.set(self._image_data_path, buffer_handle) + data_hash.set("calngShmemPaths", [self._image_data_path]) + + self._write_output(data_hash, metadata) + + if self.unsafe_get("runAsStandaloneModule"): + self._preview_friend.maybe_write( + [preview_raw, preview_corrected, preview_gain_map] + ) + else: + self._write_preview_outputs( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), + ("preview.outputGainMap", preview_gain_map), + ), + metadata, + ) + + @property + def _kernel_type(self): + return KernelRunnerVersions[self.unsafe_get("kernelType")] + + def _load_constant_to_runner(self, constant, constant_data): + if constant_data.shape[0] == self.get("dataFormat.pixelsX"): + constant_data = np.transpose(constant_data, (2, 1, 0, 3)) + else: + constant_data = np.transpose(constant_data, (2, 0, 1, 3)) + + if constant is JungfrauConstants.Offset10Hz: + if self._kernel_type is KernelRunnerVersions.CPU: + self.kernel_runner.offset_map[:] = constant_data.astype(np.float32) + else: + self.kernel_runner.offset_map_gpu.set(constant_data.astype(np.float32)) + if not self.get("corrections.offset.available"): + self.set("corrections.offset.available", True) + elif constant is JungfrauConstants.RelativeGain10Hz: + if self._kernel_type is KernelRunnerVersions.CPU: + self.kernel_runner.rel_gain_map[:] = constant_data.astype(np.float32) + else: + self.kernel_runner.rel_gain_map_gpu.set( + constant_data.astype(np.float32) + ) + if not self.get("corrections.relGain.available"): + self.set("corrections.relGain.available", True) + elif constant in ( + JungfrauConstants.BadPixelsDark10Hz, + JungfrauConstants.BadPixelsFF10Hz, + ): + if self._kernel_type is KernelRunnerVersions.CPU: + self.kernel_runner.bad_pixel_map |= constant_data + else: + self.kernel_runner.bad_pixel_map_gpu |= cupy.asarray(constant_data) + if not self.get("corrections.badPixels.available"): + self.set("corrections.badPixels.available", True) + self.kernel_runner.override_bad_pixel_flags_to_use( + self._override_bad_pixel_flags + ) + + self._update_correction_flags() + self.log_status_info(f"Done loading {constant.name} to runner") + + def postReconfigure(self): + super().postReconfigure() + + if not hasattr(self, "_prereconfigure_update_hash"): + return + + update = self._prereconfigure_update_hash + + if update.has("corrections.strixel.enable"): + self._lock_and_update(self._update_frame_filter) + + if any( + path.startswith("corrections.badPixels.subsetToUse") + for path in update.getPaths() + ): + self.log_status_info("Updating bad pixel maps based on subset specified") + if any( + update.get( + f"corrections.badPixels.subsetToUse.{field.name}", default=False + ) + for field in utils.BadPixelValues + ): + self.log_status_info( + "Some fields reenabled, reloading cached bad pixel constants" + ) + with self.calcat_friend.cached_constants_lock: + self.kernel_runner.flush_buffers( + { + JungfrauConstants.BadPixelsDark10Hz, + JungfrauConstants.BadPixelsFF10Hz, + } + ) + for ( + constant, + data, + ) in self.calcat_friend.cached_constants.items(): + if "BadPixels" in constant.name: + self._load_constant_to_runner(constant, data) + self.kernel_runner.override_bad_pixel_flags_to_use( + self._override_bad_pixel_flags + ) + + if self._preview_friend is not None: + self._preview_friend.reconfigure(update) diff --git a/src/calng/LpdCorrection.py b/src/calng/corrections/LpdCorrection.py similarity index 67% rename from src/calng/LpdCorrection.py rename to src/calng/corrections/LpdCorrection.py index 2f8321302f34df976bd85c36dfee29fc79b6fdc7..a009d62ec29d59fb9c1f903bf1caacbe795b1826 100644 --- a/src/calng/LpdCorrection.py +++ b/src/calng/corrections/LpdCorrection.py @@ -1,6 +1,5 @@ import enum -import cupy import numpy as np from karabo.bound import ( DOUBLE_ELEMENT, @@ -9,19 +8,30 @@ from karabo.bound import ( OVERWRITE_ELEMENT, STRING_ELEMENT, VECTOR_STRING_ELEMENT, + Schema, ) -from . import base_kernel_runner, base_calcat, utils -from ._version import version as deviceVersion -from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema +from .. import ( + base_calcat, + base_kernel_runner, + schemas, + preview_utils, + utils, +) +from .._version import version as deviceVersion +from ..base_correction import ( + BaseCorrection, + WarningLampType, + add_correction_step_schema, +) class LpdConstants(enum.Enum): Offset = enum.auto() - BadPixelsDark = enum.auto() GainAmpMap = enum.auto() - FFMap = enum.auto() RelativeGain = enum.auto() + FFMap = enum.auto() + BadPixelsDark = enum.auto() BadPixelsFF = enum.auto() @@ -36,49 +46,50 @@ class CorrectionFlags(enum.IntFlag): class LpdGpuRunner(base_kernel_runner.BaseGpuRunner): _kernel_source_filename = "lpd_gpu.cu" - _corrected_axis_order = "cxy" + _corrected_axis_order = "fxy" def __init__( self, pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, - input_data_dtype=cupy.uint16, - output_data_dtype=cupy.float32, - bad_pixel_mask_value=cupy.nan, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + bad_pixel_mask_value=np.nan, ): - self.input_shape = (memory_cells, 1, pixels_y, pixels_x) - self.processed_shape = (memory_cells, pixels_y, pixels_x) + global cupy + import cupy + self.input_shape = (frames, 1, pixels_y, pixels_x) + self.processed_shape = (frames, pixels_y, pixels_x) super().__init__( pixels_x, pixels_y, - memory_cells, + frames, constant_memory_cells, input_data_dtype, output_data_dtype, ) - self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=cupy.float32) - self.preview_buffer_getters.append(self._get_gain_map_for_preview) + self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=np.float32) self.map_shape = (constant_memory_cells, pixels_x, pixels_y, 3) - self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) - self.gain_amp_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) - self.rel_gain_slopes_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) - self.flatfield_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) - self.bad_pixel_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.uint32) + self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=np.float32) + self.gain_amp_map_gpu = cupy.ones(self.map_shape, dtype=np.float32) + self.rel_gain_slopes_map_gpu = cupy.ones(self.map_shape, dtype=np.float32) + self.flatfield_map_gpu = cupy.ones(self.map_shape, dtype=np.float32) + self.bad_pixel_map_gpu = cupy.zeros(self.map_shape, dtype=np.uint32) self.bad_pixel_mask_value = bad_pixel_mask_value self.update_block_size((1, 1, 64)) - def _get_raw_for_preview(self): - return self.input_data_gpu[:, 0] - - def _get_corrected_for_preview(self): - return self.processed_data_gpu - - def _get_gain_map_for_preview(self): - return self.gain_map_gpu + @property + def preview_data_views(self): + # TODO: always split off gain from raw to avoid messing up preview? + return ( + self.input_data_gpu[:, 0], # raw + self.processed_data_gpu, # corrected + self.gain_map_gpu, # gain (split from raw) + ) def correct(self, flags): self.correction_kernel( @@ -136,7 +147,7 @@ class LpdGpuRunner(base_kernel_runner.BaseGpuRunner): { "pixels_x": self.pixels_x, "pixels_y": self.pixels_y, - "data_memory_cells": self.memory_cells, + "frames": self.frames, "constant_memory_cells": self.constant_memory_cells, "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), @@ -146,20 +157,25 @@ class LpdGpuRunner(base_kernel_runner.BaseGpuRunner): self.source_module = cupy.RawModule(code=kernel_source) self.correction_kernel = self.source_module.get_function("correct") - def flush_buffers(self): - self.offset_map_gpu.fill(0) - self.gain_amp_map_gpu.fill(1) - self.rel_gain_slopes_map_gpu.fill(1) - self.flatfield_map_gpu.fill(1) - self.bad_pixel_map_gpu.fill(0) + def flush_buffers(self, constants): + if LpdConstants.Offset in constants: + self.offset_map_gpu.fill(0) + if LpdConstants.GainAmpMap in constants: + self.gain_amp_map_gpu.fill(1) + if LpdConstants.RelativeGain in constants: + self.rel_gain_slopes_map_gpu.fill(1) + if LpdConstants.FFMap in constants: + self.flatfield_map_gpu.fill(1) + if constants & {LpdConstants.BadPixelsDark, LpdConstants.BadPixelsFF}: + self.bad_pixel_map_gpu.fill(0) class LpdCalcatFriend(base_calcat.BaseCalcatFriend): _constant_enum_class = LpdConstants - def __init__(self, device, *args, **kwargs): - super().__init__(device, *args, **kwargs) - self._constants_need_conditions = { + @property + def _constants_need_conditions(self): + return { LpdConstants.Offset: self.dark_condition, LpdConstants.BadPixelsDark: self.dark_condition, LpdConstants.GainAmpMap: self.category_condition, @@ -172,63 +188,61 @@ class LpdCalcatFriend(base_calcat.BaseCalcatFriend): def add_schema( schema, managed_keys, - param_prefix="constantParameters", - status_prefix="foundConstants", ): super(LpdCalcatFriend, LpdCalcatFriend).add_schema( - schema, managed_keys, "LPD-Type", param_prefix, status_prefix + schema, managed_keys, "LPD-Type" ) ( OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsX") + .key("constantParameters.pixelsX") .setNewDefaultValue(256) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.pixelsY") + .key("constantParameters.pixelsY") .setNewDefaultValue(256) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.memoryCells") + .key("constantParameters.memoryCells") .setNewDefaultValue(512) .commit(), OVERWRITE_ELEMENT(schema) - .key(f"{param_prefix}.biasVoltage") + .key("constantParameters.biasVoltage") .setNewDefaultValue(250) .commit(), ) ( DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.feedbackCapacitor") + .key("constantParameters.feedbackCapacitor") .assignmentOptional() .defaultValue(5) .reconfigurable() .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.photonEnergy") + .key("constantParameters.photonEnergy") .assignmentOptional() .defaultValue(9.3) .reconfigurable() .commit(), DOUBLE_ELEMENT(schema) - .key(f"{param_prefix}.category") + .key("constantParameters.category") .displayedName("Category") .assignmentOptional() .defaultValue(0) .reconfigurable() .commit(), ) - managed_keys.add(f"{param_prefix}.feedbackCapacitor") - managed_keys.add(f"{param_prefix}.photonEnergy") - managed_keys.add(f"{param_prefix}.category") + managed_keys.add("constantParameters.feedbackCapacitor") + managed_keys.add("constantParameters.photonEnergy") + managed_keys.add("constantParameters.category") - base_calcat.add_status_schema_from_enum(schema, status_prefix, LpdConstants) + base_calcat.add_status_schema_from_enum(schema, LpdConstants) def dark_condition(self): res = base_calcat.OperatingConditions() @@ -254,12 +268,19 @@ class LpdCalcatFriend(base_calcat.BaseCalcatFriend): @KARABO_CLASSINFO("LpdCorrection", deviceVersion) class LpdCorrection(BaseCorrection): _correction_flag_class = CorrectionFlags - _correction_field_names = ( - ("offset", CorrectionFlags.OFFSET), - ("gainAmp", CorrectionFlags.GAIN_AMP), - ("relGain", CorrectionFlags.REL_GAIN), - ("flatfield", CorrectionFlags.FF_CORR), - ("badPixels", CorrectionFlags.BPMASK), + _correction_steps = ( + ("offset", CorrectionFlags.OFFSET, {LpdConstants.Offset}), + ("gainAmp", CorrectionFlags.GAIN_AMP, {LpdConstants.GainAmpMap}), + ("relGain", CorrectionFlags.REL_GAIN, {LpdConstants.RelativeGain}), + ("flatfield", CorrectionFlags.FF_CORR, {LpdConstants.FFMap}), + ( + "badPixels", + CorrectionFlags.BPMASK, + { + LpdConstants.BadPixelsDark, + LpdConstants.BadPixelsFF, + } + ), ) _kernel_runner_class = LpdGpuRunner _calcat_friend_class = LpdCalcatFriend @@ -269,6 +290,11 @@ class LpdCorrection(BaseCorrection): @staticmethod def expectedParameters(expected): ( + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(schemas.xtdf_output_schema()) + .commit(), + OVERWRITE_ELEMENT(expected) .key("dataFormat.pixelsX") .setNewDefaultValue(256) @@ -280,7 +306,7 @@ class LpdCorrection(BaseCorrection): .commit(), OVERWRITE_ELEMENT(expected) - .key("dataFormat.memoryCells") + .key("dataFormat.frames") .setNewDefaultValue(512) .commit(), @@ -294,14 +320,16 @@ class LpdCorrection(BaseCorrection): ( OUTPUT_CHANNEL(expected) .key("preview.outputGainMap") - .dataSchema(preview_schema) + .dataSchema(schemas.preview_schema()) .commit(), ) - LpdCalcatFriend.add_schema(expected, LpdCorrection._managed_keys) add_correction_step_schema( - expected, LpdCorrection._managed_keys, LpdCorrection._correction_field_names + expected, + LpdCorrection._managed_keys, + LpdCorrection._correction_steps, ) + LpdCalcatFriend.add_schema(expected, LpdCorrection._managed_keys) # additional settings for correction steps ( @@ -332,7 +360,7 @@ class LpdCorrection(BaseCorrection): @property def input_data_shape(self): return ( - self.unsafe_get("dataFormat.memoryCells"), + self.unsafe_get("dataFormat.frames"), 1, self.unsafe_get("dataFormat.pixelsX"), self.unsafe_get("dataFormat.pixelsY"), @@ -341,28 +369,39 @@ class LpdCorrection(BaseCorrection): def __init__(self, config): super().__init__(config) try: - bad_pixel_mask_value = np.float32( - config.get("corrections.badPixels.maskingValue") - ) + np.float32(config.get("corrections.badPixels.maskingValue")) except ValueError: - bad_pixel_mask_value = np.float32("nan") - self._kernel_runner_init_args = {"bad_pixel_mask_value": bad_pixel_mask_value} + config["corrections.badPixels.maskingValue"] = "nan" + + if config.get("runAsStandaloneModule", default=False): + schema_override = Schema() + ( + OUTPUT_CHANNEL(schema_override) + .key("dataOutput") + .dataSchema(schemas.xtdf_output_schema(use_shmem_handle=False)) + .commit(), + ) + preview_utils.PreviewFriend.add_schema( + schema_override, + output_channels=["outputRaw", "outputCorrected", "outputGainMap"], + ) + self.updateSchema(schema_override) + + def aux(): + self._preview_friend = preview_utils.PreviewFriend( + self, + output_channels=["outputRaw", "outputCorrected", "outputGainMap"], + ) + self["availableScenes"] = self["availableScenes"] + [ + "preview:outputRaw", + "preview:outputAssembled", + "preview:outputGainMap", + ] + + self.registerInitialFunction(aux) def _load_constant_to_runner(self, constant, constant_data): self.kernel_runner.load_constant(constant, constant_data) - correction_step = { - LpdConstants.Offset: "offset", - LpdConstants.GainAmpMap: "gainAmp", - LpdConstants.RelativeGain: "relGain", - LpdConstants.FFMap: "flatfield", - LpdConstants.BadPixelsDark: "badPixels", - LpdConstants.BadPixelsFF: "badPixels", - }[constant] - correction_node = f"corrections.{correction_step}" - if not self.get(f"{correction_node}.available"): - self.set(f"{correction_node}.available", True) - self._update_correction_flags() - self.log_status_info(f"Done loading {constant.name} to GPU") def process_data( self, @@ -372,7 +411,6 @@ class LpdCorrection(BaseCorrection): train_id, image_data, cell_table, - do_generate_preview, ): pulse_table = np.ravel(data_hash.get("image.pulseId")) if self._frame_filter is not None: @@ -402,34 +440,47 @@ class LpdCorrection(BaseCorrection): output_order=self.unsafe_get("dataFormat.outputAxisOrder"), out=buffer_array, ) - if do_generate_preview: + with self.warning_context( + "processingState", WarningLampType.PREVIEW_SETTINGS + ) as warn: if self._correction_flag_enabled != self._correction_flag_preview: self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, preview_pulse, - ) = utils.pick_frame_index( + ), preview_warning = utils.pick_frame_index( self.unsafe_get("preview.selectionMode"), self.unsafe_get("preview.index"), cell_table, pulse_table, - warn_func=self.log_status_warn, ) + if preview_warning is not None: + warn(preview_warning) ( preview_raw, preview_corrected, preview_gain_map, - ) = self.kernel_runner.compute_previews( - preview_slice_index, - ) + ) = self.kernel_runner.compute_previews(preview_slice_index) + + if self.unsafe_get("runAsStandaloneModule"): + data_hash.set(self._image_data_path, buffer_array) + data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) + data_hash.set("calngShmemPaths", []) + else: + data_hash.set(self._image_data_path, buffer_handle) + data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) + data_hash.set("calngShmemPaths", [self._image_data_path]) - data_hash.set(self._image_data_path, buffer_handle) - data_hash.set(self._cell_table_path, cell_table[:, np.newaxis]) - data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) - data_hash.set("calngShmemPaths", [self._image_data_path]) self._write_output(data_hash, metadata) - if do_generate_preview: + + if self.unsafe_get("runAsStandaloneModule"): + self._preview_friend.maybe_write( + [preview_raw, preview_corrected, preview_gain_map] + ) + else: self._write_preview_outputs( ( ("preview.outputRaw", preview_raw), @@ -454,6 +505,4 @@ class LpdCorrection(BaseCorrection): update = self._prereconfigure_update_hash if update.has("corrections.badPixels.maskingValue"): - masking_value = np.float32(update["corrections.badPixels.maskingValue"]) - self._kernel_runner_init_args["bad_pixel_mask_value"] = masking_value - self.kernel_runner.bad_pixel_mask_value = masking_value + self.kernel_runner.bad_pixel_mask_value = self.bad_pixel_mask_value diff --git a/src/calng/corrections/__init__.py b/src/calng/corrections/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..405275c5e267553efcc9c9cbe32aa64f72201f2b --- /dev/null +++ b/src/calng/corrections/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa: F401 +from . import ( + AgipdCorrection, + DsscCorrection, + Gotthard2Correction, + JungfrauCorrection, + LpdCorrection, +) diff --git a/src/calng/geometries/__init__.py b/src/calng/geometries/__init__.py index b82477f10ce25946ac036ffbaee347e407cb5824..091660d02ce4110a0cf979269dabe1c171b5805d 100644 --- a/src/calng/geometries/__init__.py +++ b/src/calng/geometries/__init__.py @@ -1 +1,2 @@ +# flake8: noqa: F401 from . import Agipd1MGeometry, Dssc1MGeometry, Lpd1MGeometry, JungfrauGeometry diff --git a/src/calng/kernels/agipd_gpu.cu b/src/calng/kernels/agipd_gpu.cu index 20b08d43d26b5f37d25c8633adb60ed3179db648..1e8cf672614fe108970530cc46fae419049a01ff 100644 --- a/src/calng/kernels/agipd_gpu.cu +++ b/src/calng/kernels/agipd_gpu.cu @@ -27,14 +27,14 @@ extern "C" { {{output_data_dtype}}* output) { const size_t X = {{pixels_x}}; const size_t Y = {{pixels_y}}; - const size_t input_cells = {{data_memory_cells}}; + const size_t input_frames = {{frames}}; const size_t map_cells = {{constant_memory_cells}}; const size_t cell = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = blockIdx.y * blockDim.y + threadIdx.y; const size_t y = blockIdx.z * blockDim.z + threadIdx.z; - if (cell >= input_cells || y >= Y || x >= X) { + if (cell >= input_frames || y >= Y || x >= X) { return; } @@ -51,7 +51,7 @@ extern "C" { 1 * data_stride_raw_gain + y * data_stride_y + x * data_stride_x; - float corrected = (float)data[data_index]; + float res = (float)data[data_index]; const float raw_gain_val = (float)data[raw_gain_index]; const size_t output_stride_y = 1; @@ -111,38 +111,29 @@ extern "C" { x * gm_map_stride_x; if ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index]) { - corrected = bad_pixel_mask_value; + res = bad_pixel_mask_value; gain_map[output_index] = bad_pixel_mask_value; } else { if (corr_flags & OFFSET) { - corrected -= offset_map[gm_map_index]; + res -= offset_map[gm_map_index]; // TODO: optionally reassign gain stage for this pixel based on new value } // TODO: baseline shift if (corr_flags & REL_GAIN_PC) { - corrected *= rel_gain_pc_map[gm_map_index]; + res *= rel_gain_pc_map[gm_map_index]; if (gain == 1) { - corrected += md_additional_offset[map_index]; + res += md_additional_offset[map_index]; } } if (corr_flags & GAIN_XRAY) { - corrected = (corrected / rel_gain_xray_map[map_index]) * g_gain_value; + res = (res / rel_gain_xray_map[map_index]) * g_gain_value; } } - {% if output_data_dtype == "half" %} - output[output_index] = __float2half(corrected); - {% else %} - output[output_index] = ({{output_data_dtype}})corrected; - {% endif %} - } else { - // TODO: decide what to do when we cannot threshold - {% if output_data_dtype == "half" %} - output[data_index] = __float2half(corrected); - {% else %} - output[data_index] = ({{output_data_dtype}})corrected; - {% endif %} - - gain_map[data_index] = 255; } + {% if output_data_dtype == "half" %} + output[output_index] = __float2half(res); + {% else %} + output[output_index] = ({{output_data_dtype}})res; + {% endif %} } } diff --git a/src/calng/kernels/dssc_gpu.cu b/src/calng/kernels/dssc_gpu.cu index a35eed986a4483e0b84ca84e35d2cc3d56d11cb5..61cd77eacc9b81a5f3d6d7f0e32c5bcaf03f169b 100644 --- a/src/calng/kernels/dssc_gpu.cu +++ b/src/calng/kernels/dssc_gpu.cu @@ -18,14 +18,14 @@ extern "C" { {{output_data_dtype}}* output) { const size_t X = {{pixels_x}}; const size_t Y = {{pixels_y}}; - const size_t memory_cells = {{data_memory_cells}}; + const size_t input_frames = {{frames}}; const size_t map_memory_cells = {{constant_memory_cells}}; const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; const size_t y = blockIdx.y * blockDim.y + threadIdx.y; const size_t x = blockIdx.z * blockDim.z + threadIdx.z; - if (memory_cell >= memory_cells || y >= Y || x >= X) { + if (memory_cell >= input_frames || y >= Y || x >= X) { return; } @@ -34,29 +34,23 @@ extern "C" { const size_t data_stride_y = X * data_stride_x; const size_t data_stride_cell = Y * data_stride_y; const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x; - const float raw = (float)data[data_index]; + float res = (float)data[data_index]; const size_t map_stride_x = 1; const size_t map_stride_y = X * map_stride_x; const size_t map_stride_cell = Y * map_stride_y; const size_t map_cell = cell_table[memory_cell]; + if (map_cell < map_memory_cells) { const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x; - float corrected = raw; if (corr_flags & OFFSET) { - corrected -= offset_map[map_index]; + res -= offset_map[map_index]; } - {% if output_data_dtype == "half" %} - output[data_index] = __float2half(corrected); - {% else %} - output[data_index] = ({{output_data_dtype}})corrected; - {% endif %} - } else { - {% if output_data_dtype == "half" %} - output[data_index] = __float2half(raw); - {% else %} - output[data_index] = ({{output_data_dtype}})raw; - {% endif %} } + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(res); + {% else %} + output[data_index] = ({{output_data_dtype}})res; + {% endif %} } } diff --git a/src/calng/kernels/jungfrau_cpu.pyx b/src/calng/kernels/jungfrau_cpu.pyx index 02d92537c8d214bd30103bf29ee7acaba734db4f..9b6072d41fa3c588941e400959b38728dabd8643 100644 --- a/src/calng/kernels/jungfrau_cpu.pyx +++ b/src/calng/kernels/jungfrau_cpu.pyx @@ -6,8 +6,11 @@ cdef unsigned char NONE = 0 cdef unsigned char OFFSET = 1 cdef unsigned char REL_GAIN = 2 cdef unsigned char BPMASK = 4 +cdef unsigned char STRIXEL = 8 from cython.parallel import prange +from cython.view cimport contiguous +import numpy as np def correct_burst( @@ -21,20 +24,20 @@ def correct_burst( float badpixel_fill_value, float[:, :, :] output, ): - cdef int input_cell, map_cell, x, y + cdef int frame, map_cell, x, y cdef unsigned char gain cdef float corrected - for input_cell in prange(image_data.shape[0], nogil=True): - map_cell = cell_table[input_cell] + for frame in prange(image_data.shape[0], nogil=True): + map_cell = cell_table[frame] if map_cell >= offset_map.shape[0]: for y in range(image_data.shape[1]): for x in range(image_data.shape[2]): - output[input_cell, y, x] = <float>image_data[input_cell, y, x] + output[frame, y, x] = <float>image_data[frame, y, x] continue for y in range(image_data.shape[1]): for x in range(image_data.shape[2]): - corrected = image_data[input_cell, y, x] - gain = gain_stage[input_cell, y, x] + corrected = image_data[frame, y, x] + gain = gain_stage[frame, y, x] # legal values: 0, 1, or 3 if gain == 2: corrected = badpixel_fill_value @@ -49,16 +52,16 @@ def correct_burst( corrected = corrected - offset_map[map_cell, y, x, gain] if (flags & REL_GAIN): corrected = corrected / relgain_map[map_cell, y, x, gain] - output[input_cell, y, x] = corrected + output[frame, y, x] = corrected def correct_single( unsigned short[:, :, :] image_data, unsigned char[:, :, :] gain_stage, unsigned char flags, - float[:, :, :] offset_map, - float[:, :, :] relgain_map, - unsigned[:, :, :] badpixel_mask, + float[:, :, :, :] offset_map, + float[:, :, :, :] relgain_map, + unsigned[:, :, :, :] badpixel_mask, float badpixel_fill_value, float[:, :, :] output, ): @@ -77,11 +80,41 @@ def correct_single( if gain == 3: gain = 2 - if (flags & BPMASK) and badpixel_mask[y, x, gain] != 0: + if (flags & BPMASK) and badpixel_mask[0, y, x, gain] != 0: corrected = badpixel_fill_value else: if (flags & OFFSET): - corrected = corrected - offset_map[y, x, gain] + corrected = corrected - offset_map[0, y, x, gain] if (flags & REL_GAIN): - corrected = corrected / relgain_map[y, x, gain] + corrected = corrected / relgain_map[0, y, x, gain] output[0, y, x] = corrected + + +def strixel_transform( + float[:, :, ::contiguous] image_data, + float[:, :, ::contiguous] output +): + cdef int yin, xin, igap, ichip, xout, yout, frame + + for frame in range(image_data.shape[0]): + for yin in range(256) : + yout = int(yin / 3) + for xin in range(1024) : + ichip = <int>(xin / 256) + xout = (ichip * 774) + (xin % 256) * 3 + yin % 3 + # 774 is the chip period, 256*3+6 + output[frame, yout, xout] = image_data[frame, yin, xin] + # now the gap pixels... + for yin in range(256): + yout = <int>(yin / 6) * 2 + for igap in range(3) : + # first the left side of gap + xin = igap * 256 + 255 + xout = igap * 774 + 765 + yin % 6 + output[frame, yout, xout] = image_data[frame, yin, xin] + output[frame, yout+1, xout] = image_data[frame, yin, xin] + # then the right side is mirrored + xin = igap * 256 + 255 + 1 + xout = igap * 774 + 765 + 11 - yin % 6 + output[frame, yout, xout] = image_data[frame, yin, xin] + output[frame, yout+1, xout] = image_data[frame, yin, xin] diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu index d111c0b903e1d67aa85783f5a7914271a15c8d9c..8a7bac69a922bc59dc8a5ed6c857d783926793ea 100644 --- a/src/calng/kernels/jungfrau_gpu.cu +++ b/src/calng/kernels/jungfrau_gpu.cu @@ -14,21 +14,21 @@ extern "C" { {{output_data_dtype}}* output) { const size_t X = {{pixels_x}}; const size_t Y = {{pixels_y}}; - const size_t memory_cells = {{data_memory_cells}}; + const size_t input_frames = {{frames}}; const size_t map_memory_cells = {{constant_memory_cells}}; - const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; + const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x; const size_t y = blockIdx.y * blockDim.y + threadIdx.y; const size_t x = blockIdx.z * blockDim.z + threadIdx.z; - if (memory_cell >= memory_cells || y >= Y || x >= X) { + if (current_frame >= input_frames || y >= Y || x >= X) { return; } const size_t data_stride_x = 1; const size_t data_stride_y = X * data_stride_x; - const size_t data_stride_cell = Y * data_stride_y; - const size_t data_index = memory_cell * data_stride_cell + + const size_t data_stride_frame = Y * data_stride_y; + const size_t data_index = current_frame * data_stride_frame + y * data_stride_y + x * data_stride_x; float res = (float)data[data_index]; @@ -45,7 +45,7 @@ extern "C" { {% if burst_mode %} // burst mode: "cell 255" will get copied // TODO: consider masking "cell 255" - const size_t map_cell = cell_table[memory_cell]; + const size_t map_cell = cell_table[current_frame]; {% else %} // single cell: "cell 255" will get "corrected" const size_t map_cell = 0; @@ -83,4 +83,108 @@ extern "C" { output[data_index] = ({{output_data_dtype}})res; {% endif %} } + + __global__ void strixel_transform(const {{output_data_dtype}}* data, // shape: memory cell, y, x + {{output_data_dtype}}* output) { + const size_t Xin = {{pixels_x}}; + const size_t Yin = {{pixels_y}}; + const size_t Xout = 3090; + const size_t Yout = 86; + const size_t input_frames = {{frames}}; + + const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x; + // following naming from cython version + const size_t yin = blockIdx.y * blockDim.y + threadIdx.y; + size_t xin = blockIdx.z * blockDim.z + threadIdx.z; + + // note: hardcoded limits here as only half of y-axis is used + if (current_frame >= input_frames || yin >= 256 || xin >= 1024) { + return; + } + + // avoid race conditions by only writing these once + const size_t overwritten_columns[18] = { + 765, + 766, + 767, + 774, + 775, + 776, + 1539, + 1540, + 1541, + 1548, + 1549, + 1550, + 2313, + 2314, + 2315, + 2322, + 2323, + 2324 + }; + + const size_t data_stride_x = 1; + const size_t data_stride_y = Xin * data_stride_x; + const size_t data_stride_frame = Yin * data_stride_y; + + const size_t output_stride_x = 1; + const size_t output_stride_y = Xout * output_stride_x; + const size_t output_stride_frame = Yout * output_stride_y; + + const size_t ichip = xin / 256; + size_t xout = (ichip * 774) + (xin % 256) * 3 + (yin % 3); + size_t yout = yin / 3; + bool will_be_overwritten = false; + size_t out_index, data_index; + for (int i=0; i<18; ++i) { + if (xout == overwritten_columns[i]) { + will_be_overwritten = true; + } + } + if (!will_be_overwritten) { + out_index = current_frame * output_stride_frame + + yout * output_stride_y + + xout * output_stride_x; + data_index = current_frame * data_stride_frame + + yin * data_stride_y + + xin * data_stride_x; + output[out_index] = data[data_index]; + } + if (xin < 3) { + // reuse for the gap pixel case (see cython version) + const size_t igap = xin; + yout = (yin / 6) * 2; + + // left side + xin = igap * 256 + 255; + xout = igap * 774 + 765 + yin % 6; + data_index = current_frame * data_stride_frame + + yin * data_stride_y + + xin * data_stride_x; + out_index = current_frame * output_stride_frame + + yout * output_stride_y + + xout * output_stride_x; + output[out_index] = data[data_index]; + out_index = current_frame * output_stride_frame + + (yout + 1) * output_stride_y + + xout * output_stride_x; + output[out_index] = data[data_index]; + + // mirror right side + xin = igap * 256 + 255 + 1; + xout = igap * 774 + 765 + 11 - yin % 6; + data_index = current_frame * data_stride_frame + + yin * data_stride_y + + xin * data_stride_x; + out_index = current_frame * output_stride_frame + + yout * output_stride_y + + xout * output_stride_x; + output[out_index] = data[data_index]; + out_index = current_frame * output_stride_frame + + (yout + 1) * output_stride_y + + xout * output_stride_x; + output[out_index] = data[data_index]; + } + } } diff --git a/src/calng/kernels/lpd_gpu.cu b/src/calng/kernels/lpd_gpu.cu index 4fb24fb3e0d4658b3d111efb3b8aae28af035ac1..6c84e7c3f9166cc8296622cadc22f406dbb5f223 100644 --- a/src/calng/kernels/lpd_gpu.cu +++ b/src/calng/kernels/lpd_gpu.cu @@ -16,14 +16,14 @@ extern "C" { {{output_data_dtype}}* output) { const size_t X = {{pixels_x}}; const size_t Y = {{pixels_y}}; - const size_t memory_cells = {{data_memory_cells}}; + const size_t input_frames = {{frames}}; const size_t map_memory_cells = {{constant_memory_cells}}; const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; const size_t y = blockIdx.y * blockDim.y + threadIdx.y; const size_t x = blockIdx.z * blockDim.z + threadIdx.z; - if (memory_cell >= memory_cells || y >= Y || x >= X) { + if (memory_cell >= input_frames || y >= Y || x >= X) { return; } diff --git a/src/calng/preview_utils.py b/src/calng/preview_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a4c0139ba1dcfb84013028a72e427485ab7e30 --- /dev/null +++ b/src/calng/preview_utils.py @@ -0,0 +1,187 @@ +from karabo.bound import ( + BOOL_ELEMENT, + DOUBLE_ELEMENT, + FLOAT_ELEMENT, + NODE_ELEMENT, + OUTPUT_CHANNEL, + STRING_ELEMENT, + UINT32_ELEMENT, + ChannelMetaData, + Dims, + Encoding, + Hash, + ImageData, + Unit, +) + +import numpy as np + +from . import schemas, utils + + +class PreviewFriend: + @staticmethod + def add_schema(schema, node_path="preview", output_channels=None): + if output_channels is None: + output_channels = ["output"] + ( + NODE_ELEMENT(schema) + .key(node_path) + .displayedName("Preview") + .description( + "Output specifically intended for preview in Karabo GUI. Includes " + "some options for throttling and adjustments of the output data." + ) + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_path}.flipSS") + .displayedName("Flip SS") + .description("Flip image data along slow scan axis.") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_path}.flipFS") + .displayedName("Flip FS") + .description("Flip image data along fast scan axis.") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + UINT32_ELEMENT(schema) + .key(f"{node_path}.downsamplingFactor") + .displayedName("Factor") + .description( + "If greater than 1, the preview image will be downsampled by this " + "factor before sending. This is mostly to save bandwidth in case GUI " + "updates start lagging." + ) + .assignmentOptional() + .defaultValue(1) + .options("1,2,4,8") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key(f"{node_path}.downsamplingFunction") + .displayedName("Function") + .description("Reduction function used during downsampling.") + .assignmentOptional() + .defaultValue("nanmax") + .options("nanmax,nanmean,nanmin,nanmedian") + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(schema) + .key(f"{node_path}.replaceNanWith") + .displayedName("NaN replacement") + .description( + "Displaying images in KaraboGUI seems to not go well when there are " + "NaN values in data. And there will be with bad pixel masking or just " + "geometry space between modules. NaN values get replaced with this " + "value to get around this; choose a value which clearly stands out " + "from the image data you want to see." + ) + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{node_path}.maxRate") + .displayedName("Max rate") + .description( + "Preview output is throttled to (at most) this speed. Data arriving " + "too quickly after last send is silently dropped." + ) + .unit(Unit.HERTZ) + .assignmentOptional() + .defaultValue(2) + .reconfigurable() + .commit(), + ) + for channel in output_channels: + ( + OUTPUT_CHANNEL(schema) + .key(f"{node_path}.{channel}") + .dataSchema(schemas.preview_schema(wrap_image_in_imagedata=True)) + .description("See description of parent node, 'preview'.") + .commit(), + ) + + def __init__(self, device, node_name="preview", output_channels=None): + if output_channels is None: + output_channels = ["output"] + self.output_channels = output_channels + self.device = device + self.node_name = node_name + self.outputs = [ + self.device.signalSlotable.getOutputChannel(f"{self.node_name}.{channel}") + for channel in self.output_channels + ] + self.reconfigure(device._parameters) + + def maybe_write(self, datas, inplace=True): + """If doing so would not exceed maxRate, apply preview settings to data and + write preview hash to output channel. Returns written hash or None in case + writing was skipped.""" + if self.throttler.test_and_set(): + timestamp = self.device.getActualTimestamp() + dev_id = self.device.getInstanceId() + for data, output, channel_name in zip( + datas, self.outputs, self.output_channels + ): + if self.downsampling_factor > 1: + data = utils.downsample_2d( + data, + self.downsampling_factor, + reduction_fun=self.downsampling_function, + ) + elif not inplace: + data = data.copy() + data = np.nan_to_num( + data, + copy=False, + nan=self.nan_replacement, + ) + if self.flip_ss: + data = np.flip(data, 0) + if self.flip_fs: + data = np.flip(data, 1) + output_hash = Hash( + "image.data", + ImageData( + data, + Dims(*data.shape), + Encoding.GRAY, + bitsPerPixel=32, + ), + ) + output.write( + output_hash, + ChannelMetaData(f"{dev_id}:{channel_name}", timestamp), + copyAllData=False, + ) + output.update() + + def reconfigure(self, conf): + if conf.has(f"{self.node_name}.maxRate"): + self.throttler = utils.SkippingThrottler( + 1 / conf[f"{self.node_name}.maxRate"] + ) + if conf.has(f"{self.node_name}.downsamplingFunction"): + self.downsampling_function = getattr( + np, conf[f"{self.node_name}.downsamplingFunction"] + ) + if conf.has(f"{self.node_name}.downsamplingFactor"): + self.downsampling_factor = conf[f"{self.node_name}.downsamplingFactor"] + if conf.has(f"{self.node_name}.flipSS"): + self.flip_ss = conf[f"{self.node_name}.flipSS"] + if conf.has(f"{self.node_name}.flipFS"): + self.flip_fs = conf[f"{self.node_name}.flipFS"] + if conf.has(f"{self.node_name}.replaceNanWith"): + self.nan_replacement = conf[f"{self.node_name}.replaceNanWith"] diff --git a/src/calng/scenes.py b/src/calng/scenes.py index 6fa0926608691500b0d76718c7c4d2f32507f4a2..08dd04a6c6dd5b0e8d30fbd499df1fb0869ba399 100644 --- a/src/calng/scenes.py +++ b/src/calng/scenes.py @@ -4,7 +4,6 @@ import karabo.native import karathon from karabo.common.scenemodel.api import ( CheckBoxModel, - ColorBoolModel, ComboBoxModel, DeviceSceneLinkModel, DetectorGraphModel, @@ -16,6 +15,7 @@ from karabo.common.scenemodel.api import ( EvaluatorModel, IntLineEditModel, LabelModel, + LampModel, LineEditModel, LineModel, RectangleModel, @@ -23,9 +23,13 @@ from karabo.common.scenemodel.api import ( SceneTargetWindow, TableElementModel, TrendGraphModel, + UnknownWidgetDataModel, + VectorXYGraphModel, WebCamGraphModel, + WebLinkModel, write_scene, ) +import natsort # section: common setup @@ -34,10 +38,14 @@ from karabo.common.scenemodel.api import ( BASE_INC = 25 NARROW_INC = 20 PADDING = 5 -RECONFIGURABLE = 4 # TODO: look up proper enum -NODE_TYPE_NODE = 1 -_type_to_display_model = {"BOOL": CheckBoxModel} + +def DisplayRoundedFloat(*args, decimals=2, **kwargs): + # note: naive subclass breaks as registry looks for writer based on exact class + return EvaluatorModel(*args, expression=f"f'{{x:.{decimals}f}}'", **kwargs) + + +_type_to_display_model = {"BOOL": CheckBoxModel, "FLOAT": DisplayRoundedFloat} _type_to_line_editable = { "BOOL": (CheckBoxModel, {"klass": "EditableCheckBox"}), "DOUBLE": (DoubleLineEditModel, {}), @@ -125,11 +133,6 @@ def boxed(component_class): # section: useful layout and utility classes -def DisplayRoundedFloat(*args, decimals=2, **kwargs): - # note: naive subclass breaks as registry looks for writer based on exact class - return EvaluatorModel(*args, expression=f"f'{{x:.{decimals}f}}'", **kwargs) - - class Space: def __init__(self, width, height): self.width = width @@ -156,6 +159,23 @@ class Hline: ] +class Vline: + def __init__(self, height): + self.width = 0 + self.height = height + + def render(self, x, y): + return [ + LineModel( + stroke="#000000", + x1=x, + x2=x, + y1=y, + y2=y + self.height, + ) + ] + + def dummy_wrap(model_class): class Wrapper: def __init__(self, *args, **kwargs): @@ -245,18 +265,21 @@ class VerticalLayout: ) -class MaybeEditableRow(HorizontalLayout): +class DisplayAndEditableRow(HorizontalLayout): def __init__( self, device_id, schema_hash, key_path, - label_width=7 * NARROW_INC, - display_width=5 * NARROW_INC, - edit_width=5 * NARROW_INC, - height=NARROW_INC, + label_width=7, + display_width=5, + edit_width=5, + height=None, + size_scale=BASE_INC, ): super().__init__(padding=0) + if height is None: + height = size_scale key_attr = schema_hash.getAttributes(key_path) label_text = ( key_attr["displayedName"] @@ -267,26 +290,30 @@ class MaybeEditableRow(HorizontalLayout): print(f"Key {key_path} on {device_id} had no valueType") return value_type = key_attr["valueType"] - self.children.extend( - [ - LabelModel( - text=label_text, - width=label_width, - height=height, - ), + + self.children.append( + LabelModel( + text=label_text, + width=label_width * size_scale, + height=height, + ) + ) + + if self.include_display(key_attr): + self.children.append( _type_to_display_model.get(value_type, DisplayLabelModel)( keys=[f"{device_id}.{key_path}"], - width=display_width, + width=display_width * size_scale, height=height, - ), - ] - ) - if key_attr["accessMode"] == RECONFIGURABLE: + ) + ) + + if self.include_editable(key_attr): if "options" in key_attr: self.children.append( ComboBoxModel( keys=[f"{device_id}.{key_path}"], - width=edit_width, + width=edit_width * size_scale, height=height, klass="EditableComboBox", ) @@ -296,7 +323,7 @@ class MaybeEditableRow(HorizontalLayout): self.children.append( line_editable_class( keys=[f"{device_id}.{key_path}"], - width=edit_width, + width=edit_width * size_scale, height=height, **extra_args, ) @@ -305,11 +332,89 @@ class MaybeEditableRow(HorizontalLayout): self.children.append( LabelModel( text=f"Not implemented: editing {value_type} ({key_path})", - width=edit_width, + width=edit_width * size_scale, height=height, ) ) + def include_display(self, key_attr): + return True + + def include_editable(self, key_attr): + return True + + +class DisplayAndMaybeEditableRow(DisplayAndEditableRow): + def include_editable(self, key_attr): + return key_attr["accessMode"] == karabo.native.AccessMode.RECONFIGURABLE.value + + +class MaybeDisplayMaybeEditableRow(DisplayAndEditableRow): + # overriding init to unify display_width and edit_width + def __init__( + self, + device_id, + schema_hash, + key_path, + label_width=7, + display_or_edit_width=5, + height=None, + size_scale=BASE_INC, + ): + super().__init__( + device_id, + schema_hash, + key_path, + label_width=label_width, + display_width=display_or_edit_width, + edit_width=display_or_edit_width, + height=height, + size_scale=size_scale, + ) + + def include_display(self, key_attr): + return not self.include_editable(key_attr) + + def include_editable(self, key_attr): + return key_attr["accessMode"] == karabo.native.AccessMode.RECONFIGURABLE.value + + +class EditableRow(DisplayAndEditableRow): + # overriding init to get label_width, edit_width without label_width + def __init__( + self, + device_id, + schema_hash, + key_path, + label_width=7, + edit_width=5, + height=None, + size_scale=BASE_INC, + ): + super().__init__( + device_id, + schema_hash, + key_path, + label_width=label_width, + edit_width=edit_width, + height=height, + size_scale=size_scale, + ) + + def include_display(self, key_attr): + return False + + def include_editable(self, key_attr): + return True + + +class DisplayRow(DisplayAndEditableRow): + def include_display(self, key_attr): + return True + + def include_editable(self, key_attr): + return False + # section: specific handcrafted components for device classes @@ -331,10 +436,10 @@ class FoundConstantsColumn(VerticalLayout): width=6 * NARROW_INC, height=NARROW_INC, ), - ColorBoolModel( + LampModel( width=NARROW_INC, height=NARROW_INC, - keys=[f"{device_id}.{prefix}.{constant_name}.found"], + keys=[f"{device_id}.{prefix}.{constant_name}.state"], ), DisplayLabelModel( keys=[f"{device_id}.{prefix}.{constant_name}.beginValidityAt"], @@ -366,8 +471,8 @@ class ConstantLoadedAmpeln(HorizontalLayout): super().__init__(padding=0) self.children.extend( [ - ColorBoolModel( - keys=[f"{device_id}.{prefix}.{key}.found"], + LampModel( + keys=[f"{device_id}.{prefix}.{key}.state"], width=BASE_INC, height=BASE_INC, ) @@ -402,36 +507,41 @@ class ManagerDeviceStatus(VerticalLayout): width=7 * BASE_INC, height=BASE_INC, ) - apply_button = DisplayCommandModel( - keys=[f"{device_id}.applyManagedValues"], - width=7 * BASE_INC, - height=BASE_INC, - ) status_log = DisplayTextLogModel( keys=[f"{device_id}.status"], width=14 * BASE_INC, height=14 * BASE_INC, ) + managed_properties_link = DeviceSceneLinkModel( + text="All managed properties", + keys=[f"{device_id}.availableScenes"], + target="browse_schema", + target_window=SceneTargetWindow.Dialog, + width=7 * BASE_INC, + height=BASE_INC, + ) + docs_link = WebLinkModel( + text="Documentation", + width=7 * BASE_INC, + height=BASE_INC, + target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#calibration-manager", + ) self.children.extend( [ name, HorizontalLayout( state, - restart_button, + docs_link, padding=0, ), HorizontalLayout( + restart_button, instantiate_button, - apply_button, padding=0, ), - DeviceSceneLinkModel( - text="All managed properties", - keys=[f"{device_id}.availableScenes"], - target="browse_schema", - target_window=SceneTargetWindow.Dialog, - width=7 * BASE_INC, - height=BASE_INC, + HorizontalLayout( + managed_properties_link, + padding=0, ), status_log, ] @@ -441,7 +551,7 @@ class ManagerDeviceStatus(VerticalLayout): @titled("Device status", width=6 * NARROW_INC) @boxed class CorrectionDeviceStatus(VerticalLayout): - def __init__(self, device_id): + def __init__(self, device_id, schema_hash): super().__init__(padding=0) name = DisplayLabelModel( keys=[f"{device_id}.deviceId"], @@ -487,9 +597,31 @@ class CorrectionDeviceStatus(VerticalLayout): processing_time, padding=0, ), - status_log, ] ) + self.children.append( + VerticalLayout( + children=[ + HorizontalLayout( + LampModel( + keys=[f"{device_id}.{warning_lamp}"], + width=BASE_INC, + height=BASE_INC, + ), + LabelModel( + text=warning_lamp, + width=8 * BASE_INC, + height=BASE_INC, + ), + ) + for warning_lamp in schema_hash.getAttribute( + "warningLamps", "defaultValue" + ) + ], + padding=0, + ) + ) + self.children.append(status_log) class CompactCorrectionDeviceOverview(HorizontalLayout): @@ -511,6 +643,18 @@ class CompactCorrectionDeviceOverview(HorizontalLayout): width=5 * BASE_INC, height=BASE_INC, ), + ] + ) + for warning_lamp in schema_hash.getAttribute("warningLamps", "defaultValue"): + self.children.append( + LampModel( + keys=[f"{device_id}.{warning_lamp}"], + width=BASE_INC, + height=BASE_INC, + ) + ) + self.children.extend( + [ DisplayRoundedFloat( keys=[f"{device_id}.performance.rate"], width=4 * BASE_INC, @@ -556,7 +700,7 @@ class CompactDeviceLinkList(VerticalLayout): DeviceSceneLinkModel( text=device_id.split("/")[-1], keys=[f"{device_id}.availableScenes"], - target="overview", + target="", target_window=SceneTargetWindow.Dialog, width=7 * BASE_INC, height=BASE_INC, @@ -797,37 +941,13 @@ class GeometryPreview(VerticalLayout): @titled("Geometry from file") @boxed class GeometryFromFileSettings(VerticalLayout): - def __init__(self, device_id): + def __init__(self, device_id, schema_hash): super().__init__(padding=0) - self.children.append( - LabelModel( - text="File path:", - width=4 * BASE_INC, - height=BASE_INC, - ) - ) - self.children.append( - LineEditModel( - keys=[f"{device_id}.geometryFile.filePath"], - klass="EditableLineEdit", - width=8 * BASE_INC, - height=BASE_INC, - ) - ) - self.children.append( - LabelModel( - text="File type:", - width=4 * BASE_INC, - height=BASE_INC, - ) - ) - self.children.append( - ComboBoxModel( - keys=[f"{device_id}.geometryFile.fileType"], - klass="EditableComboBox", - width=8 * BASE_INC, - height=BASE_INC, - ) + self.children.extend( + [ + EditableRow(device_id, schema_hash, "geometryFile.filePath", 4, 8), + EditableRow(device_id, schema_hash, "geometryFile.fileType", 4, 8), + ] ) self.children.append( HorizontalLayout( @@ -845,17 +965,7 @@ class GeometryFromFileSettings(VerticalLayout): ) ) self.children.append( - HorizontalLayout( - LabelModel( - text="Update manual settings", width=6 * BASE_INC, height=BASE_INC - ), - CheckBoxModel( - keys=[f"{device_id}.geometryFile.updateManualOnLoad"], - klass="EditableCheckBox", - width=2 * BASE_INC, - height=BASE_INC, - ), - ) + EditableRow(device_id, schema_hash, "geometryFile.updateManualOnLoad", 6, 2) ) self.children.append( DisplayCommandModel( @@ -866,6 +976,161 @@ class GeometryFromFileSettings(VerticalLayout): ) +@titled("Stats") +@boxed +class StatsBox(HorizontalLayout): + def __init__(self, device_id, schema_hash): + super().__init__() + self.children.extend( + [ + DisplayRow(device_id, schema_hash, "rate", 5, 3), + DisplayRow(device_id, schema_hash, "numPixelsIncluded", 5, 3), + ] + ) + + +@titled("ROI selection") +@boxed +class RoiSelection(VerticalLayout): + def __init__(self, device_id): + super().__init__() + self.children.append( + UnknownWidgetDataModel( + keys=[f"{device_id}.output.schema.roiImage"], + klass="RectRoiGraph", + height=25 * BASE_INC, + width=30 * BASE_INC, + ) + ) + + +@titled("ROI preview") +@boxed +class RoiBox(VerticalLayout): + def __init__(self, device_id): + super().__init__() + self.children.append( + DetectorGraphModel( + keys=[f"{device_id}.output.schema.zoomImage"], + height=17 * BASE_INC, + width=20 * BASE_INC, + ), + ) + + +@titled("Preview settings") +@boxed +class PreviewSettings(HorizontalLayout): + def __init__(self, device_id, schema_hash, node_name="preview", extras=None): + super().__init__() + self.children.extend( + [ + VerticalLayout( + EditableRow( + device_id, + schema_hash, + f"{node_name}.replaceNanWith", + 6, + 4, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.maxRate", + 6, + 4, + ), + HorizontalLayout( + EditableRow( + device_id, + schema_hash, + f"{node_name}.flipSS", + 3, + 2, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.flipFS", + 3, + 2, + ), + padding=0, + ), + ), + Vline(height=3 * BASE_INC), + VerticalLayout( + LabelModel( + text="Image downsampling", + width=8 * BASE_INC, + height=BASE_INC, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.downsamplingFactor", + 5, + 5, + ), + EditableRow( + device_id, + schema_hash, + f"{node_name}.downsamplingFunction", + 5, + 5, + ), + ), + ] + ) + if extras is not None: + self.children.append(Vline(height=3 * BASE_INC)) + self.children.extend(extras) + + +@titled("Histogram settings") +@boxed +class HistogramSettings(HorizontalLayout): + def __init__(self, device_id, schema_hash): + super().__init__() + self.children.extend( + [ + VerticalLayout( + LabelModel(text="Bins", width=3 * BASE_INC, height=BASE_INC), + DisplayCommandModel( + keys=[f"{device_id}.histogram.resetBins"], + width=4 * BASE_INC, + height=BASE_INC, + ), + EditableRow( + device_id, schema_hash, "histogram.resetBinsOnRoiChange", 5, 1 + ), + EditableRow(device_id, schema_hash, "histogram.numBins", 3, 3), + padding=0, + ), + VerticalLayout( + LabelModel(text="Range", width=4 * BASE_INC, height=BASE_INC), + EditableRow(device_id, schema_hash, "histogram.rangeMin", 4, 3), + EditableRow(device_id, schema_hash, "histogram.rangeMax", 4, 3), + EditableRow( + device_id, + schema_hash, + "histogram.automaticallyExpandRange", + 4, + 3, + ), + padding=0, + ), + VerticalLayout( + LabelModel(text="Averaging", width=4 * BASE_INC, height=BASE_INC), + EditableRow( + device_id, schema_hash, "histogram.rollingWindowSize", 4, 3 + ), + padding=0, + ), + ] + ) + + # section: generating actual scenes @@ -894,13 +1159,12 @@ def scene_generator(fun): @scene_generator -def correction_device_overview(device_id, schema): +def correction_device_overview(device_id, schema, direct_preview=False): schema_hash = schema_to_hash(schema) - - return HorizontalLayout( - CorrectionDeviceStatus(device_id), + main_overview = HorizontalLayout( + CorrectionDeviceStatus(device_id, schema_hash), VerticalLayout( - recursive_maybe_editable( + recursive_editable( device_id, schema_hash, "constantParameters", @@ -912,7 +1176,7 @@ def correction_device_overview(device_id, schema): ), ), FoundConstantsColumn(device_id, schema_hash), - recursive_maybe_editable( + recursive_editable( device_id, schema_hash, "corrections", @@ -920,6 +1184,67 @@ def correction_device_overview(device_id, schema): ), ) + if direct_preview: + return VerticalLayout( + main_overview, + LabelModel( + text="Preview (corrected):", + width=20 * BASE_INC, + height=BASE_INC, + ), + DetectorGraphModel( + keys=[f"{device_id}.preview.outputCorrected.schema.image.data"], + height=20 * BASE_INC, + width=60 * BASE_INC, + ), + *( + DeviceSceneLinkModel( + text=f"Preview: {channel}", + keys=[f"{device_id}.availableScenes"], + target=f"preview:{channel}", + target_window=SceneTargetWindow.Dialog, + width=16 * BASE_INC, + height=BASE_INC, + ) + for channel in schema_hash.get("preview").getKeys() + if schema_hash.hasAttribute(f"preview.{channel}", "classId") + and schema_hash.getAttribute(f"preview.{channel}", "classId") + == "OutputChannel" + ), + ) + else: + return main_overview + + +@scene_generator +def correction_device_preview(device_id, schema, preview_channel): + schema_hash = schema_to_hash(schema) + return VerticalLayout( + LabelModel( + text=f"Preview: {preview_channel}", + width=10 * BASE_INC, + height=BASE_INC, + ), + PreviewSettings( + device_id, + schema_hash, + extras=[ + VerticalLayout( + EditableRow(device_id, schema_hash, "preview.index", 8, 4), + EditableRow(device_id, schema_hash, "preview.selectionMode", 8, 4), + ) + ], + ), + titled("Preview image")(boxed(dummy_wrap(DetectorGraphModel)))( + keys=[f"{device_id}.preview.{preview_channel}.schema.image.data"], + colormap="viridis", + width=60 * BASE_INC, + height=30 * BASE_INC, + x=PADDING, + y=PADDING, + ), + ) + @scene_generator def correction_device_constant_overrides(device_id, schema, prefix="foundConstants"): @@ -958,8 +1283,8 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan VerticalLayout( VerticalLayout( HorizontalLayout( - ColorBoolModel( - keys=[f"{device_id}.{prefix}.{constant}.found"], + LampModel( + keys=[f"{device_id}.{prefix}.{constant}.state"], width=BASE_INC, height=BASE_INC, ), @@ -977,6 +1302,11 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan ), padding=0, ), + DisplayCommandModel( + keys=[f"{device_id}.{prefix}.{constant}.loadMostRecent"], + width=8 * BASE_INC, + height=BASE_INC, + ), VerticalLayout( LineEditModel( klass="EditableLineEdit", @@ -986,7 +1316,7 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan ), DisplayCommandModel( keys=[ - f"{device_id}.{prefix}.{constant}.overrideConstantVersion" + f"{device_id}.{prefix}.{constant}.overrideConstantFromVersion" ], width=8 * BASE_INC, height=BASE_INC, @@ -1021,6 +1351,30 @@ def correction_device_constant_overrides(device_id, schema, prefix="foundConstan ) +@scene_generator +def histogram_overview(device_id, schema): + schema_hash = schema_to_hash(schema) + return VerticalLayout( + HorizontalLayout( + RoiSelection(device_id), + VerticalLayout( + RoiBox(device_id), + StatsBox(device_id, schema_hash), + HistogramSettings(device_id, schema_hash), + ), + ), + VectorXYGraphModel( + keys=[ + f"{device_id}.output.schema.manualHistogram.x", + f"{device_id}.output.schema.manualHistogram.y", + f"{device_id}.output.schema.manualHistogram.yMean", + ], + height=16 * BASE_INC, + width=52 * BASE_INC, + ), + ) + + @scene_generator def manager_device_overview( manager_device_id, @@ -1037,7 +1391,7 @@ def manager_device_overview( HorizontalLayout( ManagerDeviceStatus(manager_device_id), VerticalLayout( - recursive_maybe_editable( + recursive_editable( manager_device_id, mds_hash, "managedKeys.constantParameters", @@ -1047,14 +1401,45 @@ def manager_device_overview( width=10 * BASE_INC, height=BASE_INC, ), - recursive_maybe_editable( + recursive_editable( manager_device_id, mds_hash, "managedKeys.preview", max_depth=2, ), + titled("Data throttling")(boxed(VerticalLayout))( + children=[ + EditableRow( + manager_device_id, + mds_hash, + "managedKeys.daqTrainStride", + 7, + 4, + ), + LabelModel( + text="Frame filter", + width=11 * BASE_INC, + height=BASE_INC, + ), + EditableRow( + manager_device_id, + mds_hash, + "managedKeys.frameFilter.type", + 7, + 4, + ), + EditableRow( + manager_device_id, + mds_hash, + "managedKeys.frameFilter.spec", + 7, + 4, + ), + ], + padding=0, + ), ), - recursive_maybe_editable( + recursive_editable( manager_device_id, mds_hash, "managedKeys.corrections", @@ -1066,37 +1451,49 @@ def manager_device_overview( children=[Space(width=BASE_INC, height=BASE_INC)] * 2 + [ CompactDaqOverview(device_id) - for device_id in sorted(daq_device_ids) + for device_id in natsort.natsorted(daq_device_ids) ], padding=0, ), titled("Correction devices", width=8 * NARROW_INC)(boxed(VerticalLayout))( children=[ - DeviceSceneLinkModel( - text="Performance dashboard", - keys=[f"{manager_device_id}.availableScenes"], - target="correction_performance_overview", - target_window=SceneTargetWindow.Dialog, - width=10 * BASE_INC, - height=BASE_INC, - ), - DeviceSceneLinkModel( - text="Correction constant overview", - keys=[f"{manager_device_id}.availableScenes"], - target="correction_constant_overview", - target_window=SceneTargetWindow.Dialog, - width=10 * BASE_INC, - height=BASE_INC, - ), + HorizontalLayout( + VerticalLayout( + DeviceSceneLinkModel( + text="Performance dashboard", + keys=[f"{manager_device_id}.availableScenes"], + target="correction_performance_overview", + target_window=SceneTargetWindow.Dialog, + width=10 * BASE_INC, + height=BASE_INC, + ), + DeviceSceneLinkModel( + text="Correction constant overview", + keys=[f"{manager_device_id}.availableScenes"], + target="correction_constant_overview", + target_window=SceneTargetWindow.Dialog, + width=10 * BASE_INC, + height=BASE_INC, + ), + padding=0, + ), + WebLinkModel( + text="Documentation", + width=6 * BASE_INC, + height=BASE_INC, + target="https://rtd.xfel.eu/docs/calng/en/latest/devices/#correction-devices", + ), + padding=0, + ) ] + [ CompactCorrectionDeviceOverview(device_id, cds_hash) - for device_id in sorted(correction_device_ids) + for device_id in natsort.natsorted(correction_device_ids) ], padding=0, ), titled("Other devices managed")(CompactDeviceLinkList)( - sorted( + natsort.natsorted( set(domain_device_ids) - set(correction_device_ids) - {manager_device_id} @@ -1108,7 +1505,7 @@ def manager_device_overview( @scene_generator def correction_device_performance_dashboard(correction_device_ids): - correction_device_ids = sorted(correction_device_ids) + correction_device_ids = natsort.natsorted(correction_device_ids) return HorizontalLayout( titled("Correction device links")(CompactDeviceLinkList)(correction_device_ids), VerticalLayout( @@ -1155,7 +1552,7 @@ def correction_constant_dashboard( ): correction_device_schema = schema_to_hash(correction_device_schema) constant_names = list(correction_device_schema.get(prefix).getKeys()) - correction_device_ids = sorted(correction_device_ids) + correction_device_ids = natsort.natsorted(correction_device_ids) return VerticalLayout( HorizontalLayout( Space(width=6 * BASE_INC, height=BASE_INC), @@ -1190,72 +1587,12 @@ def correction_constant_dashboard( @scene_generator -def detector_assembler_overview(device_id): +def detector_assembler_overview(device_id, schema): + schema_hash = schema_to_hash(schema) return VerticalLayout( HorizontalLayout( AssemblerDeviceStatus(device_id), - titled("Preview settings")(boxed(VerticalLayout))( - HorizontalLayout( - LabelModel( - text="Display NaN values as:", - width=7 * BASE_INC, - height=BASE_INC, - ), - DoubleLineEditModel( - keys=[f"{device_id}.preview.replaceNanWith"], - width=7 * BASE_INC, - height=BASE_INC, - ), - padding=0, - ), - HorizontalLayout( - LabelModel( - text="Max preview rate", - width=7 * BASE_INC, - height=BASE_INC, - ), - DoubleLineEditModel( - keys=[f"{device_id}.preview.maxRate"], - width=7 * BASE_INC, - height=BASE_INC, - ), - padding=0, - ), - Hline(width=14 * BASE_INC), - LabelModel( - text="Image downsampling", - width=14 * BASE_INC, - height=BASE_INC, - ), - HorizontalLayout( - LabelModel( - text="Factor", - width=7 * BASE_INC, - height=BASE_INC, - ), - ComboBoxModel( - keys=[f"{device_id}.preview.downsamplingFactor"], - width=7 * BASE_INC, - height=BASE_INC, - klass="EditableComboBox", - ), - padding=0, - ), - HorizontalLayout( - LabelModel( - text="Function", - width=7 * BASE_INC, - height=BASE_INC, - ), - ComboBoxModel( - keys=[f"{device_id}.preview.downsamplingFunction"], - width=7 * BASE_INC, - height=BASE_INC, - klass="EditableComboBox", - ), - padding=0, - ), - ), + PreviewSettings(device_id, schema_hash), ), titled("Preview image")(boxed(dummy_wrap(DetectorGraphModel)))( keys=[f"{device_id}.preview.output.schema.image.data"], @@ -1269,11 +1606,12 @@ def detector_assembler_overview(device_id): @scene_generator -def quadrant_geometry_overview(device_id): +def quadrant_geometry_overview(device_id, schema): + schema_hash = schema_to_hash(schema) return VerticalLayout( HorizontalLayout( ManualQuadrantGeometrySettings(device_id), - GeometryFromFileSettings(device_id), + GeometryFromFileSettings(device_id, schema_hash), TweakCurrentGeometry(device_id), ), GeometryPreview(device_id), @@ -1281,22 +1619,72 @@ def quadrant_geometry_overview(device_id): @scene_generator -def modules_geometry_overview(device_id): +def modules_geometry_overview(device_id, schema): + schema_hash = schema_to_hash(schema) return VerticalLayout( HorizontalLayout( ManualModulesGeometrySettings(device_id), - GeometryFromFileSettings(device_id), + GeometryFromFileSettings(device_id, schema_hash), TweakCurrentGeometry(device_id), ), GeometryPreview(device_id), ) +@scene_generator +def condition_checker_overview(device_id, schema): + schema_hash = schema_to_hash(schema) + return VerticalLayout( + HorizontalLayout( + DisplayCommandModel( + keys=[f"{device_id}.checkConditions"], + width=7 * BASE_INC, + height=BASE_INC, + ), + DisplayCommandModel( + keys=[f"{device_id}.updateConditions"], + width=7 * BASE_INC, + height=BASE_INC, + ), + DisplayStateColorModel( + keys=[f"{device_id}.conditionsMatch"], + width=7 * BASE_INC, + height=BASE_INC, + ), + ), + HorizontalLayout( + DisplayCommandModel( + keys=[f"{device_id}.startMonitoring"], + width=7 * BASE_INC, + height=BASE_INC, + ), + DisplayCommandModel( + keys=[f"{device_id}.stopMonitoring"], + width=7 * BASE_INC, + height=BASE_INC, + ), + ), + EditableRow(device_id, schema_hash, "updateManagerOnMonitor", 13, 1), + EditableRow(device_id, schema_hash, "loadConstantsOnMonitor", 13, 1), + TableElementModel( + keys=[f"{device_id}.keyMapping"], + width=30 * BASE_INC, + height=20 * BASE_INC, + ), + ) + + # section: here be monsters -def recursive_maybe_editable( - device_id, schema_hash, prefix, depth=1, max_depth=3, title=None +def recursive_editable( + device_id, + schema_hash, + prefix, + depth=1, + max_depth=3, + title=None, + row_class=MaybeDisplayMaybeEditableRow, ): schema_hash = schema_to_hash(schema_hash) # note: not just using sets because that loses ordering @@ -1313,7 +1701,7 @@ def recursive_maybe_editable( for key in schema_hash.get(prefix).getKeys(): attrs = schema_hash.getAttributes(f"{prefix}.{key}") - if attrs.get("nodeType") == NODE_TYPE_NODE: + if attrs.get("nodeType") == karabo.native.NodeType.Node.value: if "classId" in attrs and attrs.get("classId") == "Slot": slot_keys.append(key) else: @@ -1322,8 +1710,7 @@ def recursive_maybe_editable( value_keys.append(key) res = titled(title)(boxed(VerticalLayout))( children=[ - MaybeEditableRow(device_id, schema_hash, f"{prefix}.{key}") - for key in value_keys + row_class(device_id, schema_hash, f"{prefix}.{key}") for key in value_keys ] + [ DisplayCommandModel( @@ -1339,12 +1726,13 @@ def recursive_maybe_editable( res.children.append( VerticalLayout( children=[ - recursive_maybe_editable( + recursive_editable( device_id, schema_hash, f"{prefix}.{key}", depth=depth + 1, max_depth=max_depth, + row_class=row_class, ) for key in node_keys ] @@ -1376,4 +1764,4 @@ def recursive_subschema_scene( prefix="managedKeys", ): mds_hash = schema_to_hash(device_schema) - return recursive_maybe_editable(device_id, mds_hash, prefix) + return recursive_editable(device_id, mds_hash, prefix) diff --git a/src/calng/schemas.py b/src/calng/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..340f16048fc3b4a17427483d437f52153a0251e6 --- /dev/null +++ b/src/calng/schemas.py @@ -0,0 +1,248 @@ +from karabo.bound import ( + IMAGEDATA_ELEMENT, + INT32_ELEMENT, + INT64_ELEMENT, + NDARRAY_ELEMENT, + NODE_ELEMENT, + STRING_ELEMENT, + VECTOR_STRING_ELEMENT, + Schema, +) + + +def preview_schema(wrap_image_in_imagedata=False): + res = Schema() + ( + NODE_ELEMENT(res) + .key("image") + .commit(), + ) + if wrap_image_in_imagedata: + ( + IMAGEDATA_ELEMENT(res) + .key("image.data") + .commit(), + ) + else: + ( + NDARRAY_ELEMENT(res) + .key("image.data") + .dtype("FLOAT") + .commit(), + ) + return res + + +def xtdf_output_schema(use_shmem_handle=True): + # TODO: trim output schema / adapt to specific detectors + # currently: based on snapshot of actual output reusing AGIPD hash + res = Schema() + ( + NODE_ELEMENT(res) + .key("image") + .commit(), + + NDARRAY_ELEMENT(res) + .key("image.length") + .dtype("UINT32") + .commit(), + + NDARRAY_ELEMENT(res) + .key("image.cellId") + .dtype("UINT16") + .commit(), + + NDARRAY_ELEMENT(res) + .key("image.pulseId") + .dtype("UINT64") + .commit(), + + NDARRAY_ELEMENT(res) + .key("image.status") + .commit(), + + NDARRAY_ELEMENT(res) + .key("image.trainId") + .dtype("UINT64") + .commit(), + + VECTOR_STRING_ELEMENT(res) + .key("calngShmemPaths") + .assignmentOptional() + .defaultValue([]) + .commit(), + + NODE_ELEMENT(res) + .key("metadata") + .commit(), + + STRING_ELEMENT(res) + .key("metadata.source") + .assignmentOptional() + .defaultValue("") + .commit(), + + NODE_ELEMENT(res) + .key("metadata.timestamp") + .commit(), + + INT32_ELEMENT(res) + .key("metadata.timestamp.tid") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NODE_ELEMENT(res) + .key("header") + .commit(), + + INT32_ELEMENT(res) + .key("header.minorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(res) + .key("header.majorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(res) + .key("header.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(res) + .key("header.linkId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(res) + .key("header.dataId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(res) + .key("header.pulseCount") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NDARRAY_ELEMENT(res) + .key("header.reserved") + .commit(), + + NDARRAY_ELEMENT(res) + .key("header.magicNumberBegin") + .commit(), + + NODE_ELEMENT(res) + .key("detector") + .commit(), + + INT32_ELEMENT(res) + .key("detector.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NDARRAY_ELEMENT(res) + .key("detector.data") + .commit(), + + NODE_ELEMENT(res) + .key("trailer") + .commit(), + + NDARRAY_ELEMENT(res) + .key("trailer.checksum") + .commit(), + + NDARRAY_ELEMENT(res) + .key("trailer.magicNumberEnd") + .commit(), + + INT32_ELEMENT(res) + .key("trailer.status") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(res) + .key("trailer.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + ) + + if use_shmem_handle: + ( + STRING_ELEMENT(res) + .key("image.data") + .assignmentOptional() + .noDefaultValue() + .commit(), + ) + else: + ( + NDARRAY_ELEMENT(res) + .key("image.data") + .commit(), + ) + return res + + +def jf_output_schema(use_shmem_handle=True): + res = Schema() + ( + NODE_ELEMENT(res) + .key("data") + .commit(), + + INT32_ELEMENT(res) + .key("data.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NDARRAY_ELEMENT(res) + .key("data.frameNumber") + .commit(), + + NDARRAY_ELEMENT(res) + .key("data.gain") + .commit(), + + NDARRAY_ELEMENT(res) + .key("data.memoryCell") + .commit(), + + NDARRAY_ELEMENT(res) + .key("data.timestamp") + .commit(), + + VECTOR_STRING_ELEMENT(res) + .key("calngShmemPaths") + .assignmentOptional() + .defaultValue([]) + .commit(), + ) + if use_shmem_handle: + ( + STRING_ELEMENT(res) + .key("data.adc") + .assignmentOptional() + .noDefaultValue() + .commit(), + ) + else: + ( + NDARRAY_ELEMENT(res) + .key("data.adc") + .commit(), + ) + return res diff --git a/src/calng/utils.py b/src/calng/utils.py index 344752374491c1a4cbaaed6f1f0f03c492f4069f..e3e747af117fa9e32661db77120250e8a1a4ebcb 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -9,62 +9,123 @@ from timeit import default_timer import numpy as np -def pick_frame_index(selection_mode, index, cell_table, pulse_table, warn_func=None): +class ContextWarningLamp: + """Warning model: all warnings are generated within contexts. Each context handles + one type of warning (one lamp aggregates multiple types). If a warning is issued + within a context, the corresponding warning type is set. If not, then the + corresponding warning type is unset.""" + + def __init__(self, device, schema_key): + self._device = device + self._schema_key = schema_key + self._active_warnings = {} + # note: the following two attributes are only sets for future generalization + # warnings issued during "current" context + self._new_warnings = set() + # warning types used in current context (to clear if no warnings issued) + self._new_tested = set() + + def new_context(self, warn_type=None, only_print_once=False): + """Use with "with", will give warning function with appropriate parameters + + warn_type: The warning type which to set or unset based on outcome within + context (must be hashable, should probably be some enum member) + only_print_once: By default, the exact same string will not be printed twice for + any given warn_type. Some errors may, however, generate slightly differing + strings each time. With only_print_once, only the first warning for a given + warn_type is printed as long as the warning remains active. + """ + # discard instead of clear in case of nesting (not yet fully supported though) + self._new_warnings.discard(warn_type) + self._new_tested.add(warn_type) + return functools.partial( + self.warn, warn_type=warn_type, only_print_once=only_print_once + ) + + def warn(self, message, warn_type, only_print_once): + # avoid duplicating current warning message for this type + if (warn_type not in self._active_warnings) or ( + not only_print_once and message != self._active_warnings[warn_type] + ): + self._device.log_status_warn(message) + self._new_warnings.add(warn_type) + self._active_warnings[warn_type] = message + + def update_state(self, on_success="NORMAL", on_error="ERROR"): + for now_okay in self._new_tested - self._new_warnings: + self._active_warnings.pop(now_okay, None) + current_state = self._device.unsafe_get(self._schema_key) + if self._active_warnings and current_state != on_error: + self._device.set(self._schema_key, on_error) + elif not self._active_warnings and current_state != on_success: + self._device.set(self._schema_key, on_success) + # TODO: maybe handle nesting / multi-type alarm context + self._new_warnings.clear() + self._new_tested.clear() + + +class PreviewIndexSelectionMode(enum.Enum): + FRAME = "frame" + CELL = "cell" + PULSE = "pulse" + + +def rec_getattr(obj, path): + res = obj + for part in path.split("."): + res = getattr(res, part) + return res + + +def pick_frame_index(selection_mode, index, cell_table, pulse_table): """When selecting a single frame to preview, an obvious question is whether the number the operator provides is a frame index, a cell ID, or a pulse ID. This function allows any of the three, translating into frame index. - As this will be used by correction devices, the warn_func parameter allows the - function to issue warnings via this instead of raising exceptions. - Indices below zero are special values and thus returned directly. - Returns: (found frame index, corresponding cell ID, corresponding pulse ID)""" + Returns: (frame index, cell ID, pulse ID), optional warning""" if index < 0: - return index, index, index - - # TODO: enum - if selection_mode == "frame": - if index >= cell_table.size: - if warn_func is not None: - warn_func( - f"Index {index} out of range for cell table of length " - f"{len(cell_table)}, returning index 0 instead" - ) - frame_index = 0 - else: - frame_index = index + return (index, index, index), None - return frame_index, cell_table[frame_index], pulse_table[frame_index] - elif selection_mode == "cell": - found = np.where(cell_table == index)[0] - if len(found) > 0: - cell = index - frame_index = found[0] + warning = None + selection_mode = PreviewIndexSelectionMode(selection_mode) + + if selection_mode is PreviewIndexSelectionMode.FRAME: + if index < cell_table.size: + frame_index = index else: - cell = cell_table[0] - if warn_func is not None: - warn_func( - f"Cell {index} not found, arbitrary cell {cell} returned instead" - ) + warning = ( + f"Frame index {index} out of range for cell table of length " + f"{len(cell_table)}. Will use frame index 0 instead." + ) frame_index = 0 - return frame_index, cell, pulse_table[frame_index] - elif selection_mode == "pulse": - found = np.where(pulse_table == index)[0] - if len(found) > 0: - pulse = index - frame_index = found[0] + else: + if selection_mode is PreviewIndexSelectionMode.CELL: + index_table = cell_table else: - pulse = pulse_table[0] - if warn_func is not None: - warn_func( - f"Pulse {index} not found, arbitrary pulse {pulse} returned instead" - ) + index_table = pulse_table + found = np.where(index_table == index)[0] + if len(found) == 1: + frame_index = found[0] + elif len(found) == 0: frame_index = 0 - return frame_index, cell_table[frame_index], pulse - else: - raise ValueError(f"Invalid selection mode '{selection_mode}'") + warning = ( + f"{selection_mode.value.capitalize()} ID {index} not found in " + f"{selection_mode.name} table. Will use frame {frame_index}, which " + f"corresponds to {selection_mode.name} {index_table[0]} instead." + ) + elif len(found) > 1: + warning = ( + f"{selection_mode.value.capitalize()} ID {index} was not unique in " + f"{selection_mode} table. Will use first occurrence out of " + f"{len(found)} occurrences (frame {frame_index})." + ) + + cell_id = cell_table[frame_index] + pulse_id = pulse_table[frame_index] + return (frame_index, cell_id, pulse_id), warning def threadsafe_cache(fun): @@ -324,6 +385,10 @@ class StateContext: self.device.updateState(self.revert_to) +class NonMonotonicTrainIdWarning(Warning): + pass + + class TrainRatioTracker: """Measure how many percent of recent train IDs (from contiguous set) were seen @@ -331,15 +396,15 @@ class TrainRatioTracker: buffer_size from latest train ID (depending on calls to get). Call update(train_id) when you see a new train and call get to get() the ratio of recent trains seen. - In case warn_callback is given, update can issue a warning in case invalid train - IDs are received. The tracker assumes trains are strictly increasing and that they - are supposed to be contiguous - hence the ability to infer when some are missing. + Updating will raise NonMonotonicTrainIdWarning or LargeTrainIdGapWarning in case + train ID looks like it's from far in the future or from some time in the past. + Device using this tracker should decide what to do; maybe call reset. """ - def __init__(self, buffer_size=50, warn_callback=None): + def __init__(self, buffer_size=50): self._train_id_queue = collections.deque(maxlen=buffer_size) - self._train_id_range = buffer_size - self._warn_callback = warn_callback + # should train ID range be explicitly configurable? + self._train_id_range = buffer_size * 10 def get(self, current_train=None): """Get the ratio of recent trains until current_train or latest updated. @@ -369,15 +434,14 @@ class TrainRatioTracker: # TODO: avoid estimator ramp-up (don't initially divide by full range) return len(self._train_id_queue) * 100 / self._train_id_range + def reset(self): + self._train_id_queue.clear() + def update(self, train_id): - if ( - len(self._train_id_queue) > 0 - and self._train_id_queue[-1] >= train_id - and self._warn_callback is not None - ): - self._warn_callback( - f"New train ID {train_id} not greater than last thing in queue, " - f"{self._train_id_queue[-1]}, just thought you should know..." + if self._train_id_queue and (last_seen := self._train_id_queue[-1]) >= train_id: + raise NonMonotonicTrainIdWarning( + "New train ID not greater than last train ID seen! " + f"New: {train_id}, previous: {last_seen}" ) self._train_id_queue.append(train_id) @@ -437,3 +501,30 @@ class BadPixelValues(enum.IntFlag): OVERSCAN = 2 ** 18 NON_SENSITIVE = 2 ** 19 NON_LIN_RESPONSE_REGION = 2 ** 20 + WRONG_GAIN_VALUE = 2 ** 21 + NON_STANDARD_SIZE = 2 ** 22 + + +def downsample_2d(arr, factor, reduction_fun=np.nanmax): + """Generalization of downsampling from FemDataAssembler + + Expects first two dimensions of arr to be multiple of 2 ** factor + Useful if you're sitting at home and ssh connection is slow to get full-resolution + previews.""" + + for i in range(factor // 2): + arr = reduction_fun( + ( + arr[:-1:2], + arr[1::2], + ), + axis=0, + ) + arr = reduction_fun( + ( + arr[:, :-1:2], + arr[:, 1::2], + ), + axis=0, + ) + return arr diff --git a/src/tests/test_agipd_kernels.py b/src/tests/test_agipd_kernels.py index 8f0f9aca56b29ccfd914b3eda510ca32d5bfca56..125a2a9f531c5c91b788162708f26df12a4c368e 100644 --- a/src/tests/test_agipd_kernels.py +++ b/src/tests/test_agipd_kernels.py @@ -2,7 +2,7 @@ import h5py import numpy as np import pathlib -from calng import AgipdCorrection +from calng.corrections import AgipdCorrection input_dtype = np.uint16 output_dtype = np.float16 diff --git a/src/tests/test_dssc_kernels.py b/src/tests/test_dssc_kernels.py index 8277ae25e3a62eefe889f365aadb83ed4303a256..6be227775b77d71f78934d1d5a19a98c7934057c 100644 --- a/src/tests/test_dssc_kernels.py +++ b/src/tests/test_dssc_kernels.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from calng import DsscCorrection +from calng.corrections import DsscCorrection input_dtype = np.uint16 output_dtype = np.float16 @@ -76,7 +76,7 @@ def test_correct_oob_cells(): def test_reshape(): kernel_runner.processed_data_gpu.set(corrected_data) assert np.allclose( - kernel_runner.reshape(output_order="xyc"), corrected_data.transpose() + kernel_runner.reshape(output_order="xyf"), corrected_data.transpose() )