Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <cuda_fp16.h>
__device__ unsigned short pulse_filter[{{pulse_filter|length}}] = { {{pulse_filter|join(', ')}} };
extern "C" {
/*
Reshuffle data from shape like (400, 1, 128, 512) to shape like (512, 128, <=400)
That is, (cell, ???, y, x) to (x, y, cell)
Applies pulse filter; essentially taking subset of indices along memory cell axis
Equivalent to np.moveaxis(np.squeeze(data, (0, 1, 2), (2, 1, 0)))
*/
__global__ void reshape_4_3(const {{input_data_dtype}}* data,
{{input_data_dtype}}* output) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
const size_t extra_dim = 1; // mysterious extra dimension in incoming data
const size_t pulse_filter_size = {{pulse_filter|length}};
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
const size_t k = blockIdx.z * blockDim.z + threadIdx.z;
if (i >= X || j >= Y || k >= pulse_filter_size) {
// in case block size doesn't fit perfectly, some threads do nothing
return;
}
const size_t in_stride_3 = 1;
const size_t in_stride_2 = in_stride_3 * X;
const size_t in_stride_1 = in_stride_2 * Y;
const size_t in_stride_0 = in_stride_1 * extra_dim;
const size_t in_index = pulse_filter[k] * in_stride_0 // k is cell
+ 0 * in_stride_1 // for completeness, the squeezed dimension
+ j * in_stride_2 // j is y
+ i * in_stride_3; // i is x
const size_t out_stride_2 = 1;
const size_t out_stride_1 = out_stride_2 * pulse_filter_size;
const size_t out_stride_0 = out_stride_1 * Y;
const size_t out_index = i * out_stride_0 + j * out_stride_1 + k * out_stride_2;
output[out_index] = data[in_index];
}
/*
Perform correction: offset
Take cell_table into account when getting correction values
Converting to float for doing the correction
Converting to output dtype at the end
*/
__global__ void correct(const {{input_data_dtype}}* data,
const unsigned short* cell_table,
const float* offset_map,
{{output_data_dtype}}* output) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
// reshaped and output data have pulse filter length memory cells dim
const size_t filtered_memory_cells = {{pulse_filter|length}};
// but correction map has some number which may even exceed input data's (due to veto pattern)
const size_t map_memory_cells = {{constant_memory_cells}};
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
const size_t k = blockIdx.z * blockDim.z + threadIdx.z;
if (i >= X || j >= Y || k >= filtered_memory_cells) {
return;
}
// note: strides differ from numpy strides because unit here is sizeof(...), not byte
const size_t data_stride_2 = 1;
const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1;
const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
const float raw = (float)data[data_index];
const size_t map_stride_1 = map_memory_cells * map_stride_2;
const size_t map_stride_0 = Y * map_stride_1;
const size_t map_cell = cell_table[k];
if (map_cell < map_memory_cells) {
const size_t map_index = i * map_stride_0 + j * map_stride_1 + map_cell * map_stride_2;
const float corrected = raw - offset_map[map_index];
{% if output_data_dtype == "half" %}
output[data_index] = __float2half(corrected);
{% else %}
output[data_index] = ({{output_data_dtype}})corrected;
{% endif %}
} else {
{% if output_data_dtype == "half" %}
output[data_index] = __float2half(raw);
{% else %}
output[data_index] = ({{output_data_dtype}})raw;
{% endif %}
}
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
}
/*
Same as correction, except don't do any correction
*/
__global__ void only_cast(const {{input_data_dtype}}* data,
{{output_data_dtype}}* output) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
const size_t memory_cells = {{pulse_filter|length}};
const size_t data_stride_2 = 1;
const size_t data_stride_1 = memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1;
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
const size_t k = blockIdx.z * blockDim.z + threadIdx.z;
if (i >= X || j >= Y || k >= memory_cells) {
return;
}
const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
const float raw = (float)data[data_index];
{% if output_data_dtype == "half" %}
output[data_index] = __float2half(raw);
{% else %}
output[data_index] = ({{output_data_dtype}})raw;
{% endif %}
}
/* Kernels for preview
≥0: just slice desired cell; uses cell_slice_preview_*
-1: slice cell with max integrated intensity (hybrid)
-2: mean; uses cell_stat_preview_*
-3: sum; ditto
-4: stdev; ditto
Note: for loop in template due to differing dtypes
TODO: simplify
- [ ] set up C++ compilation chain on ONC
- [ ] use C++ templates
Or even better:
- [ ] switch to cupy, all of this becomes trivial
*/
{% for (name_suffix, data_dtype) in (("raw", input_data_dtype), ("corrected", output_data_dtype)) %}
__global__ void cell_slice_preview_{{name_suffix}}(const {{data_dtype}}* data,
const short cell_to_preview,
float* preview) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
const size_t filtered_memory_cells = {{pulse_filter|length}};
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= X || j >= Y) {
return;
}
const size_t preview_stride_1 = 1;
const size_t preview_stride_0 = Y * preview_stride_1;
const size_t preview_index = i * preview_stride_0 + j * preview_stride_1;
const size_t data_stride_2 = 1;
const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1;
const size_t data_index = i * data_stride_0 + j * data_stride_1 + cell_to_preview * data_stride_2;
preview[preview_index] = (float)data[data_index];
}
__global__ void cell_stat_preview_{{name_suffix}}(const {{data_dtype}}* data,
const short preview_stat,
float* preview) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
const size_t filtered_memory_cells = {{pulse_filter|length}};
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= X || j >= Y) {
return;
}
const size_t preview_stride_1 = 1;
const size_t preview_stride_0 = Y * preview_stride_1;
const size_t preview_index = i * preview_stride_0 + j * preview_stride_1;
const size_t data_stride_2 = 1;
const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1;
float sum = 0;
for (int k=0; k<filtered_memory_cells; ++k) {
const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
sum += (float)data[data_index];
}
if (preview_stat == -3) {
// just sum
preview[preview_index] = sum;
} else if (preview_stat == -2) {
// mean
preview[preview_index] = sum / filtered_memory_cells;
} else if (preview_stat == -4) {
// standard deviation
const double mean = sum / filtered_memory_cells;
// try to reduce error by increasing precision on accumulator
double var = 0;
for (int k=0; k<filtered_memory_cells; ++k) {
const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
// but "compute" values the same (floats)
var += pow((double)data[data_index] - mean, 2);
}
var /= filtered_memory_cells;
preview[preview_index] = (float)sqrt(var);
}
}
{% endfor %}
// used to find max integrated intensity frame
__global__ void sum_frames({{output_data_dtype}}* data, float* sums) {
const size_t X = {{pixels_x}};
const size_t Y = {{pixels_y}};
const size_t filtered_memory_cells = {{pulse_filter|length}};
const size_t memory_cell = blockIdx.z * blockDim.z + threadIdx.z;
if (memory_cell >= filtered_memory_cells) {
return;
}
const size_t data_stride_2 = 1;
const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
const size_t data_stride_0 = Y * data_stride_1;
float my_res = 0;
for (int i=0; i<X; ++i) {
for (int j=0; j<Y; ++j) {
const size_t data_index = i * data_stride_0 +
j * data_stride_1 +
memory_cell * data_stride_2;
const float raw = (float)data[data_index];
my_res += raw;
}
}
sums[memory_cell] = my_res;
}
}