diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py index badc2cee42b9da08e85cb82e3037846811a1cb04..14012742a24d9870efc0d368425b23e3b01eadca 100644 --- a/src/calng/AgipdCorrection.py +++ b/src/calng/AgipdCorrection.py @@ -10,7 +10,7 @@ from karabo.bound import ( ) from karabo.common.states import State -from . import shmem_utils, utils +from . import utils from ._version import version as deviceVersion from .agipd_gpu import AgipdGainMode, AgipdGpuRunner, BadPixelValues, CorrectionFlags from .base_correction import BaseCorrection, add_correction_step_schema, preview_schema @@ -28,7 +28,7 @@ class AgipdCorrection(BaseCorrection): ("relGainXray", CorrectionFlags.REL_GAIN_XRAY), ("badPixels", CorrectionFlags.BPMASK), ) - _gpu_runner_class = AgipdGpuRunner + _kernel_runner_class = AgipdGpuRunner _calcat_friend_class = AgipdCalcatFriend _constant_enum_class = AgipdConstants @@ -181,12 +181,12 @@ class AgipdCorrection(BaseCorrection): def __init__(self, config): super().__init__(config) - # TODO: different gpu runner for fixed gain mode + # TODO: consider different gpu runner for fixed gain mode self.gain_mode = AgipdGainMode[config.get("gainMode")] self.bad_pixel_mask_value = eval( config.get("corrections.badPixels.maskingValue") ) - self._gpu_runner_init_args = { + self._kernel_runner_init_args = { "gain_mode": self.gain_mode, "bad_pixel_mask_value": self.bad_pixel_mask_value, "g_gain_value": config.get("corrections.relGainXray.gGainValue"), @@ -235,7 +235,7 @@ class AgipdCorrection(BaseCorrection): return try: - self.gpu_runner.load_data(image_data) + self.kernel_runner.load_data(image_data) except ValueError as e: self.log_status_warn(f"Failed to load data: {e}") return @@ -243,16 +243,16 @@ class AgipdCorrection(BaseCorrection): self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct(self._correction_flag_enabled) - self.gpu_runner.reshape( + self.kernel_runner.load_cell_table(cell_table) + self.kernel_runner.correct(self._correction_flag_enabled) + self.kernel_runner.reshape( output_order=self._schema_cache["dataFormat.outputAxisOrder"], out=buffer_array, ) # after reshape, data for dataOutput is now safe in its own buffer if do_generate_preview: if self._correction_flag_enabled != self._correction_flag_preview: - self.gpu_runner.correct(self._correction_flag_preview) + self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, @@ -269,7 +269,7 @@ class AgipdCorrection(BaseCorrection): preview_corrected, preview_raw_gain, preview_gain_map, - ) = self.gpu_runner.compute_previews(preview_slice_index) + ) = self.kernel_runner.compute_previews(preview_slice_index) # reusing input data hash for sending data_hash.set("image.data", buffer_handle) @@ -291,30 +291,30 @@ class AgipdCorrection(BaseCorrection): source, ) - def _load_constant_to_gpu(self, constant, constant_data): + def _load_constant_to_runner(self, constant, constant_data): # TODO: encode correction / constant dependencies in a clever way if constant is AgipdConstants.ThresholdsDark: field_name = "thresholding" # TODO: (reverse) mapping, DRY if self.gain_mode is not AgipdGainMode.ADAPTIVE_GAIN: self.log.INFO("Loaded ThresholdsDark ignored due to fixed gain mode") return - self.gpu_runner.load_thresholds(constant_data) + self.kernel_runner.load_thresholds(constant_data) elif constant is AgipdConstants.Offset: field_name = "offset" - self.gpu_runner.load_offset_map(constant_data) + self.kernel_runner.load_offset_map(constant_data) elif constant is AgipdConstants.SlopesPC: field_name = "relGainPc" - self.gpu_runner.load_rel_gain_pc_map(constant_data) + self.kernel_runner.load_rel_gain_pc_map(constant_data) if self._override_md_additional_offset is not None: - self.gpu_runner.md_additional_offset_gpu.fill( + self.kernel_runner.md_additional_offset_gpu.fill( self._override_md_additional_offset ) elif constant is AgipdConstants.SlopesFF: field_name = "relGainXray" - self.gpu_runner.load_rel_gain_ff_map(constant_data) + self.kernel_runner.load_rel_gain_ff_map(constant_data) elif "BadPixels" in constant.name: field_name = "badPixels" - self.gpu_runner.load_bad_pixels_map( + self.kernel_runner.load_bad_pixels_map( constant_data, override_flags_to_use=self._override_bad_pixel_flags ) @@ -340,7 +340,7 @@ class AgipdCorrection(BaseCorrection): self._override_md_additional_offset = self.get( "corrections.relGainPc.mdAdditionalOffset" ) - self.gpu_runner.override_md_additional_offset( + self.kernel_runner.override_md_additional_offset( self._override_md_additional_offset ) else: @@ -352,10 +352,10 @@ class AgipdCorrection(BaseCorrection): update = self._prereconfigure_update_hash if update.has("corrections.relGainXray.gGainValue"): - self.gpu_runner.set_g_gain_value( + self.kernel_runner.set_g_gain_value( self.get("corrections.relGainXray.gGainValue") ) - self._gpu_runner_init_args["g_gain_value"] = self.get( + self._kernel_runner_init_args["g_gain_value"] = self.get( "corrections.relGainXray.gGainValue" ) @@ -363,8 +363,8 @@ class AgipdCorrection(BaseCorrection): self.bad_pixel_mask_value = eval( self.get("corrections.badPixels.maskingValue") ) - self.gpu_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) - self._gpu_runner_init_args[ + self.kernel_runner.set_bad_pixel_mask_value(self.bad_pixel_mask_value) + self._kernel_runner_init_args[ "bad_pixel_mask_value" ] = self.bad_pixel_mask_value @@ -388,8 +388,8 @@ class AgipdCorrection(BaseCorrection): data, ) in self.calcat_friend.cached_constants.items(): if "BadPixels" in constant.name: - self._load_constant_to_gpu(constant, data) + self._load_constant_to_runner(constant, data) self._update_bad_pixel_selection() - self.gpu_runner.override_bad_pixel_flags_to_use( + self.kernel_runner.override_bad_pixel_flags_to_use( self._override_bad_pixel_flags ) diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py index 30c99f7eaebf0923a589bf6cbbe2353638e97243..8698cf68ecf1bc9b9c9aab457c32ef7a2588e370 100644 --- a/src/calng/DsscCorrection.py +++ b/src/calng/DsscCorrection.py @@ -14,7 +14,7 @@ class DsscCorrection(BaseCorrection): # subclass *must* set these attributes _correction_flag_class = CorrectionFlags _correction_field_names = (("offset", CorrectionFlags.OFFSET),) - _gpu_runner_class = DsscGpuRunner + _kernel_runner_class = DsscGpuRunner _calcat_friend_class = DsscCalcatFriend _constant_enum_class = DsscConstants _managed_keys = BaseCorrection._managed_keys.copy() @@ -74,7 +74,7 @@ class DsscCorrection(BaseCorrection): return try: - self.gpu_runner.load_data(image_data) + self.kernel_runner.load_data(image_data) except ValueError as e: self.log_status_warn(f"Failed to load data: {e}") return @@ -82,15 +82,15 @@ class DsscCorrection(BaseCorrection): self.log_status_warn(f"Unknown exception when loading data to GPU: {e}") buffer_handle, buffer_array = self._shmem_buffer.next_slot() - self.gpu_runner.load_cell_table(cell_table) - self.gpu_runner.correct(self._correction_flag_enabled) - self.gpu_runner.reshape( + self.kernel_runner.load_cell_table(cell_table) + self.kernel_runner.correct(self._correction_flag_enabled) + self.kernel_runner.reshape( output_order=self._schema_cache["dataFormat.outputAxisOrder"], out=buffer_array, ) if do_generate_preview: if self._correction_flag_enabled != self._correction_flag_preview: - self.gpu_runner.correct(self._correction_flag_preview) + self.kernel_runner.correct(self._correction_flag_preview) ( preview_slice_index, preview_cell, @@ -102,7 +102,7 @@ class DsscCorrection(BaseCorrection): pulse_table, warn_func=self.log_status_warn, ) - preview_raw, preview_corrected = self.gpu_runner.compute_previews( + preview_raw, preview_corrected = self.kernel_runner.compute_previews( preview_slice_index, ) @@ -121,9 +121,9 @@ class DsscCorrection(BaseCorrection): source, ) - def _load_constant_to_gpu(self, constant, constant_data): + def _load_constant_to_runner(self, constant, constant_data): assert constant is DsscConstants.Offset - self.gpu_runner.load_offset_map(constant_data) + self.kernel_runner.load_offset_map(constant_data) if not self.get("corrections.offset.available"): self.set("corrections.offset.available", True) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index 49279fde4806810c1ab059a7c7eb4aba5f3a4f36..7e2386643eb1450d58d8870cd67d6961468bd6f0 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -52,12 +52,99 @@ preview_schema = Schema() .commit(), ) +# TODO: trim output schema / adapt to specific detectors +# currently: based on snapshot of actual output reusing AGIPD hash +output_schema = Schema() +( + NODE_ELEMENT(output_schema).key("image").commit(), + STRING_ELEMENT(output_schema) + .key("image.data") + .assignmentOptional() + .defaultValue("") + .commit(), + NDARRAY_ELEMENT(output_schema).key("image.length").dtype("UINT32").commit(), + NDARRAY_ELEMENT(output_schema).key("image.cellId").dtype("UINT16").commit(), + NDARRAY_ELEMENT(output_schema).key("image.pulseId").dtype("UINT64").commit(), + NDARRAY_ELEMENT(output_schema).key("image.status").commit(), + NDARRAY_ELEMENT(output_schema).key("image.trainId").dtype("UINT64").commit(), + VECTOR_STRING_ELEMENT(output_schema) + .key("calngShmemPaths") + .assignmentOptional() + .defaultValue(["image.data"]) + .commit(), + NODE_ELEMENT(output_schema).key("metadata").commit(), + STRING_ELEMENT(output_schema) + .key("metadata.source") + .assignmentOptional() + .defaultValue("") + .commit(), + NODE_ELEMENT(output_schema).key("metadata.timestamp").commit(), + INT32_ELEMENT(output_schema) + .key("metadata.timestamp.tid") + .assignmentOptional() + .defaultValue(0) + .commit(), + NODE_ELEMENT(output_schema).key("header").commit(), + INT32_ELEMENT(output_schema) + .key("header.minorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT32_ELEMENT(output_schema) + .key("header.majorTrainFormatVersion") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT32_ELEMENT(output_schema) + .key("header.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT64_ELEMENT(output_schema) + .key("header.linkId") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT64_ELEMENT(output_schema) + .key("header.dataId") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT64_ELEMENT(output_schema) + .key("header.pulseCount") + .assignmentOptional() + .defaultValue(0) + .commit(), + NDARRAY_ELEMENT(output_schema).key("header.reserved").commit(), + NDARRAY_ELEMENT(output_schema).key("header.magicNumberBegin").commit(), + NODE_ELEMENT(output_schema).key("detector").commit(), + INT32_ELEMENT(output_schema) + .key("detector.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), + NDARRAY_ELEMENT(output_schema).key("detector.data").commit(), + NODE_ELEMENT(output_schema).key("trailer").commit(), + NDARRAY_ELEMENT(output_schema).key("trailer.checksum").commit(), + NDARRAY_ELEMENT(output_schema).key("trailer.magicNumberEnd").commit(), + INT32_ELEMENT(output_schema) + .key("trailer.status") + .assignmentOptional() + .defaultValue(0) + .commit(), + INT32_ELEMENT(output_schema) + .key("trailer.trainId") + .assignmentOptional() + .defaultValue(0) + .commit(), +) + @KARABO_CLASSINFO("BaseCorrection", deviceVersion) class BaseCorrection(PythonDevice): - _correction_flag_class = None # subclass must set to some enum class - _gpu_runner_class = None # subclass must set this - _gpu_runner_init_args = {} # subclass can set this (TODO: remove, design better) + _correction_flag_class = None # subclass must set (ex.: dssc_gpu.CorrectionFlags) + _kernel_runner_class = None # subclass must set (ex.: dssc_gpu.DsscGpuRunner) + _kernel_runner_init_args = {} # optional extra args for runner _managed_keys = { "outputShmemBufferSize", "dataFormat.outputAxisOrder", @@ -69,7 +156,7 @@ class BaseCorrection(PythonDevice): "preview.selectionMode", "preview.trainIdModulo", "loadMostRecentConstants", - } # subclass must extend this and put it in schema + } # subclass can extend this, /must/ put it in schema as managedKeys _schema_cache_fields = { "doAnything", "constantParameters.memoryCells", @@ -87,15 +174,21 @@ class BaseCorrection(PythonDevice): "state", } # subclass should be aware of cache, but does not need to extend - def _load_constant_to_gpu(constant_name, constant_data): + def _load_constant_to_runner(constant_name, constant_data): + """Subclass must define how to process constants into correction maps and store + into appropriate buffers in (GPU or main) memory.""" raise NotImplementedError() @property def input_data_shape(self): + """Subclass must define expected input data shape in terms of dataFormat.{ + memoryCells,pixelsX,pixelsY} and any other axes.""" raise NotImplementedError() @property def output_data_shape(self): + """Shape of corrected image data sent on dataOutput. Depends on data format + parameters pixels x / y, and number of cells (optionally after frame filter).""" axis_lengths = { "x": self._schema_cache["dataFormat.pixelsX"], "y": self._schema_cache["dataFormat.pixelsY"], @@ -106,109 +199,36 @@ class BaseCorrection(PythonDevice): for axis in self._schema_cache["dataFormat.outputAxisOrder"] ) + def process_data( + self, + data_hash, + metadata, + source, + train_id, + image_data, + cell_table, + do_generate_preview, + ): + """Subclass must define data processing (presumably using the kernel runner). + Will be called by input_handler, which will take care of some common checks " + "and extracting the parameters given to process_data.""" + raise NotImplementedError() + @staticmethod def expectedParameters(expected): - output_schema = Schema() - ( - NODE_ELEMENT(output_schema).key("image").commit(), - STRING_ELEMENT(output_schema) - .key("image.data") - .assignmentOptional() - .defaultValue("") - .commit(), - NDARRAY_ELEMENT(output_schema).key("image.length").dtype("UINT32").commit(), - NDARRAY_ELEMENT(output_schema).key("image.cellId").dtype("UINT16").commit(), - NDARRAY_ELEMENT(output_schema) - .key("image.pulseId") - .dtype("UINT64") - .commit(), - NDARRAY_ELEMENT(output_schema).key("image.status").commit(), - NDARRAY_ELEMENT(output_schema) - .key("image.trainId") - .dtype("UINT64") - .commit(), - VECTOR_STRING_ELEMENT(output_schema) - .key("calngShmemPaths") - .assignmentOptional() - .defaultValue(["image.data"]) - .commit(), - NODE_ELEMENT(output_schema).key("metadata").commit(), - STRING_ELEMENT(output_schema) - .key("metadata.source") - .assignmentOptional() - .defaultValue("") - .commit(), - NODE_ELEMENT(output_schema).key("metadata.timestamp").commit(), - INT32_ELEMENT(output_schema) - .key("metadata.timestamp.tid") - .assignmentOptional() - .defaultValue(0) - .commit(), - NODE_ELEMENT(output_schema).key("header").commit(), - INT32_ELEMENT(output_schema) - .key("header.minorTrainFormatVersion") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT32_ELEMENT(output_schema) - .key("header.majorTrainFormatVersion") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT32_ELEMENT(output_schema) - .key("header.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT64_ELEMENT(output_schema) - .key("header.linkId") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT64_ELEMENT(output_schema) - .key("header.dataId") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT64_ELEMENT(output_schema) - .key("header.pulseCount") - .assignmentOptional() - .defaultValue(0) - .commit(), - NDARRAY_ELEMENT(output_schema).key("header.reserved").commit(), - NDARRAY_ELEMENT(output_schema).key("header.magicNumberBegin").commit(), - NODE_ELEMENT(output_schema).key("detector").commit(), - INT32_ELEMENT(output_schema) - .key("detector.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), - NDARRAY_ELEMENT(output_schema).key("detector.data").commit(), - NODE_ELEMENT(output_schema).key("trailer").commit(), - NDARRAY_ELEMENT(output_schema).key("trailer.checksum").commit(), - NDARRAY_ELEMENT(output_schema).key("trailer.magicNumberEnd").commit(), - INT32_ELEMENT(output_schema) - .key("trailer.status") - .assignmentOptional() - .defaultValue(0) - .commit(), - INT32_ELEMENT(output_schema) - .key("trailer.trainId") - .assignmentOptional() - .defaultValue(0) - .commit(), - ) - (OUTPUT_CHANNEL(expected).key("dataOutput").dataSchema(output_schema).commit(),) - ( INPUT_CHANNEL(expected).key("dataInput").commit(), - # note: output schema not set, will be updated to match data later + OUTPUT_CHANNEL(expected) + .key("dataOutput") + .dataSchema(output_schema) + .commit(), VECTOR_STRING_ELEMENT(expected) .key("fastSources") .displayedName("Fast data sources") .description( "Sources to get data from. Only incoming hashes from these sources " - "will be processed." + "will be processed. This will typically be a single entry of the form: " + "'[instrument]_DET_[detector]/DET/[channel]:xtdf'." ) .assignmentOptional() .defaultValue([]) @@ -218,12 +238,12 @@ class BaseCorrection(PythonDevice): .displayedName("Frame filter") .description( "The frame filter - if set - slices the input data. Frames not in the " - "filter will be discarded before any processing happens (will not get " - "to dataOutput or preview. Note that this filter goes by frame index " - "rather than cell ID or pulse ID; set accordingly (and only if you " - "know what you are doing as it can break everything). Filter is " - "evaluated into numpy uint16 array; a valid filter could be " - "'np.arange(0, 352, 2)'." + "filter will be discarded before any processing happens and will not " + "get to dataOutput or preview. Note that this filter goes by frame " + "index rather than cell ID or pulse ID; set accordingly. Handle with " + "care - an invalid filter can prevent all processing. The filter is " + "specified as a string which is evaluated into numpy uint16 array. A " + "valid filter could for eaxmple be 'np.arange(0, 352, 2)'." ) .assignmentOptional() .defaultValue("") @@ -237,7 +257,7 @@ class BaseCorrection(PythonDevice): .description( "Corrected trains are written to shared memory locations. These are " "pre-allocated and re-used (circular buffer). This parameter " - "determines how much memory to set aside for the buffer." + "determines how much memory to set aside for that buffer." ) .assignmentOptional() .defaultValue(10) @@ -262,8 +282,8 @@ class BaseCorrection(PythonDevice): "The shape of the image data ndarray as received from the " "DataAggregator is sometimes wrong - the axes are actually in a " "different order than the ndarray shape suggests. If this flag is on, " - "the shape of the ndarray will be overridden with the axis order we " - "expect." + "the shape of the ndarray will be overridden with the axis order which " + "was expected." ) .assignmentOptional() .defaultValue(True) @@ -281,15 +301,17 @@ class BaseCorrection(PythonDevice): .key("dataFormat.outputImageDtype") .displayedName("Output image data dtype") .description( - "The (numpy) dtype to use for outgoing image data. Input is " - "cast to float32, corrections are applied, and only then will " - "the result be cast back to outputImageDtype (all on GPU)." + "The (numpy) dtype to use for outgoing image data. Input is cast to " + "float32, corrections are applied, and only then will the result be " + "cast to outputImageDtype. Be aware that casting to integer type " + "causes truncation rather than rounding." ) + # TODO: consider adding rounding / binning for integer output .options("float16,float32,uint16") .assignmentOptional() .defaultValue("float32") .commit(), - # important: shape of data as going into correction + # important: determines shape of data as going into correction UINT32_ELEMENT(expected) .key("dataFormat.pixelsX") .displayedName("Pixels x") @@ -309,7 +331,7 @@ class BaseCorrection(PythonDevice): .displayedName("Memory cells") .description("Full number of memory cells in incoming data") .assignmentOptional() - .noDefaultValue() + .noDefaultValue() # subclass will want to set a default value .commit(), UINT32_ELEMENT(expected) .key("dataFormat.filteredFrames") @@ -336,7 +358,7 @@ class BaseCorrection(PythonDevice): .description( "Image data shape in incoming data (from reader / DAQ). This value is " "computed from pixelsX, pixelsY, and memoryCells - this field just " - "shows you what is currently expected." + "shows what is currently expected." ) .readOnly() .initialValue([]) @@ -346,7 +368,7 @@ class BaseCorrection(PythonDevice): .displayedName("Output data shape") .description( "Image data shape for data output from this device. This value is " - "computed from pixelsX, pixelsY, and the size of the pulse filter - " + "computed from pixelsX, pixelsY, and the size of the frame filter - " "this field just shows what is currently expected." ) .readOnly() @@ -358,6 +380,14 @@ class BaseCorrection(PythonDevice): SLOT_ELEMENT(expected) .key("loadMostRecentConstants") .displayedName("Load most recent constants") + .description( + "Calling this slot will flush all constant buffers and cause the " + "device to start querying CalCat for the most recent constants - all " + "constants applicable for this device - available with the currently " + "set constant parameters. This is typically called after " + "instantiating pipeline, after changing parameters, or after " + "generating new constants." + ) .commit() ) @@ -398,7 +428,8 @@ class BaseCorrection(PythonDevice): "The value of preview.index can be used in multiple ways, controlled " "by this value. If this is set to 'frame', preview.index is sliced " "directly from data. If 'cell' (or 'pulse') is selected, I will look " - "at cell (or pulse) table for the requested cell (or pulse ID)." + "at cell (or pulse) table for the requested cell (or pulse ID). " + "Special (stat) index values <0 are not affected by this." ) .options("frame,cell,pulse") .assignmentOptional() @@ -407,12 +438,12 @@ class BaseCorrection(PythonDevice): .commit(), UINT32_ELEMENT(expected) .key("preview.trainIdModulo") - .displayedName("Train modulo for throttling") + .displayedName("Preview train stride") .description( "Preview will only be generated for trains whose ID modulo this " - "number is zero. Higher values means fewer preview updates. Should be " - "adjusted based on input rate. Keep in mind that the GUI has limited " - "refresh rate anyway and that network is precious." + "number is zero. Higher values means less frequent preview updates. " + "Keep in mind that the GUI has limited refresh rate of 2 Hz. Should " + "take extra care if DAQ train stride is >1." ) .assignmentOptional() .defaultValue(6) @@ -449,7 +480,7 @@ class BaseCorrection(PythonDevice): .displayedName("Rate") .description( "Actual rate with which this device gets, processes, and sends trains. " - "This is a simple moving average." + "This is a simple windowed moving average." ) .unit(Unit.HERTZ) .readOnly() @@ -457,7 +488,7 @@ class BaseCorrection(PythonDevice): .commit(), ) - # this node will be filled out later + # this node will be filled out by subclass ( NODE_ELEMENT(expected) .key("corrections") @@ -484,7 +515,7 @@ class BaseCorrection(PythonDevice): self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype")) self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype")) - self.gpu_runner = None # must call _update_buffers() + self.kernel_runner = None # must call _update_buffers to initialize self.registerInitialFunction(self._update_frame_filter) self.registerInitialFunction(self._update_buffers) @@ -507,7 +538,7 @@ class BaseCorrection(PythonDevice): "performance.processingTime", 0, ) - self._last_processing_started = 0 # input handler should put timestamp + self._last_processing_started = 0 # used for processing time and timeout self._rate_update_timer = utils.RepeatingTimer( interval=1, callback=self._update_rate_and_state, @@ -522,7 +553,7 @@ class BaseCorrection(PythonDevice): def make_wrapper_capturing_constant(constant): def aux(): self.calcat_friend.get_specific_constant_version_and_call_me_back( - constant, self._load_constant_to_gpu + constant, self._load_constant_to_runner ) return aux @@ -569,7 +600,6 @@ class BaseCorrection(PythonDevice): if path in self._schema_cache_fields: self._schema_cache[path] = update.get(path) - # TODO: pulse filter (after reimplementing) if update.has("frameFilter"): with self._buffer_lock: self._update_frame_filter() @@ -593,14 +623,14 @@ class BaseCorrection(PythonDevice): self.calcat_friend.flush_constants() for constant in self._constant_enum_class: self.calcat_friend.get_constant_version_and_call_me_back( - constant, self._load_constant_to_gpu + constant, self._load_constant_to_runner ) def flush_constants(self): - """Override from CalibrationReceiverBaseDevice to also flush GPU buffers""" + """Reset constant buffers and disable corresponding correction steps""" for correction_step, _ in self._correction_field_names: self.set(f"corrections.{correction_step}.available", False) - self.gpu_runner.flush_buffers() + self.kernel_runner.flush_buffers() self._update_correction_flags() def log_status_info(self, msg): @@ -618,7 +648,8 @@ class BaseCorrection(PythonDevice): def set(self, *args): """Wrapper around PythonDevice.set to enable caching "hot" schema elements""" - if len(args) == 2: + # TODO: handle other cases of PythonDevice.set arguments + if len(args) == 2 and not isinstance(args[0], Hash): key, value = args if key in self._schema_cache_fields: self._schema_cache[key] = value @@ -653,6 +684,8 @@ class BaseCorrection(PythonDevice): self.reply(response) def _write_output(self, data, old_metadata): + """For dataOutput: reusing incoming data hash and setting source and timestamp + to be same as input""" metadata = ChannelMetaData( old_metadata.get("source"), Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), @@ -697,6 +730,9 @@ class BaseCorrection(PythonDevice): self.log.DEBUG(f"Corrections for preview: {str(preview)}") def _update_frame_filter(self): + """Parse frameFilter string (if set) and update cached filter array. Will set + dataFormat.filteredFrames, so one will typically want to call _update_buffers + afterwards.""" filter_string = self.get("frameFilter") if filter_string.strip() == "": self._frame_filter = None @@ -710,7 +746,7 @@ class BaseCorrection(PythonDevice): self.log_status_warn("Invalid frame filter set, expect exceptions!") def _update_buffers(self): - """(Re)initialize buffers according to expected data shapes""" + """(Re)initialize buffers / kernel runner according to expected data shapes""" self.log.INFO("Updating buffers according to data shapes") # reflect the axis reordering in the expected output shape self.set("dataFormat.inputDataShape", list(self.input_data_shape)) @@ -733,36 +769,36 @@ class BaseCorrection(PythonDevice): else: self._shmem_buffer.change_shape(self.output_data_shape) - self.gpu_runner = self._gpu_runner_class( + self.kernel_runner = self._kernel_runner_class( self.get("dataFormat.pixelsX"), self.get("dataFormat.pixelsY"), self.get("dataFormat.filteredFrames"), int(self.get("constantParameters.memoryCells")), input_data_dtype=self.input_data_dtype, output_data_dtype=self.output_data_dtype, - **self._gpu_runner_init_args, + **self._kernel_runner_init_args, ) + # TODO: gracefully handle change in constantParameters.memoryCells with self.calcat_friend.cached_constants_lock: for ( constant, data, ) in self.calcat_friend.cached_constants.items(): self.log_status_info(f"Reload constant {constant}") - self._load_constant_to_gpu(constant, data) + self._load_constant_to_runner(constant, data) def input_handler(self, input_channel): """Main handler for data input: Do a few simple checks to determine whether to - even try processing. If yes, will pass data and information to subclass' - process_data function. - """ + even try processing. If yes, will pass data and information to process_data + method provided by subclass.""" # Is device even ready for this? state = self._schema_cache["state"] if state is State.ERROR: # in this case, we should have already issued warning return - elif self.gpu_runner is None: + elif self.kernel_runner is None: self.log_status_warn("Received data, but have not initialized kernels yet") return @@ -844,7 +880,8 @@ class BaseCorrection(PythonDevice): self._buffered_status_update.set( "performance.processingTime", self._processing_time_ema.get() * 1000 ) - # trainId should be set on _buffered_status_update in input handler + # trainId in _buffered_status_update should be updated in input handler + self.set(self._buffered_status_update) if default_timer() - self._last_processing_started > PROCESSING_STATE_TIMEOUT: @@ -860,6 +897,21 @@ class BaseCorrection(PythonDevice): def add_correction_step_schema(schema, managed_keys, field_flag_mapping): + """Using the fields in the provided mapping, will add nodes to schema + + field_flag_mapping is assumed to be iterable of pairs where first entry in each + pair is the name of a correction step as it will appear in device schema (second + entry - typically an enum field - is ignored). For correction step, a node and some + booleans are added to the schema and the toggleable booleans are added to + managed_keys. Subclass can customize / add additional keys under node later. + + This method should be called in expectedParameters of subclass after the same for + BaseCorrection has been called. Would be nice to include in BaseCorrection instead, + but that is tricky: static method of superclass will need _correction_field_names + of subclass or device server gets mad. A nice solution with classmethods would be + welcome. + """ + for field_name, _ in field_flag_mapping: node_name = f"corrections.{field_name}" ( diff --git a/src/calng/calcat_utils.py b/src/calng/calcat_utils.py index 57f955a595fd94a3e79dca1ff297a10cc5cee8c5..49403228e64b6279de8ade0c75c391709cc9d4c3 100644 --- a/src/calng/calcat_utils.py +++ b/src/calng/calcat_utils.py @@ -103,13 +103,6 @@ def _add_status_schema_from_enum(schema, prefix, enum_class): ) -class DetectorStandin(typing.NamedTuple): - detector_name: str - modno_to_source: dict - frames_per_train: int - module_shape: tuple - - class OperatingConditions(dict): # TODO: support deviation? def encode(self): @@ -132,6 +125,16 @@ class OperatingConditions(dict): class BaseCalcatFriend: + """Base class for CalCat friends - handles interacting with CalCat for the device + + A CalCat friend uses the device schema to build up parameters for CalCat queries. + It focuses on two nodes (added by static method add_schema): param_prefix and + status_prefix. The former is primarily used to get parameters which are (via + condition methods - see for example dark_condition of DsscCalcatFriend) used + to look for constants. The latter is primarily used to give user information + about what was found. + """ + _constant_enum_class = None # subclass should set _constants_need_conditions = None # subclass should set @@ -272,6 +275,7 @@ class BaseCalcatFriend: self.status_prefix = status_prefix self.cached_constants = {} self.cached_constants_lock = threading.Lock() + # api lock used to force queries to be sequential (SSL issue on ONC) self.api_lock = threading.Lock() if not secrets_fn.is_file():