Skip to content
Snippets Groups Projects
Commit e3ebc723 authored by David Hammer's avatar David Hammer
Browse files

Improve dtype conversion for C(++) code templating

parent 44e7fb2f
No related branches found
No related tags found
2 merge requests!12Snapshot: field test deployed version as of end of run 202201,!3Base correction device, CalCat interaction, DSSC and AGIPD devices
......@@ -224,12 +224,8 @@ class DsscGpuRunner:
"pixels_y": self.pixels_y,
"data_memory_cells": self.memory_cells,
"constant_memory_cells": self.constant_memory_cells,
"input_data_dtype": utils.numpy_dtype_to_c_type_str[
self.input_data_dtype
],
"output_data_dtype": utils.numpy_dtype_to_c_type_str[
self.output_data_dtype
],
"input_data_dtype": utils.np_dtype_to_c_type(self.input_data_dtype),
"output_data_dtype": utils.np_dtype_to_c_type(self.output_data_dtype),
}
)
self.source_module = cupy.RawModule(code=kernel_source)
......
......@@ -4,22 +4,40 @@ import timeit
import numpy as np
numpy_dtype_to_c_type_str = {
np.uint16: "unsigned short",
np.uint32: "unsigned short",
np.float16: "half", # warning: only in CUDA with special support
np.float32: "float",
np.float64: "double",
_np_typechar_to_c_typestring = {
"?": "bool",
"B": "unsigned char",
"D": "double complex",
"F": "float complex",
"G": "long double complex",
"H": "unsigned short",
"I": "unsigned int",
"L": "unsigned long",
"Q": "unsigned long long",
"b": "char",
"d": "double",
"e": "half", # warning: only in CUDA with special support
"f": "float",
"g": "long double",
"h": "short",
"i": "int",
"l": "long",
"q": "long long",
}
def np_dtype_to_c_type(dtype):
as_char = np.sctype2char(dtype)
return _np_typechar_to_c_typestring[as_char]
def ceil_div(num, denom):
return (num + denom - 1) // denom
def shape_after_transpose(input_shape, transpose_pattern, squeeze=True):
if squeeze:
input_shape = tuple(dim for dim in input_shape if dim>1)
input_shape = tuple(dim for dim in input_shape if dim > 1)
if transpose_pattern is None:
return input_shape
return tuple(np.array(input_shape)[list(transpose_pattern)].tolist())
......
import numpy as np
from calng import utils
def test_get_c_type():
assert utils.np_dtype_to_c_type(np.float16) == "half"
assert utils.np_dtype_to_c_type(np.float32) == "float"
assert utils.np_dtype_to_c_type(np.float64) == "double"
assert utils.np_dtype_to_c_type(np.uint8) == "unsigned char"
assert utils.np_dtype_to_c_type(np.uint16) == "unsigned short"
assert utils.np_dtype_to_c_type(np.uint32) in ("unsigned", "unsigned int")
assert utils.np_dtype_to_c_type(np.uint64) == "unsigned long"
assert utils.np_dtype_to_c_type(np.int8) == "char"
assert utils.np_dtype_to_c_type(np.int16) == "short"
assert utils.np_dtype_to_c_type(np.int32) == "int"
assert utils.np_dtype_to_c_type(np.int64) == "long"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment