diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..39bac723a05dee5bc1facbf5dab2f900077adc8f --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +from setuptools import find_packages, setup + +setup( + name="calng", + version="0.0.0", + author="CAL team", + package_dir={"": "src"}, + packages=find_packages("src"), + entry_points={ + "karabo.bound_device": [ + "DsscCombinedCorrection = DsscCorrection.combined_correction_dssc:DsscCombinedCorrection", + ], + }, + package_data={}, + requires=[], +) diff --git a/src/DsscCorrection/combined_correction_dssc.py b/src/DsscCorrection/combined_correction_dssc.py new file mode 100644 index 0000000000000000000000000000000000000000..429798a7ce0f95e3b5d948a7a6821a71029a4d20 --- /dev/null +++ b/src/DsscCorrection/combined_correction_dssc.py @@ -0,0 +1,746 @@ +import copy +import re +import threading +import timeit + +import calibrationBase +import gpu_utils +import hashToSchema +import karabo.bound as bound +import numpy as np +import pycuda.compiler +import pycuda.driver +import pycuda.gpuarray +import pycuda.tools +import utils +from karabo.common.states import State + +from .cuda_pipeline import PyCudaPipeline + + +@bound.KARABO_CLASSINFO("DsscCombinedCorrection", "0.0.0") +class DsscCombinedCorrection(calibrationBase.CalibrationReceiverBaseDevice): + _dict_cache_slots = { + "applyCorrection", + "doAnything", + "dataFormat.memoryCells", + "dataFormat.pixelsX", + "dataFormat.pixelsY", + "preview.enable", + "preview.pulse", + "preview.trainIdModulo", + "processingStateTimeout", + "performance.rateUpdateOnEachInput", + "state", + } + + @staticmethod + def expectedParameters(expected): + DsscCombinedCorrection.addConstant( + "Offset", "Dark", expected, optional=True, mandatoryForIteration=True + ) + + bound.SLOT_ELEMENT(expected).key( + "askConnectedReadersToSendMySources" + ).displayedName("Request sources from connected RunToPipe").description( + "Only relevant for development environment. When running without a " + "CAL_MANAGER, we need to tell RunToPipe instances which sources " + "to send us." + ).commit() + + bound.BOOL_ELEMENT(expected).key("doAnything").displayedName( + "Enable input processing" + ).description( + "Toggle handling of input (at all). If False, the input handler " + "of this device will be skipped. Useful to decrease logspam if " + "device is misconfigured." + ).assignmentOptional().defaultValue( + True + ).reconfigurable().commit() + + bound.BOOL_ELEMENT(expected).key("applyCorrection").displayedName( + "Enable correction(s)" + ).description( + "Toggle whether or not correction(s) are applied to image data. " + "If false, this device still reshapes data to output shape, " + "applies the pulse filter, and casts to output dtype. Useful for " + "inspecting the raw data in the same format as corrected data." + ).assignmentOptional().defaultValue( + True + ).reconfigurable().commit() + + bound.INPUT_CHANNEL(expected).key("dataInput").commit() + # note: output schema not set, will be updated to match data later + bound.OUTPUT_CHANNEL(expected).key("dataOutput").commit() + + bound.VECTOR_STRING_ELEMENT(expected).key("fastSources").displayedName( + "Fast data sources" + ).description( + "Sources to fast data as provided in channel metadata. " + "Provide in the form source@path.in.hash to identify both source " + "and path in the source data hash.\n" + "Currently ignores the path.in.hash part (for hardcoded image.data)" + ).assignmentMandatory().commit() + + bound.STRING_ELEMENT(expected).key("pulseFilter").displayedName( + "Pulse filter" + ).description( + "Filter pulses: will be evaluated as array of indices to keep from data. " + "Can be anything which can be turned into numpy uint16 array. " + "Numpy is available as np. " + "Take care not to include duplicates. " + "If empty, will not filter at all." + ).assignmentOptional().defaultValue( + "" + ).reconfigurable().commit() + + bound.NODE_ELEMENT(expected).key("dataFormat").displayedName( + "Data format (in/out)" + ).commit() + bound.STRING_ELEMENT(expected).key("dataFormat.inputImageDtype").displayedName( + "Input image data dtype" + ).description("The (numpy) dtype to expect for incoming image data.").options( + "uint16,float32" + ).assignmentOptional().defaultValue( + "uint16" + ).commit() + bound.STRING_ELEMENT(expected).key("dataFormat.outputImageDtype").displayedName( + "Output image data dtype" + ).description( + "The (numpy) dtype to use for outgoing image data. " + "Input is cast to float32, corrections are applied, and only then " + "will the result be cast back to outputImageDtype (all on GPU)." + ).options( + "float16,float32,uint16" + ).assignmentOptional().defaultValue( + "float32" + ).commit() + # important: shape of data as going into correction + bound.UINT32_ELEMENT(expected).key("dataFormat.pixelsX").displayedName( + "Pixels x" + ).description( + "Number of pixels of image data along X axis" + ).assignmentMandatory().commit() + bound.UINT32_ELEMENT(expected).key("dataFormat.pixelsY").displayedName( + "Pixels y" + ).description( + "Number of pixels of image data along Y axis" + ).assignmentMandatory().commit() + bound.UINT32_ELEMENT(expected).key("dataFormat.memoryCells").displayedName( + "Memory cells" + ).description( + "Full number of memory cells in incoming data" + ).assignmentMandatory().commit() + bound.VECTOR_UINT32_ELEMENT(expected).key( + "dataFormat.inputDataShape" + ).displayedName("Input data shape").description( + "Image data shape in incoming data (from reader / DAQ). " + "Value computed from pixelsX, pixelsY, and memoryCells - " + "this slot is just showing you what is currently expected." + ).readOnly().initialValue( + [] + ).commit() + bound.VECTOR_UINT32_ELEMENT(expected).key( + "dataFormat.outputDataShape" + ).displayedName("Output data shape").description( + "Image data shape for data output from this device. " + "Value computed from pixelsX, pixelsY, and the size of the pulse filter - " + "this slot is just showing what is currently expected." + ).readOnly().initialValue( + [] + ).commit() + + bound.UINT32_ELEMENT(expected).key("outputShmemBufferLength").displayedName( + "Output buffer length" + ).description( + "Corrected trains are written to shared memory locations. These are " + "pre-allocated and re-used. This parameter determines how big " + "(number of trains) the circular buffer will be." + ).assignmentOptional().defaultValue( + 50 + ).commit() + + # preview schema (WIP) + bound.NODE_ELEMENT(expected).key("preview").displayedName("Preview").commit() + preview_schema = bound.Schema() + bound.NODE_ELEMENT(preview_schema).key("data").commit() + bound.NDARRAY_ELEMENT(preview_schema).key("data.adc").dtype("FLOAT").commit() + bound.OUTPUT_CHANNEL(expected).key("preview.outputRaw").dataSchema( + preview_schema + ).commit() + bound.OUTPUT_CHANNEL(expected).key("preview.outputCorrected").dataSchema( + preview_schema + ).commit() + bound.BOOL_ELEMENT(expected).key("preview.enable").displayedName( + "Enable preview data generation" + ).assignmentOptional().defaultValue(True).reconfigurable().commit() + bound.INT32_ELEMENT(expected).key("preview.pulse").displayedName( + "Pulse (or stat) for preview" + ).description( + "If this value is ≥ 0, the corresponding index from data will be " + "sliced for the preview. If this value is ≤ 0, preview will be one " + "of the following stats:\n" + "-1: max\n" + "-2: mean\n" + "-3: sum\n" + "-4: stdev\n" + "Max means selecting the pulse with the maximum integrated value. " + "The others are computed across all filtered pulses in the train." + "Note that index slicing (≥ 0 case) currently does not take " + "image.pulseId into account, so certain pulse filters may yield " + "unexpected preview pulse shown." + ).assignmentOptional().defaultValue( + 0 + ).reconfigurable().commit() + bound.UINT32_ELEMENT(expected).key("preview.trainIdModulo").displayedName( + "Train modulo for throttling" + ).description( + "Preview will only be sent for trains whose ID modulo this number " + "is zero. Higher values means fewer preview updates. Should be " + "adjusted based on input rate. Keep in mind that the GUI has " + "limited refresh rate anyway and that network is precious." + ).assignmentOptional().defaultValue( + 6 + ).reconfigurable().commit() + + # timer-related settings + bound.NODE_ELEMENT(expected).key("performance").displayedName( + "Performance measures" + ).commit() + bound.FLOAT_ELEMENT(expected).key( + "performance.rateUpdateInterval" + ).displayedName("Rate update interval").description( + "Maximum interval (seconds) between updates of the rate. " + "Mostly relevant if not rateUpdateOnEachInput or if input is slow." + ).assignmentOptional().defaultValue( + 1 + ).reconfigurable().commit() + bound.FLOAT_ELEMENT(expected).key("performance.rateBufferSpan").displayedName( + "Rate measurement buffer span" + ).description( + "Event buffer timespan (in seconds) for measuring rate" + ).assignmentOptional().defaultValue( + 20 + ).reconfigurable().commit() + bound.BOOL_ELEMENT(expected).key( + "performance.rateUpdateOnEachInput" + ).displayedName("Update rate on each input").description( + "Whether or not to update the device rate for each input " + "(otherwise only based on rateUpdateInterval). " + "Note that processed trains are always registered - this just " + "impacts when the rate is computed based on this." + ).assignmentOptional().defaultValue( + False + ).reconfigurable().commit() + + bound.FLOAT_ELEMENT(expected).key("processingStateTimeout").description( + "Timeout after which the device goes from PROCESSING back to ACTIVE " + "if no new input is processed" + ).assignmentOptional().defaultValue(10).reconfigurable().commit() + + # just measurements and counters to display + bound.UINT64_ELEMENT(expected).key("trainId").displayedName( + "Train ID" + ).description( + "ID of latest train processed by this device." + ).readOnly().initialValue( + 0 + ).commit() + bound.FLOAT_ELEMENT(expected).key( + "performance.lastProcessingDuration" + ).displayedName("Processing time").description( + "Amount of time spent in processing latest train. " + "Time includes generating preview and sending data." + ).unit( + bound.Unit.SECOND + ).metricPrefix( + bound.MetricPrefix.MILLI + ).readOnly().initialValue( + 0 + ).commit() + bound.FLOAT_ELEMENT(expected).key("performance.rate").displayedName( + "Rate" + ).description( + "Actual rate with which this device gets / processes / sends trains" + ).unit( + bound.Unit.HERTZ + ).readOnly().initialValue( + 0 + ).commit() + bound.FLOAT_ELEMENT(expected).key("performance.theoreticalRate").displayedName( + "Processing rate (hypothetical)" + ).description( + "Rate with which this device could hypothetically process trains. " + "Based on lastProcessingDuration." + ).unit( + bound.Unit.HERTZ + ).readOnly().initialValue( + float("NaN") + ).warnLow( + 10 + ).info( + "Processing not fast enough for full speed" + ).needsAcknowledging( + False + ).commit() + + # stuff from typical calPy that we don't use right now + # Included to avoid errors due to unexpected configuration from init device + bound.STRING_ELEMENT(expected).key("sourceInfix").displayedName( + "[Disabled]" + ).assignmentOptional().defaultValue("").commit() + + bound.UINT32_ELEMENT(expected).key("maxGPUHandlesInFlight").displayedName( + "[Disabled]" + ).assignmentOptional().defaultValue(35).commit() + + bound.VECTOR_UINT32_ELEMENT(expected).key("gainMapping").displayedName( + "[Disabled]" + ).assignmentOptional().defaultValue([]).commit() + + bound.BOOL_ELEMENT(expected).key("dontProcess").displayedName( + "[Disabled]" + ).assignmentOptional().defaultValue(False).reconfigurable().commit() + + def __init__(self, config): + self._dict_cache = {k: config.get(k) for k in self._dict_cache_slots} + super().__init__(config) + pycuda.driver.init() + self.gpu_context = pycuda.tools.make_default_context() + + # very sneaky debugging + self.KARABO_SLOT(self.askConnectedReadersToSendMySources) + + self.KARABO_ON_DATA("dataInput", self.process_input) + self.KARABO_ON_EOS("dataInput", self.handle_eos) + + self.sources = set(config.get("fastSources")) + + self.input_data_dtype = getattr(np, config.get("dataFormat.inputImageDtype")) + self.output_data_dtype = getattr(np, config.get("dataFormat.outputImageDtype")) + self._update_pulse_filter(config.get("pulseFilter")) + self._update_shapes( + config.get("dataFormat.pixelsX"), + config.get("dataFormat.pixelsY"), + config.get("dataFormat.memoryCells"), + self.pulse_filter, + ) + self._has_set_output_schema = False + self._has_set_preview_output_schema = False + self._rate_tracker = calibrationBase.utils.UpdateRate( + interval=config.get("performance.rateBufferSpan") + ) + self.managerInstance = None + self.KARABO_SLOT(self.registerManager) + self._reset_timer = None + self.updateState(State.ON) + + self._buffered_status_update = bound.Hash( + "trainId", + 0, + "performance.rate", + 0, + "performance.theoreticalRate", + float("NaN"), + "performance.lastProcessingDuration", + 0, + ) + self._rate_update_timer = utils.RepeatingTimer( + interval=config.get("performance.rateUpdateInterval"), + callback=self._update_actual_rate, + ) + self._buffer_lock = threading.Lock() + + def get(self, key): + if key in self._dict_cache_slots: + return self._dict_cache.get(key) + else: + return super().get(key) + + def set(self, *args): + if len(args) == 2: + key, value = args + if key in self._dict_cache_slots: + self._dict_cache[key] = value + super().set(*args) + + def __del__(self): + self.gpu_context.detach() + + def preReconfigure(self, config): + if config.has("pulseFilter"): + with self._buffer_lock: + # apply new pulse filter + self._update_pulse_filter(config.get("pulseFilter")) + # but existing shapes (not reconfigurable) + self._update_shapes( + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + self.get("dataFormat.memoryCells"), + self.pulse_filter, + ) + + if config.has("performance.rateUpdateInterval"): + self._rate_update_timer.stop() + self._rate_update_timer = utils.RepeatingTimer( + interval=config.get("performance.rateUpdateInterval"), + callback=self._update_actual_rate, + ) + + if config.has("performance.rateBufferSpan"): + self._rate_tracker = calibrationBase.utils.UpdateRate( + interval=config.get("performance.rateBufferSpan") + ) + + for path in config.getPaths(): + if path in self._dict_cache_slots: + self._dict_cache[path] = config.get(path) + + def process_input(self, data, metadata): + """Registered for dataInput, handles all processing and sending + + Comparable to StreamBase.onInput but hopefully faster + + """ + + if not self.get("doAnything"): + if self.get("state") is State.PROCESSING: + self.updateState(State.ACTIVE) + return + + # TODO: compare KARABO_ON_INPUT (old) against KARABO_ON_DATA (current) + source = metadata.get("source") + + if source not in self.sources: + self.log.INFO(f"Ignoring unknown source {source}") + return + + # TODO: what are these empty things for? + if not data.has("image"): + self.log.INFO("Ignoring hash without image node") + return + + time_start = timeit.default_timer() + + train_id = metadata.getAttribute("timestamp", "tid") + cell_table = np.squeeze(data.get("image.cellId")) + assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray" + if not len(cell_table.shape) == 1: + self.set( + "status", f"Failed to process, cell table had shape {cell_table.shape}" + ) + return + # original shape: 400, 1, 128, 512 (memory cells, something, y, x) + # TODO: consider making paths configurable + image_data = data.get("image.data") + if image_data.shape[0] != self.get("dataFormat.memoryCells"): + self.set( + "status", f"Updating input shapes based on received {image_data.shape}" + ) + # TODO: truncate if > 800 + self.set("dataFormat.memoryCells", image_data.shape[0]) + with self._buffer_lock: + self._update_shapes( + self.get("dataFormat.pixelsX"), + self.get("dataFormat.pixelsY"), + self.get("dataFormat.memoryCells"), + self.pulse_filter, + ) + # TODO: check shape (DAQ fake data and RunToPipe don't agree) + # TODO: consider just updating shapes based on whatever comes in + + do_generate_preview = train_id % self.get( + "preview.trainIdModulo" + ) == 0 and self.get("preview.enable") + + if not self.get("state") is State.PROCESSING: + self.updateState(State.PROCESSING) + self.set("status", "Processing data") + if self._reset_timer is None: + self._reset_timer = utils.DelayableTimer( + timeout=self.get("processingStateTimeout"), + callback=self._reset_state_from_processing, + ) + else: + self._reset_timer.set_timeout(self.get("processingStateTimeout")) + + with self._buffer_lock: + cell_table = cell_table[self.pulse_filter] + pulse_table = np.squeeze(data.get("image.pulseId"))[self.pulse_filter] + + with gpu_utils.GPUContextContext(self.gpu_context): + self.gpu_buffer_cell_table.set(cell_table) + self.gpu_buffer_input_image_data.set(image_data) + self.pipeline.reshape( + self.gpu_buffer_input_image_data, + self.gpu_buffer_reshaped_image_data, + ) + if self.get("applyCorrection"): + buffer_handle, result = self.pipeline.correct( + self.gpu_buffer_reshaped_image_data, + self.gpu_buffer_cell_table, + ) + else: + buffer_handle, result = self.pipeline.only_cast( + self.gpu_buffer_reshaped_image_data + ) + if do_generate_preview: + preview_raw, preview_corrected = self.pipeline.compute_preview( + self.gpu_buffer_reshaped_image_data, + self.get("preview.pulse"), + ) + + data.set("image.data", buffer_handle) + data.set("image.cellId", cell_table[:, np.newaxis]) + data.set("image.pulseId", pulse_table[:, np.newaxis]) + self.write_output(data, metadata) + if do_generate_preview: + self.write_combiner_preview( + preview_raw, preview_corrected, train_id, source + ) + + # update rate etc. + self._buffered_status_update.set("trainId", train_id) + self._rate_tracker.update() + time_spent = timeit.default_timer() - time_start + self._buffered_status_update.set( + "performance.lastProcessingDuration", time_spent * 1000 + ) + if self.get("performance.rateUpdateOnEachInput"): + self._update_actual_rate() + + def handle_eos(self, channel): + self._has_set_output_schema = False + self.updateState(State.ON) + self.signalEndOfStream("dataOutput") + + def write_output(self, data, old_metadata): + metadata = bound.ChannelMetaData( + old_metadata.get("source"), + bound.Timestamp.fromHashAttributes(old_metadata.getAttributes("timestamp")), + ) + + if "image.passport" not in data: + data["image.passport"] = [] + data["image.passport"].append(self.getInstanceId()) + + if not self._has_set_output_schema: + self.updateState(State.CHANGING) + self._update_output_schema(data) + self.updateState(State.PROCESSING) + + channel = self.signalSlotable.getOutputChannel("dataOutput") + channel.write(data, metadata, False) + channel.update() + + def write_combiner_preview(self, data_raw, data_corrected, train_id, source): + # TODO: take into account updated pulse table after pulse filter + preview_hash = bound.Hash() + preview_hash.set("image.passport", [self.getInstanceId()]) + preview_hash.set("image.trainId", train_id) + preview_hash.set("image.pulseId", self.get("preview.pulse")) + + # note: have to construct because setting .tid after init is broken + timestamp = bound.Timestamp(bound.Epochstamp(), bound.Trainstamp(train_id)) + metadata = bound.ChannelMetaData(source, timestamp) + for channel_name, data in ( + ("preview.outputRaw", data_raw), + ("preview.outputCorrected", data_corrected), + ): + preview_hash.set("data.adc", data[..., np.newaxis]) + channel = self.signalSlotable.getOutputChannel(channel_name) + channel.write(preview_hash, metadata, False) + channel.update() + + def getConstant(self, name): + """Hacky override of getConstant to actually return None on failure + + Full function is from CalibrationReceiverBaseDevice + + """ + + const = super().getConstant(name) + if const is not None and len(const.shape) == 1: + self.log.WARN( + f"Constant {name} should probably be None, but is array" + f" of size {const.size}, shape {const.shape}" + ) + const = None + return const + + def constantLoaded(self): + """Hook from CalibrationReceiverBaseDevice called after each getConstant + + Here, used to load the received constants (or correction maps derived + fromt them) onto GPU. + + TODO: call after receiving *all* constants instead of calling once per + new constant (will cause some overhead for bigger devices) + + """ + + self._update_maps_on_gpu() + + def registerManager(self, instance_id): + """A hook from stream.py for Manager devices to register themselves + + instance_id should be the instance id of the manager device. The + registration is currently not really used I think. + + """ + + self.managerInstance = instance_id + self.log.INFO(f"Registered calibration manager {instance_id}") + + def askConnectedReadersToSendMySources(self): + """For all connected outputs, set my sources to True + + This is really something the cal_manager should handle. Will be removed + after prototyping phase. + + """ + + connected_friends = [ + # TODO: full set of possible device / channel separators + re.split(r"[@:]", channel)[0] + for channel in self.get("dataInput.connectedOutputChannels") + ] + + send_me_the_thing_hash = bound.Hash() + for source in self.get("fastSources"): + send_me_the_thing_hash[f"sources.{source}"] = True + + # ask them nicely to send it + for friend in connected_friends: + # TODO: check that device is a reader + self.signalSlotable.call(friend, "slotReconfigure", send_me_the_thing_hash) + # who needs middlelayer? we can roll out own setNowait 😎 + + def _update_output_schema(self, data): + """Updates the schema of dataOutput based on parameter data (a Hash) + + This should only be called once: when handling output for the first + time, we update the schema to match the modified data we'd send. + + """ + + self.log.INFO("Updating the output schema based on actual outgoing data") + my_schema = self.getFullSchema() + data_schema = hashToSchema.HashToSchema(data).schema + bound.OUTPUT_CHANNEL(my_schema).key("dataOutput").dataSchema( + data_schema + ).commit() + my_config = copy.copy(self.getCurrentConfiguration()) + self.updateSchema(my_schema) + self.log.INFO("Re-applying backed up config") + self.set(my_config) + self._has_set_output_schema = True + self.log.INFO("Ready to continue") + + def _update_pulse_filter(self, filter_string): + """Called whenever the pulse filter changes - typically followed by _update_shapes""" + + if filter_string.strip() == "": + new_filter = np.arange(self.get("dataFormat.memoryCells"), dtype=np.uint16) + else: + new_filter = np.array(eval(filter_string), dtype=np.uint16) + assert np.max(new_filter) < self.get("dataFormat.memoryCells") + self.pulse_filter = new_filter + + def _update_shapes(self, pixels_x, pixels_y, memory_cells, pulse_filter): + """(Re)initialize (GPU) buffers according to expected data shapes""" + + input_data_shape = (memory_cells, 1, pixels_y, pixels_x) + output_data_shape = (pixels_x, pixels_y, pulse_filter.size) + num_buffered_trains = self.get("outputShmemBufferLength") + self.set("dataFormat.inputDataShape", list(input_data_shape)) + self.set("dataFormat.outputDataShape", list(output_data_shape)) + + shmem_buffer_name = self.getInstanceId().replace("/", "_") + ":dataOutput" + + with gpu_utils.GPUContextContext(self.gpu_context): + self.pipeline = PyCudaPipeline( + pixels_x, + pixels_y, + memory_cells, + pulse_filter, + output_buffer_name=shmem_buffer_name, + output_buffer_size=num_buffered_trains, + input_data_dtype=self.input_data_dtype, + output_data_dtype=self.output_data_dtype, + ) + self.gpu_buffer_cell_table = pycuda.gpuarray.empty( + pulse_filter.size, dtype=np.uint16 + ) + self.gpu_buffer_input_image_data = pycuda.gpuarray.empty( + input_data_shape, dtype=self.input_data_dtype + ) + self.gpu_buffer_reshaped_image_data = pycuda.gpuarray.empty( + output_data_shape, dtype=self.input_data_dtype + ) + + self._update_maps_on_gpu() + + def _update_maps_on_gpu(self): + """Updates the correction maps stored on GPU based on constants known + + This only does something useful if constants have been retrieved from + CalCat. Should be called automatically upon retrieval and after + changing the data shape. + + """ + + self.set("status", "Updating constants on GPU using known constants") + self.updateState(State.CHANGING) + + offset_map = self.getConstant("Offset") + memory_cells = self.get("dataFormat.memoryCells") + if offset_map is None: + msg = f"Warning: Did not find offset constant, offset correction will not be applied" + self.set("status", msg) + self.log.WARN(msg) + else: + if len(offset_map.shape) in (3, 4): + self.log.INFO(f"Offset map known has shape {offset_map.shape}") + # this is from offsetcorrection_dssc.py + if len(offset_map.shape) == 4: # old format? + offset_map = np.squeeze(offset_map[..., 0]) + constant_memory_cells = offset_map.shape[-1] + if memory_cells > constant_memory_cells: + msg = ( + f"Warning: Memory cells in input ({memory_cells}) exceeded memory cells in constant ({constant_memory_cells}), offset correction will not be applied", + ) + self.set("status", msg) + self.log.WARN(msg) + else: + offset_map = offset_map[..., :memory_cells, :].astype(np.float32) + with gpu_utils.GPUContextContext(self.gpu_context): + self.pipeline.offset_map.set(offset_map) + msg = "Done transferring known constant(s) to GPU" + self.set("status", msg) + self.log.INFO(msg) + else: + msg = f"Offset map had unexpected shape {offset_map.shape}, offset correction will not be applied" + self.set("status", msg) + self.log.WARN(msg) + + self.updateState(State.ON) + + def _reset_state_from_processing(self): + if self.get("state") is State.PROCESSING: + self.updateState(State.ON) + self._reset_timer = None + + def _update_actual_rate(self): + if not self.get("state") is State.PROCESSING: + self._rate_update_timer.delay() + return + self._buffered_status_update.set("performance.rate", self._rate_tracker.rate()) + theoretical_rate = 1000 / self._buffered_status_update.get( + "performance.lastProcessingDuration" + ) + self._buffered_status_update.set( + "performance.theoreticalRate", theoretical_rate + ) + self.set(self._buffered_status_update) + self._rate_update_timer.delay() diff --git a/src/DsscCorrection/cuda_pipeline.py b/src/DsscCorrection/cuda_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a4e53356a21a37759bce105bb01e20b1d44b96 --- /dev/null +++ b/src/DsscCorrection/cuda_pipeline.py @@ -0,0 +1,309 @@ +import pathlib + +import jinja2 +import numpy as np +import posixshmem +import pycuda.gpuarray +import shmem_utils +import utils + + +class PyCudaPipeline: + """Class to handle instantiation and execution of CUDA kernels on trains + + Objects of this class will also maintain their own circular buffers of + ndarrays in shared memory to allow zero-copy handover of corrected data. + + """ + + _src_dir = pathlib.Path(__file__).absolute().parent + with (_src_dir / "gpu-dssc-correct.cpp").open("r") as fd: + _kernel_template = jinja2.Template(fd.read()) + + def __init__( + self, + pixels_x, + pixels_y, + memory_cells, + pulse_filter, + output_buffer_size=20, + output_buffer_name=None, + input_data_dtype=np.uint16, + output_data_dtype=np.float32, + ): + self.pixels_x = pixels_x + self.pixels_y = pixels_y + self.memory_cells = memory_cells + self.pulse_filter = pulse_filter + self.output_shape = (self.pixels_x, self.pixels_y, self.pulse_filter.size) + self.map_shape = (self.pixels_x, self.pixels_y, self.memory_cells) + # preview will only be single memory cell + self.preview_shape = self.output_shape[:-1] + + kernel_source = self._kernel_template.render( + { + "pixels_x": self.pixels_x, + "pixels_y": self.pixels_y, + "memory_cells": self.memory_cells, + "input_data_dtype": utils.numpy_dtype_to_c_type_str[input_data_dtype], + "output_data_dtype": utils.numpy_dtype_to_c_type_str[output_data_dtype], + "pulse_filter": pulse_filter, + } + ) + self.source_module = pycuda.compiler.SourceModule( + kernel_source, no_extern_c=True + ) + self.reshaping_kernel = self.source_module.get_function("reshape_4_3") + self.correction_kernel = self.source_module.get_function("correct") + self.casting_kernel = self.source_module.get_function("only_cast") + self.preview_slice_raw_kernel = self.source_module.get_function( + "cell_slice_preview_raw" + ) + self.preview_slice_corrected_kernel = self.source_module.get_function( + "cell_slice_preview_corrected" + ) + self.preview_stat_raw_kernel = self.source_module.get_function( + "cell_stat_preview_raw" + ) + self.preview_stat_corrected_kernel = self.source_module.get_function( + "cell_stat_preview_corrected" + ) + self.frame_sum_kernel = self.source_module.get_function("sum_frames") + + self.offset_map = pycuda.gpuarray.zeros(self.map_shape, dtype=np.float32) + + # reuse output arrays + self.gpu_result = pycuda.gpuarray.empty( + self.output_shape, dtype=output_data_dtype + ) + self.gpu_frame_sums = pycuda.gpuarray.empty( + self.pulse_filter.size, dtype=np.float32 + ) + self.gpu_preview_raw = pycuda.gpuarray.empty( + self.preview_shape, dtype=np.float32 + ) + self.gpu_preview_corrected = pycuda.gpuarray.empty( + self.preview_shape, dtype=np.float32 + ) + self.preview_raw = np.empty(self.preview_shape, dtype=np.float32) + self.preview_corrected = np.empty(self.preview_shape, dtype=np.float32) + self.output_buffer_mem = posixshmem.SharedMemory( + name=output_buffer_name, + size=self.gpu_result.nbytes * output_buffer_size, + rw=True, + ) + self.output_buffer_ary = self.output_buffer_mem.ndarray( + shape=(output_buffer_size,) + self.gpu_result.shape, + dtype=self.gpu_result.dtype, + ) + self.output_buffer_handle_template = ( + shmem_utils.handle_template_from_shmem_array( + self.output_buffer_mem, self.output_buffer_ary + ) + ) + self.output_buffer_next_index = 0 + + self.update_block_size(full_block=(1, 1, 64), preview_block=(1, 64, 1)) + + def update_block_size(self, full_block=None, preview_block=None): + """Execution is scheduled with 3d "blocks" of CUDA threads, tuning can + affect performance + + Grid size is automatically computed based on block size. Note that + individual kernels must themselves check whether they go out of bounds; + grid dimensions get rounded up in case ndarray size is not multiple of + block size. + + """ + if full_block is not None: + assert len(full_block) == 3 + self.full_block = tuple(full_block) + self.full_grid = tuple( + utils.ceil_div(a_length, block_length) + for (a_length, block_length) in zip(self.output_shape, full_block) + ) + if preview_block is not None: + self.preview_block = tuple(preview_block) + self.preview_grid = ( + utils.ceil_div(self.output_shape[0], preview_block[0]), + utils.ceil_div(self.output_shape[1], preview_block[1]), + 1, + ) + # TODO: make configurable + self.cell_reduction_block = (1, 1, 32) + self.cell_reduction_grid = ( + 1, + 1, + utils.ceil_div(self.output_shape[-1], self.cell_reduction_block[-1]), + ) + + def reshape(self, input_data, output_data): + """Do the reshaping and pulse filtering that the splitter would have done + + equivalent to: + output_data[:] = np.moveaxis( + np.squeeze(input_data), (0, 1, 2), (2, 1, 0) + )[..., pulse_filter] + """ + # TODO: Move to somewhere else + self.reshaping_kernel( + input_data, output_data, block=self.full_block, grid=self.full_grid + ) + + def correct(self, data, cell_table): + """Apply corrections to data + + Applies corrections to input data and casts to desired output dtype. + Parameter cell_table allows out of order or non-contiguous memory cells + in input data. Both input ndarrays are assumed to be on GPU already, + preferably wrapped in GPU arrays (pycuda.gpuarray). + + Will return string encoded handle to shared memory output buffer and + (view of) said buffer as an ndarray. Keep in mind that the output + buffers will get overwritten eventually (circular buffer). + + """ + self.correction_kernel( + data, + cell_table, + self.offset_map, + self.gpu_result, + block=self.full_block, + grid=self.full_grid, + ) + buffer_index = self.output_buffer_next_index + output_buffer = self.output_buffer_ary[buffer_index] + handle = self.output_buffer_handle_template.format(index=buffer_index) + self.gpu_result.get(ary=output_buffer) + self.output_buffer_next_index = ( + self.output_buffer_next_index + 1 + ) % self.output_buffer_ary.shape[0] + return handle, output_buffer + + def only_cast(self, data): + """Like correct without the correction + + This currently means just casting to output dtype. + """ + self.casting_kernel( + data, + self.gpu_result, + block=self.full_block, + grid=self.full_grid, + ) + buffer_index = self.output_buffer_next_index + output_buffer = self.output_buffer_ary[buffer_index] + handle = self.output_buffer_handle_template.format(index=buffer_index) + self.gpu_result.get(ary=output_buffer) + self.output_buffer_next_index = ( + self.output_buffer_next_index + 1 + ) % self.output_buffer_ary.shape[0] + return handle, output_buffer + + def compute_preview( + self, raw_data, cell_to_preview, has_just_corrected=True, verify=False + ): + """Generate single slice or reduction preview of raw and corrected data + + Special values of cell_to_preview are -1 for max, -2 for mean, -3 for + sum, and -4 for stdev (across cells). + + Note that cell_to_preview is taken from data without checking cell + table, so if a pulse filter not contiguous from 0 has been applied + first, the resulting cell will be offset. Cell table is only used to + get the correct slice of the correction map. + + raw_data should be a gpuarray + + Assumes that correction has just happened - meaning self.gpu_result + contains corrected data (corrected from raw_data). + + """ + + if cell_to_preview < -4: + raise ValueError(f"No statistic with code {cell_to_preview} defined") + elif cell_to_preview >= self.memory_cells: + raise ValueError(f"Memory cell index {cell_to_preview} out of range") + + # TODO: lift this restriction. + assert has_just_corrected + # TODO: enum around reduction type + if cell_to_preview >= 0: + self.preview_slice_raw_kernel( + raw_data, + np.int16(cell_to_preview), + self.gpu_preview_raw, + block=self.preview_block, + grid=self.preview_grid, + ) + self.preview_slice_corrected_kernel( + self.gpu_result, + np.int16(cell_to_preview), + self.gpu_preview_corrected, + block=self.preview_block, + grid=self.preview_grid, + ) + if verify: + assert np.allclose( + self.gpu_preview_raw.get(), + raw_data.get()[..., cell_to_preview], + ) + elif cell_to_preview == -1: + # TODO: select argmax independently for raw and corrected? + # TODO: send frame sums somewhere to compute global max frame + self.frame_sum_kernel( + self.gpu_result, + self.gpu_frame_sums, + block=self.cell_reduction_block, + grid=self.cell_reduction_grid, + ) + max_index = np.argmax(self.gpu_frame_sums.get()) + self.preview_slice_raw_kernel( + raw_data, + np.int16(max_index), + self.gpu_preview_raw, + block=self.preview_block, + grid=self.preview_grid, + ) + self.preview_slice_corrected_kernel( + self.gpu_result, + np.int16(max_index), + self.gpu_preview_corrected, + block=self.preview_block, + grid=self.preview_grid, + ) + if verify: + assert np.allclose( + self.gpu_preview_raw.get(), + raw_data.get()[ + ..., + np.argmax( + np.sum(raw_data.get(), axis=(0, 1), dtype=np.float32) + ), + ], + ) + elif cell_to_preview in (-2, -3, -4): + self.preview_stat_raw_kernel( + raw_data, # this is input_data_dtype + np.int16(cell_to_preview), + self.gpu_preview_raw, + block=self.preview_block, + grid=self.preview_grid, + ) + self.preview_stat_corrected_kernel( + self.gpu_result, # this is output_data_dtype + np.int16(cell_to_preview), + self.gpu_preview_corrected, + block=self.preview_block, + grid=self.preview_grid, + ) + if verify: + assert np.allclose( + self.gpu_preview_raw.get(), + {-2: np.mean, -3: np.sum, -4: np.std}[cell_to_preview]( + raw_data.get(), axis=2 + ), + ) + self.gpu_preview_raw.get(ary=self.preview_raw) + self.gpu_preview_corrected.get(ary=self.preview_corrected) + return self.preview_raw, self.preview_corrected diff --git a/src/DsscCorrection/gpu-dssc-correct.cpp b/src/DsscCorrection/gpu-dssc-correct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a86d9707dbd4fcbc6baf7f28d31ffb6c8f8d38a4 --- /dev/null +++ b/src/DsscCorrection/gpu-dssc-correct.cpp @@ -0,0 +1,239 @@ +#include <cuda_fp16.h> + +__device__ unsigned short pulse_filter[{{pulse_filter|length}}] = { {{pulse_filter|join(', ')}} }; + +extern "C" { + /* + Reshuffle data from shape like (400, 1, 128, 512) to shape like (512, 128, <=400) + That is, (cell, ???, y, x) to (x, y, cell) + Applies pulse filter; essentially taking subset of indices along memory cell axis + Equivalent to np.moveaxis(np.squeeze(data, (0, 1, 2), (2, 1, 0))) + */ + __global__ void reshape_4_3(const {{input_data_dtype}}* data, + {{input_data_dtype}}* output) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t extra_dim = 1; // mysterious extra dimension in incoming data + const size_t pulse_filter_size = {{pulse_filter|length}}; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t j = blockIdx.y * blockDim.y + threadIdx.y; + const size_t k = blockIdx.z * blockDim.z + threadIdx.z; + + if (i >= X || j >= Y || k >= pulse_filter_size) { + // in case block size doesn't fit perfectly, some threads do nothing + return; + } + + const size_t in_stride_3 = 1; + const size_t in_stride_2 = in_stride_3 * X; + const size_t in_stride_1 = in_stride_2 * Y; + const size_t in_stride_0 = in_stride_1 * extra_dim; + const size_t in_index = pulse_filter[k] * in_stride_0 // k is cell + + 0 * in_stride_1 // for completeness, the squeezed dimension + + j * in_stride_2 // j is y + + i * in_stride_3; // i is x + + const size_t out_stride_2 = 1; + const size_t out_stride_1 = out_stride_2 * pulse_filter_size; + const size_t out_stride_0 = out_stride_1 * Y; + const size_t out_index = i * out_stride_0 + j * out_stride_1 + k * out_stride_2; + output[out_index] = data[in_index]; + } + + /* + Perform correction: offset + Take cell_table into account when getting correction values + Converting to float for doing the correction + Converting to output dtype at the end + */ + __global__ void correct(const {{input_data_dtype}}* data, + const unsigned short* cell_table, + const float* offset_map, + {{output_data_dtype}}* output) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + // reshaped and output data have pulse filter length memory cells dim + const size_t filtered_memory_cells = {{pulse_filter|length}}; + // but correction map has "full size" + const size_t full_memory_cells = {{memory_cells}}; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t j = blockIdx.y * blockDim.y + threadIdx.y; + const size_t k = blockIdx.z * blockDim.z + threadIdx.z; + + if (i >= X || j >= Y || k >= filtered_memory_cells) { + return; + } + + // note: strides differ from numpy strides because unit here is sizeof(...), not byte + const size_t data_stride_2 = 1; + const size_t data_stride_1 = filtered_memory_cells * data_stride_2; + const size_t data_stride_0 = Y * data_stride_1; + const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2; + + const size_t map_stride_2 = 1; + const size_t map_stride_1 = full_memory_cells * map_stride_2; + const size_t map_stride_0 = Y * map_stride_1; + const size_t map_cell = cell_table[k]; + const size_t map_index = i * map_stride_0 + j * map_stride_1 + map_cell * map_stride_2; + + const float raw = (float)data[data_index]; + const float corrected = raw - offset_map[map_index]; + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(corrected); + {% else %} + output[data_index] = ({{output_data_dtype}})corrected; + {% endif %} + } + + /* + Same as correction, except don't do any correction + */ + __global__ void only_cast(const {{input_data_dtype}}* data, + {{output_data_dtype}}* output) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t memory_cells = {{pulse_filter|length}}; + + const size_t data_stride_2 = 1; + const size_t data_stride_1 = memory_cells * data_stride_2; + const size_t data_stride_0 = Y * data_stride_1; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t j = blockIdx.y * blockDim.y + threadIdx.y; + const size_t k = blockIdx.z * blockDim.z + threadIdx.z; + + if (i >= X || j >= Y || k >= memory_cells) { + return; + } + + const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2; + const float raw = (float)data[data_index]; + {% if output_data_dtype == "half" %} + output[data_index] = __float2half(raw); + {% else %} + output[data_index] = ({{output_data_dtype}})raw; + {% endif %} + } + + /* Kernels for preview + ≥0: just slice desired cell; uses cell_slice_preview_* + -1: slice cell with max integrated intensity (hybrid) + -2: mean; uses cell_stat_preview_* + -3: sum; ditto + -4: stdev; ditto + Note: for loop in template due to differing dtypes + TODO: simplify + - [ ] set up C++ compilation chain on ONC + - [ ] use C++ templates + Or even better: + - [ ] switch to cupy, all of this becomes trivial + */ + {% for (name_suffix, data_dtype) in (("raw", input_data_dtype), ("corrected", output_data_dtype)) %} + __global__ void cell_slice_preview_{{name_suffix}}(const {{data_dtype}}* data, + const short cell_to_preview, + float* preview) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t filtered_memory_cells = {{pulse_filter|length}}; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t j = blockIdx.y * blockDim.y + threadIdx.y; + + if (i >= X || j >= Y) { + return; + } + + const size_t preview_stride_1 = 1; + const size_t preview_stride_0 = Y * preview_stride_1; + const size_t preview_index = i * preview_stride_0 + j * preview_stride_1; + + const size_t data_stride_2 = 1; + const size_t data_stride_1 = filtered_memory_cells * data_stride_2; + const size_t data_stride_0 = Y * data_stride_1; + + const size_t data_index = i * data_stride_0 + j * data_stride_1 + cell_to_preview * data_stride_2; + + preview[preview_index] = (float)data[data_index]; + } + + __global__ void cell_stat_preview_{{name_suffix}}(const {{data_dtype}}* data, + const short preview_stat, + float* preview) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t filtered_memory_cells = {{pulse_filter|length}}; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t j = blockIdx.y * blockDim.y + threadIdx.y; + + if (i >= X || j >= Y) { + return; + } + + const size_t preview_stride_1 = 1; + const size_t preview_stride_0 = Y * preview_stride_1; + const size_t preview_index = i * preview_stride_0 + j * preview_stride_1; + + const size_t data_stride_2 = 1; + const size_t data_stride_1 = filtered_memory_cells * data_stride_2; + const size_t data_stride_0 = Y * data_stride_1; + + float sum = 0; + for (int k=0; k<filtered_memory_cells; ++k) { + const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2; + sum += (float)data[data_index]; + } + + if (preview_stat == -3) { + // just sum + preview[preview_index] = sum; + } else if (preview_stat == -2) { + // mean + preview[preview_index] = sum / filtered_memory_cells; + } else if (preview_stat == -4) { + // standard deviation + const double mean = sum / filtered_memory_cells; + // try to reduce error by increasing precision on accumulator + double var = 0; + for (int k=0; k<filtered_memory_cells; ++k) { + const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2; + // but "compute" values the same (floats) + var += pow((double)data[data_index] - mean, 2); + } + var /= filtered_memory_cells; + preview[preview_index] = (float)sqrt(var); + } + } + {% endfor %} + + // used to find max integrated intensity frame + __global__ void sum_frames({{output_data_dtype}}* data, float* sums) { + const size_t X = {{pixels_x}}; + const size_t Y = {{pixels_y}}; + const size_t filtered_memory_cells = {{pulse_filter|length}}; + + const size_t memory_cell = blockIdx.z * blockDim.z + threadIdx.z; + + if (memory_cell >= filtered_memory_cells) { + return; + } + + const size_t data_stride_2 = 1; + const size_t data_stride_1 = filtered_memory_cells * data_stride_2; + const size_t data_stride_0 = Y * data_stride_1; + + float my_res = 0; + for (int i=0; i<X; ++i) { + for (int j=0; j<Y; ++j) { + const size_t data_index = i * data_stride_0 + + j * data_stride_1 + + memory_cell * data_stride_2; + const float raw = (float)data[data_index]; + my_res += raw; + } + } + sums[memory_cell] = my_res; + } +} diff --git a/src/gpu_utils.py b/src/gpu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c945aee1a5bf8b64fa2621de49cf19511cabd41b --- /dev/null +++ b/src/gpu_utils.py @@ -0,0 +1,59 @@ +import re + +import pycuda.driver +import pycuda.gpuarray + +_gpuptr_re = re.compile( + r"GPUPTR:(?P<gpu_pointer>\w+)" r"DEVID:(?P<device_id>.+)" r"SHAPE:(?P<shape>.+)" +) + + +def get_shape_from_ipc_handle(handle_string): + match = _gpuptr_re.match(handle_string) + return tuple(int(num) for num in match.group("shape").split(",")) + + +class IPCGPUArray: + """Context manager providing a GPUArray opened from string encoding IPC handle + + Arguments: + handle_string: String encoding a "GPU pointer" (IPC address) plus some more + stuff. This is "parsed" using _gpuptr_re. + dtype: self-explanatory (but make sure it is correct) + aray shape is parsed from the handle_string + """ + + def __init__(self, handle_string, dtype, gpu_pointer_re=None): + match = _gpuptr_re.match(handle_string) + assert match is not None + + self.dtype = dtype + self.handle_address = bytearray.fromhex(match.group("gpu_pointer")) + self.shape = tuple(int(num) for num in match.group("shape").split(",")) + # assuming contiguous C-order strides probably + # TODO: smarter + + self.open_handle = None + self.gpu_array = None + + def __enter__(self): + self.open_handle = pycuda.driver.IPCMemoryHandle(self.handle_address) + self.gpu_array = pycuda.gpuarray.GPUArray( + self.shape, dtype=self.dtype, gpudata=self.open_handle + ) + return self.gpu_array + + def __exit__(self, t, v, tb): + self.open_handle.close() + + +class GPUContextContext: + def __init__(self, gpu_context): + self.gpu_context = gpu_context + + def __enter__(self): + self.gpu_context.push() + return self.gpu_context + + def __exit__(self, t, v, tb): + self.gpu_context.pop() diff --git a/src/shmem_utils.py b/src/shmem_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2498a9a5492886d9be16d46ceb62544917b1943b --- /dev/null +++ b/src/shmem_utils.py @@ -0,0 +1,30 @@ +import numpy as np +import posixshmem + + +def parse_shmem_handle(handle_string): + buffer_name, dtype, shape, index = handle_string.split("$") + dtype = getattr(np, dtype) + shape = tuple(int(n) for n in shape.split(",")) + index = int(index) + return buffer_name, dtype, shape, index + + +def open_shmem_from_handle(handle_string, rw=False): + buffer_name, dtype, shape, _ = parse_shmem_handle(handle_string) + buffer_mem_size = np.dtype(dtype).itemsize * np.product(shape) + + buffer_mem = posixshmem.SharedMemory(name=buffer_name, size=buffer_mem_size, rw=rw) + + array = buffer_mem.ndarray( + shape=shape, + dtype=dtype, + ) + + # returning both; in case of rw (why), user should close memory allocation + return buffer_mem, array + + +def handle_template_from_shmem_array(mem, ary): + shape_str = ",".join(str(n) for n in ary.shape) + return f"{mem.name}${ary.dtype}${shape_str}${{index}}" diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f294f1ecbed450157defb934b49b2b05637dcf8c --- /dev/null +++ b/src/utils.py @@ -0,0 +1,115 @@ +import threading +import time +import timeit + +import numpy as np + +numpy_dtype_to_c_type_str = { + np.uint16: "unsigned short", + np.uint32: "unsigned short", + np.float16: "half", # warning: only in CUDA with special support + np.float32: "float", + np.float64: "double", +} + + +def ceil_div(num, denom): + return (num + denom - 1) // denom + + +class DelayableTimer: + """Start a timer which can be extended + + Useful for reverting to state after inactivity, for instance. + + timer defaults to timeit.default_timer - it should be a timer returning + globally increasing number of seconds. + """ + + def __init__(self, timeout, callback, timer=timeit.default_timer): + self.timer = timer + self.stop_time = self.timer() + timeout + + def runner(): + now = self.timer() + while now < self.stop_time: + diff = self.stop_time - now + time.sleep(diff) + now = self.timer() + callback() + + self.thread = threading.Thread(target=runner) + self.thread.start() + + def set_timeout(self, timeout): + """Delay stop time to now + timeout + + If now + timeout is sooner than already set timeout, this does nothing""" + self.stop_time = self.timer() + timeout + + def add_timeout(self, timeout): + """Simply add timeout to current stop time""" + self.stop_time += timeout + + +class RepeatingTimer: + """Similar to DelayableTimer, but will keep running with pre-set intervals""" + + def __init__(self, interval, callback, timer=timeit.default_timer, start_now=True): + self.timer = timer + self.stopped = True + self.interval = interval + self.callback = callback + if start_now: + self.start() + + def delay(self): + self.stop_time = self.timer() + self.interval + + def start(self): + self.stopped = False + self.stop_time = self.timer() + self.interval + + def runner(): + while not self.stopped: + now = self.timer() + while now < self.stop_time: + diff = self.stop_time - now + time.sleep(diff) + if self.stopped: + return + now = self.timer() + self.callback() + self.stop_time = self.timer() + self.interval + + self.thread = threading.Thread(target=runner) + self.thread.start() + + def stop(self): + self.stopped = True + + +class Throttler: + """Similar to DelayableTimer, but will keep running with pre-set intervals""" + + def __init__(self, interval, timer=timeit.default_timer): + self.timer = timer + self.interval = interval + self.latest_call = None + + def ready(self): + if self.latest_call is None: + return True + else: + return self.latest_call + self.interval <= self.timer() + + def update(self): + self.latest_call = self.timer() + + def ready_update(self): + time = self.timer() + if self.latest_call is None or self.latest_call + self.interval <= time: + self.latest_call = time + return True + else: + return False