from pathlib import Path
from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
from extra_geom import (
    AGIPD_1MGeometry,
    AGIPD_500K2GGeometry,
    DSSC_1MGeometry,
    LPD_1MGeometry,
)
from extra_geom import tests as eg_tests
from matplotlib import colors
from matplotlib.patches import Patch
from mpl_toolkits.axes_grid1 import AxesGrid


def show_overview(
        d, cell_to_preview, gain_to_preview, out_folder=None, infix=None
):
    """
    Show an overview
    :param d: A dict with the number of modules and
        a dict for the constant names and their data.
    :param cell_to_preview: Integer number of memory cells to preview.
    :param gain_to_preview: Integer number for the gain stages to preview.
    :param out_folder: Output folder for saving the plotted.png image.
    :param infix: infix to include in the png file name.
    :return:
    """
    for module, data in d.items():
        # Adapt number of columns to number of constants in d.
        ncols = (len(d[module]) + (2 - 1)) // 2
        fig = plt.figure(figsize=(20, 20))
        grid = AxesGrid(fig, 111,
                        nrows_ncols=(2, ncols),
                        axes_pad=(0.9, 0.15),
                        label_mode="1",
                        share_all=True,
                        cbar_location="right",
                        cbar_mode="each",
                        cbar_size="7%",
                        cbar_pad="2%",
                        )

        items = list(data.items())
        for ax, cbar_ax, (key, item) in zip(grid, grid.cbar_axes, items):
            cf = 0
            if "ThresholdsDark" in key:
                cf = -1
            if len(item.shape) == 4:
                med = np.nanmedian(item[..., cell_to_preview,
                                        gain_to_preview + cf])
            else:
                med = np.nanmedian(item[..., cell_to_preview])
            medscale = med
            if med == 0:
                medscale = 0.1

            bound = 0.2
            while (np.count_nonzero((item[..., cell_to_preview, gain_to_preview + cf] < med - np.abs(bound * medscale)) |  # noqa
                                    (item[..., cell_to_preview, gain_to_preview + cf] > med + np.abs(bound * medscale))) /  # noqa
                   item[..., cell_to_preview, gain_to_preview + cf].size > 0.01):  # noqa
                bound *= 2

            is_badpixels = "BadPixels" in key

            if is_badpixels:
                im = ax.imshow(
                    item[..., cell_to_preview, gain_to_preview + cf] != 0,
                    cmap=plt.cm.colors.ListedColormap(["w", "k"]),
                    aspect="auto",
                )
            else:

                if len(item.shape) == 4:
                    im_prev = item[..., cell_to_preview, gain_to_preview + cf]
                    vmax = np.abs(med + bound * med)
                else:
                    im_prev = item[..., cell_to_preview]
                    # move the axis of the image to show horizontally
                    # on the output report.
                    if im_prev.shape[0] > im_prev.shape[1]:
                        im_prev = np.moveaxis(item[..., cell_to_preview], 0, 1)
                    vmax = med + np.abs(bound * medscale)

                im = ax.imshow(im_prev, interpolation="nearest",
                               vmin=med - np.abs(bound * medscale),
                               vmax=vmax, aspect='auto')

            cb = cbar_ax.colorbar(im)
            if is_badpixels:
                cb.set_ticks([0.25, 0.75])
                cb.set_ticklabels(["good", "bad"])
            else:
                cb.set_label("ADU")

            ax.text(
                5, 20, key, color="k" if is_badpixels else "w", fontsize=20
            )

        grid[0].text(5, 50, module, color="k" if "BadPixels" in items[0][0] else "r", fontsize=20)  # noqa

        if out_folder and infix:
            fig.savefig(f"{out_folder}/"
                        f"dark_analysis_{infix}_module_{module}.png")


