diff --git a/DEPENDS b/DEPENDS new file mode 100644 index 0000000000000000000000000000000000000000..adead8b96d81197f033bdccf55a7921887c14b23 --- /dev/null +++ b/DEPENDS @@ -0,0 +1,4 @@ +TrainMatcher, 1.2.0-2.10.2 +PipeToZeroMQ, 3.2.6-2.11.0 +calngDeps, 0.0.3-2.10.0 +calibrationClient, 9.0.6 diff --git a/README.md b/README.md index 4a09a144da1115a598849e46e4c73125e8112c2d..5933bd80822722eb9565d323d09ca6bda7bf3ab0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,20 @@ # calng calng is a collection of Karabo devices to perform online processing of 2D X-ray detector data at runtime. It is the successor of the calPy package. + +# CalCat secrets and deployment +Correction devices each run their own `calibration_client.CalibrationClient`, so they need to have credentials for CalCat. +They expect to be able to load these from a JSON file; by default, this will be in `$KARABO/var/data/calibration-client-secrets.json` (`var/data` is CWD of Karabo devices). +The file should look something like: + +```json +{ + "base_url": "https://in.xfel.eu/test_calibration", + "client_id": "[sort of secret]", + "client_secret": "[actual secret]", + "user_email": "[eh, not that secret]", + "caldb_store_path": "/gpfs/exfel/d/cal/caldb_store" +} +``` + +For deployment, you'll want `/calibration` instead of `/test_calibration` and the caldb store as seen from ONC will be `/common/cal/caldb_store`. diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff0f3c85f9d5fdefa7af2d7435e8d0c5e8f8f94 --- /dev/null +++ b/setup.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +from os.path import dirname, join, realpath +from setuptools import setup, find_packages + +from karabo.packaging.versioning import device_scm_version + + +ROOT_FOLDER = dirname(realpath(__file__)) +scm_version = device_scm_version( + ROOT_FOLDER, + join(ROOT_FOLDER, 'src', 'calng', '_version.py') +) + + +setup(name='calng', + use_scm_version=scm_version, + author='CAL team', + author_email='da-support@xfel.eu', + description='', + long_description='', + url='', + package_dir={'': 'src'}, + packages=find_packages('src'), + entry_points={ + 'karabo.bound_device': [ + 'AgipdCorrection = calng.AgipdCorrection:AgipdCorrection', + 'DsscCorrection = calng.DsscCorrection:DsscCorrection', + 'ModuleStacker = calng.ModuleStacker:ModuleStacker', + 'ShmemToZMQ = calng.ShmemToZMQ:ShmemToZMQ', + ], + + 'karabo.middlelayer_device': [ + 'CalibrationManager = calng.CalibrationManager:CalibrationManager' + ], + }, + package_data={'': ['kernels/*']}, + requires=[], +) diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py new file mode 100644 index 0000000000000000000000000000000000000000..36edc80d124a71e09aeaab51454b656bc42888bc --- /dev/null +++ b/src/calng/AgipdCorrection.py @@ -0,0 +1,831 @@ +import enum + +import cupy +import numpy as np +from karabo.bound import ( + BOOL_ELEMENT, + DOUBLE_ELEMENT, + FLOAT_ELEMENT, + KARABO_CLASSINFO, + NODE_ELEMENT, + OUTPUT_CHANNEL, + OVERWRITE_ELEMENT, + STRING_ELEMENT, + VECTOR_STRING_ELEMENT, + State, +) + +from . import base_gpu, calcat_utils, utils +from ._version import version as deviceVersion +from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema + + +class AgipdConstants(enum.Enum): + SlopesFF = enum.auto() + ThresholdsDark = enum.auto() + Offset = enum.auto() + SlopesPC = enum.auto() + BadPixelsDark = enum.auto() + BadPixelsPC = enum.auto() + BadPixelsFF = enum.auto() + + +# from pycalibration's enum.py +class AgipdGainMode(enum.IntEnum): + ADAPTIVE_GAIN = 0 + FIXED_HIGH_GAIN = 1 + FIXED_MEDIUM_GAIN = 2 + FIXED_LOW_GAIN = 3 + + +class BadPixelValues(enum.IntFlag): + """The European XFEL Bad Pixel Encoding + + Straight from pycalibration's enum.py""" + + OFFSET_OUT_OF_THRESHOLD = 2 ** 0 + NOISE_OUT_OF_THRESHOLD = 2 ** 1 + OFFSET_NOISE_EVAL_ERROR = 2 ** 2 + NO_DARK_DATA = 2 ** 3 + CI_GAIN_OUT_OF_THRESHOLD = 2 ** 4 + CI_LINEAR_DEVIATION = 2 ** 5 + CI_EVAL_ERROR = 2 ** 6 + FF_GAIN_EVAL_ERROR = 2 ** 7 + FF_GAIN_DEVIATION = 2 ** 8 + FF_NO_ENTRIES = 2 ** 9 + CI2_EVAL_ERROR = 2 ** 10 + VALUE_IS_NAN = 2 ** 11 + VALUE_OUT_OF_RANGE = 2 ** 12 + GAIN_THRESHOLDING_ERROR = 2 ** 13 + DATA_STD_IS_ZERO = 2 ** 14 + ASIC_STD_BELOW_NOISE = 2 ** 15 + INTERPOLATED = 2 ** 16 + NOISY_ADC = 2 ** 17 + OVERSCAN = 2 ** 18 + NON_SENSITIVE = 2 ** 19 + NON_LIN_RESPONSE_REGION = 2 ** 20 + + +class CorrectionFlags(enum.IntFlag): + NONE = 0 + THRESHOLD = 1 + OFFSET = 2 + BLSHIFT = 4 + REL_GAIN_PC = 8 + GAIN_XRAY = 16 + BPMASK = 32 + + +class AgipdGpuRunner(base_gpu.BaseGpuRunner): + _kernel_source_filename = "agipd_gpu.cu" + _corrected_axis_order = "cxy" + + def __init__( + self, + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells, + input_data_dtype=cupy.uint16, + output_data_dtype=cupy.float32, + bad_pixel_mask_value=cupy.nan, + gain_mode=AgipdGainMode.ADAPTIVE_GAIN, + g_gain_value=1, + ): + self.gain_mode = gain_mode + # default gain only matters when not thresholding (missing constant or fixed) + # note: gain stage (result of thresholding) is 0, 1, or 2 + if self.gain_mode is AgipdGainMode.ADAPTIVE_GAIN: + self.default_gain = cupy.uint8(gain_mode) + else: + self.default_gain = cupy.uint8(gain_mode - 1) + self.input_shape = (memory_cells, 2, pixels_x, pixels_y) + self.processed_shape = (memory_cells, pixels_x, pixels_y) + super().__init__( + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells, + input_data_dtype, + output_data_dtype, + ) + self.gain_map_gpu = cupy.empty(self.processed_shape, dtype=cupy.float32) + self.preview_buffer_getters.extend( + [self._get_raw_gain_for_preview, self._get_gain_map_for_preview] + ) + + self.map_shape = (self.constant_memory_cells, self.pixels_x, self.pixels_y) + self.gm_map_shape = self.map_shape + (3,) # for gain-mapped constants + self.threshold_map_shape = self.map_shape + (2,) + # constants + self.gain_thresholds_gpu = cupy.empty( + self.threshold_map_shape, dtype=cupy.float32 + ) + self.offset_map_gpu = cupy.zeros(self.gm_map_shape, dtype=cupy.float32) + self.rel_gain_pc_map_gpu = cupy.ones(self.gm_map_shape, dtype=cupy.float32) + # not gm_map_shape because it only applies to medium gain pixels + self.md_additional_offset_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) + self.rel_gain_xray_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) + self.bad_pixel_map_gpu = cupy.zeros(self.gm_map_shape, dtype=cupy.uint32) + self.set_bad_pixel_mask_value(bad_pixel_mask_value) + self.set_g_gain_value(g_gain_value) + + self.update_block_size((1, 1, 64)) + + def _get_raw_for_preview(self): + return self.input_data_gpu[:, 0] + + def _get_corrected_for_preview(self): + return self.processed_data_gpu + + # special to AGIPD + def _get_raw_gain_for_preview(self): + return self.input_data_gpu[:, 1] + + def _get_gain_map_for_preview(self): + return self.gain_map_gpu + + def load_thresholds(self, threshold_map): + # shape: y, x, memory cell, thresholds and gain values + # note: the gain values are something like means used to derive thresholds + self.gain_thresholds_gpu.set( + np.transpose(threshold_map[..., :2], (2, 1, 0, 3)).astype(np.float32) + ) + + def load_offset_map(self, offset_map): + # shape: y, x, memory cell, gain stage + self.offset_map_gpu.set( + np.transpose(offset_map, (2, 1, 0, 3)).astype(np.float32) + ) + + def load_rel_gain_pc_map(self, slopes_pc_map, override_md_additional_offset=None): + # pc has funny shape (11, 352, 128, 512) from file + # this is (fi, memory cell, y, x) + slopes_pc_map = slopes_pc_map.astype(np.float32) + # the following may contain NaNs, though... + hg_slope = slopes_pc_map[0] + hg_intercept = slopes_pc_map[1] + mg_slope = slopes_pc_map[3] + mg_intercept = slopes_pc_map[4] + # TODO: remove sanitization (should happen in constant preparation notebook) + # from agipdlib.py: replace NaN with median (per memory cell) + # note: suffixes in agipdlib are "_m" and "_l", should probably be "_I" + for naughty_array in (hg_slope, hg_intercept, mg_slope, mg_intercept): + medians = np.nanmedian(naughty_array, axis=(1, 2)) + nan_bool = np.isnan(naughty_array) + nan_cell, _, _ = np.where(nan_bool) + naughty_array[nan_bool] = medians[nan_cell] + + too_low_bool = naughty_array < 0.8 * medians[:, np.newaxis, np.newaxis] + too_low_cell, _, _ = np.where(too_low_bool) + naughty_array[too_low_bool] = medians[too_low_cell] + + too_high_bool = naughty_array > 1.2 * medians[:, np.newaxis, np.newaxis] + too_high_cell, _, _ = np.where(too_high_bool) + naughty_array[too_high_bool] = medians[too_high_cell] + + frac_hg_mg = hg_slope / mg_slope + rel_gain_map = np.ones( + (3, self.constant_memory_cells, self.pixels_y, self.pixels_x), + dtype=np.float32, + ) + rel_gain_map[1] = rel_gain_map[0] * frac_hg_mg + rel_gain_map[2] = rel_gain_map[1] * 4.48 + self.rel_gain_pc_map_gpu.set(np.transpose(rel_gain_map, (1, 3, 2, 0))) + if override_md_additional_offset is None: + md_additional_offset = (hg_intercept - mg_intercept * frac_hg_mg).astype( + np.float32 + ) + self.md_additional_offset_gpu.set( + np.transpose(md_additional_offset, (0, 2, 1)) + ) + else: + self.override_md_additional_offset(override_md_additional_offset) + + def override_md_additional_offset(self, override_value): + self.md_additional_offset_gpu.fill(override_value) + + def load_rel_gain_ff_map(self, slopes_ff_map): + # constant shape: y, x, memory cell + if slopes_ff_map.shape[2] == 2: + # TODO: remove support for old format + # old format, is per pixel only (shape is y, x, 2) + # note: we should not support this in online + slopes_ff_map = np.broadcast_to( + slopes_ff_map[..., 0][..., np.newaxis], + (self.pixels_y, self.pixels_x, self.constant_memory_cells), + ) + self.rel_gain_xray_map_gpu.set(np.transpose(slopes_ff_map).astype(np.float32)) + + def set_g_gain_value(self, override_value): + self.g_gain_value = cupy.float32(override_value) + + def load_bad_pixels_map(self, bad_pixels_map, override_flags_to_use=None): + # will simply OR with already loaded, does not take into account which ones + # TODO: inquire what "mask for double size pixels" means + if len(bad_pixels_map.shape) == 3: + if bad_pixels_map.shape == ( + self.pixels_y, + self.pixels_x, + self.constant_memory_cells, + ): + # BadPixelsFF is not per gain stage - broadcasting along gain dimension + self.bad_pixel_map_gpu |= cupy.asarray( + np.broadcast_to( + np.transpose(bad_pixels_map)[..., np.newaxis], + self.gm_map_shape, + ), + dtype=np.uint32, + ) + elif bad_pixels_map.shape == ( + self.constant_memory_cells, + self.pixels_y, + self.pixels_x, + ): + # oh, can also be old bad pixels pc? + self.bad_pixel_map_gpu |= cupy.asarray( + np.broadcast_to( + np.transpose(bad_pixels_map, (0, 2, 1))[..., np.newaxis], + self.gm_map_shape, + ), + dtype=np.uint32, + ) + else: + raise ValueError( + f"Unsupported bad pixel map shape: {bad_pixels_map.shape}" + ) + else: + self.bad_pixel_map_gpu |= cupy.asarray( + np.transpose(bad_pixels_map, (2, 1, 0, 3)), dtype=np.uint32 + ) + + if override_flags_to_use is not None: + self.override_bad_pixel_flags_to_use(override_flags_to_use) + + def override_bad_pixel_flags_to_use(self, override_value): + self.bad_pixel_map_gpu &= cupy.uint32(override_value) + + def set_bad_pixel_mask_value(self, mask_value): + self.bad_pixel_mask_value = cupy.float32(mask_value) + + def flush_buffers(self): + self.offset_map_gpu.fill(0) + self.rel_gain_pc_map_gpu.fill(1) + self.md_additional_offset_gpu.fill(0) + self.rel_gain_xray_map_gpu.fill(1) + self.bad_pixel_map_gpu.fill(0) + + # TODO: baseline shift + + def correct(self, flags): + if flags & CorrectionFlags.BLSHIFT: + raise NotImplementedError("Baseline shift not implemented yet") + + self.correction_kernel( + self.full_grid, + self.full_block, + ( + self.input_data_gpu, + self.cell_table_gpu, + cupy.uint8(flags), + self.default_gain, + self.gain_thresholds_gpu, + self.offset_map_gpu, + self.rel_gain_pc_map_gpu, + self.md_additional_offset_gpu, + self.rel_gain_xray_map_gpu, + self.g_gain_value, + self.bad_pixel_map_gpu, + self.bad_pixel_mask_value, + self.gain_map_gpu, + self.processed_data_gpu, + ), + ) + + def _init_kernels(self): + kernel_source = self._kernel_template.render( + { + "pixels_x": self.pixels_x, + "pixels_y": self.pixels_y, + "data_memory_cells": self.memory_cells, + "constant_memory_cells": self.constant_memory_cells, + "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), + "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), + "corr_enum": utils.enum_to_c_template(CorrectionFlags), + } + ) + self.source_module = cupy.RawModule(code=kernel_source) + self.correction_kernel = self.source_module.get_function("correct") + + +class AgipdCalcatFriend(calcat_utils.BaseCalcatFriend): + _constant_enum_class = AgipdConstants + + def __init__(self, device, *args, **kwargs): + super().__init__(device, *args, **kwargs) + self._constants_need_conditions = { + AgipdConstants.ThresholdsDark: self.dark_condition, + AgipdConstants.Offset: self.dark_condition, + AgipdConstants.SlopesPC: self.dark_condition, + AgipdConstants.SlopesFF: self.illuminated_condition, + AgipdConstants.BadPixelsDark: self.dark_condition, + AgipdConstants.BadPixelsPC: self.dark_condition, + AgipdConstants.BadPixelsFF: self.illuminated_condition, + } + + @staticmethod + def add_schema( + schema, + managed_keys, + param_prefix="constantParameters", + status_prefix="foundConstants", + ): + super(AgipdCalcatFriend, AgipdCalcatFriend).add_schema( + schema, managed_keys, "AGIPD-Type", param_prefix, status_prefix + ) + + ( + OVERWRITE_ELEMENT(schema) + .key(f"{param_prefix}.memoryCells") + .setNewDefaultValue(352) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key(f"{param_prefix}.biasVoltage") + .setNewDefaultValue(300) + .commit() + ) + + ( + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.acquisitionRate") + .assignmentOptional() + .defaultValue(1.1) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.gainSetting") + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.photonEnergy") + .assignmentOptional() + .defaultValue(9.2) + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.gainMode") + .assignmentOptional() + .defaultValue("ADAPTIVE_GAIN") + .options(",".join(gain_mode.name for gain_mode in AgipdGainMode)) + .reconfigurable() + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.integrationTime") + .assignmentOptional() + .defaultValue(12) + .reconfigurable() + .commit(), + ) + managed_keys.add(f"{param_prefix}.acquisitionRate") + managed_keys.add(f"{param_prefix}.gainSetting") + managed_keys.add(f"{param_prefix}.photonEnergy") + managed_keys.add(f"{param_prefix}.gainMode") + managed_keys.add(f"{param_prefix}.integrationTime") + + calcat_utils.add_status_schema_from_enum(schema, status_prefix, AgipdConstants) + + def dark_condition(self): + res = calcat_utils.OperatingConditions() + res["Memory cells"] = self._get_param("memoryCells") + res["Sensor Bias Voltage"] = self._get_param("biasVoltage") + res["Pixels X"] = self._get_param("pixelsX") + res["Pixels Y"] = self._get_param("pixelsY") + res["Acquisition rate"] = self._get_param("acquisitionRate") + + # TODO: remove this workaround after CalCat update + integration_time = self._get_param("integrationTime") + if integration_time != 12: + res["Integration Time"] = integration_time + + gain_mode = AgipdGainMode[self._get_param("gainMode")] + if gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: + res["Gain Mode"] = 1 + + # TODO: make configurable whether or not to include gain setting? + res["Gain Setting"] = self._get_param("gainSetting") + + return res + + def illuminated_condition(self): + res = self.dark_condition() + + # note: can consider always setting memory cells to 352 for FF + # (deviation on constants in database should remove need for this, though) + + # for now, FF constants don't care about gain mode + if "Gain Mode" in res: + del res["Gain Mode"] + + res["Source Energy"] = self._get_param("photonEnergy") + + return res + + +@KARABO_CLASSINFO("AgipdCorrection", deviceVersion) +class AgipdCorrection(BaseCorrection): + # subclass *must* set these attributes + _correction_flag_class = CorrectionFlags + _correction_field_names = ( + ("thresholding", CorrectionFlags.THRESHOLD), + ("offset", CorrectionFlags.OFFSET), + ("relGainPc", CorrectionFlags.REL_GAIN_PC), + ("gainXray", CorrectionFlags.GAIN_XRAY), + ("badPixels", CorrectionFlags.BPMASK), + ) + _kernel_runner_class = AgipdGpuRunner + _calcat_friend_class = AgipdCalcatFriend + _constant_enum_class = AgipdConstants + _managed_keys = BaseCorrection._managed_keys.copy() + + @staticmethod + def expectedParameters(expected): + ( + OVERWRITE_ELEMENT(expected) + .key("dataFormat.memoryCells") + .setNewDefaultValue(352) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("preview.selectionMode") + .setNewDefaultValue("cell") + .commit(), + ) + + ( + OUTPUT_CHANNEL(expected) + .key("preview.outputRawGain") + .dataSchema(preview_schema) + .commit(), + + OUTPUT_CHANNEL(expected) + .key("preview.outputGainMap") + .dataSchema(preview_schema) + .commit(), + ) + + AgipdCalcatFriend.add_schema(expected, AgipdCorrection._managed_keys) + # this is not automatically done by superclass for complicated class reasons + add_correction_step_schema( + expected, + AgipdCorrection._managed_keys, + AgipdCorrection._correction_field_names, + ) + + # additional settings specific to AGIPD correction steps + ( + BOOL_ELEMENT(expected) + .key("corrections.relGainPc.overrideMdAdditionalOffset") + .displayedName("Override md_additional_offset") + .description( + "Toggling this on will use the value in the next field globally for " + "md_additional_offset. Note that the correction map on GPU gets " + "overwritten as long as this boolean is True, so reload constants " + "after turning off." + ) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(expected) + .key("corrections.relGainPc.mdAdditionalOffset") + .displayedName("Value for md_additional_offset (if overriding)") + .description( + "Normally, md_additional_offset (part of relative gain correction) is " + "computed when loading SlopesPC. In case you want to use a different " + "value (global for all medium gain pixels), you can specify it here " + "and set corrections.overrideMdAdditionalOffset to True." + ) + .assignmentOptional() + .defaultValue(0) + .reconfigurable() + .commit(), + + FLOAT_ELEMENT(expected) + .key("corrections.gainXray.gGainValue") + .displayedName("G_gain_value") + .description( + "Newer X-ray gain correction constants are absolute. The default " + "G_gain_value of 1 means that output is expected to be in keV. If " + "this is not desired, one can here specify the mean X-ray gain value " + "over all modules to get ADU values out - operator must manually " + "find this mean value." + ) + .assignmentOptional() + .defaultValue(1) + .reconfigurable() + .commit(), + + STRING_ELEMENT(expected) + .key("corrections.badPixels.maskingValue") + .displayedName("Bad pixel masking value") + .description( + "Any pixels masked by the bad pixel mask will have their value " + "replaced with this. Note that this parameter is to be interpreted as " + "a numpy.float32; use 'nan' to get NaN value." + ) + .assignmentOptional() + .defaultValue("nan") + .reconfigurable() + .commit(), + + NODE_ELEMENT(expected) + .key("corrections.badPixels.subsetToUse") + .displayedName("Bad pixel flags to use") + .description( + "The booleans under this node allow for selecting a subset of bad " + "pixel types to take into account when doing bad pixel masking. " + "Upon updating these flags, the map used for bad pixel masking will " + "be ANDed with this selection. Turning disabled flags back on causes " + "reloading of cached constants." + ) + .commit(), + ) + AgipdCorrection._managed_keys.add( + "corrections.relGainPc.overrideMdAdditionalOffset" + ) + AgipdCorrection._managed_keys.add("corrections.relGainPc.mdAdditionalOffset") + AgipdCorrection._managed_keys.add("corrections.gainXray.gGainValue") + AgipdCorrection._managed_keys.add("corrections.badPixels.maskingValue") + # TODO: DRY / encapsulate + for field in BadPixelValues: + ( + BOOL_ELEMENT(expected) + .key(f"corrections.badPixels.subsetToUse.{field.name}") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit() + ) + AgipdCorrection._managed_keys.add( + f"corrections.badPixels.subsetToUse.{field.name}" + ) + + # mandatory: manager needs this in schema + ( + VECTOR_STRING_ELEMENT(expected) + .key("managedKeys") + .assignmentOptional() + .defaultValue(list(AgipdCorrection._managed_keys)) + .commit() + ) + + @property + def input_data_shape(self): + return ( + self.unsafe_get("dataFormat.memoryCells"), + 2, + self.unsafe_get("dataFormat.pixelsX"), + self.unsafe_get("dataFormat.pixelsY"), + ) + + def __init__(self, config): + super().__init__(config) + # note: gain mode single sourced from constant retrieval node + self.gain_mode = AgipdGainMode[config.get("constantParameters.gainMode")] + + try: + self.bad_pixel_mask_value = np.float32( + config.get("corrections.badPixels.maskingValue") + ) + except ValueError: + self.bad_pixel_mask_value = np.float32("nan") + + self._kernel_runner_init_args = { + "gain_mode": self.gain_mode, + "bad_pixel_mask_value": self.bad_pixel_mask_value, + "g_gain_value": config.get("corrections.gainXray.gGainValue"), + } + + # configurability: overriding md_additional_offset + if config.get("corrections.relGainPc.overrideMdAdditionalOffset"): + self._override_md_additional_offset = config.get( + "corrections.relGainPc.mdAdditionalOffset" + ) + else: + self._override_md_additional_offset = None + + self._has_updated_bad_pixel_selection = False + + + def _initialization(self): + self._update_bad_pixel_selection() + super()._initialization() + + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + """Called by input_handler for each data hash. Should correct data, optionally + compute preview, write data output, and optionally write preview outputs.""" + # original shape: memory_cell, data/raw_gain, x, y + + pulse_table = np.squeeze(data_hash.get("image.pulseId")) + if self._frame_filter is not None: + try: + cell_table = cell_table[self._frame_filter] + pulse_table = pulse_table[self._frame_filter] + image_data = image_data[self._frame_filter] + except IndexError: + self.log_status_warn( + "Failed to apply frame filter, please check that it is valid!" + ) + return + + try: + self.kernel_runner.load_data(image_data) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") + return + except Exception as e: + self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") + + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + self.kernel_runner.load_cell_table(cell_table) + self.kernel_runner.correct(self._correction_flag_enabled) + self.kernel_runner.reshape( + output_order=self.unsafe_get("dataFormat.outputAxisOrder"), + out=buffer_array, + ) + # after reshape, data for dataOutput is now safe in its own buffer + if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.kernel_runner.correct(self._correction_flag_preview) + ( + preview_slice_index, + preview_cell, + preview_pulse, + ) = utils.pick_frame_index( + self.unsafe_get("preview.selectionMode"), + self.unsafe_get("preview.index"), + cell_table, + pulse_table, + warn_func=self.log_status_warn, + ) + ( + preview_raw, + preview_corrected, + preview_raw_gain, + preview_gain_map, + ) = self.kernel_runner.compute_previews(preview_slice_index) + + # reusing input data hash for sending + data_hash.set("image.data", buffer_handle) + data_hash.set("calngShmemPaths", ["image.data"]) + + data_hash.set("image.cellId", cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) + + self._write_output(data_hash, metadata) + if do_generate_preview: + self._write_combiner_previews( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), + ("preview.outputRawGain", preview_raw_gain), + ("preview.outputGainMap", preview_gain_map), + ), + train_id, + source, + ) + + def _load_constant_to_runner(self, constant, constant_data): + # TODO: encode correction / constant dependencies in a clever way + if constant is AgipdConstants.ThresholdsDark: + field_name = "thresholding" # TODO: (reverse) mapping, DRY + if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: + self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode") + return + self.kernel_runner.load_thresholds(constant_data) + elif constant is AgipdConstants.Offset: + field_name = "offset" + self.kernel_runner.load_offset_map(constant_data) + elif constant is AgipdConstants.SlopesPC: + field_name = "relGainPc" + self.kernel_runner.load_rel_gain_pc_map(constant_data) + if self._override_md_additional_offset is not None: + self.kernel_runner.md_additional_offset_gpu.fill( + self._override_md_additional_offset + ) + elif constant is AgipdConstants.SlopesFF: + field_name = "gainXray" + self.kernel_runner.load_rel_gain_ff_map(constant_data) + elif "BadPixels" in constant.name: + field_name = "badPixels" + self.kernel_runner.load_bad_pixels_map( + constant_data, override_flags_to_use=self._override_bad_pixel_flags + ) + + # switch relevant correction on if it just now became available + if not self.get(f"corrections.{field_name}.available"): + # TODO: turn off again when flushing + self.set(f"corrections.{field_name}.available", True) + + self._update_correction_flags() + self.log_status_info(f"Done loading {constant.name} to GPU") + + def _update_bad_pixel_selection(self): + selection = 0 + for field in BadPixelValues: + if self.get(f"corrections.badPixels.subsetToUse.{field.name}"): + selection |= field + self._override_bad_pixel_flags = selection + + def preReconfigure(self, config): + super().preReconfigure(config) + if config.has("corrections.badPixels.maskingValue"): + # only check if it is valid; postReconfigure will use it + try: + np.float32(config.get("corrections.badPixels.maskingValue")) + except ValueError: + self.log_status_warn("Invalid masking value, ignoring.") + config.erase("corrections.badPixels.maskingValue") + + def postReconfigure(self): + super().postReconfigure() + + # TODO: move after getting cached update, check if necessary + if self.get("corrections.relGainPc.overrideMdAdditionalOffset"): + self._override_md_additional_offset = self.get( + "corrections.relGainPc.mdAdditionalOffset" + ) + self.kernel_runner.override_md_additional_offset( + self._override_md_additional_offset + ) + else: + self._override_md_additional_offset = None + + if not hasattr(self, "_prereconfigure_update_hash"): + return + + update = self._prereconfigure_update_hash + + if update.has("constantParameters.gainMode"): + self.gain_mode = AgipdGainMode[update["constantParameters.gainMode"]] + self._update_buffers() + + if update.has("corrections.gainXray.gGainValue"): + self.kernel_runner.set_g_gain_value( + self.get("corrections.gainXray.gGainValue") + ) + self._kernel_runner_init_args["g_gain_value"] = self.get( + "corrections.gainXray.gGainValue" + ) + + if update.has("corrections.badPixels.maskingValue"): + self.bad_pixel_mask_value = np.float32( + self.get("corrections.badPixels.maskingValue") + ) + self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) + self._kernel_runner_init_args[ + "bad_pixel_mask_value" + ] = self.bad_pixel_mask_value + + if any( + path.startswith("corrections.badPixels.subsetToUse") + for path in update.getPaths() + ): + self.log_status_info("Updating bad pixel maps based on subset specified") + if any( + update.get( + f"corrections.badPixels.subsetToUse.{field.name}", default=False + ) + for field in BadPixelValues + ): + self.log_status_info( + "Some fields reenabled, reloading cached bad pixel constants" + ) + with self.calcat_friend.cached_constants_lock: + for ( + constant, + data, + ) in self.calcat_friend.cached_constants.items(): + if "BadPixels" in constant.name: + self._load_constant_to_runner(constant, data) + self._update_bad_pixel_selection() + self.kernel_runner.override_bad_pixel_flags_to_use( + self._override_bad_pixel_flags + ) diff --git a/src/calng/CalibrationManager.py b/src/calng/CalibrationManager.py new file mode 100644 index 0000000000000000000000000000000000000000..bd519b395bdd93340064383b8214f31e0978d16e --- /dev/null +++ b/src/calng/CalibrationManager.py @@ -0,0 +1,1323 @@ +############################################################################# +# Author: schmidtp +# Created on August 06, 2021, 05:06 PM +# Copyright (C) European XFEL GmbH Hamburg. All rights reserved. +############################################################################# + +from asyncio import gather, wait_for, TimeoutError as AsyncTimeoutError +from collections import defaultdict +from collections.abc import Hashable +from datetime import datetime +from inspect import ismethod +from itertools import chain, repeat +from traceback import format_exc +from urllib.parse import urlparse +import json +import logging + +from tornado.httpclient import AsyncHTTPClient, HTTPError +from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future +from pkg_resources import parse_version + +from karabo.middlelayer import ( + KaraboError, Device, DeviceClientBase, Descriptor, Hash, Configurable, + Slot, Node, Type, + AccessMode, AccessLevel, Assignment, DaqPolicy, State, Unit, + UInt16, UInt32, Bool, Double, Schema, String, VectorString, VectorHash, + background, call, callNoWait, setNoWait, sleep, instantiate, slot, coslot, + getDevice, getTopology, getConfiguration, getConfigurationFromPast, + get_property) +from karabo.middlelayer_api.proxy import ProxyFactory + +from karabo import version as karaboVersion +from ._version import version as deviceVersion +from . import scenes + + +''' +Device states: + - INIT: When the device is starting up + - ACTIVE: When the device is ready to manage a calibration pipeline + - CHANGING: When the device is actively changing the pipeline configuration + - ERROR: Recoverable error, only allows server restart + - UNKNOWN: Unrecoverable error +''' + + +# Copied from karabo MDL source (location depending on version) +# Will be part of MDL's public API in 2.12 +def get_instance_parent(instance): + """Find the parent of the instance""" + parent = instance + while True: + try: + parent = next(iter(parent._parents)) + except StopIteration: + break + + return parent + + +class ClassIdsNode(Configurable): + correctionClass = String( + displayedName='Correction class', + description='Device class to use for corrections.', + defaultValue='{}Correction', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + groupMatcherClass = String( + displayedName='Group matcher class', + description='Device class to use for matching the stream output of a ' + 'module group.', + defaultValue='TrainMatcher', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + bridgeClass = String( + displayedName='Bridge class', + description='Device class to use for bridging the stream output out ' + 'of Karabo.', + defaultValue='ShmemToZmq', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + previewMatcherClass = String( + displayedName='Preview matcher class', + description='Device class to use for matching the output of a preview ' + 'layer.', + defaultValue='ModuleStacker', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + assemblerClass = String( + displayedName='Assembler class', + description='Device class to use for assembling the matched output of ' + 'a preview layer.', + defaultValue='FemDataAssembler', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + +class DeviceIdsNode(Configurable): + correctionSuffix = String( + displayedName='Correction suffix', + description='Suffix for correction device IDs. The formatting ' + 'placeholders \'virtualName\', \'index\' and \'group\' ' + 'may be used.', + defaultValue='CORRECT{index:02d}_{virtualName}', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + groupMatcherSuffix = String( + displayedName='Group matcher suffix', + description='Suffix for group matching device IDs. The formatting ' + 'placeholder \'group\' may be used.', + defaultValue='MATCH_G{group}', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + bridgeSuffix = String( + displayedName='Bridge suffix', + description='Suffix for group bridge device IDs. The formatting ' + 'placeholder \'group\' may be used.', + defaultValue='BRIDGE_G{group}', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + previewMatcherSuffix = String( + displayedName='Preview matcher suffix', + description='Suffix for preview layer matching device IDs. The ' + 'formatting placeholder \'layer\' may be used.', + defaultValue='MATCH_{layer}', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + assemblerSuffix = String( + displayedName='Assembler suffix', + description='Suffix for assembler device IDs. The formatting ' + 'placeholder \'layer\' may be used.', + defaultValue='ASSEMBLE_{layer}', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + +class ModuleRow(Configurable): + virtualName = String( + displayedName='Virtual name') + + group = UInt32( + displayedName='Group') + + aggregator = String( + displayedName='Aggregator name') + + inputChannel = String( + displayedName='Input channel') + + inputSource = String( + displayedName='Input source') + + +class ModuleGroupRow(Configurable): + group = UInt32( + displayedName='Group') + + deviceServer = String( + displayedName='Device server') + + withMatcher = Bool( + displayedName='Matcher?') + + withBridge = Bool( + displayedName='Bridge?') + + bridgePort = UInt16( + displayedName='Bridge port', + defaultValue=47000, + minInc=1024, + maxInc=65353) + + bridgePattern = String( + displayedName='Bridge pattern', + options=['PUSH', 'REP', 'PUBLISH'], + defaultValue='PUSH') + + +class PreviewLayerRow(Configurable): + layer = String( + displayedName='Preview layer') + + outputPipeline = String( + displayedName='Output pipeline') + + deviceServer = String( + displayedName='Device server') + + +class DeviceServerRow(Configurable): + deviceServer = String( + displayedName='Device server') + + webserverHost = String( + displayedName='Webserver host') + + +class WebserverApiNode(Configurable): + statePollInterval = Double( + displayedName='Status poll interval', + description='Time between subsequent polls of the webserver API when ' + 'a device server is expected to reach a particular state.', + defaultValue=0.5, + minInc=0.2, + unitSymbol=Unit.SECOND, + accessMode=AccessMode.RECONFIGURABLE) + + downTimeout = Double( + displayedName='Down timeout', + description='Time to wait for a device server to stop gracefully ' + 'after a \'down\' command and until the \'kill\' command ' + 'is sent.', + defaultValue=5.0, + minInc=1.0, + unitSymbol=Unit.SECOND, + accessMode=AccessMode.RECONFIGURABLE) + + killTimeout = Double( + displayedName='Kill timeout', + description='Time to wait for a device server to stop after a ' + '\'kill\' command and until the server is deemed ' + 'unrecoverable.', + defaultValue=5.0, + minInc=1.0, + unitSymbol=Unit.SECOND, + accessMode=AccessMode.RECONFIGURABLE) + + upTimeout = Double( + displayedName='Up timeout', + description='Time to wait for a device server to start successfully ' + 'after being down and until the server i deemed ' + 'broken.', + defaultValue=5.0, + minInc=1.0, + unitSymbol=Unit.SECOND, + accessMode=AccessMode.RECONFIGURABLE) + + +class InstantiationOptionsNode(Configurable): + restoreMatcherSources = Bool( + displayedName='Restore matcher sources', + description='Attempt to retrieve and restore the last known ' + 'configuration for slow and fast sources of matcher ' + 'devices when the pipeline is instantiated.', + defaultValue=False, + accessMode=AccessMode.RECONFIGURABLE) + + autoActivateGroupBridges = Bool( + displayedName='Activate bridges automatically', + description='Whether to activate all group bridges immediately after ' + 'instantation.', + defaultValue=False, + accessMode=AccessMode.RECONFIGURABLE) + + autoActivateGroupMatchers = Bool( + displayedName='Activate group matchers automatically', + description='Whether to activate all group matchers immediately after ' + 'instantation.', + defaultValue=False, + accessMode=AccessMode.RECONFIGURABLE) + + +class ManagedKeysNode(Configurable): + # Keys managed on detector DAQ devices. + DAQ_KEYS = {'DataDispatcher.trainStride': 'daqTrainStride'} + + @UInt32( + displayedName='DAQ train stride', + unitSymbol=Unit.COUNT, + defaultValue=5, + allowedStates=[State.ACTIVE], + minInc=1) + async def daqTrainStride(self, value): + self.daqTrainStride = value + background(get_instance_parent(self)._set_on_daq( + 'DataDispatcher.trainStride', value)) + + +class ManagedKeysCloneFactory(ProxyFactory): + Proxy = ManagedKeysNode + SubProxy = Configurable + ProxyNode = Node + node_factories = dict(Slot=Slot) + + +class CalibrationManager(DeviceClientBase, Device): + __version__ = deviceVersion + + interfaces = VectorString( + displayedName='Device interfaces', + description='Interfaces implemented by this device.', + defaultValue=( + ['DeviceInstantiator'] + if parse_version(karaboVersion) >= parse_version('2.11') + else []), + accessMode=AccessMode.READONLY) + + availableScenes = VectorString( + displayedName='Available scenes', + displayType='Scenes', + requiredAccessLevel=AccessLevel.OBSERVER, + accessMode=AccessMode.READONLY, + defaultValue=['overview', 'managed_keys'], + daqPolicy=DaqPolicy.OMIT) + + @slot + def requestScene(self, params): + name = params.get('name', default='overview') + if name == 'overview': + # Assumes there are correction devices known to manager + scene_data = scenes.manager_device_overview_scene( + self.deviceId, + self.getDeviceSchema(), + self._correction_device_schema, + self._correction_device_ids, + self._domain_device_ids, + ) + payload = Hash('success', True, 'name', name, 'data', scene_data) + elif name.startswith('browse_schema'): + if ':' in name: + prefix = name[len('browse_schema:'):] + else: + prefix = 'managed' + scene_data = scenes.recursive_subschema_scene( + self.deviceId, + self.getDeviceSchema(), + prefix, + ) + payload = Hash('success', True, 'name', name, 'data', scene_data) + else: + payload = Hash('success', False, 'name', name) + + return Hash('type', 'deviceScene', + 'origin', self.deviceId, + 'payload', payload) + + detectorType = String( + displayedName='Detector type', + description='Type of the detector to manage.', + options=['AGIPD', 'LPD', 'DSSC', 'Jungfrau', 'ePix100', 'pnCCD', + 'FastCCD'], + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + detectorIdentifier = String( + displayedName='Detector identifier', + description='Name of this detector in CalCat and device ID domain ' + 'for the corresponding data aggregators, if applicable. ', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + classIds = Node( + ClassIdsNode, + displayedName='Class IDs', + description='class IDs to instantiate for each role. A formatting ' + 'placeholer may be used to indicate the capitalized ' + 'detector type, e.g. Dssc or Jungfrau') + + deviceIds = Node( + DeviceIdsNode, + displayedName='Device IDs', + description='Templates for the member field of the device IDs in the ' + 'pipeline, which is appended to the same device ID root ' + '(DOMAIN/TYPE) as the manager instance itself.') + + modules = VectorHash( + displayedName='Modules', + description='Individual modules constituting this detector. A module ' + 'is the boundary on which correction constants are ' + 'calculated, and in general equivalent to PDUs. Only the ' + 'virtual name and group columns are obligatory, if left ' + 'blank the aggregator name is set to ' + 'f\'{detector_type.upper()}{index:02d}\', the input ' + 'channel is set to ' + 'f\'{detector_identifier}/DET/{index}CH0:output\' and the ' + 'input source ' + 'f\'{detector_identifier}/DET/{index}CH0:xtdf\'.', + rows=ModuleRow, + accessMode=AccessMode.RECONFIGURABLE, + assignment=Assignment.MANDATORY) + + moduleGroups = VectorHash( + displayedName='Module groups', + rows=ModuleGroupRow, + accessMode=AccessMode.RECONFIGURABLE, + assignment=Assignment.MANDATORY) + + previewLayers = VectorHash( + displayedName='Preview layers', + rows=PreviewLayerRow, + accessMode=AccessMode.RECONFIGURABLE, + assignment=Assignment.MANDATORY) + + @VectorHash( + displayedName='Device servers', + rows=DeviceServerRow, + accessMode=AccessMode.RECONFIGURABLE, + assignment=Assignment.MANDATORY) + async def deviceServers(self, value): + self.deviceServers = value + self._servers_changed = True + + geometryDevice = String( + displayedName='Geometry device', + description='[NYI] Device ID for a geometry device defining the ' + 'detector layout and module positions.', + accessMode=AccessMode.INITONLY, + assignment=Assignment.MANDATORY) + + calcatUrl = String( + displayedName='CalCat URL', + description='[NYI] URL to CalCat API to use for constant retrieval, ' + 'set by local secrets file', + accessMode=AccessMode.READONLY) + + webserverApi = Node( + WebserverApiNode, + displayedName='Webserver API', + description='Configurations for the webserver API to control device ' + 'servers.') + + instantiationOptions = Node( + InstantiationOptionsNode, + displayedName='Instantiation options', + description='Optional flags controlling the pipeline instantiation.') + + doNotCompressEvents = Bool( + requiredAccessLevel=AccessLevel.GOD, + accessMode=AccessMode.READONLY, + defaultValue=False, + daqPolicy=DaqPolicy.OMIT) + + @Slot( + displayedName='Restart servers', + allowedStates=[State.ACTIVE, State.ERROR]) + async def restartServers(self): + self.state = State.CHANGING + background(self._restart_servers()) + + @Slot( + displayedName='Instantiate pipeline', + allowedStates=[State.ACTIVE]) + async def startInstantiate(self): + # Slot name is mandated by DeviceInstantiator interface. + self.state = State.CHANGING + background(self._instantiate_pipeline()) + + @Slot( + displayedName='Apply managed values', + description='Set all managed keys to the values currently active on ' + 'the manager, replacing any manual change.', + allowedStates=[State.ACTIVE]) + async def applyManagedValues(self): + background(self._apply_managed_values()) + + managed = Node( + ManagedKeysNode, + displayedName='Managed keys', + description='Properties and slots managed on devices in the pipeline.') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Take the manager's device ID root (DOMAIN/TYPE) as the root + # for all devices it manages. + self._device_id_root = self.deviceId[:self.deviceId.rfind('/')].value + + # Concretized class ID for correction devices. + self._correction_class_id = self.classIds.correctionClass.format( + self.detectorType.value.lower().capitalize()) + + # Set of data aggregators associated with the managed detector. + self._daq_device_ids = set() + + # Set of devices in the same domain as the manager, i.e. having + # the same device ID root. + self._domain_device_ids = set() + + # Set of correction devices for the managed detector, i.e. being + # in the same domain and having the specified class ID. + self._correction_device_ids = set() + + async def onInitialization(self): + self.state = State.INIT + + if parse_version(karaboVersion) < parse_version('2.11'): + # Compatibility with 2.10 and before, where onInitialization + # is awaited during device instantation. + background(self._async_init()) + else: + await self._async_init() + + @coslot + async def slotInstanceNew(self, instance_id, info): + # A new instance has appeared in the topology. + + await super().slotInstanceNew(instance_id, info) + + if info['type'] == 'device': + self._check_new_device(instance_id, info['classId']) + + @slot + def slotInstanceGone(self, instance_id, info): + # An instance is gone from the topology. + + super().slotInstanceGone(instance_id, info) + + if info['type'] == 'device': + self._daq_device_ids.discard(instance_id) + self._domain_device_ids.discard(instance_id) + self._correction_device_ids.discard(instance_id) + + async def _async_init(self): + # Populate the device ID sets with what's out there right now. + await self._check_topology() + + # Set-up Tornado. + if hasattr(AsyncIOMainLoop, 'initialized'): + if not AsyncIOMainLoop.initialized(): + AsyncIOMainLoop().install() + + self._http_client = AsyncHTTPClient() + + # Check device servers and initialize webserver access. + await self._check_servers() + + # Inject schema for configuration of managed devices. + await self._inject_managed_keys() + + if self.state == State.INIT: + self._set_status('Calibration manager ready') + self.state = State.ACTIVE + + def _check_new_device(self, device_id, class_id): + if class_id == 'DataAggregator' and \ + device_id.startswith(self.detectorIdentifier.value): + # This device is a data aggregator belonging to the detector + # installation + self._daq_device_ids.add(device_id) + + elif device_id.startswith(self._device_id_root): + # This device lives under the same device ID root as this + # manager instance. + self._domain_device_ids.add(device_id) + + if class_id == self._correction_class_id: + # This device is also a correction device. + self._correction_device_ids.add(device_id) + + async def _check_topology(self): + for i in range(10): + try: + devices = getTopology()['device'] + except AttributeError: + # Not ready yet + await sleep(1.0) + else: + if not devices: + await sleep(1.0) + else: + break + else: + self._set_fatal(f'Topology not available after {i+1} tries') + return + + self.logger.debug(f'Topology arrived after {i+1} tries') + + for device_id in devices: + self._check_new_device(device_id, devices[device_id, 'classId']) + + async def _get_shared_keys(self, device_ids, keys): + """Find the most common property values on devices.""" + + key_values = defaultdict(list) + + for device_id in device_ids: + try: + config = await wait_for(getConfiguration(device_id), + timeout=2.0) + except AsyncTimeoutError: + # Ignore this device if the configuration can no longer + # be obtained. + continue + + for key in keys: + value = config[key] + + if isinstance(value, Hashable): + # Value must be hashable to determine the most + # common one below. + key_values[key].append(value) + + return {key: max(set(values), key=values.count) for key, values + in key_values.items()} + + async def _inject_managed_keys(self): + """Attempt to retrieve the correction device's schema and insert + part of it as managed keys. + """ + + correction_device_servers = [ + server for _, server, _, _, _, _ in self.moduleGroups.value] + + up_corr_servers = await self._get_servers_in_state( + 'up', servers=correction_device_servers) + + if up_corr_servers: + # At least one correction server is up. + corr_server = next(iter(up_corr_servers)) + else: + # None of the correction servers is up, try to start the + # first one. + corr_server = correction_device_servers[0] + + try: + await wait_for(self._ensure_server_state(corr_server, 'up'), + timeout=self.webserverApi.upTimeout.value) + except AsyncTimeoutError: + self._set_fatal(f'Could not bring up correction device server ' + f'`{corr_server}` for schema retrieval') + return + + # Obtain the device schema from a correction device server. + managed_schema, _, _ = await call(corr_server, 'slotGetClassSchema', + self._correction_class_id) + # saving this for later + self._correction_device_schema = Schema() + self._correction_device_schema.copy(managed_schema) + + if managed_schema.name != self._correction_class_id: + self._set_fatal( + f'Correction class ID `{self._correction_class_id}` not known ' + f'or loadable by device server `{corr_server}`') + return + + # Collect the keys to be managed and build a nested hash + # expressing its hierarchy, leafs are set to None. + managed_keys = set(managed_schema.hash['managedKeys', 'defaultValue']) + managed_tree = Hash(*chain.from_iterable( + zip(managed_keys, repeat(None, len(managed_keys))))) + managed_paths = set(managed_tree.paths()) + + # Reduce the correction schema to the managed paths. + managed_hash = managed_schema.hash + for path in managed_hash.paths(): + if path not in managed_paths: + del managed_hash[path] + + # Retrieve any previous values already on running devices in + # order to update the defaultValue attribute in the schema just + # before injection. + prev_vals = await self._get_shared_keys( + self._correction_device_ids, managed_keys) + + if self._daq_device_ids: + prev_vals.update(await self._get_shared_keys( + self._daq_device_ids, ManagedKeysNode.DAQ_KEYS.keys())) + + # Retrieve the attributes on the current managed node. The + # original implementation of toSchemaAndAttrs in the Node's + # superclass Descriptor is used to avoid Node-specific + # attributes in the attributes that are not valid in the + # property definition. + # The value are then obtained from the Node object again since + # enums are converted to their values by toSchemaAndAttrs, which + # in turn is not valid for property definition. + _, attrs = Descriptor.toSchemaAndAttrs(self.__class__.managed, + None, None) + managed_node_attrs = {key: getattr(self.__class__.managed, key) + for key in attrs.keys()} + + # Build a proxy from the managed schema, and create a new node + # based on it with the original node attributes. This code is + # heavily inspired by deviceClone. + managed_node = Node( + ManagedKeysCloneFactory.createProxy(managed_schema), + **managed_node_attrs) + + # Walk the managed tree to and sanitize all descriptors to our + # specifications. + def _sanitize_node(parent, tree, prefix=''): + for key, value in tree.items(): + # Fetch the descriptor class, not its instance! + descr = getattr(parent.cls, key) + + full_key = f'{prefix}.{key}' if prefix else key + + if isinstance(descr, Node): + _sanitize_node(descr, value, full_key) + + elif isinstance(descr, Slot): + async def _managed_slot_called(parent, fk=full_key): + background(self._call_on_corrections(fk)) + + _managed_slot_called.__name__ = f'managed.{full_key}' + descr.__call__(_managed_slot_called) + + # Managed slots can only be called in the ACTIVE + # state. + descr.allowedStates = [State.ACTIVE] + + elif isinstance(descr, Type): + # Regular property. + + if descr.accessMode == AccessMode.RECONFIGURABLE: + # Add a callback only if the original descriptor + # is reconfigurable. + + async def _managed_prop_changed(parent, v, k=key, + fk=full_key): + setattr(parent, k, v) + + if self.state != State.INIT: + # Do not propagate updates during injection. + background(self._set_on_corrections(fk, v)) + + descr.__call__(_managed_prop_changed) + + # Managed properties are always optional, + # reconfigurable and may only be changed in the + # ACTIVE state. + descr.assignment = Assignment.OPTIONAL + descr.accessMode = AccessMode.RECONFIGURABLE + descr.allowedStates = [State.ACTIVE] + + try: + # If there's been a previous value before + # injection, use it. + descr.defaultValue = prev_vals[full_key] + except KeyError: + pass + else: + self.logger.warn(f'Encountered unknown descriptor type ' + f'{type(descr)}') + + _sanitize_node(managed_node, managed_tree) + + # Inject the newly prepared node for managed keys. + self.__class__.managed = managed_node + await self.publishInjectedParameters() + self._managed_keys = managed_keys + + self.logger.debug('Managed schema injected') + + def _set_status(self, text, level=logging.INFO): + """Add and log a status message. + + Suppresses throttling from the gui server. + """ + + self.status = text + self.doNotCompressEvents = not self.doNotCompressEvents + self.logger.log(level, text) + + def _set_error(self, text): + """Set the device into error state and log an error message.""" + + self.state = State.ERROR + self._set_status(text, level=logging.ERROR) + + def _set_fatal(self, text): + """Set the device into unknown state and log an error message.""" + + self.state = State.UNKNOWN + self._set_status(text, level=logging.CRITICAL) + + def _set_exception(self, text, e): + """Set the device into error upon an exception.""" + + return self._set_error(f'{text}\n{format_exc()[:-1]}') + + def _get_server_api_url(self, name): + """Get API URL for a device server.""" + + return '{host}/api/servers/{api_name}.json'.format( + host=self._server_hosts[name], + api_name=self._server_api_names[name]) + + async def _get_server_info(self, name): + """Get info for a device server. + + Args: + name (str): Device server (Karabo) name. + + Returns: + (dict) Information returned by webserver for this server. + """ + + try: + reply = await to_asyncio_future(self._http_client.fetch( + self._get_server_api_url(name), method='GET')) + except (ConnectionError, HTTPError) as e: + raise RuntimeError( + f'Failed retrieving server info of {name} on ' + f'{urlparse(self._server_hosts[name]).hostname}: {e}') + + body = json.loads(reply.body) + + if not body['success']: + raise RuntimeError(f'Request to retrieve server info for {name} ' + f'failed') + + self.logger.debug(f'Obtained server info for {name}: ' + f'{body["servers"][0]}') + + return body['servers'][0] + + async def _get_servers_in_state(self, state, servers=None): + """List all servers in a particular state. + + Args: + state (str): State to filter for, e.g. 'up' or 'down'. + + Returns: + (list[str]) List of server names found in the passed state. + """ + + servers = await gather(*[self._get_server_info(name) + for name in self._server_hosts.keys() + if servers is None or name in servers]) + + return {server['karabo_name'] for server in servers + if server['status'].startswith(state)} + + async def _ensure_server_state(self, name, state, command=None): + """Sets and waits for a a server state to be reached. + + Args: + name (str): Device server (Karabo) name. + state (str): State to set and wait for, e.g. 'up' or 'down'. + command (str, optional): Command to use for reaching the + state, same as state of omitted. + """ + + if command is None: + command = state + + try: + reply = await to_asyncio_future(self._http_client.fetch( + self._get_server_api_url(name), method='PUT', + body=json.dumps({'server': {'command': command}}))) + except (ConnectionError, HTTPError) as e: + raise RuntimeError( + f'Failed sending `{command}` to {name} on ' + f'{urlparse(self._server_hosts[name]).hostname}: {e}') + + if not json.loads(reply.body)['success']: + raise RuntimeError('Command `{command}` for {name} failed') + + while True: + await sleep(self.webserverApi.statePollInterval.value) + info = await self._get_server_info(name) + + if info['status'].startswith(state): + return + + async def _check_servers(self): + """Validate device server configuration.""" + + if not self._servers_changed: + return True + + # Mapping of servers to device servers. + self._server_hosts = {server: host for server, host + in self.deviceServers.value} + + # Mapping of "Karabo names" (with capitalization and slashes) + # to "API names" (all lower case and only underscores). + self._server_api_names = {} + + # Build a mapping of hosts to server names. + hosts = defaultdict(set) + for name, host in self._server_hosts.items(): + hosts[host].add(name) + + # List of errors found during device server check. + err = [] + + # Query all servers on each host and check whether the ones we + # need are in there, and obtain their API names. + for host, req_names in hosts.items(): + # Retrieve hostname for nice error messages. + hostname = urlparse(self._server_hosts[name]).hostname + + try: + reply = await to_asyncio_future(self._http_client.fetch( + f'{host}/api/servers.json', method='GET')) + except (ConnectionError, HTTPError) as e: + err.append(f'- {e.__class__.__name__} when retrieving server ' + f'list on {hostname}: {e}') + continue + + body = json.loads(reply.body) + + if not body['success']: + err.append(f'- Request to retrieve server list failed on ' + f'{hostname}') + continue + + servers = {s['karabo_name']: s for s in body['servers'] + if s['karabo_name'] in req_names} + + if len(servers) != len(req_names): + err.append(f'- Device servers missing on {hostname}: ' + + ', '.join(req_names - servers.keys())) + + servers_in_error = {s['karabo_name'] for s in servers.values() + if 'error' in s['status']} + + if servers_in_error: + err.append(f'- Device servers on {hostname} are in error state' + + ', '.join(servers_in_error)) + continue + + servers_disabled = {s['karabo_name'] for s in servers.values() + if not s['control_allowed']} + + if servers_disabled: + err.append(f'- Device servers on {hostname} not controllable ' + f'via webserver: ' + ', '.join(servers_disabled)) + continue + + for name, info in servers.items(): + self._server_api_names[name] = info['name'] + + if err: + self._set_error('One or more device server problems were ' + 'detected:\n' + '\n'.join(err)) + return False + else: + return True + + async def _restart_servers(self): + """Restart all managed device servers.""" + + if not await self._check_servers(): + return + + try: + # Find all servers in need of being brought down. + up_names = await self._get_servers_in_state('up') + + # Bring down all servers which are up. + if up_names: + try: + await wait_for(gather( + *[self._ensure_server_state(name, 'down') + for name in up_names]), + timeout=self.webserverApi.downTimeout.value) + except AsyncTimeoutError: + # Narrow down to the list of servers STILL up. + up_names = await self._get_servers_in_state('up') + else: + up_names.clear() + + # If some servers are still up, go for the kill. + if up_names: + self.logger.warn('Some servers still up after waiting for ' + f'downTimeout: {up_names}') + + try: + await wait_for(gather( + *[self._ensure_server_state(name, 'down', 'kill') + for name in up_names]), + timeout=self.webserverApi.killTimeout.value) + except AsyncTimeoutError: + # All hope is lost. + up_names = await self._get_servers_in_state('up') + return self._set_error( + 'One or more device servers could not be brought ' + 'down, even via kill:' + ', '.join(up_names)) + + # Bring all servers up again. + try: + await wait_for(gather( + *[self._ensure_server_state(name, 'up') + for name in self._server_hosts.keys()]), + timeout=self.webserverApi.upTimeout.value) + except AsyncTimeoutError: + down_names = await self._get_servers_in_state('down') + return self._set_error('One or more device servers could not ' + 'brought up: ' + ', '.join(down_names)) + + except RuntimeError as e: + return self._set_exception( + 'Request unexpectedly failed during restart procedure', e) + + # Wait a moment for good measure until the servers are + # actually up. If we try to instantiate to quickly after the + # servers reported running, instantiate() blocks forever. + await sleep(5.0) + + if self._domain_device_ids: + # If there are still devices in this set, log a warning. + self.logger.warn('Some devices left in manager domain after ' + 'device servers have been restarted: ' + + ', '.join(self._domain_device_ids)) + + self._set_status('All device servers restarted') + self.state = State.ACTIVE + + async def _instantiate_device(self, server, class_id, device_id, config): + """Instantiate a single device. + + Small wrapper around karabo.middlelayer.instantiate for error + handling and logging. + """ + + if device_id in self._domain_device_ids: + self.logger.debug(f'Skipped instantiation of already existing ' + f'device {device_id}') + return True + + try: + msg = await wait_for(instantiate( + server, class_id, device_id, config), 5.0) + except AsyncTimeoutError: + self._set_error(f'Instantiation timeout on {device_id}') + return False + except KaraboError as e: + self._set_exception(f'Instantiation error on {device_id}', e) + return False + + self.logger.debug(f'Instantation result for {device_id}: {msg}') + return True + + async def _instantiate_pipeline(self): + """Instantiate all managed devices.""" + + if not await self._check_servers(): + return + + # Make sure all servers are up. + try: + up_servers = await self._get_servers_in_state('up') + + if up_servers != self._server_api_names.keys(): + self.state = State.ACTIVE + return self._set_status('One or more device servers are not ' + 'up, restart servers first.') + except RuntimeError as e: + return self._set_error('Request unexpectedly failed while ' + 'checking device server state', e) + + # Class and device ID templates per role. + class_ids = {} + device_id_templates = {} + + class_args = (self.detectorType.value.lower().capitalize(),) + for role in ['correction', 'groupMatcher', 'bridge', 'previewMatcher', + 'assembler']: + class_ids[role] = getattr( + self.classIds, f'{role}Class').value.format(*class_args) + device_id_templates[role] = f'{self._device_id_root}/' + \ + getattr(self.deviceIds, f'{role}Suffix') + + # Servers by group and layer. + server_by_group = {group: server for group, server, _, _, _, _ + in self.moduleGroups.value} + server_by_layer = {layer: server for layer, _, server + in self.previewLayers.value} + + all_req_servers = set(server_by_group.values()).union( + server_by_layer.values()) + + if all_req_servers != up_servers: + return self._set_error('One or more device servers are not ' + 'listed in the device servers ' + 'configuration') + + # Instantiate modules. + modules_by_group = defaultdict(list) + correct_device_id_by_module = {} + input_source_by_module = {} + + for index, row in enumerate(self.modules.value): + vname, group, aggregator, input_channel, input_source = row + + modules_by_group[group].append(vname) + device_id = device_id_templates['correction'].format( + virtualName=vname, index=index, group=group) + correct_device_id_by_module[vname] = device_id + + if not aggregator: + aggregator = f'{self.detectorType.upper()}{index:02d}' + + daq_device_id = f'{self.detectorIdentifier}/DET/{index}CH0' + + if not input_channel: + input_channel = f'{daq_device_id}:output' + + if not input_source: + input_source = f'{daq_device_id}:xtdf' + + input_source_by_module[vname] = input_source + + config = Hash() + + config['constantParameters.detectorName'] = self.detectorIdentifier.value + config['constantParameters.karaboDa'] = aggregator + config['dataInput.connectedOutputChannels'] = [input_channel] + config['fastSources'] = [input_source] + + # Add managed keys. + for key in self._managed_keys: + value = get_property(self, f'managed.{key}') + + if not ismethod(value): + config[key] = value + + if not await self._instantiate_device( + server_by_group[group], class_ids['correction'], device_id, + config + ): + return + + # Instantiate group matchers and bridges. + for row in self.moduleGroups.value: + group, server, with_matcher, with_bridge, bridge_port, \ + bridge_pattern = row + + # Group matcher, if applicable. + if with_matcher: + matcher_device_id = device_id_templates['groupMatcher'].format( + group=group) + + config = Hash() + config['channels'] = [ + f'{correct_device_id_by_module[vname]}:dataOutput' + for vname in modules_by_group[group]] + config['fastSources'] = [ + Hash('fsSelect', True, + 'fsSource', input_source_by_module[vname]) + for vname in modules_by_group[group]] + + if self.instantiationOptions.restoreMatcherSources: + try: + old_config = await getConfigurationFromPast( + matcher_device_id, datetime.now().isoformat()) + except KaraboError: + pass # Ignore configuration on error + else: + config['channels'] = old_config['channels'] + config['slowSources'] = old_config['slowSources'] + config['fastSources'] = old_config['fastSources'] + + if not await self._instantiate_device( + server, class_ids['groupMatcher'], + matcher_device_id, config + ): + return + elif self.instantiationOptions.autoActivateGroupMatchers: + async def _activate_matcher(device_id): + with await getDevice(device_id) as device: + await sleep(3) + if device.state == State.PASSIVE: + await device.start() + + background(_activate_matcher(matcher_device_id)) + + # Group bridge, if applicable. + if with_bridge: + bridge_device_id = device_id_templates['bridge'].format( + group=group) + + config = Hash() + config['outputsConfig'] = [Hash( + 'pattern', bridge_pattern, 'hwm', 1, 'port', bridge_port)] + + config['input.connectedOutputChannels'] = [ + f'{matcher_device_id}:output'] + + if not await self._instantiate_device( + server, class_ids['bridge'], bridge_device_id, config + ): + return + elif self.instantiationOptions.autoActivateGroupBridges: + # Delay the slot a bit since it will get lost during + # instantation. + + async def _activate_bridge(device_id): + with await getDevice(device_id) as device: + await sleep(3) + if device.state == State.PASSIVE: + await device.activate() + + background(_activate_bridge(bridge_device_id)) + + # Instantiate preview layer matchers and assemblers. + for layer, output_pipeline, server in self.previewLayers.value: + # Preview matcher. + matcher_device_id = device_id_templates['previewMatcher'].format( + layer=layer) + + config = Hash() + config['channels'] = [ + f'{device_id}:{output_pipeline}' + for device_id in correct_device_id_by_module.values()] + config['fastSources'] = [ + Hash('fsSelect', True, + 'fsSource', + f'{input_source_by_module[virtual_id]}') + for (virtual_id, device_id) + in correct_device_id_by_module.items()] + config['pathToStack'] = 'data.adc' + + if not await self._instantiate_device( + server, class_ids['previewMatcher'], matcher_device_id, config + ): + return + + # Preview assembler. + assembler_device_id = device_id_templates['assembler'].format( + layer=layer) + + config = Hash() + config['input.connectedOutputChannels'] = [ + f'{matcher_device_id}:output'] + config['modules'] = [ + Hash('source', input_source_by_module.get('Q1M1', ''), + 'offX', 474, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q1M2', ''), + 'offX', 316, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q1M3', ''), + 'offX', 158, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q1M4', ''), + 'offX', 0, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q2M1', ''), + 'offX', 1136, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q2M2', ''), + 'offX', 978, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q2M3', ''), + 'offX', 820, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q2M4', ''), + 'offX', 662, 'offY', 612, 'rot', 90), + Hash('source', input_source_by_module.get('Q3M1', ''), + 'offX', 712, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q3M2', ''), + 'offX', 870, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q3M3', ''), + 'offX', 1028, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q3M4', ''), + 'offX', 1186, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q4M1', ''), + 'offX', 50, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q4M2', ''), + 'offX', 208, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q4M3', ''), + 'offX', 366, 'offY', 0, 'rot', 270), + Hash('source', input_source_by_module.get('Q4M4', ''), + 'offX', 524, 'offY', 0, 'rot', 270), + ] + config['pathsToCombine'] = ['data.adc'] + config['trainIdPath'] = 'image.trainId' + config['pulseIdPath'] = 'image.pulseId' + config['preview.enablePreview'] = True + config['preview.pathToPreview'] = 'data.adc' + config['preview.downSample'] = 2 + config['badpixelPath'] = 'image.bad_pixels' + config['rotated90Grad'] = True + + if not await self._instantiate_device( + server, class_ids['assembler'], assembler_device_id, config + ): + return + + self._set_status('All devices instantiated') + self.state = State.ACTIVE + + async def _apply_managed_values(self): + """Apply all managed keys to local values.""" + + for daq_key, local_key in ManagedKeysNode.DAQ_KEYS.items(): + await self._set_on_daq( + daq_key, get_property(self, f'managed.{local_key}')) + + for key in self._managed_keys: + value = get_property(self, f'managed.{key}') + if not ismethod(value): + await self._set_on_corrections(key, value) + + def _call(self, device_ids, slot, *args): + """Call the same slot on a list of devices.""" + + for device_id in device_ids: + callNoWait(device_id, slot, *args) + self.logger.debug(f'Called {device_id}.{slot}{args}') + + async def _call_on_corrections(self, slot, *args): + """Call a slot on all correction devices.""" + + if self._correction_device_ids: + self._call(self._correction_device_ids, slot, *args) + self.logger.info(f'Called <CORR>.{slot}{args}') + + def _set(self, device_ids, key, value): + """Set the same property on a list of devices.""" + + for device_id in device_ids: + setNoWait(device_id, key, value) + self.logger.debug(f'Set {device_id}.{key} to {value}') + + async def _set_on_daq(self, key, value): + """Set a property on all DAQ devices.""" + + if self._daq_device_ids: + self._set(self._daq_device_ids, key, value) + self.logger.info(f'Set <DAQ>.{key} to {value}') + + async def _set_on_corrections(self, key, value): + """Set a property on all correction devices.""" + + if self._correction_device_ids: + self._set(self._correction_device_ids, key, value) + self.logger.info(f'Set <CORR>.{key} to {value}') diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py new file mode 100644 index 0000000000000000000000000000000000000000..4003b36cda9a6655db5781be892da869bc2fc0f2 --- /dev/null +++ b/src/calng/DsscCorrection.py @@ -0,0 +1,290 @@ +import enum + +import cupy +import numpy as np +from karabo.bound import ( + DOUBLE_ELEMENT, + KARABO_CLASSINFO, + OVERWRITE_ELEMENT, + VECTOR_STRING_ELEMENT, + State, +) + +from . import base_gpu, calcat_utils, utils +from ._version import version as deviceVersion +from .base_correction import BaseCorrection, add_correction_step_schema + + +class CorrectionFlags(enum.IntFlag): + NONE = 0 + OFFSET = 1 + + +class DsscConstants(enum.Enum): + Offset = enum.auto() + + +class DsscGpuRunner(base_gpu.BaseGpuRunner): + _kernel_source_filename = "dssc_gpu.cu" + _corrected_axis_order = "cyx" + + def __init__( + self, + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + ): + self.input_shape = (memory_cells, pixels_y, pixels_x) + self.processed_shape = self.input_shape + super().__init__( + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells, + input_data_dtype, + output_data_dtype, + ) + + self.map_shape = (self.constant_memory_cells, self.pixels_y, self.pixels_x) + self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) + + self._init_kernels() + + self.offset_map_gpu = cupy.empty(self.map_shape, dtype=np.float32) + + self.update_block_size((1, 1, 64)) + + def _get_raw_for_preview(self): + return self.input_data_gpu + + def _get_corrected_for_preview(self): + return self.processed_data_gpu + + def load_offset_map(self, offset_map): + # can have an extra dimension for some reason + if len(offset_map.shape) == 4: # old format (see offsetcorrection_dssc.py)? + offset_map = offset_map[..., 0] + # shape (now): x, y, memory cell + offset_map = np.transpose(offset_map).astype(np.float32) + self.offset_map_gpu.set(offset_map) + + def correct(self, flags): + self.correction_kernel( + self.full_grid, + self.full_block, + ( + self.input_data_gpu, + self.cell_table_gpu, + np.uint8(flags), + self.offset_map_gpu, + self.processed_data_gpu, + ), + ) + + def _init_kernels(self): + kernel_source = self._kernel_template.render( + { + "pixels_x": self.pixels_x, + "pixels_y": self.pixels_y, + "data_memory_cells": self.memory_cells, + "constant_memory_cells": self.constant_memory_cells, + "input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype), + "output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype), + "corr_enum": utils.enum_to_c_template(CorrectionFlags), + } + ) + self.source_module = cupy.RawModule(code=kernel_source) + self.correction_kernel = self.source_module.get_function("correct") + + +class DsscCalcatFriend(calcat_utils.BaseCalcatFriend): + _constant_enum_class = DsscConstants + + def __init__(self, device, *args, **kwargs): + super().__init__(device, *args, **kwargs) + self._constants_need_conditions = { + DsscConstants.Offset: self.dark_condition, + } + + @staticmethod + def add_schema( + schema, + managed_keys, + param_prefix="constantParameters", + status_prefix="foundConstants", + ): + super(DsscCalcatFriend, DsscCalcatFriend).add_schema( + schema, managed_keys, "DSSC-Type", param_prefix, status_prefix + ) + ( + OVERWRITE_ELEMENT(schema) + .key(f"{param_prefix}.memoryCells") + .setNewDefaultValue(400) + .commit(), + + OVERWRITE_ELEMENT(schema) + .key(f"{param_prefix}.biasVoltage") + .setNewDefaultValue(100) # TODO: proper + .commit() + ) + ( + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.pulseIdChecksum") + .assignmentOptional() + .defaultValue(2.8866323107820637e-36) + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.acquisitionRate") + .assignmentOptional() + .defaultValue(4.5) + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.encodedGain") + .assignmentOptional() + .defaultValue(67328) + .commit(), + ) + + calcat_utils.add_status_schema_from_enum(schema, status_prefix, DsscConstants) + + def dark_condition(self): + res = calcat_utils.OperatingConditions() + res["Memory cells"] = self._get_param("memoryCells") + res["Sensor Bias Voltage"] = self._get_param("biasVoltage") + res["Pixels X"] = self._get_param("pixelsX") + res["Pixels Y"] = self._get_param("pixelsY") + # res["Pulse id checksum"] = self._get_param("pulseIdChecksum") + # res["Acquisition rate"] = self._get_param("acquisitionRate") + # res["Encoded gain"] = self._get_param("encodedGain") + return res + + +@KARABO_CLASSINFO("DsscCorrection", deviceVersion) +class DsscCorrection(BaseCorrection): + # subclass *must* set these attributes + _correction_flag_class = CorrectionFlags + _correction_field_names = (("offset", CorrectionFlags.OFFSET),) + _kernel_runner_class = DsscGpuRunner + _calcat_friend_class = DsscCalcatFriend + _constant_enum_class = DsscConstants + _managed_keys = BaseCorrection._managed_keys.copy() + + @staticmethod + def expectedParameters(expected): + ( + OVERWRITE_ELEMENT(expected) + .key("dataFormat.memoryCells") + .setNewDefaultValue(400) + .commit(), + + OVERWRITE_ELEMENT(expected) + .key("preview.selectionMode") + .setNewDefaultValue("pulse") + .commit(), + ) + DsscCalcatFriend.add_schema(expected, DsscCorrection._managed_keys) + add_correction_step_schema( + expected, + DsscCorrection._managed_keys, + DsscCorrection._correction_field_names, + ) + ( + VECTOR_STRING_ELEMENT(expected) + .key("managedKeys") + .assignmentOptional() + .defaultValue(list(DsscCorrection._managed_keys)) + .commit() + ) + + @property + def input_data_shape(self): + return ( + self.get("dataFormat.memoryCells"), + 1, + self.get("dataFormat.pixelsY"), + self.get("dataFormat.pixelsX"), + ) + + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + pulse_table = np.squeeze(data_hash.get("image.pulseId")) + if self._frame_filter is not None: + try: + cell_table = cell_table[self._frame_filter] + pulse_table = pulse_table[self._frame_filter] + image_data = image_data[self._frame_filter] + except IndexError: + self.log_status_warn( + "Failed to apply frame filter, please check that it is valid!" + ) + return + + try: + self.kernel_runner.load_data(image_data) + except ValueError as e: + self.log_status_warn(f"Failed to load data: {e}") + return + except Exception as e: + self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") + + buffer_handle, buffer_array = self._shmem_buffer.next_slot() + self.kernel_runner.load_cell_table(cell_table) + self.kernel_runner.correct(self._correction_flag_enabled) + self.kernel_runner.reshape( + output_order=self.unsafe_get("dataFormat.outputAxisOrder"), + out=buffer_array, + ) + if do_generate_preview: + if self._correction_flag_enabled != self._correction_flag_preview: + self.kernel_runner.correct(self._correction_flag_preview) + ( + preview_slice_index, + preview_cell, + preview_pulse, + ) = utils.pick_frame_index( + self.unsafe_get("preview.selectionMode"), + self.unsafe_get("preview.index"), + cell_table, + pulse_table, + warn_func=self.log_status_warn, + ) + preview_raw, preview_corrected = self.kernel_runner.compute_previews( + preview_slice_index, + ) + + data_hash.set("image.data", buffer_handle) + data_hash.set("image.cellId", cell_table[:, np.newaxis]) + data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) + data_hash.set("calngShmemPaths", ["image.data"]) + self._write_output(data_hash, metadata) + if do_generate_preview: + self._write_combiner_previews( + ( + ("preview.outputRaw", preview_raw), + ("preview.outputCorrected", preview_corrected), + ), + train_id, + source, + ) + + def _load_constant_to_runner(self, constant, constant_data): + assert constant is DsscConstants.Offset + self.kernel_runner.load_offset_map(constant_data) + if not self.get("corrections.offset.available"): + self.set("corrections.offset.available", True) + + self._update_correction_flags() + self.log_status_info(f"Done loading {constant.name} to GPU") diff --git a/src/calng/ModuleStacker.py b/src/calng/ModuleStacker.py new file mode 100644 index 0000000000000000000000000000000000000000..95abe95ee79f05007fffe98bb4c73b1fad6c6f08 --- /dev/null +++ b/src/calng/ModuleStacker.py @@ -0,0 +1,142 @@ +import numpy as np +from karabo.bound import ( + FLOAT_ELEMENT, + KARABO_CLASSINFO, + NODE_ELEMENT, + STRING_ELEMENT, + ChannelMetaData, + Epochstamp, + Hash, + MetricPrefix, + Schema, + Timestamp, + Trainstamp, + Unit, +) +from TrainMatcher import TrainMatcher + +from ._version import version as deviceVersion + + +@KARABO_CLASSINFO("ModuleStacker", deviceVersion) +class ModuleStacker(TrainMatcher.TrainMatcher): + """This will be deprecated: now just equivalent to ModuleMatcher except this + stacks `image.data` instead of `data.adc` + + """ + + def __init__(self, conf): + super().__init__(conf) + self.info.merge(Hash("timeOfFlight", 0)) + + @staticmethod + def expectedParameters(expected): + ( + FLOAT_ELEMENT(expected) + .key("timeOfFlight") + .displayedName("Time of flight") + .description( + "Time elapsed from DAQ sent data until train was matched and ready to " + "send from here. Measured for latest train matched. Maximum over all " + "sources included in said train." + ) + .unit(Unit.SECOND) + .metricPrefix(MetricPrefix.MILLI) + .readOnly() + .commit(), + + STRING_ELEMENT(expected) + .key("pathToStack") + .displayedName("Data path to stack") + .description( + "Typically, image.data will be used for full data going through the " + "pipeline. Set this when input is dataOutput from a correction device " + "and output goes to a bridge. For previews, data.adc is used as part " + "of the combiner format. Set to data.adc when input is a preview " + "output and output goes to a femDataAssembler." + ) + .options("image.data,data.adc") + .assignmentOptional() + .defaultValue("image.data") + .commit(), + ) + + def initialization(self): + """Apply configuration and automatically start. + + Upon instantiation, the device will automatically start matching data + from the specified data sources. + """ + super().initialization() + + # Disable the start and stop slots. + # It's not possible to pop items from the full schema. The slots are + # therefore converted to unused nodes. + desc = ( + "Disable slots from the parent class. The acquisition start automatically " + "on instantiation. Check that the `fastSources` table is populated and " + "its booleans set to True the project configuration (not instantiated). " + "These nodes are not used." + ) + schema = Schema() + ( + NODE_ELEMENT(schema).key("start").description(desc).commit(), + + NODE_ELEMENT(schema).key("stop").description(desc).commit(), + ) + self.path_to_stack = self.get("pathToStack") + self.updateSchema(schema) + + super().start() + + def _send(self, tid, sources): + # Add control data + timestamp = Timestamp(Epochstamp(), Trainstamp(tid)) + + # Reuse arbitrary hash from existing ones (among the ones to stack) + try: + out_hash = next( + data + for (data, _) in iter(sources.values()) + if data.has(self.path_to_stack) + ) + except StopIteration: + out_hash = Hash() + + # TODO: handle missing modules properly (track all sources) + stacked_data = [] + stacked_sources = [] + stacked_present = [] + # TODO: should this be threaded? + time_of_flight = 0 + for source, (data, metadata) in sources.items(): + if not data.has(self.path_to_stack): + # may add sources not for stacking + # TODO: make stack or no part of source configuration + out_hash[f"unstacked.{source}"] = data + continue + old_ts = metadata.getTimestamp() + elapsed = timestamp.toTimestamp() - old_ts.toTimestamp() + time_of_flight = max(time_of_flight, elapsed) + image_data = data.get(self.path_to_stack) + stacked_data.append(image_data) + stacked_sources.append(source) + stacked_present.append(True) + + for source, data in self.ctrlmon.get(tid): + out_hash[f"unstacked.{source}"] = data + + if stacked_data: + if not isinstance(stacked_data[0], str): + # strings (like shmem handles) should stay list + stacked_data = np.stack(stacked_data, axis=0) + # TODO: merge with super().update_info (throttled updates) + self.info["timeOfFlight"] = time_of_flight * 1000 + + out_hash[self.path_to_stack] = stacked_data + out_hash["sources"] = stacked_sources + out_hash["modulesPresent"] = stacked_present + channel = self.signalSlotable.getOutputChannel("output") + channel.write(out_hash, ChannelMetaData(self.getInstanceId(), timestamp)) + channel.update() + self.rate_out.update() diff --git a/src/calng/ShmemToZMQ.py b/src/calng/ShmemToZMQ.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0af8e14f5b3b71f1fa29db231e95370ac714a0 --- /dev/null +++ b/src/calng/ShmemToZMQ.py @@ -0,0 +1,75 @@ +from time import time + +from karabo.bound import KARABO_CLASSINFO +from PipeToZeroMQ import PipeToZeroMQ, conversion, device_schema + +from . import shmem_utils +from ._version import version as deviceVersion + + +@KARABO_CLASSINFO("ShmemToZMQ", deviceVersion) +class ShmemToZMQ(PipeToZeroMQ.PipeToZeroMQ): + def initialization(self): + super().initialization() + self._shmem_handler = shmem_utils.ShmemCircularBufferReceiver() + + def onInput(self, input_channel): + actual = self.getActualTimestamp() + input_tic = time() + self.info["inputUpdated"] += 1 + self.info["dataRecv"] = input_channel.size() + all_meta = input_channel.getMetaData() + + data = {} + meta = {} + + for idx in range(input_channel.size()): + # Read metadata + metadata = self._extract_metadata(all_meta, idx) + source = metadata["source"] + if source not in self.sources: + self.appendSchema(device_schema.timestamp_schema(source)) + self.sources.add(source) + self._time_info(metadata, actual, self.info) + + # Read data + hash_data = input_channel.read(idx) + + # filters + if self.allowed_sources and source not in self.allowed_sources: + continue + forward, ignore = self._filter_properties(hash_data, source) + + meta[source] = conversion.meta_to_dict(metadata) + meta[source]["ignored_keys"] = ignore + # only this bit differs from PipeToZeroMQ.onInput + dic, arr = conversion.hash_to_dict( + hash_data, paths=forward, version=self.version + ) + for shmem_handle_path in dic.pop("calngShmemPaths", []): + shmem_handle = dic.pop(shmem_handle_path, None) + if shmem_handle_path is None: + self.log.INFO( + f"Hash from {source} did not have {shmem_handle_path}" + ) + continue + elif shmem_handle_path == "": + self.log.INFO( + f"Hash from {source} had empty {shmem_handle_path}" + ) + continue + actual_data = self._shmem_handler.get(shmem_handle) + arr[shmem_handle_path] = actual_data + + data[source] = (dic, arr) + + # forward data to all connected ZMQ sockets + self._send(data, meta) + + output_tic = time() + self.info["onInputTotal"] = 1000 * (output_tic - input_tic) + + # update properties on device + self._updateProperties(output_tic) + # block if device is in passive state + self.monitoring.wait() diff --git a/src/calng/__init__.py b/src/calng/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py new file mode 100644 index 0000000000000000000000000000000000000000..fae7950d7e836f78aec2047331d00b7a41f76f79 --- /dev/null +++ b/src/calng/base_correction.py @@ -0,0 +1,1079 @@ +import enum +import pathlib +import threading +from timeit import default_timer + +import dateutil.parser +import numpy as np +from karabo.bound import ( + BOOL_ELEMENT, + DOUBLE_ELEMENT, + INPUT_CHANNEL, + INT32_ELEMENT, + INT64_ELEMENT, + KARABO_CLASSINFO, + NDARRAY_ELEMENT, + NODE_ELEMENT, + OUTPUT_CHANNEL, + OVERWRITE_ELEMENT, + SLOT_ELEMENT, + STRING_ELEMENT, + UINT32_ELEMENT, + UINT64_ELEMENT, + VECTOR_STRING_ELEMENT, + VECTOR_UINT32_ELEMENT, + ChannelMetaData, + Epochstamp, + Hash, + MetricPrefix, + PythonDevice, + Schema, + State, + Timestamp, + Trainstamp, + Unit, +) +from karabo.common.api import KARABO_SCHEMA_DISPLAY_TYPE_SCENES as DT_SCENES +from karabo import version as karaboVersion +from pkg_resources import parse_version + +from . import scenes, shmem_utils, utils +from ._version import version as deviceVersion + +PROCESSING_STATE_TIMEOUT = 10 + + +class FramefilterSpecType(enum.Enum): + NONE = "none" + RANGE = "range" + COMMASEPARATED = "commaseparated" + + +preview_schema = Schema() +( + NODE_ELEMENT(preview_schema).key("image").commit(), + + NDARRAY_ELEMENT(preview_schema).key("image.data").dtype("FLOAT").commit(), + + UINT64_ELEMENT(preview_schema) + .key("image.trainId") + .displayedName("Train ID") + .assignmentOptional() + .defaultValue(0) + .commit(), +) + +# TODO: trim output schema / adapt to specific detectors +# currently: based on snapshot of actual output reusing AGIPD hash +output_schema = Schema() +( + NODE_ELEMENT(output_schema).key("image").commit(), + + STRING_ELEMENT(output_schema) + .key("image.data") + .assignmentOptional() + .defaultValue("") + .commit(), + + NDARRAY_ELEMENT(output_schema).key("image.length").dtype("UINT32").commit(), + + NDARRAY_ELEMENT(output_schema).key("image.cellId").dtype("UINT16").commit(), + + NDARRAY_ELEMENT(output_schema).key("image.pulseId").dtype("UINT64").commit(), + + NDARRAY_ELEMENT(output_schema).key("image.status").commit(), + + NDARRAY_ELEMENT(output_schema).key("image.trainId").dtype("UINT64").commit(), + + VECTOR_STRING_ELEMENT(output_schema) + .key("calngShmemPaths") + .assignmentOptional() + .defaultValue(["image.data"]) + .commit(), + + NODE_ELEMENT(output_schema).key("metadata").commit(), + + STRING_ELEMENT(output_schema) + .key("metadata.source") + .assignmentOptional() + .defaultValue("") + .commit(), + + NODE_ELEMENT(output_schema).key("metadata.timestamp").commit(), + + INT32_ELEMENT(output_schema) + .key("metadata.timestamp.tid") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NODE_ELEMENT(output_schema).key("header").commit(), + + INT32_ELEMENT(output_schema) + .key("header.minorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(output_schema) + .key("header.majorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(output_schema) + .key("header.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(output_schema) + .key("header.linkId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(output_schema) + .key("header.dataId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT64_ELEMENT(output_schema) + .key("header.pulseCount") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NDARRAY_ELEMENT(output_schema).key("header.reserved").commit(), + + NDARRAY_ELEMENT(output_schema).key("header.magicNumberBegin").commit(), + + NODE_ELEMENT(output_schema).key("detector").commit(), + + INT32_ELEMENT(output_schema) + .key("detector.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + + NDARRAY_ELEMENT(output_schema).key("detector.data").commit(), + + NODE_ELEMENT(output_schema).key("trailer").commit(), + + NDARRAY_ELEMENT(output_schema).key("trailer.checksum").commit(), + + NDARRAY_ELEMENT(output_schema).key("trailer.magicNumberEnd").commit(), + + INT32_ELEMENT(output_schema) + .key("trailer.status") + .assignmentOptional() + .defaultValue(0) + .commit(), + + INT32_ELEMENT(output_schema) + .key("trailer.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), +) + + +@KARABO_CLASSINFO("BaseCorrection", deviceVersion) +class BaseCorrection(PythonDevice): + _correction_flag_class = None # subclass must set (ex.: dssc_gpu.CorrectionFlags) + _kernel_runner_class = None # subclass must set (ex.: dssc_gpu.DsscGpuRunner) + _kernel_runner_init_args = {} # optional extra args for runner + _managed_keys = { + "outputShmemBufferSize", + "dataFormat.outputAxisOrder", + "dataFormat.outputImageDtype", + "dataFormat.overrideInputAxisOrder", + "frameFilter.type", + "frameFilter.spec", + "preview.enable", + "preview.index", + "preview.selectionMode", + "preview.trainIdModulo", + "loadMostRecentConstants", + } # subclass can extend this, /must/ put it in schema as managedKeys + _image_data_path = "image.data" + _cell_table_path = "image.cellId" + + def _load_constant_to_runner(constant_name, constant_data): + """Subclass must define how to process constants into correction maps and store + into appropriate buffers in (GPU or main) memory.""" + raise NotImplementedError() + + @property + def input_data_shape(self): + """Subclass must define expected input data shape in terms of dataFormat.{ + memoryCells,pixelsX,pixelsY} and any other axes.""" + raise NotImplementedError() + + @property + def output_data_shape(self): + """Shape of corrected image data sent on dataOutput. Depends on data format + parameters pixels x / y, and number of cells (optionally after frame filter).""" + axis_lengths = { + "x": self.unsafe_get("dataFormat.pixelsX"), + "y": self.unsafe_get("dataFormat.pixelsY"), + "c": self.unsafe_get("dataFormat.filteredFrames"), + } + return tuple( + axis_lengths[axis] + for axis in self.unsafe_get("dataFormat.outputAxisOrder") + ) + + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + """Subclass must define data processing (presumably using the kernel runner). + Will be called by input_handler, which will take care of some common checks and + extracting the parameters given to process_data.""" + raise NotImplementedError() + + @staticmethod + def expectedParameters(expected): + ( + OVERWRITE_ELEMENT(expected) + .key("state") + .setNewDefaultValue(State.INIT) + .commit(), + + INPUT_CHANNEL(expected).key("dataInput").commit(), + + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(output_schema) + .commit(), + + VECTOR_STRING_ELEMENT(expected) + .key("fastSources") + .displayedName("Fast data sources") + .description( + "Sources to get data from. Only incoming hashes from these sources " + "will be processed. This will typically be a single entry of the form: " + "'[instrument]_DET_[detector]/DET/[channel]:xtdf'." + ) + .assignmentOptional() + .defaultValue([]) + .commit(), + + NODE_ELEMENT(expected) + .key("frameFilter") + .displayedName("Frame filter") + .description( + "The frame filter - if set - slices the input data. Frames not in the " + "filter will be discarded before any processing happens and will not " + "get to dataOutput or preview. Note that this filter goes by frame " + "index rather than cell ID or pulse ID; set accordingly. Handle with " + "care - an invalid filter can prevent all processing. How the filter " + "is specified depends on frameFilter.type. See frameFilter.current to " + "inspect the currently set frame filter array (if any)." + ) + .commit(), + + STRING_ELEMENT(expected) + .key("frameFilter.type") + .displayedName("Filter definition type") + .description( + "Controls how frameFilter.spec is used. The default value of 'none' " + "means that no filter is set (regardless of frameFilter.spec). " + "'arange' allows between one and three integers separated by ',' which " + "are parsed and passed directly to numpy.arange. 'commaseparated' " + "reads a list of integers separated by commas." + ) + .options(",".join(spectype.value for spectype in FramefilterSpecType)) + .assignmentOptional() + .defaultValue("none") + .reconfigurable() + .commit(), + + STRING_ELEMENT(expected) + .key("frameFilter.spec") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + VECTOR_UINT32_ELEMENT(expected) + .key("frameFilter.current") + .displayedName("Current filter") + .description( + "This read-only value is used to display the contents of the current " + "frame filter. An empty array means no filtering is done." + ) + .readOnly() + .initialValue([]) + .commit(), + + UINT32_ELEMENT(expected) + .key("outputShmemBufferSize") + .displayedName("Output buffer size limit") + .unit(Unit.BYTE) + .metricPrefix(MetricPrefix.GIGA) + .description( + "Corrected trains are written to shared memory locations. These are " + "pre-allocated and re-used (circular buffer). This parameter " + "determines how much memory to set aside for that buffer." + ) + .assignmentOptional() + .defaultValue(10) + .commit(), + + VECTOR_STRING_ELEMENT(expected) + .key("availableScenes") + .setSpecialDisplayType(DT_SCENES) + .readOnly() + .initialValue(["overview"]) + .commit(), + ) + + ( + NODE_ELEMENT(expected) + .key("dataFormat") + .displayedName("Data format (in/out)") + .commit(), + + BOOL_ELEMENT(expected) + .key("dataFormat.overrideInputAxisOrder") + .displayedName("Override input axis order") + .description( + "The shape of the image data ndarray as received from the " + "DataAggregator is sometimes wrong - the axes are actually in a " + "different order than the ndarray shape suggests. If this flag is on, " + "the shape of the ndarray will be overridden with the axis order which " + "was expected." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + STRING_ELEMENT(expected) + .key("dataFormat.inputImageDtype") + .displayedName("Input image data dtype") + .description("The (numpy) dtype to expect for incoming image data.") + .options("uint16,float32") + .assignmentOptional() + .defaultValue("uint16") + .commit(), + + STRING_ELEMENT(expected) + .key("dataFormat.outputImageDtype") + .displayedName("Output image data dtype") + .description( + "The (numpy) dtype to use for outgoing image data. Input is cast to " + "float32, corrections are applied, and only then will the result be " + "cast to outputImageDtype. Be aware that casting to integer type " + "causes truncation rather than rounding." + ) + # TODO: consider adding rounding / binning for integer output + .options("float16,float32,uint16") + .assignmentOptional() + .defaultValue("float32") + .commit(), + + # important: determines shape of data as going into correction + UINT32_ELEMENT(expected) + .key("dataFormat.pixelsX") + .displayedName("Pixels x") + .description("Number of pixels of image data along X axis") + .assignmentOptional() + .defaultValue(512) + .commit(), + + UINT32_ELEMENT(expected) + .key("dataFormat.pixelsY") + .displayedName("Pixels y") + .description("Number of pixels of image data along Y axis") + .assignmentOptional() + .defaultValue(128) + .commit(), + + UINT32_ELEMENT(expected) + .key("dataFormat.memoryCells") + .displayedName("Memory cells") + .description("Full number of memory cells in incoming data") + .assignmentOptional() + .defaultValue(1) # subclass will want to set a default value + .commit(), + + UINT32_ELEMENT(expected) + .key("dataFormat.filteredFrames") + .displayedName("Frames after filter") + .description("Number of frames left after applying frame filter") + .readOnly() + .initialValue(0) + .commit(), + + STRING_ELEMENT(expected) + .key("dataFormat.outputAxisOrder") + .displayedName("Output axis order") + .description( + "Axes of main data output can be reordered after correction. Axis " + "order is specified as string consisting of 'x', 'y', and 'c', with " + "the latter indicating the memory cell axis. The default value of " + "'cxy' puts pixels on the fast axes." + ) + .options("cxy,cyx,xcy,xyc,ycx,yxc") + .assignmentOptional() + .defaultValue("cxy") + .commit(), + + VECTOR_UINT32_ELEMENT(expected) + .key("dataFormat.inputDataShape") + .displayedName("Input data shape") + .description( + "Image data shape in incoming data (from reader / DAQ). This value is " + "computed from pixelsX, pixelsY, and memoryCells - this field just " + "shows what is currently expected." + ) + .readOnly() + .initialValue([]) + .commit(), + + VECTOR_UINT32_ELEMENT(expected) + .key("dataFormat.outputDataShape") + .displayedName("Output data shape") + .description( + "Image data shape for data output from this device. This value is " + "computed from pixelsX, pixelsY, and the size of the frame filter - " + "this field just shows what is currently expected." + ) + .readOnly() + .initialValue([]) + .commit(), + ) + + ( + SLOT_ELEMENT(expected) + .key("loadMostRecentConstants") + .displayedName("Load most recent constants") + .description( + "Calling this slot will flush all constant buffers and cause the " + "device to start querying CalCat for the most recent constants - all " + "constants applicable for this device - available with the currently " + "set constant parameters. This is typically called after " + "instantiating pipeline, after changing parameters, or after " + "generating new constants." + ) + .commit() + ) + + ( + NODE_ELEMENT(expected).key("preview").displayedName("Preview").commit(), + + OUTPUT_CHANNEL(expected) + .key("preview.outputRaw") + .dataSchema(preview_schema) + .commit(), + + OUTPUT_CHANNEL(expected) + .key("preview.outputCorrected") + .dataSchema(preview_schema) + .commit(), + + BOOL_ELEMENT(expected) + .key("preview.enable") + .displayedName("Enable preview") + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + INT32_ELEMENT(expected) + .key("preview.index") + .displayedName("Index (or stat) for preview") + .description( + "If this value is ≥ 0, the corresponding index (frame, cell, or pulse) " + "will be sliced for the preview output. If this value is < 0, preview " + "will be one of the following stats: -1: max, -2: mean, -3: sum, -4: " + "stdev. These stats are computed across memory cells." + ) + .assignmentOptional() + .defaultValue(0) + .minInc(-4) + .reconfigurable() + .commit(), + + STRING_ELEMENT(expected) + .key("preview.selectionMode") + .displayedName("Index selection mode") + .description( + "The value of preview.index can be used in multiple ways, controlled " + "by this value. If this is set to 'frame', preview.index is sliced " + "directly from data. If 'cell' (or 'pulse') is selected, I will look " + "at cell (or pulse) table for the requested cell (or pulse ID). " + "Special (stat) index values <0 are not affected by this." + ) + .options("frame,cell,pulse") + .assignmentOptional() + .defaultValue("frame") + .reconfigurable() + .commit(), + + UINT32_ELEMENT(expected) + .key("preview.trainIdModulo") + .displayedName("Preview train stride") + .description( + "Preview will only be generated for trains whose ID modulo this " + "number is zero. Higher values means less frequent preview updates. " + "Keep in mind that the GUI has limited refresh rate. Extra care must " + "be taken if DAQ train stride is >1." + ) + .assignmentOptional() + .defaultValue(6) + .reconfigurable() + .commit(), + ) + + # just measurements and counters to display + ( + UINT64_ELEMENT(expected) + .key("trainId") + .displayedName("Train ID") + .description("ID of latest train processed by this device.") + .readOnly() + .initialValue(0) + .commit(), + + NODE_ELEMENT(expected) + .key("performance") + .displayedName("Performance measures") + .commit(), + + DOUBLE_ELEMENT(expected) + .key("performance.processingTime") + .displayedName("Processing time") + .unit(Unit.SECOND) + .metricPrefix(MetricPrefix.MILLI) + .readOnly() + .initialValue(0) + .warnHigh(100) + .info("Processing too slow to reach 10 Hz") + .needsAcknowledging(False) + .commit(), + + DOUBLE_ELEMENT(expected) + .key("performance.rate") + .displayedName("Rate") + .description( + "Actual rate with which this device gets, processes, and sends trains. " + "This is a simple windowed moving average." + ) + .unit(Unit.HERTZ) + .readOnly() + .initialValue(0) + .commit(), + ) + + # this node will be filled out by subclass + ( + NODE_ELEMENT(expected) + .key("corrections") + .displayedName("Correction steps") + .commit(), + ) + + def __init__(self, config): + super().__init__(config) + + self.input_data_dtype = np.dtype(config["dataFormat.inputImageDtype"]) + self.output_data_dtype = np.dtype(config["dataFormat.outputImageDtype"]) + + self.sources = set(config.get("fastSources")) + + self.kernel_runner = None # must call _update_buffers to initialize + self._shmem_buffer = None # ditto + + self._correction_flag_enabled = self._correction_flag_class.NONE + self._correction_flag_preview = self._correction_flag_class.NONE + self._buffer_lock = threading.Lock() + self._last_processing_started = 0 # used for processing time and timeout + + # register slots + if parse_version(karaboVersion) >= parse_version("2.11"): + # TODO: the CalCatFriend could add these for us + # note: overly complicated for closure to work + def make_wrapper_capturing_constant(constant): + def aux(): + self.calcat_friend.get_specific_constant_version_and_call_me_back( + constant, self._load_constant_to_runner + ) + + return aux + + for constant in self._constant_enum_class: + slot_name = f"foundConstants.{constant.name}.overrideConstantVersion" + meth_name = slot_name.replace(".", "_") + self.KARABO_SLOT( + make_wrapper_capturing_constant(constant), + slotName=meth_name, + ) + + self.KARABO_SLOT(self.loadMostRecentConstants) + self.KARABO_SLOT(self.requestScene) + + self.registerInitialFunction(self._initialization) + + def _initialization(self): + self.calcat_friend = self._calcat_friend_class( + self, pathlib.Path.cwd() / "calibration-client-secrets.json" + ) + self._update_frame_filter() + + self._buffered_status_update = Hash( + "trainId", + 0, + "performance.rate", + 0, + "performance.processingTime", + 0, + ) + self._processing_time_ema = utils.ExponentialMovingAverage(alpha=0.3) + self._rate_tracker = utils.WindowRateTracker() + self._rate_update_timer = utils.RepeatingTimer( + interval=1, + callback=self._update_rate_and_state, + ) + + self.KARABO_ON_INPUT("dataInput", self.input_handler) + self.KARABO_ON_EOS("dataInput", self.handle_eos) + + self.updateState(State.ON) + + def __del__(self): + del self._shmem_buffer + super().__del__() + + def preReconfigure(self, config): + for ts_path in ( + "constantParameters.deviceMappingSnapshotAt", + "constantParameters.constantVersionEventAt", + ): + if config.has(ts_path): + ts_string = config.get(ts_path) + if ts_string.strip() == "": + config.set(ts_path, "") + else: + try: + timestamp = dateutil.parser.isoparse(ts_string) + except ValueError as error: + self.log_status_warn(f"Failed to parse {ts_path}; {error}") + config.erase(ts_path) + else: + config.set(ts_path, timestamp.isoformat()) + if config.has("constantParameters.deviceMappingSnapshotAt"): + self.calcat_friend.flush_pdu_mapping() + self._prereconfigure_update_hash = config + + def postReconfigure(self): + if not hasattr(self, "_prereconfigure_update_hash"): + self.log_status_warn("postReconfigure without knowing update hash") + return + + update = self._prereconfigure_update_hash + + if update.has("frameFilter"): + with self._buffer_lock: + self._update_frame_filter() + if any( + update.has(shape_param) + for shape_param in ( + "dataFormat.pixelsX", + "dataFormat.pixelsY", + "dataFormat.memoryCells", + "constantParameters.memoryCells", + "frameFilter", + ) + ): + with self._buffer_lock: + self._update_buffers() + # TODO: only call this if they are changed (is cheap, though) + self._update_correction_flags() + + def loadMostRecentConstants(self): + self.flush_constants() + self.calcat_friend.flush_constants() + for constant in self._constant_enum_class: + self.calcat_friend.get_constant_version_and_call_me_back( + constant, self._load_constant_to_runner + ) + + def flush_constants(self): + """Reset constant buffers and disable corresponding correction steps""" + for correction_step, _ in self._correction_field_names: + self.set(f"corrections.{correction_step}.available", False) + self.kernel_runner.flush_buffers() + self._update_correction_flags() + + def log_status_info(self, msg): + self.log.INFO(msg) + self.set("status", msg) + + def log_status_warn(self, msg): + self.log.WARN(msg) + self.set("status", msg) + + def log_status_error(self, msg): + self.set("status", msg) + self.log.ERROR(msg) + self.updateState(State.ERROR) + + def requestScene(self, params): + payload = Hash() + name = params.get("name", default="") + payload["name"] = name + payload["success"] = True + if name == "overview": + payload["data"] = scenes.correction_device_overview_scene( + device_id=self.getInstanceId(), + schema=self.getFullSchema(), + ) + elif name.startswith("browse_schema"): + if ":" in name: + prefix = name[len("browse_schema:") :] + else: + prefix = "managed" + payload["data"] = scenes.recursive_subschema_scene( + self.getInstanceId(), + self.getFullSchema(), + prefix, + ) + else: + payload["success"] = False + response = Hash() + response["type"] = "deviceScene" + response["origin"] = self.getInstanceId() + response["payload"] = payload + self.reply(response) + + def _write_output(self, data, old_metadata): + """For dataOutput: reusing incoming data hash and setting source and timestamp + to be same as input""" + metadata = ChannelMetaData( + old_metadata.get("source"), + Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), + ) + + channel = self.signalSlotable.getOutputChannel("dataOutput") + channel.write(data, metadata, False) + channel.update() + + def _write_combiner_previews(self, channel_data_pairs, train_id, source): + # TODO: send as ImageData (requires updated assembler) + # TODO: allow sending *all* frames for commissioning (request: Jola) + preview_hash = Hash() + preview_hash.set("image.trainId", train_id) + + # note: have to construct because setting .tid after init is broken + timestamp = Timestamp(Epochstamp(), Trainstamp(train_id)) + metadata = ChannelMetaData(source, timestamp) + for channel_name, data in channel_data_pairs: + preview_hash.set("image.data", data) + channel = self.signalSlotable.getOutputChannel(channel_name) + channel.write(preview_hash, metadata, False) + channel.update() + + def _update_correction_flags(self): + """Based on constants loaded and settings, update bit mask flags for kernel""" + available = self._correction_flag_class.NONE + enabled = self._correction_flag_class.NONE + preview = self._correction_flag_class.NONE + for field_name, flag in self._correction_field_names: + if self.get(f"corrections.{field_name}.available"): + available |= flag + if self.get(f"corrections.{field_name}.enable"): + enabled |= flag + if self.get(f"corrections.{field_name}.preview"): + preview |= flag + enabled &= available + preview &= available + self._correction_flag_enabled = enabled + self._correction_flag_preview = preview + self.log.DEBUG(f"Corrections for dataOutput: {str(enabled)}") + self.log.DEBUG(f"Corrections for preview: {str(preview)}") + + def _update_frame_filter(self, update_buffers=True): + """Parse frameFilter string (if set) and update cached filter array. May update + dataFormat.filteredFrames - will therefore by default call _update_buffers + afterwards.""" + # TODO: add some validation to preReconfigure + self.log.DEBUG("Updating frame filter") + filter_type = FramefilterSpecType(self.get("frameFilter.type")) + filter_string = self.get("frameFilter.spec") + + if filter_type is FramefilterSpecType.NONE or filter_string.strip() == "": + self._frame_filter = None + elif filter_type is FramefilterSpecType.RANGE: + try: + numbers = tuple(int(part) for part in filter_string.split(",")) + except (ValueError, TypeError): + self.log_status_warn( + f"Invalid frame filter specification: {filter_string}" + ) + else: + self._frame_filter = np.arange(*numbers, dtype=np.uint16) + elif filter_type is FramefilterSpecType.COMMASEPARATED: + try: + self._frame_filter = np.fromstring( + filter_string, sep=",", dtype=np.uint16 + ) + except ValueError: + # note: only in the future will numpy actually give ValueError + self.log_status_warn( + f"Invalid frame filter specification: {filter_string}" + ) + else: + self.log_status_warn(f"Invalid frame filter type '{filter_type}'") + + if self._frame_filter is None: + self.set("dataFormat.filteredFrames", self.get("dataFormat.memoryCells")) + self.set("frameFilter.current", []) + else: + self.set("dataFormat.filteredFrames", self._frame_filter.size) + self.set("frameFilter.current", list(map(int, self._frame_filter))) + + if self._frame_filter is not None and ( + self._frame_filter.min() < 0 + or self._frame_filter.max() >= self.get("dataFormat.memoryCells") + ): + self.log_status_warn("Invalid frame filter set, expect exceptions!") + + if update_buffers: + self._update_buffers() + + def _update_buffers(self): + """(Re)initialize buffers / kernel runner according to expected data shapes""" + self.log.INFO("Updating buffers according to data shapes") + # reflect the axis reordering in the expected output shape + self.set("dataFormat.inputDataShape", list(self.input_data_shape)) + self.set("dataFormat.outputDataShape", list(self.output_data_shape)) + self.log.INFO(f"Input shape: {self.input_data_shape}") + self.log.INFO(f"Output shape: {self.output_data_shape}") + + if self._shmem_buffer is None: + shmem_buffer_name = self.getInstanceId() + ":dataOutput" + memory_budget = self.get("outputShmemBufferSize") * 2 ** 30 + self.log.INFO(f"Opening new shmem buffer: {shmem_buffer_name}") + self._shmem_buffer = shmem_utils.ShmemCircularBuffer( + memory_budget, + self.output_data_shape, + self.output_data_dtype, + shmem_buffer_name, + ) + self.log.INFO("Trying to pin the shmem buffer memory") + self._shmem_buffer.cuda_pin() + self.log.INFO("Done, shmem buffer is ready") + else: + self._shmem_buffer.change_shape(self.output_data_shape) + + self.kernel_runner = self._kernel_runner_class( + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + self.get("dataFormat.filteredFrames"), + int(self.get("constantParameters.memoryCells")), + input_data_dtype=self.input_data_dtype, + output_data_dtype=self.output_data_dtype, + **self._kernel_runner_init_args, + ) + + # TODO: gracefully handle change in constantParameters.memoryCells + with self.calcat_friend.cached_constants_lock: + for ( + constant, + data, + ) in self.calcat_friend.cached_constants.items(): + self.log_status_info(f"Reload constant {constant}") + self._load_constant_to_runner(constant, data) + + def input_handler(self, input_channel): + """Main handler for data input: Do a few simple checks to determine whether to + even try processing. If yes, will pass data and information to process_data + method provided by subclass.""" + + # Is device even ready for this? + state = State[self.unsafe_get("state")] + if state is State.ERROR: + # in this case, we should have already issued warning + return + elif self.kernel_runner is None: + self.log_status_warn("Received data, but have not initialized kernels yet") + return + + all_metadata = input_channel.getMetaData() + for input_index in range(input_channel.size()): + self._last_processing_started = default_timer() + data_hash = input_channel.read(input_index) + metadata = all_metadata[input_index] + source = metadata.get("source") + + if source not in self.sources: + self.log_status_info(f"Ignoring hash with unknown source {source}") + return + elif not data_hash.has("image"): + self.log_status_info("Ignoring hash without image node") + return + + train_id = metadata.getAttribute("timestamp", "tid") + cell_table = data_hash.get(self._cell_table_path) + if ( + (isinstance(cell_table, np.ndarray) and cell_table.size == 0) + or len(cell_table) == 0 + ): + self.log_status_warn( + "Empty cell table, DAQ probably not sending data." + ) + return + cell_table = np.squeeze(cell_table) + + # no more common reasons to skip input, so go to processing + if state is State.ON: + self.updateState(State.PROCESSING) + self.log_status_info("Processing data") + + correction_cell_num = self.unsafe_get("constantParameters.memoryCells") + cell_table_max = np.max(cell_table) + + image_data = data_hash.get(self._image_data_path) + if cell_table.size != self.unsafe_get("dataFormat.memoryCells"): + self.log_status_info( + f"Updating new input shape {image_data.shape}, updating buffers" + ) + self.set("dataFormat.memoryCells", cell_table.size) + with self._buffer_lock: + self._update_frame_filter() + + # DataAggregator typically tells us the wrong axis order + if self.unsafe_get("dataFormat.overrideInputAxisOrder"): + expected_shape = self.input_data_shape + if expected_shape != image_data.shape: + image_data.shape = expected_shape + + do_generate_preview = ( + train_id % self.unsafe_get("preview.trainIdModulo") == 0 + and self.unsafe_get("preview.enable") + ) + + with self._buffer_lock: + self.process_data( + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ) + self._buffered_status_update.set("trainId", train_id) + self._processing_time_ema.update( + default_timer() - self._last_processing_started + ) + self._rate_tracker.update() + + def _update_rate_and_state(self): + self._buffered_status_update.set("performance.rate", self._rate_tracker.get()) + self._buffered_status_update.set( + "performance.processingTime", self._processing_time_ema.get() * 1000 + ) + # trainId in _buffered_status_update should be updated in input handler + + self.set(self._buffered_status_update) + + if default_timer() - self._last_processing_started > PROCESSING_STATE_TIMEOUT: + if self.get("state") is State.PROCESSING: + self.updateState(State.ON) + self.log_status_info( + f"No new train in {PROCESSING_STATE_TIMEOUT} s, switching state." + ) + + def handle_eos(self, channel): + self.updateState(State.ON) + self.signalEndOfStream("dataOutput") + + +# forward-compatible unsafe_get proposed by @haufs +if not hasattr(BaseCorrection, "unsafe_get"): + def unsafe_get(self, key): + """Look up key in device schema quickly, but without consistency locks + + This is only relevant for use in hot path (input handler). Circumvents the + locking done by PythonDevice.get. Note that PythonDevice.get does handle some + special types (by looking at full schema for type information). In particular, + device state enum: `self.get("state")` will return a State whereas + `self.unsafe_get("state")` will return a string. Handle with care!""" + + # at least until Karabo 2.14, self._parameters is maintained by PythonDevice + return self._parameters.get(key) + + setattr(BaseCorrection, "unsafe_get", unsafe_get) + + +def add_correction_step_schema(schema, managed_keys, field_flag_mapping): + """Using the fields in the provided mapping, will add nodes to schema + + field_flag_mapping is assumed to be iterable of pairs where first entry in each + pair is the name of a correction step as it will appear in device schema (second + entry - typically an enum field - is ignored). For correction step, a node and some + booleans are added to the schema and the toggleable booleans are added to + managed_keys. Subclass can customize / add additional keys under node later. + + This method should be called in expectedParameters of subclass after the same for + BaseCorrection has been called. Would be nice to include in BaseCorrection instead, + but that is tricky: static method of superclass will need _correction_field_names + of subclass or device server gets mad. A nice solution with classmethods would be + welcome. + """ + + for field_name, _ in field_flag_mapping: + node_name = f"corrections.{field_name}" + ( + NODE_ELEMENT(schema).key(node_name).commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.available") + .displayedName("Available") + .description( + "This boolean indicates whether the necessary constants have been " + "loaded for this correction step to be applied. Enabling the " + "correction will have no effect unless this is True." + ) + .readOnly() + .initialValue(False) + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.enable") + .displayedName("Enable") + .description( + "Controls whether to apply this correction step for main data " + "output - subject to availability." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + + BOOL_ELEMENT(schema) + .key(f"{node_name}.preview") + .displayedName("Preview") + .description( + "Whether to apply this correction step for corrected preview " + "output - subject to availability." + ) + .assignmentOptional() + .defaultValue(True) + .reconfigurable() + .commit(), + ) + managed_keys.add(f"{node_name}.enable") + managed_keys.add(f"{node_name}.preview") diff --git a/src/calng/base_gpu.py b/src/calng/base_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..c0619d0ae253b7cb736dbc2fe7275092c7493a3c --- /dev/null +++ b/src/calng/base_gpu.py @@ -0,0 +1,195 @@ +import pathlib + +import cupy +import jinja2 +import numpy as np + +from . import utils + + +class BaseGpuRunner: + """Class to handle GPU buffers and execution of CUDA kernels on image data + + All GPU buffers are kept within this class and it is intentionally very stateful. + This generally means that you will want to load data into it and then do something. + Typical usage in correct order: + + 1. instantiate + 2. load constants + 3. load_data + 4. load_cell_table + 5. correct + 6a. reshape (only here does data transfer back to host) + 6b. compute_preview (optional) + + repeat from 2. or 3. + + In case no constants are available / correction is not desired, can skip 3 and 4 and + pass CorrectionFlags.NONE to correct(...). Generally, user must handle which + correction steps are appropriate given the constants loaded so far. + """ + + # These must be set by subclass + _kernel_source_filename = None + _corrected_axis_order = None + + def __init__( + self, + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + ): + _src_dir = pathlib.Path(__file__).absolute().parent + # subclass must define _kernel_source_filename + with (_src_dir / "kernels" / self._kernel_source_filename).open("r") as fd: + self._kernel_template = jinja2.Template(fd.read()) + + self.pixels_x = pixels_x + self.pixels_y = pixels_y + self.memory_cells = memory_cells + if constant_memory_cells == 0: + # if not set, guess same as input; may save one recompilation + self.constant_memory_cells = memory_cells + else: + self.constant_memory_cells = constant_memory_cells + # preview will only be single memory cell + self.preview_shape = (self.pixels_x, self.pixels_y) + self.input_data_dtype = input_data_dtype + self.output_data_dtype = output_data_dtype + + self._init_kernels() + + # reuse buffers for input / output + self.cell_table_gpu = cupy.empty(self.memory_cells, dtype=np.uint16) + self.input_data_gpu = cupy.empty(self.input_shape, dtype=input_data_dtype) + self.processed_data_gpu = cupy.empty( + self.processed_shape, dtype=output_data_dtype + ) + self.reshaped_data_gpu = None # currently not reusing buffer + + # default preview layers: raw and corrected (subclass can extend) + self.preview_buffer_getters = [ + self._get_raw_for_preview, + self._get_corrected_for_preview, + ] + + # to get data from respective buffers to cell, x, y shape for preview computation + def _get_raw_for_preview(self): + """Should return view of self.input_data_gpu with shape (cell, x/y, x/y)""" + raise NotImplementedError() + + def _get_corrected_for_preview(self): + """Should return view of self.processed_data_gpu with shape (cell, x/y, x/y)""" + raise NotImplementedError() + + def flush_buffers(self): + """Optional reset GPU buffers (implement in subclasses which need this)""" + pass + + def correct(self, flags): + """Correct (already loaded) image data according to flags + + Subclass must define this method. It should assume that image data, cell table, + and other data (including constants) has already been loaded. It should + probably run some GPU kernel and output should go into self.processed_data_gpu. + + Keep in mind that user only gets output from compute_preview or reshape + (either of these should come after correct). + + The submodules providing subclasses should have some IntFlag enums defining + which flags are available to pass along to the kernel. A zero flag should allow + the kernel to do no actual correction - but still copy the data between buffers + and cast it to desired output type. + + """ + raise NotImplementedError() + + def reshape(self, output_order, out=None): + """Move axes to desired output order and copy to host memory + + The out parameter is passed directly to the get function of GPU array: if + None, then a new ndarray (in host memory) is returned. If not None, then data + will be loaded into the provided array, which must match shape / dtype. + """ + # TODO: avoid copy + if output_order == self._corrected_axis_order: + self.reshaped_data_gpu = self.processed_data_gpu + else: + self.reshaped_data_gpu = cupy.transpose( + self.processed_data_gpu, + utils.transpose_order(self._corrected_axis_order, output_order), + ) + + return self.reshaped_data_gpu.get(out=out) + + def load_data(self, raw_data): + self.input_data_gpu.set(np.squeeze(raw_data)) + + def load_cell_table(self, cell_table): + self.cell_table_gpu.set(cell_table) + + def compute_previews(self, preview_index): + """Generate single slice or reduction preview of raw and corrected data + + Special values of preview_index are -1 for max, -2 for mean, -3 for sum, and + -4 for stdev (across cells). + + Note that preview_index is taken from data without checking cell table. + Caller has to figure out which index along memory cell dimension they + actually want to preview in case it needs to be a specific pulse. + + Will reuse data from corrected output buffer. Therefore, correct(...) must have + been called with the appropriate flags before compute_preview(...). + """ + + if preview_index < -4: + raise ValueError(f"No statistic with code {preview_index} defined") + elif preview_index >= self.memory_cells: + raise ValueError(f"Memory cell index {preview_index} out of range") + + # TODO: enum around reduction type + return tuple( + self._compute_a_preview(image_data=getter(), preview_index=preview_index) + for getter in self.preview_buffer_getters + ) + + def _compute_a_preview(self, image_data, preview_index): + """image_data must have cells on first axis; X and Y order is not important + here for now (and can differ between AGIPD and DSSC)""" + if preview_index >= 0: + # TODO: reuse pinned buffers for this + return image_data[preview_index].astype(np.float32).get() + elif preview_index == -1: + # TODO: confirm that max is pixel and not integrated intensity + # separate from next case because dtype not applicable here + return cupy.nanmax(image_data, axis=0).astype(cupy.float32).get() + elif preview_index in (-2, -3, -4): + stat_fun = { + -2: cupy.nanmean, + -3: cupy.nansum, + -4: cupy.nanstd, + }[preview_index] + return stat_fun(image_data, axis=0, dtype=cupy.float32).get() + + def update_block_size(self, full_block): + """Set execution grid such that it covers processed_shape with full_blocks + + Execution is scheduled with 3d "blocks" of CUDA threads. Tuning can affect + performance. Correction kernels are "monolithic" for simplicity (i.e. each + logical thread handles one entry in output data), so in each dimension we + parallelize, grid * block >= length to cover all entries. + + Note that individual kernels must themselves check whether they go out of + bounds; grid dimensions get rounded up in case ndarray size is not multiple of + block size. + + """ + assert len(full_block) == 3 + self.full_block = tuple(full_block) + self.full_grid = tuple( + utils.ceil_div(a_length, block_length) + for (a_length, block_length) in zip(self.processed_shape, full_block) + ) diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76b27adf0d18e3129daf5e331bd5492caef59ce1 --- /dev/null +++ b/src/calng/calcat_utils.py @@ -0,0 +1,558 @@ +import copy +import functools +import json +import pathlib +import threading + +import calibration_client +import h5py +import numpy as np +from calibration_client.modules import ( + Calibration, + CalibrationConstant, + CalibrationConstantVersion, + Detector, + DetectorType, + PhysicalDetectorUnit, +) +from karabo.bound import ( + BOOL_ELEMENT, + DOUBLE_ELEMENT, + NODE_ELEMENT, + SLOT_ELEMENT, + STRING_ELEMENT, + UINT32_ELEMENT, + VECTOR_UINT32_ELEMENT, +) +from karabo import version as karaboVersion +from pkg_resources import parse_version + +from . import utils + + +class ConditionNotFound(Exception): + pass + + +class DetectorNotFound(Exception): + pass + + +class ModuleNotFound(Exception): + pass + + +class CalibrationNotFound(Exception): + pass + + +class CalibrationClientConfigError(Exception): + pass + + +def add_status_schema_from_enum(schema, prefix, enum_class): + for constant in enum_class: + constant_node = f"{prefix}.{constant.name}" + ( + NODE_ELEMENT(schema).key(constant_node).commit(), + + BOOL_ELEMENT(schema) + .key(f"{constant_node}.found") + .readOnly() + .initialValue(False) + .commit(), + + STRING_ELEMENT(schema) + .key(f"{constant_node}.validFrom") + .readOnly() + .initialValue("") + .commit(), + + STRING_ELEMENT(schema) + .key(f"{constant_node}.calibrationId") + .readOnly() + .initialValue("") + .commit(), + + VECTOR_UINT32_ELEMENT(schema) + .key(f"{constant_node}.conditionIds") + .readOnly() + .initialValue([]) + .commit(), + + VECTOR_UINT32_ELEMENT(schema) + .key(f"{constant_node}.constantIds") + .readOnly() + .initialValue([]) + .commit(), + + STRING_ELEMENT(schema) + .key(f"{constant_node}.constantVersionId") + .description( + "This field is editable - if for any reason a specific constant " + "version is desired, the constant version ID (as used in CalCat) can " + "be set here and the slot below can be called to load this particular " + "version, overriding the automatic loading of latest constants." + ) + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + ) + if parse_version(karaboVersion) >= parse_version("2.11"): + ( + SLOT_ELEMENT(schema) + .key(f"{constant_node}.overrideConstantVersion") + .displayedName("Override constant version") + .commit(), + ) + + +class OperatingConditions(dict): + # TODO: support deviation? + def encode(self): + return { + "parameters": [ + { + "parameter_name": key, + "lower_deviation_value": 0.0, + "upper_deviation_value": 0.0, + "flg_available": False, + "value": value, + } + for (key, value) in self.items() + ] + } + + def __hash__(self): + # this takes me back to pre-screening interview time... + return hash(tuple(sorted(self.items()))) + + +class BaseCalcatFriend: + """Base class for CalCat friends - handles interacting with CalCat for the device + + A CalCat friend uses the device schema to build up parameters for CalCat queries. + It focuses on two nodes (added by static method add_schema): param_prefix and + status_prefix. The former is primarily used to get parameters which are (via + condition methods - see for example dark_condition of DsscCalcatFriend) used + to look for constants. The latter is primarily used to give user information + about what was found. + """ + + _constant_enum_class = None # subclass should set + _constants_need_conditions = None # subclass should set + + @staticmethod + def add_schema( + schema, + managed_keys, + detector_type, + param_prefix="constantParameters", + status_prefix="foundConstants", + ): + """Add elements needed by this object to device's schema (expectedSchema) + + All elements added to schema go under prefixes which should end with name of + node which does not exist yet. To change default values and add more fields, + extend this method in subclass. + + The param_prefix node will hold all the parameters needed to build constant + condition dicts for querying CalCat. These values are set either directly on + the device or via manager and this class gets them from the device using helper + function _get_param. See for example AgipdCalcatFriend.dark_condition. + + The status_prefix node is used to report information about what was found in + CalCat. This class will update the values on the device using the helper + function _set_status. This should not need to happen in subclass methods. + """ + + ( + NODE_ELEMENT(schema) + .key(param_prefix) + .displayedName("Constant retrieval parameters") + .commit(), + + NODE_ELEMENT(schema) + .key(status_prefix) + .displayedName("Constants retrieved") + .commit(), + ) + + # Parameters which any detector would probably have (extend this in subclass) + # TODO: probably switch to floating point for everything, including mem cells + ( + STRING_ELEMENT(schema) + .key(f"{param_prefix}.deviceMappingSnapshotAt") + .displayedName("Snapshot timestamp (for device mapping)") + .description( + "CalCat supports querying with a specific snapshot of the database. " + "When playing back a run from the file system, this feature is useful " + "to look up the device mapping at the time of the run. If this field " + "is left empty, the latest device mapping is used. Date format should " + "be 'YYYY-MM-DD' with optional time of day starting with 'T' followed " + "by 'hh:mm:ss.mil+02:00'." + ) + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.constantVersionEventAt") + .displayedName("Event at timestamp (for constant version)") + .description("TODO") + .assignmentOptional() + .defaultValue("") + .reconfigurable() + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.detectorType") + .displayedName("Detector type name") + .description( + "Name of detector type in CalCat; typically has suffix '-Type'" + ) + .readOnly() + .initialValue(detector_type) + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.detectorTypeId") + .readOnly() + .initialValue("") + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.detectorName") + .assignmentOptional() + .defaultValue("") + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.detectorId") + .readOnly() + .initialValue("") + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.karaboDa") + .assignmentOptional() + .defaultValue("") + .commit(), + + STRING_ELEMENT(schema) + .key(f"{param_prefix}.moduleId") + .readOnly() + .initialValue("") + .commit(), + + UINT32_ELEMENT(schema) + .key(f"{param_prefix}.memoryCells") + .displayedName("Memory cells") + .description( + "Number of memory cells / frames per train. Relevant for burst mode." + ) + .assignmentOptional() + .defaultValue(1) + .reconfigurable() + .commit(), + + UINT32_ELEMENT(schema) + .key(f"{param_prefix}.pixelsX") + .displayedName("Pixels X") + .assignmentOptional() + .defaultValue(512) + .commit(), + + UINT32_ELEMENT(schema) + .key(f"{param_prefix}.pixelsY") + .displayedName("Pixels Y") + .assignmentOptional() + .defaultValue(128) + .commit(), + + DOUBLE_ELEMENT(schema) + .key(f"{param_prefix}.biasVoltage") + .displayedName("Bias voltage") + .description("Sensor bias voltage") + .assignmentOptional() + .defaultValue(300) + .reconfigurable() + .commit(), + ) + managed_keys.add(f"{param_prefix}.deviceMappingSnapshotAt") + managed_keys.add(f"{param_prefix}.constantVersionEventAt") + managed_keys.add(f"{param_prefix}.memoryCells") + managed_keys.add(f"{param_prefix}.pixelsX") + managed_keys.add(f"{param_prefix}.pixelsY") + managed_keys.add(f"{param_prefix}.biasVoltage") + + def __init__( + self, + device, + secrets_fn: pathlib.Path, + param_prefix="constantParameters", + status_prefix="foundConstants", + ): + self.device = device + self.param_prefix = param_prefix + self.status_prefix = status_prefix + self.cached_constants = {} + self.cached_constants_lock = threading.Lock() + # api lock used to force queries to be sequential (SSL issue on ONC) + self.api_lock = threading.Lock() + + if not secrets_fn.is_file(): + self.device.log_status_warn( + f"Missing CalCat secrets file (expected {secrets_fn})" + ) + with secrets_fn.open("r") as fd: + calcat_secrets = json.load(fd) + + self.caldb_store = pathlib.Path(calcat_secrets["caldb_store_path"]) + if not self.caldb_store.is_dir(): + raise ValueError(f"caldb_store location '{self.caldb_store}' is not dir") + + self.device.log.INFO(f"Connecting to CalCat at {calcat_secrets['base_url']}") + base_url = calcat_secrets["base_url"] + self.client = calibration_client.CalibrationClient( + client_id=calcat_secrets["client_id"], + client_secret=calcat_secrets["client_secret"], + user_email=calcat_secrets["user_email"], + base_api_url=f"{base_url}/api/", + token_url=f"{base_url}/oauth/token", + refresh_url=f"{base_url}/oauth/token", + auth_url=f"{base_url}/oauth/authorize", + scope="public", + session_token=None, + ) + self.device.log_status_info("CalCat connection established") + + def _get_param(self, key): + """Helper to get value from attached device schema""" + return self.device.get(f"{self.param_prefix}.{key}") + + def _set_param(self, key, value): + self.device.set(f"{self.param_prefix}.{key}", value) + + def _get_status(self, constant, key): + return self.device.get(f"{self.status_prefix}.{constant.name}.{key}") + + def _set_status(self, constant, key, value): + """Helper to update information about found constants on device""" + self.device.set(f"{self.status_prefix}.{constant.name}.{key}", value) + + # Python 3.6 does not have functools.cached_property or even functools.cache + @property + @functools.lru_cache() + def detector_id(self): + detector_name = self._get_param("detectorName") + resp = Detector.get_by_identifier(self.client, detector_name) + self._check_resp(resp, DetectorNotFound, f"Detector {detector_name} not found") + res = resp["data"]["id"] + self._set_param("detectorId", str(res)) + return res + + @property + @functools.lru_cache() + def detector_type_id(self): + detector_type = self._get_param("detectorType") + resp = DetectorType.get_by_name(self.client, detector_type) + self._check_resp( + resp, DetectorNotFound, f"Detector type {detector_type} not found" + ) + res = resp["data"]["id"] + self._set_param("detectorTypeId", str(res)) + return res + + @property + @functools.lru_cache() + def pdus(self): + resp = PhysicalDetectorUnit.get_all_by_detector( + self.client, self.detector_id, self._get_param("deviceMappingSnapshotAt") + ) + self._check_resp(resp, warning="Failed to retrieve module mapping") + for irrelevant_key in ("detector", "detector_type", "flg_available"): + for pdu in resp["data"]: + del pdu[irrelevant_key] + return resp["data"] + + @property + @functools.lru_cache() + def _karabo_da_to_float_uuid(self): + return {pdu["karabo_da"]: pdu["float_uuid"] for pdu in self.pdus} + + @property + @functools.lru_cache() + def _karabo_da_to_id(self): + return {pdu["karabo_da"]: pdu["id"] for pdu in self.pdus} + + def flush_pdu_mapping(self): + for attr in ("pdus", "_karabo_da_to_float_uuid", "_karabo_da_to_id"): + if hasattr(self, attr): + delattr(self, attr) + self._set_param("moduleId", "") + + @utils.threadsafe_cache + def calibration_id(self, calibration_name: str): + resp = Calibration.get_by_name(self.client, calibration_name) + self._check_resp( + resp, CalibrationNotFound, f"Calibration type {calibration_name} not found!" + ) + return resp["data"]["id"] + + @utils.threadsafe_cache + def condition_ids(self, pdu, condition): + # modifying condition parameter messes with cache + condition_with_detector = copy.copy(condition) + condition_with_detector["Detector UUID"] = pdu + self.device.log.DEBUG(f"Look for condition: {condition_with_detector}") + resp = self.client.search_possible_conditions_from_dict( + "", condition_with_detector.encode() + ) + self._check_resp( + resp, + ConditionNotFound, + f"Failed to find condition {condition} for pdu {pdu}", + ) + return [d["id"] for d in resp["data"]] + + def constant_ids(self, calibration_id, condition_ids): + resp = CalibrationConstant.get_all_by_conditions( + self.client, + calibration_id=calibration_id, + detector_type_id=self.detector_type_id, + condition_ids=condition_ids, + ) + self._check_resp(resp, warning="Failed to retrieve constant ID") + return [d["id"] for d in resp["data"]] + + def get_constant_version(self, constant): + # TODO: catch exceptions, give warnings appropriately + karabo_da = self._get_param("karaboDa") + self.device.log_status_info(f"Attempting to find {constant} for {karabo_da}") + + if karabo_da not in self._karabo_da_to_float_uuid: + self.device.log_status_warn( + f"Module {karabo_da} not found in mapping, check configuration!" + ) + raise ModuleNotFound(f"Module map did not include {karabo_da}") + self._set_param("moduleId", str(self._karabo_da_to_id[karabo_da])) + + if isinstance(constant, str): + constant = self._constant_enum_class[constant] + + calibration_id = self.calibration_id(constant.name) + self._set_status(constant, "calibrationId", calibration_id) + + condition = self._constants_need_conditions[constant]() + condition_ids = self.condition_ids( + self._karabo_da_to_float_uuid[karabo_da], condition + ) + self._set_status(constant, "conditionIds", condition_ids) + + constant_ids = self.constant_ids( + calibration_id=calibration_id, condition_ids=condition_ids + ) + self._set_status(constant, "constantIds", constant_ids) + + resp = CalibrationConstantVersion.get_closest_by_time( + self.client, + calibration_constant_ids=constant_ids, + physical_detector_unit_id=self._karabo_da_to_id[karabo_da], + event_at=self._get_param("constantVersionEventAt"), + snapshot_at=None, + ) + self._check_resp(resp, warning="Failed to find calibration constant version") + # TODO: replace with start date and end date + timestamp = resp["data"]["begin_validity_at"] + self._set_status(constant, "validFrom", timestamp) + self._set_status(constant, "constantVersionId", resp["data"]["id"]) + + file_path = ( + self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] + ) + # TODO: handle FileNotFoundError if we are led astray + with h5py.File(file_path, "r") as fd: + constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) + with self.cached_constants_lock: + self.cached_constants[constant] = constant_data + self._set_status(constant, "found", True) + self.device.log_status_info(f"Done finding {constant} for {karabo_da}") + + return constant_data + + def get_specific_constant_version(self, constant): + # TODO: warn if PDU or constant type does not match + # TODO: warn if result is list (happens for empty version ID) + constant_version_id = self.device.get( + f"{self.status_prefix}.{constant.name}.constantVersionId" + ) + + resp = CalibrationConstantVersion.get_by_id(self.client, constant_version_id) + self._check_resp(resp, warning="Failed to find calibration constant version") + file_path = ( + self.caldb_store / resp["data"]["path_to_file"] / resp["data"]["file_name"] + ) + with h5py.File(file_path, "r") as fd: + constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) + with self.cached_constants_lock: + self.cached_constants[constant] = constant_data + self._set_status(constant, "validFrom", resp["data"]["begin_at"]) + self._set_status(constant, "calibrationId", "manual override") + self._set_status(constant, "conditionId", "manual override") + self._set_status(constant, "constantId", "manual override") + self._set_status(constant, "constantVersionId", constant_version_id) + self._set_status(constant, "found", True) + return constant_data + + def get_constant_version_and_call_me_back(self, constant, callback): + """Runs get_constant_version in thread, will call callback on completion""" + # TODO: do we want to use asyncio / "modern" async? + # TODO: consider moving out of this class, closer to correction device + def aux(): + with self.api_lock: + data = self.get_constant_version(constant) + callback(constant, data) + + thread = threading.Thread(target=aux) + thread.start() + return thread + + def get_specific_constant_version_and_call_me_back(self, constant, callback): + """Blindly load whatever CalCat points to for CCV - user must be confident that + this CCV corresponds to correct kind of constant.""" + + # TODO: warn user about all the things that go wrong + def aux(): + with self.api_lock: + data = self.get_specific_constant_version(constant) + callback(constant, data) + + thread = threading.Thread(target=aux) + thread.start() + return thread + + def flush_constants(self): + for constant in self._constant_enum_class: + self._set_status(constant, "validFrom", "") + self._set_status(constant, "found", False) + + def _check_resp(self, resp, exception=Exception, warning=None): + # TODO: probably verify using "info" that exception is the right one + to_raise = None + if not resp["success"]: + # TODO: probably more types of app_info errors? + if resp["app_info"]: + if "not found" in resp["info"]: + # this was likely the exception exception + to_raise = exception(resp["info"]) + else: + # but could also be authorization or similar issue + to_raise = CalibrationClientConfigError(resp["app_info"]) + to_raise = exception(resp["info"]) + if to_raise is not None: + if warning is not None: + self.device.log_status_warn(warning) + raise to_raise diff --git a/src/calng/kernels/agipd_gpu.cu b/src/calng/kernels/agipd_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..20b08d43d26b5f37d25c8633adb60ed3179db648 --- /dev/null +++ b/src/calng/kernels/agipd_gpu.cu @@ -0,0 +1,148 @@ +#include <cuda_fp16.h> + +{{corr_enum}} + +extern "C" { + /* + Perform corrections; see agipd_gpu.CorrectionFlags + Note that THRESHOLD and OFFSET should for any later corrections to make sense + Will take cell_table into account when getting correction values + Will convert from input dtype to float for correction + Will convert to output dtype for output + */ + __global__ void correct(const {{input_data_dtype}}* data, + const unsigned short* cell_table, + const unsigned char corr_flags, + // default_gain can be 0, 1, or 2, and is relevant for fixed gain mode (no THRESHOLD) + const unsigned char default_gain, + const float* threshold_map, + const float* offset_map, + const float* rel_gain_pc_map, + const float* md_additional_offset, + const float* rel_gain_xray_map, + const float g_gain_value, + const unsigned int* bad_pixel_map, + const float bad_pixel_mask_value, + float* gain_map, // TODO: more compact yet plottable representation + {{output_data_dtype}}* output) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t input_cells = {{data_memory_cells}}; + const size_t map_cells = {{constant_memory_cells}}; + + const size_t cell = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = blockIdx.y * blockDim.y + threadIdx.y; + const size_t y = blockIdx.z * blockDim.z + threadIdx.z; + + if (cell >= input_cells || y >= Y || x >= X) { + return; + } + + // data shape: memory cell, data/raw_gain (dim size 2), x, y + const size_t data_stride_y = 1; + const size_t data_stride_x = Y * data_stride_y; + const size_t data_stride_raw_gain = X * data_stride_x; + const size_t data_stride_cell = 2 * data_stride_raw_gain; + const size_t data_index = cell * data_stride_cell + + 0 * data_stride_raw_gain + + y * data_stride_y + + x * data_stride_x; + const size_t raw_gain_index = cell * data_stride_cell + + 1 * data_stride_raw_gain + + y * data_stride_y + + x * data_stride_x; + float corrected = (float)data[data_index]; + const float raw_gain_val = (float)data[raw_gain_index]; + + const size_t output_stride_y = 1; + const size_t output_stride_x = output_stride_y * Y; + const size_t output_stride_cell = output_stride_x * X; + const size_t output_index = cell * output_stride_cell + x * output_stride_x + y * output_stride_y; + + // per-pixel only constant: cell, x, y + const size_t map_stride_y = 1; + const size_t map_stride_x = Y * map_stride_y; + const size_t map_stride_cell = X * map_stride_x; + + // threshold constant shape: cell, x, y, threshold (dim size 2) + const size_t threshold_map_stride_threshold = 1; + const size_t threshold_map_stride_y = 2 * threshold_map_stride_threshold; + const size_t threshold_map_stride_x = Y * threshold_map_stride_y; + const size_t threshold_map_stride_cell = X * threshold_map_stride_x; + + // gain mapped constant shape: cell, x, y, gain_level (dim size 3) + const size_t gm_map_stride_gain = 1; + const size_t gm_map_stride_y = 3 * gm_map_stride_gain; + const size_t gm_map_stride_x = Y * gm_map_stride_y; + const size_t gm_map_stride_cell = X * gm_map_stride_x; + // note: assuming all maps have same shape (in terms of cells / x / y) + + const size_t map_cell = cell_table[cell]; + + if (map_cell < map_cells) { + unsigned char gain = default_gain; + if (corr_flags & THRESHOLD) { + const float threshold_0 = threshold_map[0 * threshold_map_stride_threshold + + map_cell * threshold_map_stride_cell + + y * threshold_map_stride_y + + x * threshold_map_stride_x]; + const float threshold_1 = threshold_map[1 * threshold_map_stride_threshold + + map_cell * threshold_map_stride_cell + + y * threshold_map_stride_y + + x * threshold_map_stride_x]; + // could consider making this const using ternaries / tiny function + if (raw_gain_val <= threshold_0) { + gain = 0; + } else if (raw_gain_val <= threshold_1) { + gain = 1; + } else { + gain = 2; + } + } + gain_map[output_index] = (float)gain; + + const size_t map_index = map_cell * map_stride_cell + + y * map_stride_y + + x * map_stride_x; + + const size_t gm_map_index = gain * gm_map_stride_gain + + map_cell * gm_map_stride_cell + + y * gm_map_stride_y + + x * gm_map_stride_x; + + if ((corr_flags & BPMASK) && bad_pixel_map[gm_map_index]) { + corrected = bad_pixel_mask_value; + gain_map[output_index] = bad_pixel_mask_value; + } else { + if (corr_flags & OFFSET) { + corrected -= offset_map[gm_map_index]; + // TODO: optionally reassign gain stage for this pixel based on new value + } + // TODO: baseline shift + if (corr_flags & REL_GAIN_PC) { + corrected *= rel_gain_pc_map[gm_map_index]; + if (gain == 1) { + corrected += md_additional_offset[map_index]; + } + } + if (corr_flags & GAIN_XRAY) { + corrected = (corrected / rel_gain_xray_map[map_index]) * g_gain_value; + } + } + {% if output_data_dtype == "half" %} + output[output_index] = __float2half(corrected); + {% else %} + output[output_index] = ({{output_data_dtype}})corrected; + {% endif %} + } else { + // TODO: decide what to do when we cannot threshold + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(corrected); + {% else %} + output[data_index] = ({{output_data_dtype}})corrected; + {% endif %} + + gain_map[data_index] = 255; + } + } +} diff --git a/src/calng/kernels/dssc_gpu.cu b/src/calng/kernels/dssc_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..a35eed986a4483e0b84ca84e35d2cc3d56d11cb5 --- /dev/null +++ b/src/calng/kernels/dssc_gpu.cu @@ -0,0 +1,62 @@ +#include <cuda_fp16.h> + +{{corr_enum}} + +extern "C" { + /* + Perform corrections: NONE or OFFSET + Take cell_table into account when getting correction values + Converting to float while correcting + Converting to output dtype at the end + Shape of input data: memory cell, 1, y, x + Shape of offset constant: x, y, memory cell + */ + __global__ void correct(const {{input_data_dtype}}* data, // shape: memory cell, 1, y, x + const unsigned short* cell_table, + const unsigned char corr_flags, + const float* offset_map, + {{output_data_dtype}}* output) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t memory_cells = {{data_memory_cells}}; + const size_t map_memory_cells = {{constant_memory_cells}}; + + const size_t memory_cell = blockIdx.x * blockDim.x + threadIdx.x; + const size_t y = blockIdx.y * blockDim.y + threadIdx.y; + const size_t x = blockIdx.z * blockDim.z + threadIdx.z; + + if (memory_cell >= memory_cells || y >= Y || x >= X) { + return; + } + + // note: strides differ from numpy strides because unit here is sizeof(...), not byte + const size_t data_stride_x = 1; + const size_t data_stride_y = X * data_stride_x; + const size_t data_stride_cell = Y * data_stride_y; + const size_t data_index = memory_cell * data_stride_cell + y * data_stride_y + x * data_stride_x; + const float raw = (float)data[data_index]; + + const size_t map_stride_x = 1; + const size_t map_stride_y = X * map_stride_x; + const size_t map_stride_cell = Y * map_stride_y; + const size_t map_cell = cell_table[memory_cell]; + if (map_cell < map_memory_cells) { + const size_t map_index = map_cell * map_stride_cell + y * map_stride_y + x * map_stride_x; + float corrected = raw; + if (corr_flags & OFFSET) { + corrected -= offset_map[map_index]; + } + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(corrected); + {% else %} + output[data_index] = ({{output_data_dtype}})corrected; + {% endif %} + } else { + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(raw); + {% else %} + output[data_index] = ({{output_data_dtype}})raw; + {% endif %} + } + } +} diff --git a/src/calng/scenes.py b/src/calng/scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..1373e022f35e6d6d09bf87ece9bfe50f51fe6090 --- /dev/null +++ b/src/calng/scenes.py @@ -0,0 +1,664 @@ +import enum + +import karathon +from karabo.common.scenemodel.api import ( + CheckBoxModel, + ColorBoolModel, + DeviceSceneLinkModel, + DisplayCommandModel, + DisplayLabelModel, + DisplayStateColorModel, + DisplayTextLogModel, + DoubleLineEditModel, + EvaluatorModel, + IntLineEditModel, + LabelModel, + LineEditModel, + RectangleModel, + SceneModel, + SceneTargetWindow, + write_scene, +) + +# section: common setup + +BASE_INC = 25 +NARROW_INC = 20 +PADDING = 5 +RECONFIGURABLE = 4 # TODO: look up proper enum +NODE_TYPE_NODE = 1 + +_type_to_line_editable = { + "BOOL": (CheckBoxModel, {"klass": "EditableCheckBox"}), + "DOUBLE": (DoubleLineEditModel, {}), + "FLOAT": (DoubleLineEditModel, {}), + "INT32": (IntLineEditModel, {}), + "UINT32": (IntLineEditModel, {}), + "INT64": (IntLineEditModel, {}), + "UINT64": (IntLineEditModel, {}), + "STRING": (LineEditModel, {"klass": "EditableLineEdit"}), +} + + +def safe_render(obj, x, y): + if hasattr(obj, "render"): + return obj.render(x, y) + else: + obj.x = x + obj.y = y + return [obj] + + +class Align(enum.Enum): + CENTER = enum.auto() + TOP = enum.auto() + BOTTOM = enum.auto() + LEFT = enum.auto() + RIGHT = enum.auto() + + +# section: nice component decorators + + +def titled(title, width=8 * NARROW_INC): + def actual_decorator(component_class): + class new_class(component_class): + def render(self, x, y, *args, **kwargs): + return [ + LabelModel( + frame_width=1, + text=title, + width=width, + height=NARROW_INC, + x=x, + y=y, + ) + ] + component_class.render(self, x, y + NARROW_INC, *args, **kwargs) + + @property + def width(self): + return max(component_class.width.fget(self), width) + + @property + def height(self): + return component_class.height.fget(self) + NARROW_INC + + return new_class + + return actual_decorator + + +def boxed(component_class): + class new_class(component_class): + def render(self, x, y, *args, **kwargs): + return [ + RectangleModel( + x=x, + y=y, + width=component_class.width.fget(self) + 2 * PADDING, + height=component_class.height.fget(self) + 2 * PADDING, + stroke="#000000", + ) + ] + component_class.render(self, x + PADDING, y + PADDING, *args, **kwargs) + + @property + def width(self): + return component_class.width.fget(self) + 2 * PADDING + + @property + def height(self): + return component_class.height.fget(self) + 2 * PADDING + + return new_class + + +# section: useful layout and utility classes + + +class Space: + def __init__(self, width, height): + self.width = width + self.height = height + + def render(self, x, y): + return [] + + +class HorizontalLayout: + def __init__(self, *arg_children, children=None, padding=PADDING): + self.children = list(arg_children) + if children is not None: + self.children.extend(children) + self.padding = padding + + def render(self, x, y, align=Align.TOP): + if align is not Align.TOP: + height = self.height + res = [] + for child in self.children: + if align is Align.TOP: + y_ = y + elif align is Align.CENTER: + y_ = y + (height - child.height) / 2 + elif align is Align.BOTTOM: + y_ = y + (height - child.height) + else: + raise ValueError(f"Invalid align {align} for HorizontalLayout") + res.extend(safe_render(child, x, y_)) + x += child.width + self.padding + return res + + @property + def width(self): + if not self.children: + return 0 + return self.padding * (len(self.children) - 1) + sum( + c.width for c in self.children + ) + + @property + def height(self): + if not self.children: + return 0 + return max(c.height for c in self.children) + + +class VerticalLayout: + def __init__(self, *arg_children, children=None, padding=PADDING): + self.children = list(arg_children) + if children is not None: + self.children.extend(children) + self.padding = padding + + def render(self, x, y): + res = [] + for child in self.children: + res.extend(safe_render(child, x, y)) + y += child.height + self.padding + return res + + @property + def width(self): + if not self.children: + return 0 + return max(c.width for c in self.children) + + @property + def height(self): + if not self.children: + return 0 + return self.padding * (len(self.children) - 1) + sum( + c.height for c in self.children + ) + + +class MaybeEditableRow(HorizontalLayout): + def __init__( + self, + device_id, + schema_hash, + key_path, + label_width=7 * NARROW_INC, + display_width=5 * NARROW_INC, + edit_width=5 * NARROW_INC, + height=NARROW_INC, + ): + super().__init__(padding=0) + key_attr = schema_hash.getAttributes(key_path) + label_text = ( + key_attr["displayedName"] + if "displayedName" in key_attr + else key_path.split(".")[-1] + ) + self.children.extend( + [ + LabelModel( + text=label_text, + width=label_width, + height=height, + ), + DisplayLabelModel( + keys=[f"{device_id}.{key_path}"], + width=display_width, + height=height, + ), + ] + ) + if key_attr["accessMode"] == RECONFIGURABLE: + if "valueType" not in key_attr: + return + value_type = key_attr["valueType"] + if value_type in _type_to_line_editable: + line_editable_class, extra_args = _type_to_line_editable[value_type] + self.children.append( + line_editable_class( + keys=[f"{device_id}.{key_path}"], + width=edit_width, + height=height, + **extra_args, + ) + ) + else: + self.children.append( + LabelModel( + text=f"Not implemented: editing {value_type} ({key_path})", + width=edit_width, + height=height, + ) + ) + + +# section: specific handcrafted components for device classes + + +@titled("Found constants", width=6 * NARROW_INC) +@boxed +class FoundConstantsColumn(VerticalLayout): + def __init__(self, device_id, schema_hash, prefix="foundConstants"): + super().__init__(padding=0) + self.children.extend( + [ + HorizontalLayout( + LabelModel( + text=constant_name, + width=6 * NARROW_INC, + height=NARROW_INC, + ), + ColorBoolModel( + width=NARROW_INC, + height=NARROW_INC, + keys=[f"{device_id}.{prefix}.{constant_name}.found"], + ), + DisplayLabelModel( + keys=[f"{device_id}.{prefix}.{constant_name}.validFrom"], + width=8 * BASE_INC, + height=BASE_INC, + ), + padding=0, + ) + for constant_name in schema_hash.get(prefix).getKeys() + ] + ) + + +class ConstantLoadedAmpeln(HorizontalLayout): + def __init__(self, device_id, schema_hash, prefix="foundConstants"): + super().__init__(padding=0) + self.children.extend( + [ + ColorBoolModel( + keys=[f"{device_id}.{prefix}.{key}.found"], + height=BASE_INC, + width=BASE_INC, + ) + for key in schema_hash.get(prefix).getKeys() + ] + ) + + +@titled("Manager status", width=6 * NARROW_INC) +@boxed +class ManagerDeviceStatus(VerticalLayout): + def __init__(self, device_id): + super().__init__(padding=0) + name = DisplayLabelModel( + keys=[f"{device_id}.deviceId"], + width=14 * BASE_INC, + height=BASE_INC, + ) + state = DisplayStateColorModel( + show_string=True, + keys=[f"{device_id}.state"], + width=7 * BASE_INC, + height=BASE_INC, + ) + restart_button = DisplayCommandModel( + keys=[f"{device_id}.restartServers"], + width=7 * BASE_INC, + height=BASE_INC, + ) + instantiate_button = DisplayCommandModel( + keys=[f"{device_id}.startInstantiate"], + width=7 * BASE_INC, + height=BASE_INC, + ) + apply_button = DisplayCommandModel( + keys=[f"{device_id}.applyManagedValues"], + width=7 * BASE_INC, + height=BASE_INC, + ) + status_log = DisplayTextLogModel( + keys=[f"{device_id}.status"], + width=14 * BASE_INC, + height=14 * BASE_INC, + ) + self.children.extend( + [ + name, + HorizontalLayout( + state, + restart_button, + padding=0, + ), + HorizontalLayout( + instantiate_button, + apply_button, + padding=0, + ), + DeviceSceneLinkModel( + text="All managed properties", + keys=[f"{device_id}.availableScenes"], + target="browse_schema", + target_window=SceneTargetWindow.Dialog, + width=7 * BASE_INC, + height=BASE_INC, + ), + status_log, + ] + ) + + +@titled("Device status", width=6 * NARROW_INC) +@boxed +class CorrectionDeviceStatus(VerticalLayout): + def __init__(self, device_id): + super().__init__(padding=0) + name = DisplayLabelModel( + keys=[f"{device_id}.deviceId"], + width=14 * BASE_INC, + height=BASE_INC, + ) + state = DisplayStateColorModel( + show_string=True, + keys=[f"{device_id}.state"], + width=7 * BASE_INC, + height=BASE_INC, + ) + rate = EvaluatorModel( + expression="f'{x:.02f}'", + keys=[f"{device_id}.performance.rate"], + width=7 * BASE_INC, + height=BASE_INC, + ) + processing_time = EvaluatorModel( + expression="f'{x:.02f}'", + keys=[f"{device_id}.performance.processingTime"], + width=7 * BASE_INC, + height=BASE_INC, + ) + tid = DisplayLabelModel( + keys=[f"{device_id}.trainId"], + width=7 * BASE_INC, + height=BASE_INC, + ) + status_log = DisplayTextLogModel( + keys=[f"{device_id}.status"], + width=14 * BASE_INC, + height=14 * BASE_INC, + ) + self.children.extend( + [ + name, + HorizontalLayout( + state, + tid, + padding=0, + ), + HorizontalLayout( + rate, + processing_time, + padding=0, + ), + status_log, + ] + ) + + +class CompactCorrectionDeviceOverview(HorizontalLayout): + def __init__(self, device_id, schema_hash): + super().__init__(padding=0) + self.children.extend( + [ + DeviceSceneLinkModel( + text=device_id.split("/")[-1], + keys=[f"{device_id}.availableScenes"], + target="overview", + target_window=SceneTargetWindow.Dialog, + width=5 * BASE_INC, + height=BASE_INC, + ), + DisplayStateColorModel( + show_string=True, + keys=[f"{device_id}.state"], + width=6 * BASE_INC, + height=BASE_INC, + ), + EvaluatorModel( + expression="f'{x:.02f}'", + keys=[f"{device_id}.performance.rate"], + width=4 * BASE_INC, + height=BASE_INC, + ), + DisplayLabelModel( + keys=[f"{device_id}.trainId"], + width=4 * BASE_INC, + height=BASE_INC, + ), + ConstantLoadedAmpeln(device_id, schema_hash), + ] + ) + + +@titled("Other devices managed") +@boxed +class CompactDeviceLinkList(VerticalLayout): + def __init__(self, device_ids): + super().__init__() + self.children.extend( + [ + HorizontalLayout( + DeviceSceneLinkModel( + text=device_id.split("/")[-1], + keys=[f"{device_id}.availableScenes"], + width=7 * BASE_INC, + height=BASE_INC, + ), + DisplayStateColorModel( + show_string=True, + keys=[f"{device_id}.state"], + width=7 * BASE_INC, + height=BASE_INC, + ), + padding=0, + ) + for device_id in device_ids + ] + ) + + +# section: generating actual scenes + + +def schema_to_hash(schema): + if isinstance(schema, karathon.Schema): + return schema.getParameterHash() + else: + return schema.hash + + +def scene_generator(fun): + # TODO: pretty decorator + def aux(*args, **kwargs): + content = fun(*args, **kwargs) + + scene = SceneModel( + children=content.render(PADDING, PADDING), + width=content.width + 2 * PADDING, + height=content.height + 2 * PADDING, + ) + return write_scene(scene) + + return aux + + +@scene_generator +def correction_device_overview_scene(device_id, schema): + schema_hash = schema_to_hash(schema) + + return HorizontalLayout( + CorrectionDeviceStatus(device_id), + VerticalLayout( + recursive_maybe_editable( + device_id, + schema_hash, + "constantParameters", + title="Parameters used for CalCat queries", + ), + DisplayCommandModel( + keys=[f"{device_id}.loadMostRecentConstants"], + width=10 * BASE_INC, + height=BASE_INC, + ), + ), + FoundConstantsColumn(device_id, schema_hash), + recursive_maybe_editable( + device_id, + schema_hash, + "corrections", + max_depth=2, + title="Correction steps", + ), + ) + + +@scene_generator +def manager_device_overview_scene( + manager_device_id, + manager_device_schema, + correction_device_schema, + correction_device_ids, + domain_device_ids, +): + mds_hash = schema_to_hash(manager_device_schema) + cds_hash = schema_to_hash(correction_device_schema) + + return VerticalLayout( + HorizontalLayout( + ManagerDeviceStatus(manager_device_id), + VerticalLayout( + recursive_maybe_editable( + manager_device_id, + mds_hash, + "managed.constantParameters", + title="Parameters used for CalCat queries", + ), + DisplayCommandModel( + keys=[f"{manager_device_id}.managed.loadMostRecentConstants"], + width=10 * BASE_INC, + height=BASE_INC, + ), + ), + recursive_maybe_editable( + manager_device_id, + mds_hash, + "managed.corrections", + max_depth=2, + ), + ), + HorizontalLayout( + titled("Correction devices", width=8 * NARROW_INC)(boxed(VerticalLayout))( + children=[ + CompactCorrectionDeviceOverview(device_id, cds_hash) + for device_id in sorted(correction_device_ids) + ], + padding=0, + ), + CompactDeviceLinkList( + sorted( + set(domain_device_ids) + - set(correction_device_ids) + - {manager_device_id} + ) + ), + ), + ) + + +# section: here be monsters + + +def recursive_maybe_editable( + device_id, schema_hash, prefix, depth=1, max_depth=3, title=None +): + if title is None: + title = prefix.split(".")[-1] + # note: not just using sets because that loses ordering + node_keys = [] + value_keys = [] + slot_keys = [] + for key in schema_hash.get(prefix).getKeys(): + attrs = schema_hash.getAttributes(f"{prefix}.{key}") + if attrs.get("nodeType") == NODE_TYPE_NODE: + if "classId" in attrs and attrs.get("classId") == "Slot": + slot_keys.append(key) + else: + node_keys.append(key) + else: + value_keys.append(key) + res = titled(title)(boxed(VerticalLayout))( + children=[ + MaybeEditableRow(device_id, schema_hash, f"{prefix}.{key}") + for key in value_keys + ] + + [ + DisplayCommandModel( + keys=[f"{device_id}.{prefix}.{key}"], + width=10 * BASE_INC, + height=BASE_INC, + ) + for key in slot_keys + ], + padding=0, + ) + if depth < max_depth: + res.children.append( + VerticalLayout( + children=[ + recursive_maybe_editable( + device_id, + schema_hash, + f"{prefix}.{key}", + depth=depth + 1, + max_depth=max_depth, + ) + for key in node_keys + ] + ) + ) + else: + res.children.extend( + [ + VerticalLayout( + DeviceSceneLinkModel( + text=key, + keys=[f"{device_id}.availableScenes"], + target=f"browse_schema:{prefix}.{key}", + target_window=SceneTargetWindow.Dialog, + width=5 * BASE_INC, + height=BASE_INC, + ), + ) + for key in node_keys + ] + ) + return res + + +@scene_generator +def recursive_subschema_scene( + device_id, + device_schema, + prefix="managed", +): + mds_hash = schema_to_hash(device_schema) + return recursive_maybe_editable(device_id, mds_hash, prefix) diff --git a/src/calng/shmem_utils.py b/src/calng/shmem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a82cdf6eaa6a5e0bf065e6eda1812597db1f3d0 --- /dev/null +++ b/src/calng/shmem_utils.py @@ -0,0 +1,127 @@ +import numpy as np +import posixshmem + + +def parse_shmem_handle(handle_string): + buffer_name, dtype, shape, index = handle_string.split("$") + dtype = getattr(np, dtype) + shape = tuple(int(n) for n in shape.split(",")) + index = int(index) + return buffer_name, dtype, shape, index + + +def open_shmem_from_handle(handle_string): + """Conveniently open readonly SharedMemory with ndarray view from a handle.""" + buffer_name, dtype, shape, _ = parse_shmem_handle(handle_string) + buffer_mem = posixshmem.SharedMemory(name=buffer_name, rw=False) + + array = buffer_mem.ndarray( + shape=shape, + dtype=dtype, + ) + + return buffer_mem, array + + +class ShmemCircularBufferReceiver: + def __init__(self): + self._name_to_mem = {} + self._name_to_ary = {} + + def get(self, handle_string): + name, dtype, shape, index = parse_shmem_handle(handle_string) + if name not in self._name_to_mem: + mem = posixshmem.SharedMemory(name=name, rw=False) + self._name_to_mem[name] = mem + ary = mem.ndarray(shape=shape, dtype=dtype) + self._name_to_ary[name] = ary + return ary[index] + + ary = self._name_to_ary[name] + if ary.shape != shape or ary.dtype != dtype: + del ary + mem = self._name_to_mem[name] + ary = mem.ndarray(shape=shape, dtype=dtype) + self._name_to_ary[name] = ary + + return ary[index] + + +class ShmemCircularBuffer: + """Convenience wrapper around posixshmem-backed ndarray buffers + + The underlying memory will be opened as an ndarray with shape (buffer_size, ) + + array_shape where buffer_size is memory_budget // dtype * array size. Each call + to next_slot will return the next entry along the first dimension of this array + (both a handle for IPC usage and the ndarray view). + """ + + def __init__(self, memory_budget, array_shape, dtype, shmem_name): + # for portable use: name has leading slash and no other slashes + self.shmem_name = "/" + shmem_name.lstrip("/").replace("/", "_") + self._shared_memory = posixshmem.SharedMemory( + name=self.shmem_name, + size=memory_budget, + rw=True, + ) + self._buffer_ary = None + self._update_shape(array_shape, dtype) + self._cuda_pinned = False + # important for performance and pinning: touch memory to actually allocate + self._buffer_ary.fill(0) + + def _update_shape(self, array_shape, dtype): + array_shape = tuple(array_shape) + array_bytes = np.dtype(dtype).itemsize * np.product(array_shape) + num_slots = self._shared_memory.size // array_bytes + if num_slots == 0: + raise ValueError("Array size exceeds size of allocated memory block") + full_shape = (num_slots,) + array_shape + + if self._buffer_ary is not None: + del self._buffer_ary + self._buffer_ary = self._shared_memory.ndarray( + shape=full_shape, + dtype=dtype, + ) + shape_str = ",".join(str(n) for n in full_shape) + self.shmem_handle_template = ( + f"{self.shmem_name}${np.dtype(dtype)}${shape_str}${{index}}" + ) + self._next_slot_index = 0 + + def change_shape(self, array_shape, dtype=None): + """Set new array shape to buffer. Note that the existing SharedMemory object is + still used. Old data in there will be mangled and number of slots will depend + upon new array shape and original memory budget. + """ + if dtype is None: + dtype = self._buffer_ary.dtype + self._update_shape(array_shape, dtype) + + def cuda_pin(self): + import cupy + self._memory_pointer = self._buffer_ary.ctypes.get_data() + cupy.cuda.runtime.hostRegister( + self._memory_pointer, + self._shared_memory.size, + 0 + ) + + def __del__(self): + if self._cuda_pinned: + import cupy + cupy.cuda.runtime.hostUnregister(self._memory_pointer) + del self._buffer_ary + del self._shared_memory + + @property + def num_slots(self): + return self._buffer_ary.shape[0] + + def next_slot(self): + current_index = self._next_slot_index + self._next_slot_index = (self._next_slot_index + 1) % self.num_slots + shmem_handle = self.shmem_handle_template.format(index=current_index) + data = self._buffer_ary[current_index] + return shmem_handle, data diff --git a/src/calng/utils.py b/src/calng/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..291c2d976ae4458495d351337e20f95fdcf7ebce --- /dev/null +++ b/src/calng/utils.py @@ -0,0 +1,286 @@ +import collections +import functools +import inspect +import threading +import time +from timeit import default_timer + +import numpy as np + + +def pick_frame_index(selection_mode, index, cell_table, pulse_table, warn_func=None): + """When selecting a single frame to preview, an obvious question is whether the + number the operator provides is a frame index, a cell ID, or a pulse ID. This + function allows any of the three, translating into frame index. + + As this will be used by correction devices, the warn_func parameter allows the + function to issue warnings via this instead of raising exceptions. + + Indices below zero are special values and thus returned directly. + + Returns: (found frame index, corresponding cell ID, corresponding pulse ID)""" + + if index < 0: + return index, index, index + + # TODO: enum + if selection_mode == "frame": + if index >= len(cell_table): + if warn_func is not None: + warn_func( + f"Index {index} out of range for cell table of length " + f"{len(cell_table)}, returning index 0 instead" + ) + frame_index = 0 + else: + frame_index = index + + return frame_index, cell_table[frame_index], pulse_table[frame_index] + elif selection_mode == "cell": + found = np.where(cell_table == index)[0] + if len(found) > 0: + cell = index + frame_index = found[0] + else: + cell = cell_table[0] + if warn_func is not None: + warn_func( + f"Cell {index} not found, arbitrary cell {cell} returned instead" + ) + frame_index = 0 + return frame_index, cell, pulse_table[frame_index] + elif selection_mode == "pulse": + found = np.where(pulse_table == index)[0] + if len(found) > 0: + pulse = index + frame_index = found[0] + else: + pulse = pulse_table[0] + if warn_func is not None: + warn_func( + f"Pulse {index} not found, arbitrary pulse {pulse} returned instead" + ) + frame_index = 0 + return frame_index, cell_table[frame_index], pulse + else: + raise ValueError(f"Invalid selection mode '{selection_mode}'") + + +def threadsafe_cache(fun): + """This decorator imitates functools.cache, but threadsafer + + With multiple threads hitting a function cached by functools.cache, it is possible + to trigger recomputation. This decorator adds granular locking: each key in the + cache (derived from arguments) has its own lock. + """ + + locks = {} + results = {} + fun_sig = inspect.signature(fun) + + @functools.wraps(fun) + def aux(*args, **kwargs): + bound_args = fun_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + key = bound_args.args + tuple(bound_args.kwargs.items()) + if key in results: + return results[key] + with locks.setdefault(key, threading.Lock()): + if key in results: + # someone else did this - may still be processing + return results[key] + else: + res = fun(*args, **kwargs) + results[key] = res + return res + + return aux + + +@functools.lru_cache() +def transpose_order(axes_in, axes_out): + """Computes the order of axes_out relative to axes_in for transposition purposes + + Both axes_in and axes_out are assumed to be strings in which each letter represents + an axis (duck typing accepts: any iterable of hashable elements). They should + probably be of the same length and have no repetitions, but this is not enforced. + Off-label use voids warranty. + """ + axis_order = {axis: index for index, axis in enumerate(axes_in)} + return tuple(axis_order[axis] for axis in axes_out) + + +_np_typechar_to_c_typestring = { + "?": "bool", + "B": "unsigned char", + "D": "double complex", + "F": "float complex", + "G": "long double complex", + "H": "unsigned short", + "I": "unsigned int", + "L": "unsigned long", + "Q": "unsigned long long", + "b": "char", + "d": "double", + "e": "half", # warning: only in CUDA with special support + "f": "float", + "g": "long double", + "h": "short", + "i": "int", + "l": "long", + "q": "long long", +} + + +def np_dtype_to_c_type(dtype): + as_char = np.sctype2char(dtype) + return _np_typechar_to_c_typestring[as_char] + + +def enum_to_c_template(enum_class): + res = [f"enum {enum_class.__name__} {{"] + for field in enum_class: + res.append(f"\t{field.name} = {field.value},") + res.append("};") + return "\n".join(res) + + +def ceil_div(num, denom): + return (num + denom - 1) // denom + + +def shape_after_transpose(input_shape, transpose_pattern, squeeze=True): + if squeeze: + input_shape = tuple(dim for dim in input_shape if dim > 1) + if transpose_pattern is None: + return input_shape + return tuple(np.array(input_shape)[list(transpose_pattern)].tolist()) + + +class RepeatingTimer: + """A timer which will call callback every interval seconds""" + + def __init__( + self, + interval, + callback, + start_now=True, + daemon=True, + ): + self.stopped = True + self.interval = interval + self.callback = callback + self.daemonize = daemon + if start_now: + self.start() + + def start(self): + self.stopped = False + self.wakeup_time = default_timer() + self.interval + + def runner(): + while not self.stopped: + now = default_timer() + while now < self.wakeup_time: + diff = self.wakeup_time - now + time.sleep(diff) + if self.stopped: + return + now = default_timer() + self.callback() + self.wakeup_time = default_timer() + self.interval + + self.thread = threading.Thread(target=runner, daemon=self.daemonize) + self.thread.start() + + def stop(self): + self.stopped = True + + +class ExponentialMovingAverage: + def __init__(self, alpha, use_first_value=True): + self.alpha = alpha + self.initialised = not use_first_value + self.mean = 0 + + def update(self, value): + if self.initialised: + self.mean += self.alpha * (value - self.mean) + else: + self.mean = value + self.initialised = True + + def get(self): + return self.mean + + +class WindowRateTracker: + def __init__(self, buffer_size=20, time_window=10): + self.time_window = time_window + self.buffer_size = buffer_size + self.deque = collections.deque(maxlen=self.buffer_size) + + def update(self): + self.deque.append(default_timer()) + + def get(self): + now = default_timer() + cutoff = now - self.time_window + try: + while self.deque[0] < cutoff: + self.deque.popleft() + except IndexError: + return 0 + if len(self.deque) < 2: + return 0 + if len(self.deque) < self.buffer_size: + # TODO: estimator avoiding ramp-up of when starting anew + return len(self.deque) / self.time_window + else: + # if going faster than buffer size per time window, look at timestamps + oldest, newest = self.deque[0], self.deque[-1] + buffer_span = newest - oldest + period = buffer_span / (self.buffer_size - 1) + if (now - newest) < period: + # no new estimate yet, expecting new event after period + return 1 / period + else: + return self.buffer_size / (now - oldest) + + +class Stopwatch: + """Context manager measuring time spent in context. + + Keyword arguments: + name: if not None, will appear in string representation + also, if not None, will automatically print self when done + """ + + def __init__(self, name=None): + self.stop_time = None + self.name = name + + def __enter__(self): + self.start_time = default_timer() + return self + + def __exit__(self, t, v, tb): # type, value and traceback irrelevant + self.stop_time = default_timer() + if self.name is not None: + print(repr(self)) + + @property + def elapsed(self): + if self.stop_time is not None: + return self.stop_time - self.start_time + else: + return default_timer() - self.start_time + + def __str__(self): + return self.__repr__() + + def __repr__(self): + if self.name is None: + return f"{self.elapsed():.3f} s" + else: + return f"{self.name}: {self.elapsed():.3f} s" diff --git a/src/tests/problem.py b/src/tests/problem.py new file mode 100644 index 0000000000000000000000000000000000000000..97c2ca2cbb25310492561d267a482147c0751ce9 --- /dev/null +++ b/src/tests/problem.py @@ -0,0 +1,20 @@ +from calng import utils + + +calls = 0 + +@utils.threadsafe_cache +def will_raise_once(argument): + global calls + calls += 1 + if calls == 1: + raise Exception("That's just what I do") + return argument + 1 + +try: + will_raise_once(0) +except Exception as ex: + print("As expected, firs call raised:", ex) + +print("Now calling again:") +print(will_raise_once(0)) diff --git a/src/tests/test_agipd_kernels.py b/src/tests/test_agipd_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..73b776aba85acf5e34bb5a671f3b484e3c57b8b9 --- /dev/null +++ b/src/tests/test_agipd_kernels.py @@ -0,0 +1,119 @@ +import h5py +import numpy as np +import pathlib +import pytest + +from calng import AgipdCorrection + +input_dtype = np.uint16 +output_dtype = np.float16 +corr_dtype = np.float32 +pixels_x = 512 +pixels_y = 128 +memory_cells = 352 + +raw_data = np.random.randint( + low=0, high=2000, size=(memory_cells, 2, pixels_x, pixels_y), dtype=input_dtype +) +image_data = raw_data[:, 0] +raw_gain = raw_data[:, 1] +cell_table = np.arange(memory_cells, dtype=np.uint16) +np.random.shuffle(cell_table) + +caldb_store = pathlib.Path("/gpfs/exfel/d/cal/caldb_store/xfel/cal") +caldb_prefix = caldb_store / "agipd-type/agipd_siv1_agipdv11_m305" + +with h5py.File(caldb_prefix / "cal.1619543695.4679213.h5", "r") as fd: + thresholds = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/ThresholdsDark/0/data"]) +with h5py.File(caldb_prefix / "cal.1619543664.1545036.h5", "r") as fd: + offset_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/Offset/0/data"]) +with h5py.File(caldb_prefix / "cal.1615377705.8904035.h5", "r") as fd: + slopes_pc_map = np.array(fd["/AGIPD_SIV1_AGIPDV11_M305/SlopesPC/0/data"]) + +kernel_runner = AgipdCorrection.AgipdGpuRunner( + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells=memory_cells, + input_data_dtype=input_dtype, + output_data_dtype=output_dtype, +) + + +def thresholding_cpu(data, cell_table, thresholds): + # get to memory_cell, x, y + raw_gain = data[:, 1, ...].astype(corr_dtype) + # get to threshold, memory_cell, x, y + thresholds = np.transpose(thresholds)[:, cell_table] + res = np.zeros((memory_cells, pixels_x, pixels_y), dtype=np.uint8) + res[raw_gain > thresholds[0]] = 1 + res[raw_gain > thresholds[1]] = 2 + return res + + +gain_map_cpu = thresholding_cpu(raw_data, cell_table, thresholds) + + +def corr_offset_cpu(data, cell_table, gain_map, offset): + image_data = data[:, 0].astype(corr_dtype) + offset = np.transpose(offset)[:, cell_table] + return (image_data - np.choose(gain_map, offset)).astype(output_dtype) + + +def corr_rel_gain_pc_cpu(data, cell_table, gain_map, slopes_pc): + slopes_pc = slopes_pc.astype(np.float32) + pc_high_m = slopes_pc[0] + pc_high_I = slopes_pc[1] + pc_med_m = slopes_pc[3] + pc_med_I = slopes_pc[4] + frac_high_med = pc_high_m / pc_med_m + md_additional_offset = pc_high_I - pc_med_I * frac_high_med + rel_gain_map = np.ones((3, pixels_x, pixels_y, memory_cells), dtype=np.float32) + rel_gain_map[0] = 1 # rel xray gain can come after + rel_gain_map[1] = rel_gain_map[0] * np.transpose(frac_high_med) + rel_gain_map[2] = rel_gain_map[1] * 4.48 + res = data[:, 0].astype(corr_dtype, copy=True) + res *= np.choose(gain_map, np.transpose(rel_gain_map, (0, 3, 1, 2))) + pixels_in_medium_gain = gain_map == 1 + res[pixels_in_medium_gain] += np.transpose(md_additional_offset, (0, 2, 1))[ + pixels_in_medium_gain + ] + return res + + +def test_thresholding(): + kernel_runner.load_cell_table(cell_table) + kernel_runner.load_data(raw_data) + kernel_runner.load_thresholds(thresholds) + kernel_runner.correct(AgipdCorrection.CorrectionFlags.THRESHOLD) + gpu_res = kernel_runner.gain_map_gpu.get() + assert np.allclose(gpu_res, gain_map_cpu) + + +def test_offset(): + kernel_runner.load_cell_table(cell_table) + kernel_runner.load_data(raw_data) + kernel_runner.load_thresholds(thresholds) + kernel_runner.load_offset_map(offset_map) + # have to do thresholding, otherwise all is treated as high gain + kernel_runner.correct( + AgipdCorrection.CorrectionFlags.THRESHOLD + | AgipdCorrection.CorrectionFlags.OFFSET + ) + cpu_res = corr_offset_cpu(raw_data, cell_table, gain_map_cpu, offset_map) + gpu_res = kernel_runner.processed_data_gpu.get() + assert np.allclose(gpu_res, cpu_res) + + +def test_rel_gain_pc(): + kernel_runner.load_cell_table(cell_table) + kernel_runner.load_data(raw_data) + kernel_runner.load_thresholds(thresholds) + kernel_runner.load_rel_gain_pc_map(slopes_pc_map) + kernel_runner.correct( + AgipdCorrection.CorrectionFlags.THRESHOLD + | AgipdCorrection.CorrectionFlags.REL_GAIN_PC + ) + cpu_res = corr_rel_gain_pc_cpu(raw_data, cell_table, gain_map_cpu, slopes_pc_map) + gpu_res = kernel_runner.processed_data_gpu.get() + assert np.allclose(gpu_res, cpu_res) diff --git a/src/tests/test_calcat_utils.py b/src/tests/test_calcat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b647c805876766ccc9fc66a36576788f3f216e5 --- /dev/null +++ b/src/tests/test_calcat_utils.py @@ -0,0 +1,154 @@ +import pathlib +import timeit + +from calng import AgipdCorrection, DsscCorrection +from calng.utils import Stopwatch +from karabo.bound import Hash, Schema +import pytest + +# TODO: secrets management +_test_dir = pathlib.Path(__file__).absolute().parent +_test_calcat_secrets_fn = _test_dir / "calibration-client-secrets.json" + + +class DummyLogger: + DEBUG = print + INFO = print + WARN = print + + +class DummyBaseDevice: + log = DummyLogger() + + def log_status_info(self, msg): + self.log.INFO(msg) + + def log_status_warn(self, msg): + self.log.WARN(msg) + + def get(self, key): + return self.schema.get(key) + + def set(self, key, value): + print(f'Would set "{key}" = {value}') + + +# TODO: consider testing by attaching to real karabo.bound.PythonDevice +class DummyAgipdDevice(DummyBaseDevice): + device_class_schema = Schema() + managed_keys = set() + + @staticmethod + def expectedParameters(expected): + AgipdCorrection.AgipdCalcatFriend.add_schema( + expected, DummyAgipdDevice.managed_keys + ) + + def __init__(self, config): + self.schema = config + self.calibration_constant_manager = AgipdCorrection.AgipdCalcatFriend( + self, + _test_calcat_secrets_fn, + ) + print(self.managed_keys) + + +DummyAgipdDevice.expectedParameters(DummyAgipdDevice.device_class_schema) + + +class DummyDsscDevice(DummyBaseDevice): + device_class_schema = Schema() + managed_keys = set() + + @staticmethod + def expectedParameters(expected): + DsscCorrection.DsscCalcatFriend.add_schema( + expected, DummyDsscDevice.managed_keys + ) + + def __init__(self, config): + # TODO: check config against schema (as Karabo would) + self.schema = config + self.calibration_constant_manager = DsscCorrection.DsscCalcatFriend( + self, + _test_calcat_secrets_fn, + ) + + +DummyDsscDevice.expectedParameters(DummyDsscDevice.device_class_schema) + + +@pytest.mark.skip(reason="Async currently behind lock, so no concurrent funt") +def test_agipd_constants_and_caching_and_async(): + # def test_agipd_constants(): + conf = Hash() + conf["constantParameters.detectorType"] = "AGIPD-Type" + conf["constantParameters.detectorName"] = "SPB_DET_AGIPD1M-1" + conf["constantParameters.karaboDa"] = "AGIPD00" + conf["constantParameters.pixelsX"] = 512 + conf["constantParameters.pixelsY"] = 128 + conf["constantParameters.memoryCells"] = 352 + conf["constantParameters.acquisitionRate"] = 1.1 + conf["constantParameters.biasVoltage"] = 300 + conf["constantParameters.gainSetting"] = 0 + conf["constantParameters.photonEnergy"] = 9.2 + device = DummyAgipdDevice(conf) + + def backcall(constant_name, metadata_and_data): + # TODO: think of something reasonable to check + data = metadata_and_data + assert data.nbytes > 1000 + + with Stopwatch() as timer_async_cold: + # TODO: put this sort of thing in BaseCalcatFriend + threads = [] + for constant in AgipdCorrection.AgipdConstants: + thread = device.calibration_constant_manager.get_constant_version_and_call_me_back( + constant, backcall + ) + threads.append(thread) + for thread in threads: + thread.join() + + with Stopwatch() as timer_async_warm: + threads = [] + for constant in AgipdCorrection.AgipdConstants: + thread = device.calibration_constant_manager.get_constant_version_and_call_me_back( + constant, backcall + ) + threads.append(thread) + for thread in threads: + thread.join() + + with Stopwatch() as timer_sync_warm: + for constant in AgipdCorrection.AgipdConstants: + data = device.calibration_constant_manager.get_constant_version( + constant, + ) + assert data.nbytes > 1000, "Should find some constant data" + + print(f"Cold async took {timer_async_cold.elapsed} s") + print(f"Warm async took {timer_async_warm.elapsed} s") + print(f"Warm sync took {timer_sync_warm.elapsed} s") + assert ( + timer_async_cold.elapsed > timer_async_warm.elapsed + ), "Caching should make second go faster" + assert timer_sync_warm.elapsed > timer_async_warm.elapsed, "Async should be faster" + + +def test_dssc_constants(): + conf = Hash() + conf["constantParameters.detectorType"] = "DSSC-Type" + conf["constantParameters.detectorName"] = "SCS_DET_DSSC1M-1" + conf["constantParameters.karaboDa"] = "DSSC00" + conf["constantParameters.memoryCells"] = 400 + conf["constantParameters.biasVoltage"] = 100 + conf["constantParameters.pixelsX"] = 512 + conf["constantParameters.pixelsY"] = 128 + # conf["constantParameters.pulseIdChecksum"] = 2.8866323107820637e-36 + # conf["constantParameters.acquisitionRate"] = 4.5 + # conf["constantParameters.encodedGain"] = 67328 + device = DummyDsscDevice(conf) + offset_map = device.calibration_constant_manager.get_constant_version("Offset") + + assert offset_map is not None diff --git a/src/tests/test_dssc_kernels.py b/src/tests/test_dssc_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..95c3ec55c580aa216eef871006955109c0a31ee4 --- /dev/null +++ b/src/tests/test_dssc_kernels.py @@ -0,0 +1,141 @@ +import numpy as np +import pytest + +from calng import DsscCorrection + +input_dtype = np.uint16 +output_dtype = np.float16 +corr_dtype = np.float32 +pixels_x = 512 +pixels_y = 128 +memory_cells = 400 +offset_map = ( + np.random.random(size=(pixels_x, pixels_y, memory_cells)).astype(corr_dtype) * 20 +) +cell_table = np.arange(memory_cells, dtype=np.uint16) +np.random.shuffle(cell_table) +raw_data = np.random.randint( + low=0, high=2000, size=(memory_cells, pixels_y, pixels_x), dtype=input_dtype +) + +# TODO: gather CPU implementations elsewhere +def correct_cpu(data, cell_table, offset_map): + corr = np.squeeze(data).astype(corr_dtype, copy=True) + safe_cell_bool = cell_table < offset_map.shape[-1] + safe_cell_index = cell_table[safe_cell_bool] + corr[safe_cell_bool] -= offset_map.transpose()[safe_cell_index] + return corr.astype(output_dtype, copy=False) + + +corrected_data = correct_cpu(raw_data, cell_table, offset_map) +only_cast_data = np.squeeze(raw_data).astype(output_dtype) + + +kernel_runner = DsscCorrection.DsscGpuRunner( + pixels_x, + pixels_y, + memory_cells, + constant_memory_cells=memory_cells, + input_data_dtype=input_dtype, + output_data_dtype=output_dtype, +) + + +def test_only_cast(): + kernel_runner.load_data(raw_data) + kernel_runner.correct(DsscCorrection.CorrectionFlags.NONE) + assert np.allclose( + kernel_runner.processed_data_gpu.get(), raw_data.astype(output_dtype) + ) + + +def test_correct(): + kernel_runner.load_offset_map(offset_map) + kernel_runner.load_data(raw_data) + kernel_runner.load_cell_table(cell_table) + kernel_runner.correct(DsscCorrection.CorrectionFlags.OFFSET) + assert np.allclose(kernel_runner.processed_data_gpu.get(), corrected_data) + + +def test_correct_oob_cells(): + kernel_runner.load_offset_map(offset_map) + kernel_runner.load_data(raw_data) + # here, half the cell IDs will be out of bounds + wild_cell_table = cell_table * 2 + kernel_runner.load_cell_table(wild_cell_table) + # should not crash + kernel_runner.correct(DsscCorrection.CorrectionFlags.OFFSET) + # should correct as much as possible + assert np.allclose( + kernel_runner.processed_data_gpu.get(), + correct_cpu(raw_data, wild_cell_table, offset_map), + ) + + +def test_reshape(): + kernel_runner.processed_data_gpu.set(corrected_data) + assert np.allclose( + kernel_runner.reshape(output_order="xyc"), corrected_data.transpose() + ) + + +def test_preview_slice(): + kernel_runner.load_data(raw_data) + kernel_runner.processed_data_gpu.set(corrected_data) + preview_raw, preview_corrected = kernel_runner.compute_previews(42) + assert np.allclose( + preview_raw, + raw_data[42].astype(np.float32), + ) + assert np.allclose( + preview_corrected, + corrected_data[42].astype(np.float32), + ) + + +def test_preview_max(): + # note: in case correction failed, still test this separately + kernel_runner.load_data(raw_data) + kernel_runner.processed_data_gpu.set(corrected_data) + preview_raw, preview_corrected = kernel_runner.compute_previews(-1) + assert np.allclose(preview_raw, np.max(raw_data, axis=0).astype(np.float32)) + assert np.allclose( + preview_corrected, np.max(corrected_data, axis=0).astype(np.float32) + ) + + +def test_preview_mean(): + kernel_runner.load_data(raw_data) + kernel_runner.processed_data_gpu.set(corrected_data) + preview_raw, preview_corrected = kernel_runner.compute_previews(-2) + assert np.allclose(preview_raw, np.nanmean(raw_data, axis=0, dtype=np.float32)) + assert np.allclose( + preview_corrected, np.nanmean(corrected_data, axis=0, dtype=np.float32) + ) + + +def test_preview_sum(): + kernel_runner.load_data(raw_data) + kernel_runner.processed_data_gpu.set(corrected_data) + preview_raw, preview_corrected = kernel_runner.compute_previews(-3) + assert np.allclose(preview_raw, np.nansum(raw_data, axis=0, dtype=np.float32)) + assert np.allclose( + preview_corrected, np.nansum(corrected_data, axis=0, dtype=np.float32) + ) + + +def test_preview_std(): + kernel_runner.load_data(raw_data) + kernel_runner.processed_data_gpu.set(corrected_data) + preview_raw, preview_corrected = kernel_runner.compute_previews(-4) + assert np.allclose(preview_raw, np.nanstd(raw_data, axis=0, dtype=np.float32)) + assert np.allclose( + preview_corrected, np.nanstd(corrected_data, axis=0, dtype=np.float32) + ) + + +def test_preview_valid_index(): + with pytest.raises(ValueError): + kernel_runner.compute_previews(-5) + with pytest.raises(ValueError): + kernel_runner.compute_previews(memory_cells) diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91b1f280f63bb61f54e0325d94216ecfe5d90a44 --- /dev/null +++ b/src/tests/test_utils.py @@ -0,0 +1,102 @@ +import random +import threading +import time +import timeit + +import numpy as np +from calng import utils + + +def test_get_c_type(): + assert utils.np_dtype_to_c_type(np.float16) == "half" + assert utils.np_dtype_to_c_type(np.float32) == "float" + assert utils.np_dtype_to_c_type(np.float64) == "double" + + assert utils.np_dtype_to_c_type(np.uint8) == "unsigned char" + assert utils.np_dtype_to_c_type(np.uint16) == "unsigned short" + assert utils.np_dtype_to_c_type(np.uint32) in ("unsigned", "unsigned int") + assert utils.np_dtype_to_c_type(np.uint64) == "unsigned long" + + assert utils.np_dtype_to_c_type(np.int8) == "char" + assert utils.np_dtype_to_c_type(np.int16) == "short" + assert utils.np_dtype_to_c_type(np.int32) == "int" + assert utils.np_dtype_to_c_type(np.int64) == "long" + + +class TestThreadsafeCache: + def test_arg_key_wrap(self): + calls = [] + + @utils.threadsafe_cache + def fun(a, b, c=1, d=2, *args, **kwargs): + calls.append((a, b, c, d, args, kwargs)) + + # reordering kwargs /does/ matter because dicts are ordered now + # (note: functools.lru_cache doesn't sort, claims because of speed) + fun(1, 2, 3, 4, 5, six=6, seven=7) + fun(1, 2, 3, 4, 5, seven=7, six=6) + assert len(calls) == 2, "kwargs order matters" + calls.clear() + + # reordering kw-style positional args does not matter + fun(1, 2, 1, 2) + fun(a=1, c=1, b=2, d=2) + assert len(calls) == 1, "reordering regular args as kws doesn't matter" + # and omitting default values does not matter + fun(b=2, a=1) + fun(1, 2) + assert len(calls) == 1, "omitting default args doesn't matter" + + def test_threadsafeness(self): + # wow, synchronization (presumably) makes this take forever *without* the decorator... + from_was_called = [] + + base_sleep = 1 + random_sleep = 0.1 + + @utils.threadsafe_cache + def was_called(x): + time.sleep(random.random() * random_sleep + base_sleep) + from_was_called.append(x) + + threads = [] + num_threads = 1000 + letters = "abcd" + start_ts = timeit.default_timer() + for i in range(num_threads): + for l in letters: + thread = threading.Thread(target=was_called, args=(l,)) + thread.start() + threads.append(thread) + submitted_ts = timeit.default_timer() + print(f"Right after: {len(from_was_called)}") + for thread in threads: + thread.join() + stop_ts = timeit.default_timer() + total_time = stop_ts - start_ts + print(f"After join: {len(from_was_called)}") + print(f"Time to submit: {submitted_ts - start_ts}") + print(f"Wait for join: {stop_ts - submitted_ts}") + print(f"Total: {total_time}") + + # check that function was only called with each letter once + # this is where the decorator from functools will fail + assert len(from_was_called) == len( + letters + ), "Caching prevents recomputation due to threading" + + # check that the function was not locked too broadly (should run faster than sequential lower bound) + reasonable_time_to_spawn_thread = 0.45 / 1000 + cutoff = ( + len(letters) * base_sleep + reasonable_time_to_spawn_thread * num_threads + ) + print(f"Cutoff (sequential lower bound): {cutoff}") + assert ( + total_time < cutoff + ), "Locking should not be so broad as to make sequential" + print( + f"Each thread would have slept [{base_sleep}, {base_sleep + random_sleep})" + ) + + # check that time doesn't go backwards suddenly + assert total_time >= base_sleep, "These tests should measure time correctly"