diff --git a/src/calng/base_kernel_runner.py b/src/calng/base_kernel_runner.py index 14932cae9c19639f31d796a7cd51d6c35e0b93e6..933bee7d8433d2722f848163d6855fcb1cbb3a0f 100644 --- a/src/calng/base_kernel_runner.py +++ b/src/calng/base_kernel_runner.py @@ -140,9 +140,11 @@ class BaseKernelRunner: out[:] = self.reshaped_data +kernel_dir = pathlib.Path(__file__).absolute().parent / "kernels" + + def get_kernel_template(kernel_fn): - src_dir = pathlib.Path(__file__).absolute().parent - with (src_dir / "kernels" / kernel_fn).open("r") as fd: + with (kernel_dir / kernel_fn).open("r") as fd: return jinja2.Template(fd.read()) diff --git a/src/calng/corrections/JungfrauCorrection.py b/src/calng/corrections/JungfrauCorrection.py index 64ef130d5330a00793305a09fddde8eacf7d33c8..ab5af5f688fa74c32ab02195fffde302cf305ef9 100644 --- a/src/calng/corrections/JungfrauCorrection.py +++ b/src/calng/corrections/JungfrauCorrection.py @@ -9,7 +9,6 @@ from karabo.bound import ( OUTPUT_CHANNEL, OVERWRITE_ELEMENT, STRING_ELEMENT, - VECTOR_STRING_ELEMENT, Schema, ) @@ -18,7 +17,6 @@ from .. import ( base_correction, base_kernel_runner, schemas, - preview_utils, utils, ) from .._version import version as deviceVersion @@ -81,6 +79,7 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): pixels_y, frames, constant_memory_cells, + config, output_data_dtype=np.float32, bad_pixel_mask_value=np.nan, ): @@ -98,23 +97,42 @@ class JungfrauBaseRunner(base_kernel_runner.BaseKernelRunner): self.bad_pixel_map = self._xp.empty(self.map_shape, dtype=np.uint32) self.bad_pixel_mask_value = bad_pixel_mask_value - # strixel support - self._strixel_out_shape = (frames, 86, 3090) - self._strixel_block = (1, 1, 64) - # note: only executing kernel on lower half of y range, hence 256 - self._strixel_grid = tuple( - utils.ceil_div(a_length, block_length) - for (a_length, block_length) in zip( - (frames, 256, 1024), self._strixel_block - ) - ) + self.output_dtype = output_data_dtype self._processed_data_regular = self._xp.empty( self.processed_shape, dtype=output_data_dtype ) - self._processed_data_strixel = self._xp.empty( - self._strixel_out_shape, dtype=output_data_dtype - ) + self._processed_data_strixel = None self.flush_buffers(set(Constants)) + self.correction_kernel_strixel = None + self.reconfigure(config) + + def reconfigure(self, config): + # note: regular bad pixel masking uses device property (TODO: unify) + if (mask_value := config.get("badPixels.maskingValue")) is not None: + self._bad_pixel_mask_value = self._xp.float32(mask_value) + # this is a functools.partial, can just update the captured parameter + if self.correction_kernel_strixel is not None: + self.correction_kernel_strixel.keywords[ + "missing" + ] = self._bad_pixel_mask_value + + if (strixel_type := config.get("strixel.type")) is not None: + # drop the friendly parenthesized name + strixel_type = strixel_type.partition("(")[0] + strixel_package = np.load( + base_kernel_runner.kernel_dir + / f"strixel_{strixel_type}-lut_mask.npz" + ) + self._strixel_out_shape = tuple(strixel_package["frame_shape"]) + self._processed_data_strixel = None + # TODO: use bad pixel masking config here + self.correction_kernel_strixel = functools.partial( + utils.apply_partial_lut, + lut=self._xp.asarray(strixel_package["lut"]), + mask=self._xp.asarray(strixel_package["mask"]), + missing=self._bad_pixel_mask_value, + ) + # note: always masking unmapped pixels (not respecting NON_STANDARD_SIZE) def load_constant(self, constant, constant_data): if constant_data.shape[0] == self.pixels_x: @@ -195,7 +213,6 @@ class JungfrauGpuRunner(JungfrauBaseRunner): ) ) self.correction_kernel = source_module.get_function("correct") - self.strixel_transform_kernel = source_module.get_function("strixel_transform") def load_data(self, image_data, input_gain_stage, cell_table): self.input_data.set(image_data) @@ -220,14 +237,21 @@ class JungfrauGpuRunner(JungfrauBaseRunner): ), ) if flags & CorrectionFlags.STRIXEL: - self.strixel_transform_kernel( - self._strixel_grid, - self._strixel_block, - ( - self._processed_data_regular, - self._processed_data_strixel, - ), - ) + if ( + self._processed_data_strixel is None + or self.frames != self._processed_data_strixel.shape[0] + ): + self._processed_data_strixel = self._xp.empty( + (self.frames,) + self._strixel_out_shape, + dtype=self.output_dtype, + ) + for pixel_frame, strixel_frame in zip( + self._processed_data_regular, self._processed_data_strixel + ): + self.correction_kernel_strixel( + data=pixel_frame, + out=strixel_frame, + ) self.processed_data = self._processed_data_strixel else: self.processed_data = self._processed_data_regular @@ -244,13 +268,12 @@ class JungfrauCpuRunner(JungfrauBaseRunner): self.input_cell_table = None # for computing previews faster - self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=3) + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16) from ..kernels import jungfrau_cython self.correction_kernel_single = jungfrau_cython.correct_single self.correction_kernel_burst = jungfrau_cython.correct_burst - self.correction_kernel_strixel = jungfrau_cython.strixel_transform def __del__(self): self.thread_pool.shutdown() @@ -286,8 +309,25 @@ class JungfrauCpuRunner(JungfrauBaseRunner): ) if flags & CorrectionFlags.STRIXEL: - self.correction_kernel_strixel( - self._processed_data_regular, self._processed_data_strixel + if ( + self._processed_data_strixel is None + or self.frames != self._processed_data_strixel.shape[0] + ): + self._processed_data_strixel = self._xp.empty( + (self.frames,) + self._strixel_out_shape, + dtype=self.output_dtype, + ) + concurrent.futures.wait( + [ + self.thread_pool.submit( + self.correction_kernel_strixel, + data=pixel_frame, + out=strixel_frame, + ) + for pixel_frame, strixel_frame in zip( + self._processed_data_regular, self._processed_data_strixel + ) + ] ) self.processed_data = self._processed_data_strixel else: @@ -530,6 +570,19 @@ class JungfrauCorrection(base_correction.BaseCorrection): .key("corrections.strixel.preview") .setNewDefaultValue(False) .commit(), + + STRING_ELEMENT(expected) + .key("corrections.strixel.type") + .description( + "Which kind of strixel layout is used for this module? cols_A0123 is " + "the first strixel layout deployed at HED and rows_A1256 is the one " + "later deployed at SCS." + ) + .assignmentOptional() + .defaultValue("cols_A0123(HED-type)") + .options("cols_A0123(HED-type),rows_A1256(SCS-type)") + .reconfigurable() + .commit(), ) base_correction.add_bad_pixel_config_node(expected) JungfrauCalcatFriend.add_schema(expected) @@ -575,6 +628,8 @@ class JungfrauCorrection(base_correction.BaseCorrection): def _kernel_runner_init_args(self): return { "bad_pixel_mask_value": self.bad_pixel_mask_value, + # temporary: will refactor base class to always pass config node + "config": self.get("corrections"), } @property @@ -673,6 +728,11 @@ class JungfrauCorrection(base_correction.BaseCorrection): constant_data &= self._override_bad_pixel_flags self.kernel_runner.load_constant(constant, constant_data) + def preReconfigure(self, config): + super().preReconfigure(config) + if config.has("corrections"): + self.kernel_runner.reconfigure(config["corrections"]) + def postReconfigure(self): super().postReconfigure() diff --git a/src/calng/kernels/jungfrau_cpu.pyx b/src/calng/kernels/jungfrau_cpu.pyx index 9b6072d41fa3c588941e400959b38728dabd8643..2a200063879007b54f92c762378863c0d3e71c22 100644 --- a/src/calng/kernels/jungfrau_cpu.pyx +++ b/src/calng/kernels/jungfrau_cpu.pyx @@ -88,33 +88,3 @@ def correct_single( if (flags & REL_GAIN): corrected = corrected / relgain_map[0, y, x, gain] output[0, y, x] = corrected - - -def strixel_transform( - float[:, :, ::contiguous] image_data, - float[:, :, ::contiguous] output -): - cdef int yin, xin, igap, ichip, xout, yout, frame - - for frame in range(image_data.shape[0]): - for yin in range(256) : - yout = int(yin / 3) - for xin in range(1024) : - ichip = <int>(xin / 256) - xout = (ichip * 774) + (xin % 256) * 3 + yin % 3 - # 774 is the chip period, 256*3+6 - output[frame, yout, xout] = image_data[frame, yin, xin] - # now the gap pixels... - for yin in range(256): - yout = <int>(yin / 6) * 2 - for igap in range(3) : - # first the left side of gap - xin = igap * 256 + 255 - xout = igap * 774 + 765 + yin % 6 - output[frame, yout, xout] = image_data[frame, yin, xin] - output[frame, yout+1, xout] = image_data[frame, yin, xin] - # then the right side is mirrored - xin = igap * 256 + 255 + 1 - xout = igap * 774 + 765 + 11 - yin % 6 - output[frame, yout, xout] = image_data[frame, yin, xin] - output[frame, yout+1, xout] = image_data[frame, yin, xin] diff --git a/src/calng/kernels/jungfrau_gpu.cu b/src/calng/kernels/jungfrau_gpu.cu index fc2bc326e125ddfcbe4884174ca498ff8c7e5a56..a4f148f1c49d172b31f91a6a1a4ae4ec13c61c49 100644 --- a/src/calng/kernels/jungfrau_gpu.cu +++ b/src/calng/kernels/jungfrau_gpu.cu @@ -83,108 +83,4 @@ extern "C" { output[data_index] = ({{output_data_dtype}})res; {% endif %} } - - __global__ void strixel_transform(const {{output_data_dtype}}* data, // shape: memory cell, y, x - {{output_data_dtype}}* output) { - const size_t Xin = {{pixels_x}}; - const size_t Yin = {{pixels_y}}; - const size_t Xout = 3090; - const size_t Yout = 86; - const size_t input_frames = {{frames}}; - - const size_t current_frame = blockIdx.x * blockDim.x + threadIdx.x; - // following naming from cython version - const size_t yin = blockIdx.y * blockDim.y + threadIdx.y; - size_t xin = blockIdx.z * blockDim.z + threadIdx.z; - - // note: hardcoded limits here as only half of y-axis is used - if (current_frame >= input_frames || yin >= 256 || xin >= 1024) { - return; - } - - // avoid race conditions by only writing these once - const size_t overwritten_columns[18] = { - 765, - 766, - 767, - 774, - 775, - 776, - 1539, - 1540, - 1541, - 1548, - 1549, - 1550, - 2313, - 2314, - 2315, - 2322, - 2323, - 2324 - }; - - const size_t data_stride_x = 1; - const size_t data_stride_y = Xin * data_stride_x; - const size_t data_stride_frame = Yin * data_stride_y; - - const size_t output_stride_x = 1; - const size_t output_stride_y = Xout * output_stride_x; - const size_t output_stride_frame = Yout * output_stride_y; - - const size_t ichip = xin / 256; - size_t xout = (ichip * 774) + (xin % 256) * 3 + (yin % 3); - size_t yout = yin / 3; - bool will_be_overwritten = false; - size_t out_index, data_index; - for (int i=0; i<18; ++i) { - if (xout == overwritten_columns[i]) { - will_be_overwritten = true; - } - } - if (!will_be_overwritten) { - out_index = current_frame * output_stride_frame + - yout * output_stride_y + - xout * output_stride_x; - data_index = current_frame * data_stride_frame + - yin * data_stride_y + - xin * data_stride_x; - output[out_index] = data[data_index]; - } - if (xin < 3) { - // reuse for the gap pixel case (see cython version) - const size_t igap = xin; - yout = (yin / 6) * 2; - - // left side - xin = igap * 256 + 255; - xout = igap * 774 + 765 + yin % 6; - data_index = current_frame * data_stride_frame + - yin * data_stride_y + - xin * data_stride_x; - out_index = current_frame * output_stride_frame + - yout * output_stride_y + - xout * output_stride_x; - output[out_index] = data[data_index]; - out_index = current_frame * output_stride_frame + - (yout + 1) * output_stride_y + - xout * output_stride_x; - output[out_index] = data[data_index]; - - // mirror right side - xin = igap * 256 + 255 + 1; - xout = igap * 774 + 765 + 11 - yin % 6; - data_index = current_frame * data_stride_frame + - yin * data_stride_y + - xin * data_stride_x; - out_index = current_frame * output_stride_frame + - yout * output_stride_y + - xout * output_stride_x; - output[out_index] = data[data_index]; - out_index = current_frame * output_stride_frame + - (yout + 1) * output_stride_y + - xout * output_stride_x; - output[out_index] = data[data_index]; - } - } } diff --git a/src/calng/kernels/strixel_cols_A0123-lut_mask.npz b/src/calng/kernels/strixel_cols_A0123-lut_mask.npz new file mode 100644 index 0000000000000000000000000000000000000000..bdef03cb602e67db2c4732733b3053a9afa35940 Binary files /dev/null and b/src/calng/kernels/strixel_cols_A0123-lut_mask.npz differ diff --git a/src/calng/kernels/strixel_rows_A1256-lut_mask.npz b/src/calng/kernels/strixel_rows_A1256-lut_mask.npz new file mode 100644 index 0000000000000000000000000000000000000000..be31b53bb7bb0fe2cf1060468e21438793c33dd3 Binary files /dev/null and b/src/calng/kernels/strixel_rows_A1256-lut_mask.npz differ diff --git a/src/calng/utils.py b/src/calng/utils.py index c97ecb47b7ac7727574c8096090aa3bc95880c65..08815412922d1793f3587cfbe80cba860c1d895c 100644 --- a/src/calng/utils.py +++ b/src/calng/utils.py @@ -569,3 +569,9 @@ def cell_table_to_string(cell_table): def grid_to_cover_shape_with_blocks(full_shape, block_shape): return tuple(itertools.starmap(ceil_div, zip(full_shape, block_shape))) + + +def apply_partial_lut(data, lut, mask, out, missing=np.nan): + tmp = out.ravel() + tmp[~mask] = data.ravel()[lut] + tmp[mask] = missing diff --git a/src/tests/test_strixel.py b/src/tests/test_strixel.py new file mode 100644 index 0000000000000000000000000000000000000000..8160f444ee36894ef212afb539f688e82430d3a0 --- /dev/null +++ b/src/tests/test_strixel.py @@ -0,0 +1,101 @@ +import numpy as np + +from calng import base_kernel_runner, utils + + +pixel_data = np.arange(512 * 1024).reshape(512, 1024).astype(np.float32) + + +def test_cols_A0123(): + def convert_cols_A0123(data, out=None): + # https://redmine.xfel.eu/issues/126444 + if out is None: + out = np.zeros((86, (1024 * 3 + 18)), dtype=np.float32) + ## 256 not divisible by 3, so we round up + ## 18 since we have 6 more pixels in H per gap + # firs we fill the normal pixels, the gap ones will be overwritten later + for yin in range(256): + for xin in range(1024): + ichip = int(xin / 256) + xout = (ichip * 774) + (xin % 256) * 3 + yin % 3 + ## 774 is the chip period, 256*3+6 + yout = int(yin / 3) + out[yout, xout] = data[yin, xin] + # now the gap pixels... + for igap in range(3): + for yin in range(256): + yout = int(yin / 6) * 2 + # first the left side of gap + xin = igap * 256 + 255 + xout = igap * 774 + 765 + yin % 6 + out[yout, xout] = data[yin, xin] + out[yout + 1, xout] = data[yin, xin] + # then the right side is mirrored + xin = igap * 256 + 255 + 1 + xout = igap * 774 + 765 + 11 - yin % 6 + out[yout, xout] = data[yin, xin] + out[yout + 1, xout] = data[yin, xin] + # out[yout, xout] = out[yout, xout] / 2 + # if we want a proper normalization (the area of those pixels is double, so they see 2x the signal) + return out + + strixel_naive = convert_cols_A0123(pixel_data) + lut_package = np.load( + base_kernel_runner.kernel_dir / "strixel_cols_A0123-lut_mask.npz" + ) + assert tuple(lut_package["frame_shape"]) == strixel_naive.shape + strixel_lut = np.zeros_like(strixel_naive) + utils.apply_partial_lut( + data=pixel_data, + lut=lut_package["lut"], + mask=lut_package["mask"], + out=strixel_lut, + missing=0, + ) + assert np.array_equal(strixel_naive, strixel_lut) + + +def test_rows_A1256(): + def convert_rows_A1256(data, out=None): + # https://redmine.xfel.eu/issues/161148 + if out is None: + out = np.zeros([*data.shape[:-2], 1488, 165], dtype=np.float32) + + # Select only 4 center ASICS + data = data[..., 256:768] + + # Offset due to guard ring pixels + offset_y = 9 + offset_x = 11 + + for xin in range( + offset_x, 512 - 9 + ): # on the right side, there are 9 guard ring pixels + # Bottom ASICs + for yin in range(offset_y, 255): + yout = (xin - offset_x) % 3 + (yin - offset_y) * 3 + xout = (xin - offset_x) // 3 + (xin) // 257 + out[..., yout, xout] = data[..., yin, xin] + + # Top ASICs (mirrored on the horizontal axis) + for yin in range(257, 512 - offset_y): + yout = 2 - (xin - offset_x) % 3 + (yin - offset_y) * 3 + 6 + xout = (xin - offset_x) // 3 + (xin) // 257 + out[..., yout, xout] = data[..., yin, xin] + + return out + + strixel_naive = convert_rows_A1256(pixel_data) + lut_package = np.load( + base_kernel_runner.kernel_dir / "strixel_rows_A1256-lut_mask.npz" + ) + assert tuple(lut_package["frame_shape"]) == strixel_naive.shape + strixel_lut = np.zeros_like(strixel_naive) + utils.apply_partial_lut( + data=pixel_data, + lut=lut_package["lut"], + mask=lut_package["mask"], + out=strixel_lut, + missing=0, + ) + assert np.array_equal(strixel_naive, strixel_lut)