diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index 276c223a6702e95a25e92aa06736c8c7bc0417ae..d54f984d950db06cc79bd56c5de50cb74461d8ae 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -1,4 +1,3 @@ -import pathlib import timeit import numpy as np @@ -15,7 +14,7 @@ from karabo.common.states import State from . import utils from ._version import version as deviceVersion from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags -from .base_correction import BaseCorrection +from .base_correction import BaseCorrection, add_correction_step_schema from .calcat_utils import AgipdCalcatFriend, AgipdConstants @@ -23,7 +22,7 @@ from .calcat_utils import AgipdCalcatFriend, AgipdConstants class AgipdCorrection(BaseCorrection): # subclass *must* set these attributes _correction_flag_class = CorrectionFlags - _correction_slot_names = ( + _correction_field_names = ( ("thresholding", CorrectionFlags.THRESHOLD), ("offset", CorrectionFlags.OFFSET), ("relGainPc", CorrectionFlags.REL_GAIN_PC), @@ -31,44 +30,52 @@ class AgipdCorrection(BaseCorrection): ("badPixels", CorrectionFlags.BPMASK), ) _gpu_runner_class = AgipdGpuRunner - _schema_cache_slots = BaseCorrection._schema_cache_slots | {"sendGainMap"} + _calcat_friend_class = AgipdCalcatFriend + _constant_enum_class = AgipdConstants + + # this is just extending (not mandatory) + _schema_cache_fields = BaseCorrection._schema_cache_fields | {"sendGainMap"} @staticmethod def expectedParameters(expected): super(AgipdCorrection, AgipdCorrection).expectedParameters(expected) + ( + STRING_ELEMENT(expected) + .key("gainMode") + .displayedName("Gain mode") + .assignmentOptional() + .defaultValue("ADAPTIVE_GAIN") + .options("ADAPTIVE_GAIN,FIXED_HIGH_GAIN,FIXED_MEDIUM_GAIN,FIXED_LOW_GAIN") + .commit(), + BOOL_ELEMENT(expected) + .key("sendGainMap") + .displayedName("Send gain map on dataOutput") + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + ) AgipdCalcatFriend.add_schema(expected) - # TODO: encapsulate correction configuration subschema - for slot_name, _ in AgipdCorrection._correction_slot_names: - ( - BOOL_ELEMENT(expected) - .key(f"corrections.available.{slot_name}") - .readOnly() - .initialValue(False) - .commit(), - BOOL_ELEMENT(expected) - .key(f"corrections.enabled.{slot_name}") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - BOOL_ELEMENT(expected) - .key(f"corrections.preview.{slot_name}") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - ) + # this is not automatically done by superclass for complicated class reasons + add_correction_step_schema(expected, AgipdCorrection._correction_field_names) + # additional settings specific to AGIPD correction steps ( BOOL_ELEMENT(expected) - .key("corrections.overrideMdAdditionalOffset") + .key("corrections.relGainPc.overrideMdAdditionalOffset") .displayedName("Override md_additional_offset") + .description( + "Toggling this on will use the value in the next field globally for " + "md_additional_offset. Note that the correction map on GPU gets " + "overwritten as long as this boolean is True, so reload constants " + "after turning off." + ) .assignmentOptional() .defaultValue(False) .reconfigurable() .commit(), FLOAT_ELEMENT(expected) - .key("corrections.mdAdditionalOffset") + .key("corrections.relGainPc.mdAdditionalOffset") .displayedName("Value for md_additional_offset (if overriding)") .description( "Normally, md_additional_offset (part of relative gain correction) is " @@ -80,43 +87,24 @@ class AgipdCorrection(BaseCorrection): .defaultValue(0) .reconfigurable() .commit(), - ) - ( - STRING_ELEMENT(expected) - .key("gainMode") - .displayedName("Gain mode") - .assignmentOptional() - .defaultValue("ADAPTIVE_GAIN") - .options("ADAPTIVE_GAIN,FIXED_HIGH_GAIN,FIXED_MEDIUM_GAIN,FIXED_LOW_GAIN") - .commit(), - BOOL_ELEMENT(expected) - .key("sendGainMap") - .displayedName("Send gain map on dataOutput") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - ) - # TODO: hook this up to actual correction done - # NOTE: wanted as table, could not make readonly rows of reconfigurable bools - ( NODE_ELEMENT(expected) - .key("corrections.badPixelFlagsToUse") + .key("corrections.badPixels.subsetToUse") .displayedName("Bad pixel flags to use") .description( "The booleans under this node allow for selecting a subset of bad " - "pixel values to take into account when doing bad pixel masking. " + "pixel types to take into account when doing bad pixel masking. " "Upon updating these flags, the map used for bad pixel masking will " - "be ANDed with this selection. TEMPORARY NOTE: if you want to toggle " - "a disabled flag back on, please reload constants for this to take " + "be ANDed with this selection. Therefore, if you want to toggle a " + "disabled flag back on, please reload constants for this to take " "effect (will be triggered automatically in future version)." ) .commit(), ) + # TODO: DRY / encapsulate for field in BadPixelValues: ( BOOL_ELEMENT(expected) - .key(f"corrections.badPixelFlagsToUse.{field.name}") + .key(f"corrections.badPixels.subsetToUse.{field.name}") .assignmentOptional() .defaultValue(True) .reconfigurable() @@ -124,7 +112,7 @@ class AgipdCorrection(BaseCorrection): ) ( STRING_ELEMENT(expected) - .key("corrections.badPixelMaskValue") + .key("corrections.badPixels.maskingValue") .displayedName("Bad pixel masking value") .description( "Any pixels masked by the bad pixel mask will have their value " @@ -159,13 +147,11 @@ class AgipdCorrection(BaseCorrection): def __init__(self, config): super().__init__(config) - # TODO: consider putting this initialization in base class - self.calibration_constant_manager = AgipdCalcatFriend( - self, pathlib.Path.cwd() / "calibration-client-secrets.json" - ) # TODO: different gpu runner for fixed gain mode self.gain_mode = AgipdGainMode[config.get("gainMode")] - self.bad_pixel_mask_value = eval(config.get("corrections.badPixelMaskValue")) + self.bad_pixel_mask_value = eval( + config.get("corrections.badPixels.maskingValue") + ) self._gpu_runner_init_args = { "gain_mode": self.gain_mode, "bad_pixel_mask_value": self.bad_pixel_mask_value, @@ -179,9 +165,9 @@ class AgipdCorrection(BaseCorrection): self._update_shapes() # configurability: overriding md_additional_offset - if config.get("corrections.overrideMdAdditionalOffset"): + if config.get("corrections.relGainPc.overrideMdAdditionalOffset"): self._override_md_additional_offset = config.get( - "corrections.mdAdditionalOffset" + "corrections.relGainPc.mdAdditionalOffset" ) else: self._override_md_additional_offset = None @@ -191,25 +177,18 @@ class AgipdCorrection(BaseCorrection): self._update_bad_pixel_selection() self.updateState(State.ON) - self.KARABO_SLOT(self.loadMostRecentConstants) def process_input(self, data, metadata): """Registered for dataInput, handles all processing and sending""" - if not self._schema_cache["doAnything"]: - if self._schema_cache["state"] is State.PROCESSING: - self.updateState(State.ACTIVE) - return - source = metadata.get("source") if source not in self.sources: - self.log.INFO(f"Ignoring unknown source {source}") + self.log_status_info(f"Ignoring hash with unknown source {source}") return - # TODO: what are these empty things for? if not data.has("image"): - # self.log.INFO("Ignoring hash without image node") + self.log_status_info("Ignoring hash without image node") return time_start = timeit.default_timer() @@ -217,17 +196,16 @@ class AgipdCorrection(BaseCorrection): train_id = metadata.getAttribute("timestamp", "tid") cell_table = np.squeeze(data.get("image.cellId")) - assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray" if len(cell_table.shape) == 0: - msg = "cellId had 0 dimensions. DAQ may not be sending data." - self.set("status", msg) - self.log.WARN(msg) + self.log_status_warn( + "cellId had 0 dimensions. DAQ may not be sending data." + ) return # original shape: memory_cell, data/raw_gain, x, y image_data = data.get("image.data") if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: - self.set( - "status", f"Updating input shapes based on received {image_data.shape}" + self.log_status_info( + f"Updating input shapes based on received {image_data.shape}" ) self.set("dataFormat.memoryCells", image_data.shape[0]) with self._buffer_lock: @@ -236,7 +214,7 @@ class AgipdCorrection(BaseCorrection): if not self._schema_cache["state"] is State.PROCESSING: self.updateState(State.PROCESSING) - self.set("status", "Processing data") + self.log_status_info("Processing data") correction_cell_num = self._schema_cache["constantParameters.memoryCells"] do_generate_preview = ( @@ -250,13 +228,11 @@ class AgipdCorrection(BaseCorrection): cell_table_max = np.max(cell_table) # TODO: all this checking and warning can go into GPU runner if cell_table_max >= correction_cell_num: - msg = ( + self.log_status_info( f"Max cell ID ({cell_table_max}) exceeds range for loaded " f"constants ({correction_cell_num} cells). Some frames will not be " "corrected." ) - self.log.WARN(msg) - self.set("status", msg) self.gpu_runner.load_data(image_data) buffer_handle, buffer_array = self._shmem_buffer.next_slot() @@ -274,13 +250,10 @@ class AgipdCorrection(BaseCorrection): pulse_id_found = np.where(pulse_table == preview_slice_index)[0] if len(pulse_id_found) == 0: pulse_found_instead = pulse_table[0] - msg = ( - f"Pulse {preview_slice_index} not found in " - f"image.pulseId, arbitrary pulse " - f"{pulse_found_instead} will be shown." + self.log_status_info( + f"Pulse {preview_slice_index} not found, arbitrary pulse " + f"{pulse_found_instead} will be shown instead." ) - self.log.WARN(msg) - self.set("status", msg) preview_slice_index = 0 else: preview_slice_index = pulse_id_found[0] @@ -288,6 +261,9 @@ class AgipdCorrection(BaseCorrection): preview_slice_index ) + if self._schema_cache["sendGainMap"]: + data.set("image.gainMap", ImageData(self.gpu_runner.get_gain_map())) + data.set("image.data", buffer_handle) data.set("image.cellId", cell_table[:, np.newaxis]) data.set("image.pulseId", pulse_table[:, np.newaxis]) @@ -304,55 +280,38 @@ class AgipdCorrection(BaseCorrection): time_spent = timeit.default_timer() - time_start self._processing_time_ema.update(time_spent) - def loadMostRecentConstants(self): - self.flush_constants() - self.calibration_constant_manager.flush_constants() - for constant in AgipdConstants: - self.calibration_constant_manager.get_constant_version_and_call_me_back( - constant, self._load_constant_to_gpu - ) - def _load_constant_to_gpu(self, constant, constant_data): # TODO: encode correction / constant dependencies in a clever way if constant is AgipdConstants.ThresholdsDark: + field_name = "thresholding" # TODO: (reverse) mapping, DRY if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode") return self.gpu_runner.load_thresholds(constant_data) - if not self.get("corrections.available.thresholding"): - self.set("corrections.available.thresholding", True) - self.set("corrections.enabled.thresholding", True) - self.set("corrections.preview.thresholding", True) elif constant is AgipdConstants.Offset: + field_name = "offset" self.gpu_runner.load_offset_map(constant_data) - if not self.get("corrections.available.offset"): - self.set("corrections.available.offset", True) - self.set("corrections.enabled.offset", True) - self.set("corrections.preview.offset", True) elif constant is AgipdConstants.SlopesPC: + field_name = "relGainPc" self.gpu_runner.load_rel_gain_pc_map(constant_data) - if not self.get("corrections.available.relGainPc"): - self.set("corrections.available.relGainPc", True) - self.set("corrections.enabled.relGainPc", True) - self.set("corrections.preview.relGainPc", True) if self._override_md_additional_offset is not None: self.gpu_runner.md_additional_offset_gpu.fill( self._override_md_additional_offset ) elif constant is AgipdConstants.SlopesFF: + field_name = "relGainXray" self.gpu_runner.load_rel_gain_ff_map(constant_data) - if not self.get("corrections.available.relGainXray"): - self.set("corrections.available.relGainXray", True) - self.set("corrections.enabled.relGainXray", True) - self.set("corrections.preview.relGainXray", True) elif "BadPixels" in constant.name: + field_name = "badPixels" self.gpu_runner.load_bad_pixels_map( constant_data, override_flags_to_use=self._override_bad_pixel_flags ) - if not self.get("corrections.available.badPixels"): - self.set("corrections.available.badPixels", True) - self.set("corrections.enabled.badPixels", True) - self.set("corrections.preview.badPixels", True) + + # switch relevant correction on if it just now became available + if not self.get(f"corrections.{field_name}.available"): + self.set(f"corrections.{field_name}.available", True) + self.set(f"corrections.{field_name}.enable", True) + self.set(f"corrections.{field_name}.preview", True) self._update_correction_flags() @@ -370,20 +329,20 @@ class AgipdCorrection(BaseCorrection): def _update_bad_pixel_selection(self): selection = 0 for field in BadPixelValues: - if self.get(f"corrections.badPixelFlagsToUse.{field.name}"): + if self.get(f"corrections.badPixels.subsetToUse.{field.name}"): selection |= field self._override_bad_pixel_flags = selection def preReconfigure(self, config): super().preReconfigure(config) if any( - path.startswith("corrections.badPixelFlagsToUse") + path.startswith("corrections.badPixels.subsetToUse") for path in config.getPaths() ): self._has_updated_bad_pixel_selection = False - if config.has("corrections.badPixelMaskValue"): + if config.has("corrections.badPixels.maskingValue"): self.bad_pixel_mask_value = eval( - config.get("corrections.badPixelMaskValue") + config.get("corrections.badPixels.maskingValue") ) self.gpu_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index b1a616aba9558b7e92da79977ee9457faaaad807..1687328cd0ebacaac0e7e1ea41ae1d09784da504 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -1,12 +1,12 @@ import timeit import numpy as np -from karabo.bound import BOOL_ELEMENT, KARABO_CLASSINFO +from karabo.bound import KARABO_CLASSINFO from karabo.common.states import State from . import utils from ._version import version as deviceVersion -from .base_correction import BaseCorrection +from .base_correction import BaseCorrection, add_correction_step_schema from .calcat_utils import DsscCalcatFriend, DsscConstants from .dssc_gpu import DsscGpuRunner, CorrectionFlags @@ -15,32 +15,16 @@ from .dssc_gpu import DsscGpuRunner, CorrectionFlags class DsscCorrection(BaseCorrection): # subclass *must* set these attributes _correction_flag_class = CorrectionFlags - _correction_slot_names = (("offset", CorrectionFlags.OFFSET),) + _correction_field_names = (("offset", CorrectionFlags.OFFSET),) _gpu_runner_class = DsscGpuRunner + _calcat_friend_class = DsscCalcatFriend + _constant_enum_class = DsscConstants @staticmethod def expectedParameters(expected): super(DsscCorrection, DsscCorrection).expectedParameters(expected) - for slot_name, _ in DsscCorrection._correction_slot_names: - ( - BOOL_ELEMENT(expected) - .key(f"corrections.available.{slot_name}") - .readOnly() - .initialValue(False) - .commit(), - BOOL_ELEMENT(expected) - .key(f"corrections.enabled.{slot_name}") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - BOOL_ELEMENT(expected) - .key(f"corrections.preview.{slot_name}") - .assignmentOptional() - .defaultValue(False) - .reconfigurable() - .commit(), - ) + DsscCalcatFriend.add_schema(expected) + add_correction_step_schema(expected, DsscCorrection._correction_field_names) @property def input_data_shape(self): @@ -59,52 +43,45 @@ class DsscCorrection(BaseCorrection): def __init__(self, config): super().__init__(config) - output_axis_order = config.get("dataFormat.outputAxisOrder") - if output_axis_order == "pixels-fast": - self._output_transpose = (0, 2, 1) - elif output_axis_order == "memorycells-fast": - self._output_transpose = (2, 1, 0) - else: - self._output_transpose = None + self._output_transpose = { + "pixels-fast": (0, 2, 1), + "memorycells-fast": (2, 1, 0), + "no-reshape": None, + }[config.get("dataFormat.outputAxisOrder")] self._update_shapes() self.updateState(State.ON) def process_input(self, data, metadata): """Registered for dataInput, handles all processing and sending""" - if not self._schema_cache["doAnything"]: - if self._schema_cache["state"] is State.PROCESSING: - self.updateState(State.ACTIVE) - return - # TODO: compare KARABO_ON_INPUT (old) against KARABO_ON_DATA (current) source = metadata.get("source") if source not in self.sources: - self.log.INFO(f"Ignoring unknown source {source}") + self.log_status_info(f"Ignoring hash with unknown source {source}") return if not data.has("image"): - self.log.INFO("Ignoring hash without image node") + self.log_status_info("Ignoring hash without image node") return time_start = timeit.default_timer() + self._last_processing_started = time_start train_id = metadata.getAttribute("timestamp", "tid") cell_table = np.squeeze(data.get("image.cellId")) assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray" if len(cell_table.shape) == 0: - msg = "cellId had 0 dimensions. DAQ may not be sending data." - self.set("status", msg) - self.log.WARN(msg) + self.log_status_warn( + "cellId had 0 dimensions. DAQ may not be sending data." + ) return # original shape: 400, 1, 128, 512 (memory cells, something, y, x) image_data = data.get("image.data") if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]: - self.set( - "status", f"Updating input shapes based on received {image_data.shape}" + self.log_status_info( + f"Updating input shapes based on received {image_data.shape}" ) - # TODO: truncate if > 800 self.set("dataFormat.memoryCells", image_data.shape[0]) with self._buffer_lock: # TODO: pulse filter update after reimplementation @@ -112,9 +89,9 @@ class DsscCorrection(BaseCorrection): if not self._schema_cache["state"] is State.PROCESSING: self.updateState(State.PROCESSING) - self.set("status", "Processing data") + self.log_status_info("Processing data") - correction_cell_num = self._schema_cache["dataFormat.constantMemoryCells"] + correction_cell_num = self._schema_cache["constantParameters.memoryCells"] do_generate_preview = ( train_id % self._schema_cache["preview.trainIdModulo"] == 0 and self._schema_cache["preview.enable"] @@ -123,16 +100,13 @@ class DsscCorrection(BaseCorrection): with self._buffer_lock: # cell_table = cell_table[self.pulse_filter] pulse_table = np.squeeze(data.get("image.pulseId")) # [self.pulse_filter] - cell_table_max = np.max(cell_table) if cell_table_max >= correction_cell_num: - msg = ( + self.log_status_info( f"Max cell ID ({cell_table_max}) exceeds range for loaded " f"constant ({correction_cell_num} cells). Some frames will not be " "corrected." ) - self.log.WARN(msg) - self.set("status", msg) self.gpu_runner.load_data(image_data) buffer_handle, buffer_array = self._shmem_buffer.next_slot() @@ -189,10 +163,12 @@ class DsscCorrection(BaseCorrection): assert np.max(new_filter) < self.get("dataFormat.memoryCells") self.pulse_filter = new_filter - def _load_constant_to_gpu(self, constant_name, constant_data): - assert constant_name == "Offset" + def _load_constant_to_gpu(self, constant, constant_data): + assert constant is DsscConstants.Offset self.gpu_runner.load_offset_map(constant_data) - if not self.get("corrections.available.offset"): - self.set("corrections.available.offset", True) - self.set("corrections.enabled.offset", True) - self.set("corrections.preview.offset", True) + if not self.get("corrections.offset.available"): + self.set("corrections.offset.available", True) + self.set("corrections.offset.enable", True) + self.set("corrections.offset.preview", True) + + self._update_correction_flags() diff --git a/src/calng/agipd_gpu.py b/src/calng/agipd_gpu.py index b291401344d7f9daba2d8cb8e26a15db5c82de57..c1d63382b4a0ad4fbfbbe9f3f01491be8c442eb3 100644 --- a/src/calng/agipd_gpu.py +++ b/src/calng/agipd_gpu.py @@ -81,8 +81,8 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): return self.processed_data_gpu def load_thresholds(self, threshold_map): - # shape: y, x, memory cell, threshold 0 / threshold 1 / 3 gain values - # TODO: do we need the gain values (in the constant) for anything? + # shape: y, x, memory cell, thresholds and gain values + # note: the gain values are something like means used to derive thresholds self.gain_thresholds_gpu.set( np.transpose(threshold_map[..., :2], (2, 1, 0, 3)).astype(np.float32) ) @@ -102,6 +102,7 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): hg_intercept = slopes_pc_map[1] mg_slope = slopes_pc_map[3] mg_intercept = slopes_pc_map[4] + # TODO: remove sanitization (should happen in constant preparation notebook) # from agipdlib.py: replace NaN with median (per memory cell) # note: suffixes in agipdlib are "_m" and "_l", should probably be "_I" for naughty_array in (hg_slope, hg_intercept, mg_slope, mg_intercept): @@ -110,7 +111,6 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): nan_cell, _, _ = np.where(nan_bool) naughty_array[nan_bool] = medians[nan_cell] - # TODO: verify that this clamping should be done (as in agipdlib.py) too_low_bool = naughty_array < 0.8 * medians[:, np.newaxis, np.newaxis] too_low_cell, _, _ = np.where(too_low_bool) naughty_array[too_low_bool] = medians[too_low_cell] @@ -155,14 +155,11 @@ class AgipdGpuRunner(base_gpu.BaseGpuRunner): # TODO: maybe extend in case constant has too few memory cells # TODO: maybe divide by something because new constants are absolute ... - # TODO: maybe clamp and replace NaNs like slopes_pc self.rel_gain_xray_map_gpu.set(np.transpose(slopes_ff_map).astype(np.float32)) def load_bad_pixels_map(self, bad_pixels_map, override_flags_to_use=None): print(f"Loading bad pixels with shape: {bad_pixels_map.shape}") # will simply OR with already loaded, does not take into account which ones - # TODO: allow configuring subset of bad pixels to care about - # TODO: allow configuring value for masked pixels # TODO: inquire what "mask for double size pixels" means if len(bad_pixels_map.shape) == 3: if bad_pixels_map.shape == ( diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 3f12c65744e04559907c719e85a5e99a9fdb9777..8d22d024830c05f742a4c00df5851a1b6126fe0f 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -1,3 +1,4 @@ +import pathlib import threading import timeit @@ -43,7 +44,7 @@ class BaseCorrection(PythonDevice): _correction_flag_class = None # subclass must override this with some enum class _gpu_runner_class = None # subclass must set this _gpu_runner_init_args = {} # subclass can set this (TODO: remove, design better) - _schema_cache_slots = { + _schema_cache_fields = { "doAnything", "dataFormat.memoryCells", "constantParameters.memoryCells", @@ -70,18 +71,6 @@ class BaseCorrection(PythonDevice): @staticmethod def expectedParameters(expected): ( - BOOL_ELEMENT(expected) - .key("doAnything") - .displayedName("Enable input processing") - .description( - "Toggle handling of input (at all). If False, the input handler of " - "this device will be skipped. Useful to decrease logspam if device is " - "misconfigured." - ) - .assignmentOptional() - .defaultValue(True) - .reconfigurable() - .commit(), INPUT_CHANNEL(expected).key("dataInput").commit(), # note: output schema not set, will be updated to match data later OUTPUT_CHANNEL(expected).key("dataOutput").commit(), @@ -313,27 +302,7 @@ class BaseCorrection(PythonDevice): ( NODE_ELEMENT(expected) .key("corrections") - .displayedName("Corrections to apply") - .commit(), - NODE_ELEMENT(expected) - .key("corrections.available") - .displayedName("Available given loaded constants") - .description( - "Corrections typically require some constants to be loaded. These " - "flags indicate which corrections can be done, given the constants " - "found and loaded so far. Corrections - for dataOutput or preview - " - "only happen if selected AND available." - ) - .commit(), - NODE_ELEMENT(expected) - .key("corrections.enabled") - .displayedName("Enabled (if available)") - .description("Corrections applied to data for dataOutput (main output)") - .commit(), - NODE_ELEMENT(expected) - .key("corrections.preview") - .displayedName("Preview (if available)") - .description("Corrections applied for corrected preview output") + .displayedName("Correction steps") .commit(), BOOL_ELEMENT(expected) .key("corrections.disableAll") @@ -355,7 +324,7 @@ class BaseCorrection(PythonDevice): def __init__(self, config): self._schema_cache = { - k: config.get(k) for k in self._schema_cache_slots if config.has(k) + k: config.get(k) for k in self._schema_cache_fields if config.has(k) } super().__init__(config) @@ -368,9 +337,12 @@ class BaseCorrection(PythonDevice): self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) self.gpu_runner = None # must call _update_shapes() in subclass init + self.calcat_friend = self._calcat_friend_class( + self, pathlib.Path.cwd() / "calibration-client-secrets.json" + ) + self._correction_flag_enabled = self._correction_flag_class.NONE self._correction_flag_preview = self._correction_flag_class.NONE - self._cached_constants = {} self._shmem_buffer = None self._has_set_output_schema = False @@ -392,11 +364,12 @@ class BaseCorrection(PythonDevice): callback=self._update_rate_and_state, ) self._buffer_lock = threading.Lock() + self.KARABO_SLOT(self.loadMostRecentConstants) self.KARABO_SLOT(self.requestScene) def preReconfigure(self, config): for path in config.getPaths(): - if path in self._schema_cache_slots: + if path in self._schema_cache_fields: self._schema_cache[path] = config.get(path) # TODO: pulse filter (after reimplementing) @@ -412,6 +385,14 @@ class BaseCorrection(PythonDevice): # will make postReconfigure handle shape update after merging schema self._has_updated_shapes = False + def log_status_info(self, msg): + self.log.INFO(msg) + self.set("status", msg) + + def log_status_warn(self, msg): + self.log.WARN(msg) + self.set("status", msg) + def postReconfigure(self): if not self._has_updated_shapes: self._update_shapes() @@ -422,15 +403,22 @@ class BaseCorrection(PythonDevice): """Wrapper around PythonDevice.set to enable caching "hot" schema elements""" if len(args) == 2: key, value = args - if key in self._schema_cache_slots: + if key in self._schema_cache_fields: self._schema_cache[key] = value super().set(*args) + def loadMostRecentConstants(self): + self.flush_constants() + self.calcat_friend.flush_constants() + for constant in self._constant_enum_class: + self.calcat_friend.get_constant_version_and_call_me_back( + constant, self._load_constant_to_gpu + ) + def flush_constants(self): """Override from CalibrationReceiverBaseDevice to also flush GPU buffers""" - # TODO: update when revamping constant retrieval - for correction_step, _ in self._correction_slot_names: - self.set(f"corrections.available.{correction_step}", False) + for correction_step, _ in self._correction_field_names: + self.set(f"corrections.{correction_step}.available", False) self.gpu_runner.flush_buffers() self._update_correction_flags() @@ -498,12 +486,12 @@ class BaseCorrection(PythonDevice): available = self._correction_flag_class.NONE enabled = self._correction_flag_class.NONE preview = self._correction_flag_class.NONE - for slot_name, flag in self._correction_slot_names: - if self.get(f"corrections.available.{slot_name}"): + for field_name, flag in self._correction_field_names: + if self.get(f"corrections.{field_name}.available"): available |= flag - if self.get(f"corrections.enabled.{slot_name}"): + if self.get(f"corrections.{field_name}.enable"): enabled |= flag - if self.get(f"corrections.preview.{slot_name}"): + if self.get(f"corrections.{field_name}.preview"): preview |= flag enabled &= available preview &= available @@ -565,10 +553,13 @@ class BaseCorrection(PythonDevice): **self._gpu_runner_init_args, ) - # TODO: put this under lock so dictionary doesn't change shape underneath us - for constant_name, constant_data in self._cached_constants.items(): - self.log.INFO(f"Reload constant {constant_name}") - self._load_constant_to_gpu(constant_name, constant_data) + with self._buffer_lock: + for ( + constant, + data, + ) in self.calcat_friend.cached_constants.items(): + self.log_status_info(f"Reload constant {constant}") + self._load_constant_to_gpu(constant, data) self._has_updated_shapes = True @@ -594,3 +585,45 @@ class BaseCorrection(PythonDevice): self._has_set_output_schema = False self.updateState(State.ON) self.signalEndOfStream("dataOutput") + + +def add_correction_step_schema(schema, field_flag_mapping): + for field_name, _ in field_flag_mapping: + node_name = f"corrections.{field_name}" + ( + NODE_ELEMENT(schema).key(node_name).commit(), + BOOL_ELEMENT(schema) + .key(f"{node_name}.available") + .displayedName("Available") + .description( + "This boolean indicates whether the necessary constants have been " + "loaded for this correction step to be applied. Enabling the " + "correction will have no effect unless this is True." + ) + .readOnly() + .initialValue(False) + .commit(), + BOOL_ELEMENT(schema) + .key(f"{node_name}.enable") + .displayedName("Enable") + .description( + "Controls whether to apply this correction step for main data " + "output - subject to availability. When constants are first loaded " + "and availability toggles on, this, too, will toggle on by default." + ) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + BOOL_ELEMENT(schema) + .key(f"{node_name}.preview") + .displayedName("Preview") + .description( + "Whether to apply this correction step for corrected preview " + "output. Notes in description of 'Enable' apply here, too." + ) + .assignmentOptional() + .defaultValue(False) + .reconfigurable() + .commit(), + ) diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py index 9f002407cb6cad68416727a8092d74ede14a89f8..2943b05f6d29c854f377113fef12f32838e0465b 100644 --- a/src/calng/calcat_utils.py +++ b/src/calng/calcat_utils.py @@ -218,6 +218,7 @@ class BaseCalcatFriend: self.device = device self.param_prefix = param_prefix self.status_prefix = status_prefix + self.cached_constants = {} if not secrets_fn.is_file(): self.device.log.WARN(f"Missing CalCat secrets file (expected {secrets_fn})") @@ -359,6 +360,7 @@ class BaseCalcatFriend: # TODO: handle FileNotFoundError if we are led astray with h5py.File(file_path, "r") as fd: constant_data = np.array(fd[resp["data"]["data_set_name"]]["data"]) + self.cached_constants[constant] = constant_data self._set_status(f"{constant.name}.found", True) return constant_data diff --git a/src/tests/test_calcat_utils.py b/src/tests/test_calcat_utils.py index c0c69fbe6191c730b6aa633ebae61670b6d69ee9..61beb37b0ecd497bdb5e05c4cea3f4570ef9214f 100644 --- a/src/tests/test_calcat_utils.py +++ b/src/tests/test_calcat_utils.py @@ -51,6 +51,9 @@ class DummyAgipdDevice: def get(self, key): return self.schema.get(key) + def set(self, key, value): + print(f'Would set "{key}" = {value}') + DummyAgipdDevice.expectedParameters(DummyAgipdDevice.device_class_schema) @@ -75,6 +78,9 @@ class DummyDsscDevice: def get(self, key): return self.schema.get(key) + def set(self, key, value): + print(f'Would set "{key}" = {value}') + DummyDsscDevice.expectedParameters(DummyDsscDevice.device_class_schema) @@ -96,7 +102,7 @@ def test_agipd_constants_and_caching_and_async(): def backcall(constant_name, metadata_and_data): # TODO: think of something reasonable to check - timestamp, data = metadata_and_data + data = metadata_and_data assert data.nbytes > 1000 with Stopwatch() as timer_async_cold: @@ -122,10 +128,10 @@ def test_agipd_constants_and_caching_and_async(): with Stopwatch() as timer_sync_warm: for constant in calcat_utils.AgipdConstants: - ts, ary = device.calibration_constant_manager.get_constant_version( + data = device.calibration_constant_manager.get_constant_version( constant, ) - assert ts is not None, "Some constants should be found" + assert data.nbytes > 1000, "Should find some constant data" print(f"Cold async took {timer_async_cold.elapsed} s") print(f"Warm async took {timer_async_warm.elapsed} s") @@ -149,6 +155,6 @@ def test_dssc_constants(): # conf["constantParameters.acquisitionRate"] = 4.5 # conf["constantParameters.encodedGain"] = 67328 device = DummyDsscDevice(conf) - ts, offset_map = device.calibration_constant_manager.get_constant_version("Offset") + offset_map = device.calibration_constant_manager.get_constant_version("Offset") - assert ts is not None + assert offset_map is not None