diff --git a/src/calng/JungfrauCorrection.py b/src/calng/JungfrauCorrection.py index bf7b981560b7f2439a232350ac3ca42b77ad313d..028e2678d3984ade9a6fb3a15b80612f176e7219 100644 --- a/src/calng/JungfrauCorrection.py +++ b/src/calng/JungfrauCorrection.py @@ -20,6 +20,7 @@ _pretend_pulse_table = np.arange(16, dtype=np.uint8) class JungfrauConstants(enum.Enum): Offset10Hz = enum.auto() BadPixelsDark10Hz = enum.auto() + RelativeGain10Hz = enum.auto() class JungfrauGainMode(enum.IntEnum): @@ -75,6 +76,9 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner): self.offset_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.float32) self.rel_gain_map_gpu = cupy.ones(self.map_shape, dtype=cupy.float32) self.bad_pixel_map_gpu = cupy.zeros(self.map_shape, dtype=cupy.uint32) + self.bad_pixel_mask_value = bad_pixel_mask_value + + self.update_block_size((1, 1, 64)) def _init_kernels(self): kernel_source = self._kernel_template.render( @@ -106,7 +110,17 @@ class JungfrauGpuRunner(base_gpu.BaseGpuRunner): """Experiment: loading all three in one function as they are tied""" self.input_data_gpu.set(image_data) self.input_gain_map_gpu.set(input_gain_map) - self.cell_table_gpu.set(cell_table) + if self.burst_mode: + self.cell_table_gpu.set(cell_table) + + def load_constant(self, constant, constant_data): + if constant is JungfrauConstants.Offset10Hz: + self.offset_map_gpu.set(constant_data.astype(np.float32)) + elif constant is JungfrauConstants.RelativeGain10Hz: + self.rel_gain_map_gpu.set(constant_data.astype(np.float32)) + elif constant is JungfrauConstants.BadPixelsDark10Hz: + self.bad_pixel_map_gpu.set(constant_data) + def correct(self, flags): self.correction_kernel( @@ -134,6 +148,7 @@ class JungfrauCalcatFriend(calcat_utils.BaseCalcatFriend): self._constants_need_conditions = { JungfrauConstants.Offset10Hz: self.dark_condition, JungfrauConstants.BadPixelsDark10Hz: self.dark_condition, + JungfrauConstants.RelativeGain10Hz: self.dark_condition, } @staticmethod @@ -243,6 +258,8 @@ class JungfrauCorrection(BaseCorrection): _calcat_friend_class = JungfrauCalcatFriend _constant_enum_class = JungfrauConstants _managed_keys = BaseCorrection._managed_keys.copy() + _image_data_path = "data.adc" + _cell_table_path = "data.memoryCell" @staticmethod def expectedParameters(expected): @@ -321,16 +338,8 @@ class JungfrauCorrection(BaseCorrection): cell_table, do_generate_preview, ): - if self._frame_filter is not None: - try: - cell_table = cell_table[self._frame_filter] - image_data = image_data[self._frame_filter] - except IndexError: - self.log_status_warn( - "Failed to apply frame filter, please check that it is valid!" - ) - return - + if len(cell_table.shape) == 0: + cell_table = cell_table[np.newaxis] try: self.kernel_runner.load_data( image_data, data_hash.get("data.gain"), cell_table @@ -370,9 +379,6 @@ class JungfrauCorrection(BaseCorrection): data_hash.set(self._image_data_path, buffer_handle) data_hash.set("calngShmemPaths", [self._image_data_path]) - data_hash.set(self._cell_table_path, cell_table) - data_hash.set("image.pulseId", pulse_table[:, np.newaxis]) - self._write_output(data_hash, metadata) if do_generate_preview: @@ -384,3 +390,8 @@ class JungfrauCorrection(BaseCorrection): train_id, source, ) + + def _load_constant_to_runner(self, constant, constant_data): + self.kernel_runner.load_constant( + constant, np.transpose(constant_data, (2, 0, 1, 3)) + ) diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py index eadf0a815b46debe4dd7297c43b96bbc63bd86b8..1697f568ce0c1fecea2ecb6db3ed4097725608e6 100644 --- a/src/calng/base_correction.py +++ b/src/calng/base_correction.py @@ -216,7 +216,7 @@ class BaseCorrection(PythonDevice): _image_data_path = "image.data" # customize for *some* subclasses _cell_table_path = "image.cellId" - def _load_constant_to_runner(constant_name, constant_data): + def _load_constant_to_runner(self, constant_name, constant_data): """Subclass must define how to process constants into correction maps and store into appropriate buffers in (GPU or main) memory.""" raise NotImplementedError() @@ -951,13 +951,13 @@ class BaseCorrection(PythonDevice): if source not in self.sources: self.log_status_info(f"Ignoring hash with unknown source {source}") return - elif not data_hash.has("image"): + elif not data_hash.has(self._image_data_path): self.log_status_info("Ignoring hash without image node") return train_id = metadata.getAttribute("timestamp", "tid") cell_table = np.squeeze(data_hash.get(self._cell_table_path)) - if len(cell_table.shape) == 0: + if len(cell_table.shape) == 0 and int(cell_table) == 0: self.log_status_warn( "cellId had 0 dimensions. DAQ may not be sending data." ) diff --git a/src/calng/utils.py b/src/calng/utils.py index 291c2d976ae4458495d351337e20f95fdcf7ebce..0006ab9422c9d9806497774603ee90c03a21b267 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -25,7 +25,7 @@ def pick_frame_index(selection_mode, index, cell_table, pulse_table, warn_func=N # TODO: enum if selection_mode == "frame": - if index >= len(cell_table): + if index >= cell_table.size: if warn_func is not None: warn_func( f"Index {index} out of range for cell table of length "