def rebin(a, *args):
    '''rebin ndarray data into a smaller ndarray of the same rank whose
    dimensions are factors of the original dimensions. eg. An array with 6
    columns and 4 rows can be reduced to have 6,3,2 or 1 columns and 4,2 or 1
    rows. example usages:
    https://scipy-cookbook.readthedocs.io/items/Rebinning.html
    >>> a=rand(6,4); b=rebin(a,3,2)
    >>> a=rand(6); b=rebin(a,2)
    '''
    shape = a.shape
    lenShape = len(shape)
    factor = np.asarray(shape) // np.asarray(args)
    evList = ['a.reshape('] + \
             ['args[%d],factor[%d],' % (i, i) for i in range(lenShape)] + \
             [')'] + ['.sum(%d)' % (i + 1) for i in range(lenShape)] + \
             ['/factor[%d]' % i for i in range(lenShape - 1)]
    ta = eval(''.join(evList))
    return ta.astype(np.uint32), np.indices([s + 1 for s in ta.shape])


def plot_badpix_3d(data, definitions, title=None, rebin_fac=2, azim=22.5):
    od = data
    d, dims = rebin(
        od.astype(np.uint32),
        od.shape[0] // rebin_fac,
        od.shape[1] // rebin_fac,
        od.shape[2],
    )
    xx, yy, zz = dims
    voxels = d.astype(np.bool)
    colors = np.full(voxels.shape, '#FFFFFF')
    cols = definitions

    for k, c in cols.items():
        colors[d == k] = c[1]

    fig = plt.figure(figsize=(15, 10))
    ax = fig.gca(projection="3d")
    ax.voxels(xx*rebin_fac, yy*rebin_fac, zz, voxels, facecolors=colors)
    ax.view_init(elev=25., azim=azim)
    ax.set_xlabel("pixels")
    ax.set_ylabel("pixels")
    ax.set_zlabel("memory cell")
    ax.set_xlim(0, np.max(xx) * rebin_fac)
    ax.set_ylim(0, np.max(yy) * rebin_fac)
    ax.set_zlim(0, np.max(zz))

    for k, c in cols.items():
        ax.plot([-1, ], [-1, ], color=c[1], label=c[0])
    ax.legend()
    if title:
        ax.set_title(title)


def create_constant_overview(constant, name, cells, vmin=None, vmax=None,
                             entries=3, badpixels=None, gmap=None,
                             marker=None):
    """
    Create a step plot for constant data across memory cells for requested
    gain entries

    :param constant: dict with constants for each module.
    :param name: Name to be used for the x-axis
    :param cells: Number of memory cells
    :param vmin: plot minumim value boundaries
    :param vmax: plot maximum value boundaries
    :param entries: (int)number of gain entries.
        A tuple specifying the range can also be used.
    TODO: remove unused inputs from notebooks.
    :param out_folder: out_folder for showing .md table statistics
    :param infix: infix for the output png image
    :param badpixels: A list of 2 elements.
        badpixels[0] has the dict with badpixels constant for each module
        and badpixels[1] has the value to apply for bad pixels.
    :param gmap: A list with len equal to number of gain entires.
        if not supported, a default would be used for 3 entries.
        ['High gain', 'Medium gain', 'Low gain']
    :param marker: A list of line markers for each gain entry.
        default: ['']*entries
    :return:
    """
    if gmap is None:
        gmap = ['High gain', 'Medium gain', 'Low gain']
    if marker is None:
        marker = [''] * entries
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(111)
    for g in range(entries):
        table = []
        dbp = None

        for qm in constant.keys():
            if len(constant[qm].shape) == 4:
                d = constant[qm][..., g]
                if badpixels is not None and isinstance(badpixels, list):
                    dbp = np.copy(d)
                    dbp[badpixels[0][qm][..., g] > 0] = badpixels[1]

            else:
                # e.g. DSSC
                d = constant[qm]
                if badpixels is not None and isinstance(badpixels, list):
                    dbp = np.copy(d)
                    dbp[badpixels[0][qm] > 0] = badpixels[1]
            # TODO: check if not used to remove.
            table.append([name, qm, gmap[g], np.nanmean(d), np.nanmedian(d),
                          np.nanstd(d)])

            ax.step(np.arange(cells), np.nanmean(d, axis=(0, 1)),
                    label=gmap[g], color=f'C{g}', marker=marker[g])
            # Plotting good pixels only if bad-pixels were given
            if dbp is not None:
                ax.step(np.arange(cells), np.nanmean(dbp, axis=(0, 1)),
                        label=f'Good pixels {gmap[g]}', color=f'C{g}',
                        linestyle='--')
    ax.set_xlabel("Memory cell")
    ax.set_ylabel(name)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_title(f"{name} Median per Cell")
    if vmin and vmax:
        ax.set_ylim(vmin, vmax)


