import pathlib

import jinja2
import numpy as np
import posixshmem
import pycuda.gpuarray
import shmem_utils
import utils


class PyCudaPipeline:
    """Class to handle instantiation and execution of CUDA kernels on trains

    Objects of this class will also maintain their own circular buffers of
    ndarrays in shared memory to allow zero-copy handover of corrected data.

    """

    _src_dir = pathlib.Path(__file__).absolute().parent
    with (_src_dir / "gpu-dssc-correct.cpp").open("r") as fd:
        _kernel_template = jinja2.Template(fd.read())

    def __init__(
        self,
        pixels_x,
        pixels_y,
        memory_cells,
        pulse_filter,
        output_buffer_size=20,
        output_buffer_name=None,
        input_data_dtype=np.uint16,
        output_data_dtype=np.float32,
    ):
        self.pixels_x = pixels_x
        self.pixels_y = pixels_y
        self.memory_cells = memory_cells
        self.constant_memory_cells = 0
        self.pulse_filter = pulse_filter
        self.output_shape = (self.pixels_x, self.pixels_y, self.pulse_filter.size)
        self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
        # preview will only be single memory cell
        self.preview_shape = self.output_shape[:-1]
        self.input_data_dtype = input_data_dtype
        self.output_data_dtype = output_data_dtype

        self._init_kernels()

        self.offset_map = pycuda.gpuarray.empty(self.map_shape, dtype=np.float32)

        # reuse output arrays
        self.gpu_result = pycuda.gpuarray.empty(
            self.output_shape, dtype=output_data_dtype
        )
        self.gpu_frame_sums = pycuda.gpuarray.empty(
            self.pulse_filter.size, dtype=np.float32
        )
        self.gpu_preview_raw = pycuda.gpuarray.empty(
            self.preview_shape, dtype=np.float32
        )
        self.gpu_preview_corrected = pycuda.gpuarray.empty(
            self.preview_shape, dtype=np.float32
        )
        self.preview_raw = np.empty(self.preview_shape, dtype=np.float32)
        self.preview_corrected = np.empty(self.preview_shape, dtype=np.float32)
        self.output_buffer_mem = posixshmem.SharedMemory(
            name=output_buffer_name,
            size=self.gpu_result.nbytes * output_buffer_size,
            rw=True,
        )
        self.output_buffer_ary = self.output_buffer_mem.ndarray(
            shape=(output_buffer_size,) + self.gpu_result.shape,
            dtype=self.gpu_result.dtype,
        )
        self.output_buffer_handle_template = (
            shmem_utils.handle_template_from_shmem_array(
                self.output_buffer_mem, self.output_buffer_ary
            )
        )
        self.output_buffer_next_index = 0

        self.update_block_size(full_block=(1, 1, 64), preview_block=(1, 64, 1))

    def load_constants(self, offset_map_host):
        constant_memory_cells = offset_map_host.shape[-1]
        if constant_memory_cells != self.constant_memory_cells:
            self.constant_memory_cells = constant_memory_cells
            self.map_shape = (self.pixels_x, self.pixels_y, self.constant_memory_cells)
            self.offset_map = pycuda.gpuarray.empty(self.map_shape, dtype=np.float32)
            self._init_kernels()
        self.offset_map.set(offset_map_host)

    def _init_kernels(self):
        kernel_source = self._kernel_template.render(
            {
                "pixels_x": self.pixels_x,
                "pixels_y": self.pixels_y,
                "memory_cells": self.memory_cells,
                "constant_memory_cells": self.constant_memory_cells,
                "input_data_dtype": utils.numpy_dtype_to_c_type_str[
                    self.input_data_dtype
                ],
                "output_data_dtype": utils.numpy_dtype_to_c_type_str[
                    self.output_data_dtype
                ],
                "pulse_filter": self.pulse_filter,
            }
        )
        self.source_module = pycuda.compiler.SourceModule(
            kernel_source, no_extern_c=True
        )
        self.reshaping_kernel = self.source_module.get_function("reshape_4_3")
        self.correction_kernel = self.source_module.get_function("correct")
        self.casting_kernel = self.source_module.get_function("only_cast")
        self.preview_slice_raw_kernel = self.source_module.get_function(
            "cell_slice_preview_raw"
        )
        self.preview_slice_corrected_kernel = self.source_module.get_function(
            "cell_slice_preview_corrected"
        )
        self.preview_stat_raw_kernel = self.source_module.get_function(
            "cell_stat_preview_raw"
        )
        self.preview_stat_corrected_kernel = self.source_module.get_function(
            "cell_stat_preview_corrected"
        )
        self.frame_sum_kernel = self.source_module.get_function("sum_frames")

    def update_block_size(self, full_block=None, preview_block=None):
        """Execution is scheduled with 3d "blocks" of CUDA threads, tuning can
        affect performance

        Grid size is automatically computed based on block size. Note that
        individual kernels must themselves check whether they go out of bounds;
        grid dimensions get rounded up in case ndarray size is not multiple of
        block size.

        """
        if full_block is not None:
            assert len(full_block) == 3
            self.full_block = tuple(full_block)
            self.full_grid = tuple(
                utils.ceil_div(a_length, block_length)
                for (a_length, block_length) in zip(self.output_shape, full_block)
            )
        if preview_block is not None:
            self.preview_block = tuple(preview_block)
            self.preview_grid = (
                utils.ceil_div(self.output_shape[0], preview_block[0]),
                utils.ceil_div(self.output_shape[1], preview_block[1]),
                1,
            )
        # TODO: make configurable
        self.cell_reduction_block = (1, 1, 32)
        self.cell_reduction_grid = (
            1,
            1,
            utils.ceil_div(self.output_shape[-1], self.cell_reduction_block[-1]),
        )

    def reshape(self, input_data, output_data):
        """Do the reshaping and pulse filtering that the splitter would have done

        equivalent to:
        output_data[:] = np.moveaxis(
            np.squeeze(input_data), (0, 1, 2), (2, 1, 0)
        )[..., pulse_filter]
        """
        # TODO: Move to somewhere else
        self.reshaping_kernel(
            input_data, output_data, block=self.full_block, grid=self.full_grid
        )

    def correct(self, data, cell_table):
        """Apply corrections to data

        Applies corrections to input data and casts to desired output dtype.
        Parameter cell_table allows out of order or non-contiguous memory cells
        in input data.  Both input ndarrays are assumed to be on GPU already,
        preferably wrapped in GPU arrays (pycuda.gpuarray).

        Will return string encoded handle to shared memory output buffer and
        (view of) said buffer as an ndarray.  Keep in mind that the output
        buffers will get overwritten eventually (circular buffer).
        """
        self.correction_kernel(
            data,
            cell_table,
            self.offset_map,
            self.gpu_result,
            block=self.full_block,
            grid=self.full_grid,
        )
        buffer_index = self.output_buffer_next_index
        output_buffer = self.output_buffer_ary[buffer_index]
        handle = self.output_buffer_handle_template.format(index=buffer_index)
        self.gpu_result.get(ary=output_buffer)
        self.output_buffer_next_index = (
            self.output_buffer_next_index + 1
        ) % self.output_buffer_ary.shape[0]
        return handle, output_buffer

    def only_cast(self, data):
        """Like correct without the correction

        This currently means just casting to output dtype.
        """
        self.casting_kernel(
            data,
            self.gpu_result,
            block=self.full_block,
            grid=self.full_grid,
        )
        buffer_index = self.output_buffer_next_index
        output_buffer = self.output_buffer_ary[buffer_index]
        handle = self.output_buffer_handle_template.format(index=buffer_index)
        self.gpu_result.get(ary=output_buffer)
        self.output_buffer_next_index = (
            self.output_buffer_next_index + 1
        ) % self.output_buffer_ary.shape[0]
        return handle, output_buffer

    def compute_preview(
        self,
        raw_data,
        preview_index,
        reuse_corrected=True,
        cell_table=None,
    ):
        """Generate single slice or reduction preview of raw and corrected data

        Special values of preview_index are -1 for max, -2 for mean, -3 for
        sum, and -4 for stdev (across cells).

        Note that preview_index is taken from data without checking cell table.
        Caller has to figure out which index along memory cell dimension they
        actually want to preview.

        raw_data should be a gpuarray

        Assumes that correction has just happened - meaning self.gpu_result
        contains corrected data (corrected from raw_data).

        """

        if preview_index < -4:
            raise ValueError(f"No statistic with code {preview_index} defined")
        elif preview_index >= self.memory_cells:
            raise ValueError(f"Memory cell index {preview_index} out of range")

        if not reuse_corrected:
            # if we didn't already correct, need to do so to get corrected data in buffer
            if self.offset_map.size == 0 or cell_table is None:
                self.casting_kernel(
                    raw_data,
                    self.gpu_result,
                    block=self.full_block,
                    grid=self.full_grid,
                )
                if self.offset_map.size == 0:
                    print(
                        "Warning: no offset map loaded, corrected preview "
                        "will be not actually have correction applied."
                    )
                if cell_table is None:
                    print(
                        "Warning: missing parameter cell_table for applying "
                        "correction for preview."
                    )
            else:
                self.correction_kernel(
                    raw_data,
                    cell_table,
                    self.offset_map,
                    self.gpu_result,
                    block=self.full_block,
                    grid=self.full_grid,
                )

        # TODO: enum around reduction type
        if preview_index >= 0:
            self.preview_slice_raw_kernel(
                raw_data,
                np.int16(preview_index),
                self.gpu_preview_raw,
                block=self.preview_block,
                grid=self.preview_grid,
            )
            self.preview_slice_corrected_kernel(
                self.gpu_result,
                np.int16(preview_index),
                self.gpu_preview_corrected,
                block=self.preview_block,
                grid=self.preview_grid,
            )
        elif preview_index == -1:
            # TODO: select argmax independently for raw and corrected?
            # TODO: send frame sums somewhere to compute global max frame
            self.frame_sum_kernel(
                self.gpu_result,
                self.gpu_frame_sums,
                block=self.cell_reduction_block,
                grid=self.cell_reduction_grid,
            )
            max_index = np.argmax(self.gpu_frame_sums.get())
            self.preview_slice_raw_kernel(
                raw_data,
                np.int16(max_index),
                self.gpu_preview_raw,
                block=self.preview_block,
                grid=self.preview_grid,
            )
            self.preview_slice_corrected_kernel(
                self.gpu_result,
                np.int16(max_index),
                self.gpu_preview_corrected,
                block=self.preview_block,
                grid=self.preview_grid,
            )
        elif preview_index in (-2, -3, -4):
            self.preview_stat_raw_kernel(
                raw_data,  # this is input_data_dtype
                np.int16(preview_index),
                self.gpu_preview_raw,
                block=self.preview_block,
                grid=self.preview_grid,
            )
            self.preview_stat_corrected_kernel(
                self.gpu_result,  # this is output_data_dtype
                np.int16(preview_index),
                self.gpu_preview_corrected,
                block=self.preview_block,
                grid=self.preview_grid,
            )
        self.gpu_preview_raw.get(ary=self.preview_raw)
        self.gpu_preview_corrected.get(ary=self.preview_corrected)
        return self.preview_raw, self.preview_corrected