Skip to content
Snippets Groups Projects
gpu-dssc-correct.cpp 8.53 KiB
Newer Older
#include <cuda_fp16.h>

__device__ unsigned short pulse_filter[{{pulse_filter|length}}] = { {{pulse_filter|join(', ')}} };

extern "C" {
	/*
	  Reshuffle data from shape like (400, 1, 128, 512) to shape like (512, 128, <=400)
	  That is, (cell, ???, y, x) to (x, y, cell)
	  Applies pulse filter; essentially taking subset of indices along memory cell axis
	  Equivalent to np.moveaxis(np.squeeze(data, (0, 1, 2), (2, 1, 0)))
	*/
	__global__ void reshape_4_3(const {{input_data_dtype}}* data,
								{{input_data_dtype}}* output) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		const size_t extra_dim = 1; // mysterious extra dimension in incoming data
		const size_t pulse_filter_size = {{pulse_filter|length}};

		const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
		const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
		const size_t k = blockIdx.z * blockDim.z + threadIdx.z;

		if (i >= X || j >= Y || k >= pulse_filter_size) {
			// in case block size doesn't fit perfectly, some threads do nothing
			return;
		}

		const size_t in_stride_3 = 1;
		const size_t in_stride_2 = in_stride_3 * X;
		const size_t in_stride_1 = in_stride_2 * Y;
		const size_t in_stride_0 = in_stride_1 * extra_dim;
		const size_t in_index = pulse_filter[k] * in_stride_0 // k is cell
			+ 0 * in_stride_1 // for completeness, the squeezed dimension
			+ j * in_stride_2 // j is y
			+ i * in_stride_3; // i is x

		const size_t out_stride_2 = 1;
		const size_t out_stride_1 = out_stride_2 * pulse_filter_size;
		const size_t out_stride_0 = out_stride_1 * Y;
		const size_t out_index = i * out_stride_0 + j * out_stride_1 + k * out_stride_2;
		output[out_index] = data[in_index];
	}

	/*
	  Perform correction: offset
	  Take cell_table into account when getting correction values
	  Converting to float for doing the correction
	  Converting to output dtype at the end
	*/
	__global__ void correct(const {{input_data_dtype}}* data,
							const unsigned short* cell_table,
							const float* offset_map,
							{{output_data_dtype}}* output) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		// reshaped and output data have pulse filter length memory cells dim
		const size_t filtered_memory_cells = {{pulse_filter|length}};
		// but correction map has some number which may even exceed input data's (due to veto pattern)
		const size_t map_memory_cells = {{constant_memory_cells}};

		const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
		const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
		const size_t k = blockIdx.z * blockDim.z + threadIdx.z;

		if (i >= X || j >= Y || k >= filtered_memory_cells) {
			return;
		}

		// note: strides differ from numpy strides because unit here is sizeof(...), not byte
		const size_t data_stride_2 = 1;
		const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
		const size_t data_stride_0 = Y * data_stride_1;
		const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
		const float raw = (float)data[data_index];

		const size_t map_stride_2 = 1;
		const size_t map_stride_1 = map_memory_cells * map_stride_2;
		const size_t map_stride_0 = Y * map_stride_1;
		const size_t map_cell = cell_table[k];
		if (map_cell < map_memory_cells) {
			const size_t map_index = i * map_stride_0 + j * map_stride_1 + map_cell * map_stride_2;
			const float corrected = raw - offset_map[map_index];
			{% if output_data_dtype == "half" %}
			output[data_index] = __float2half(corrected);
			{% else %}
			output[data_index] = ({{output_data_dtype}})corrected;
			{% endif %}
		} else {
			{% if output_data_dtype == "half" %}
			output[data_index] = __float2half(raw);
			{% else %}
			output[data_index] = ({{output_data_dtype}})raw;
			{% endif %}
		}
	}

	/*
	  Same as correction, except don't do any correction
	*/
	__global__ void only_cast(const {{input_data_dtype}}* data,
							  {{output_data_dtype}}* output) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		const size_t memory_cells = {{pulse_filter|length}};

		const size_t data_stride_2 = 1;
		const size_t data_stride_1 = memory_cells * data_stride_2;
		const size_t data_stride_0 = Y * data_stride_1;

		const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
		const size_t j = blockIdx.y * blockDim.y + threadIdx.y;
		const size_t k = blockIdx.z * blockDim.z + threadIdx.z;

		if (i >= X || j >= Y || k >= memory_cells) {
			return;
		}

		const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
		const float raw = (float)data[data_index];
		{% if output_data_dtype == "half" %}
		output[data_index] = __float2half(raw);
		{% else %}
		output[data_index] = ({{output_data_dtype}})raw;
		{% endif %}
	}

	/* Kernels for preview
	   ≥0: just slice desired cell; uses cell_slice_preview_*
	   -1: slice cell with max integrated intensity (hybrid)
	   -2: mean; uses cell_stat_preview_*
	   -3: sum; ditto
	   -4: stdev; ditto
	   Note: for loop in template due to differing dtypes
	   TODO: simplify
	   - [ ] set up C++ compilation chain on ONC
	   - [ ] use C++ templates
	   Or even better:
	   - [ ] switch to cupy, all of this becomes trivial
	*/
	{% for (name_suffix, data_dtype) in (("raw", input_data_dtype), ("corrected", output_data_dtype)) %}
	__global__ void cell_slice_preview_{{name_suffix}}(const {{data_dtype}}* data,
													   const short cell_to_preview,
													   float* preview) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		const size_t filtered_memory_cells = {{pulse_filter|length}};

		const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
		const size_t j = blockIdx.y * blockDim.y + threadIdx.y;

		if (i >= X || j >= Y) {
			return;
		}

		const size_t preview_stride_1 = 1;
		const size_t preview_stride_0 = Y * preview_stride_1;
		const size_t preview_index = i * preview_stride_0 + j * preview_stride_1;

		const size_t data_stride_2 = 1;
		const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
		const size_t data_stride_0 = Y * data_stride_1;

		const size_t data_index = i * data_stride_0 + j * data_stride_1 + cell_to_preview * data_stride_2;

		preview[preview_index] = (float)data[data_index];
	}

	__global__ void cell_stat_preview_{{name_suffix}}(const {{data_dtype}}* data,
													  const short preview_stat,
													  float* preview) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		const size_t filtered_memory_cells = {{pulse_filter|length}};

		const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
		const size_t j = blockIdx.y * blockDim.y + threadIdx.y;

		if (i >= X || j >= Y) {
			return;
		}

		const size_t preview_stride_1 = 1;
		const size_t preview_stride_0 = Y * preview_stride_1;
		const size_t preview_index = i * preview_stride_0 + j * preview_stride_1;

		const size_t data_stride_2 = 1;
		const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
		const size_t data_stride_0 = Y * data_stride_1;

		float sum = 0;
		for (int k=0; k<filtered_memory_cells; ++k) {
			const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
			sum += (float)data[data_index];
		}

		if (preview_stat == -3) {
			// just sum
			preview[preview_index] = sum;
		} else if (preview_stat == -2) {
			// mean
			preview[preview_index] = sum / filtered_memory_cells;
		} else if (preview_stat == -4) {
			// standard deviation
			const double mean = sum / filtered_memory_cells;
			// try to reduce error by increasing precision on accumulator
			double var = 0;
			for (int k=0; k<filtered_memory_cells; ++k) {
				const size_t data_index = i * data_stride_0 + j * data_stride_1 + k * data_stride_2;
				// but "compute" values the same (floats)
				var += pow((double)data[data_index] - mean, 2);
			}
			var /= filtered_memory_cells;
			preview[preview_index] = (float)sqrt(var);
		}
	}
	{% endfor %}

	// used to find max integrated intensity frame
	__global__ void sum_frames({{output_data_dtype}}* data, float* sums) {
		const size_t X = {{pixels_x}};
		const size_t Y = {{pixels_y}};
		const size_t filtered_memory_cells = {{pulse_filter|length}};

		const size_t memory_cell = blockIdx.z * blockDim.z + threadIdx.z;

		if (memory_cell >= filtered_memory_cells) {
			return;
		}

		const size_t data_stride_2 = 1;
		const size_t data_stride_1 = filtered_memory_cells * data_stride_2;
		const size_t data_stride_0 = Y * data_stride_1;

		float my_res = 0;
		for (int i=0; i<X; ++i) {
			for (int j=0; j<Y; ++j) {
				const size_t data_index = i * data_stride_0 +
					j * data_stride_1 +
					memory_cell * data_stride_2;
				const float raw = (float)data[data_index];
				my_res += raw;
			}
		}
		sums[memory_cell] = my_res;
	}
}