def show_processed_modules(dinstance: str, constants: Optional[Dict[str, Any]],
                           mnames: str, mode: str):
    """
    Show the status of the processed modules.
    Green: Processed. Gray: Not Processed. Red: No data available.
    :param dinstance: The detector instance (e.g. AGIPD1M1 or LPD1M1)
    :param constants: A dict of the plotted constants data.
        {"const_name":constant_data}. Can be None in case of position mode.
    :param mnames: A list of available module names.
    :param mode: String selecting on of the two modes of operation.
         "position": To just show the position of the processed modules.
         "processed": To show the modules successfully processed.
    :return
    """

    # Create the geometry figure for each detector

    if dinstance in ('AGIPD1M1', 'AGIPD1M2'):
        quadrants = 4
        modules = 4
        tiles = 8
        quad_pos = [(-525, 625), (-550, -10), (520, -160), (542.5, 475)]
        geom = AGIPD_1MGeometry.from_quad_positions(quad_pos)

    elif dinstance == 'AGIPD500K':
        quadrants = 2
        modules = 4
        tiles = 8
        geom = AGIPD_500K2GGeometry.from_origin()

    elif 'LPD' in dinstance:
        quadrants = 4
        modules = 4
        tiles = 16
        quad_pos = [(11.4, 299), (-11.5, 8), (254.5, -16), (278.5, 275)]
        geom = LPD_1MGeometry.from_quad_positions(quad_pos)

    elif 'DSSC' in dinstance:
        quadrants = 4
        modules = 4
        tiles = 2
        quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]

        geom = DSSC_1MGeometry.from_h5_file_and_quad_positions(
            Path(eg_tests.__file__).parent / 'dssc_geo_june19.h5',
            quad_pos)

    else:
        raise ValueError(f'{dinstance} is not a real detector')

    # Create a dict that contains the range of tiles, in the figure,
    # that belong to a module.
    ranges = {}
    tile_count = 0
    for quadrant in range(1, quadrants+1):
        for module in range(1, modules+1):
            ranges[f'Q{quadrant}M{module}'] = [tile_count, tile_count + tiles]
            tile_count += tiles

    # Create the figure
    ax = geom.inspect()
    ax.set_title('')  # Cannot remove title
    ax.set_axis_off()
    ax.get_legend().set_visible(False)

    # Remove non-tiles markings from figure
    tiles, = ax.collections = ax.collections[:1]

    # Set each tile colour individually, extra_geom provides a single color
    # for all tiles.
    facecolors = np.repeat(tiles.get_facecolor(), tile_count, axis=0)

    # Set module name fonts
    for text in ax.texts:
        text.set_fontweight('regular')

    texts = [t for t in ax.texts if t.get_text() in mnames]
    for text in texts:
        text.set_fontweight('extra bold')
        text.set_fontsize(14)

    if mode == 'position':  # Highlight selected modules
        for module in mnames:
            start, stop = ranges[module]
            facecolors[start:stop] = colors.to_rgba('pink')

    else:  # mode == 'processed': Highlight processed modules
        counter = 0  # Used as index within the `Noise` matrix
        for module, (start, stop) in ranges.items():
            color = 'grey'  # Unprocessed modules are grey

            if module in mnames:
                color = 'green'
                if ('Noise' not in constants.keys() or
                        np.nanmean(constants['Noise'][counter, ..., 0]) == 0):  # noqa
                    color = 'red'
                counter += 1

            for idx in range(start, stop):  # Set the colours
                facecolors[idx] = colors.to_rgba(color)

    tiles.set_facecolors(facecolors)  # Update colours in figure

    if mode == "processed":
        _ = ax.legend(handles=[Patch(facecolor='red', label='No data'),
                               Patch(facecolor='gray', label='Not processed'),
                               Patch(facecolor='green', label='Processed')],
                      loc='best', ncol=3, bbox_to_anchor=(0.1, 0.25, 0.7, 0.8))
    plt.show()