from functools import lru_cache

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq
from scipy.optimize import curve_fit
from scipy.signal import fftconvolve
import xarray as xr
import toolbox_scs as tb


__all__ = [
    'find_curvature',
    'correct_curvature',
    'get_spectrum',
    'energy_calibration',
    'calibrate',
    'gaussian_fit',
    'to_fwhm'
]


# -----------------------------------------------------------------------------
# Curvature

def find_curvature(image, frangex=None, frangey=None,
                   deg=2, axis=1, offset=100):
    # Resolve arguments
    x_range = (0, image.shape[1])
    if frangex is not None:
        x_range = (max(frangex[0], x_range[0]), min(frangex[1], x_range[1]))
    y_range = (0, image.shape[0])
    if frangex is not None:
        y_range = (max(frangey[0], y_range[0]), min(frangey[1], y_range[1]))

    axis_range = y_range if axis == 1 else x_range
    axis_dim = image.shape[axis - 1]

    # Get kernel
    integral = image[slice(*y_range), slice(*x_range)].mean(axis=axis)
    roi = np.ones([axis_range[1] - axis_range[0], axis_dim])
    ref = roi * integral[:, np.newaxis]

    # Get sliced image
    slice_ = [slice(None), slice(None)]
    slice_[axis - 1] = slice(max(axis_range[0] - offset, 0),
                             min(axis_range[1] + offset, axis_dim))
    sliced = image[tuple(slice_)]
    if axis == 0:
        sliced = sliced.T

    # Get curvature factor from cross correlation
    crosscorr = fftconvolve(sliced,
                            ref[::-1, :],
                            axes=0, )
    shifts = np.argmax(crosscorr, axis=0)
    curv = np.polyfit(np.arange(axis_dim), shifts, deg=deg)
    return curv[:-1][::-1]


def find_curvature(img, args, plot=False, **kwargs):
    def parabola(x, a, b, c, s=0, h=0, o=0):
        return (a*x + b)*x + c
    def gauss(y, x, a, b, c, s, h, o=0):
        return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
    x = np.arange(img.shape[1])[None, :]
    y = np.arange(img.shape[0])[:, None]

    if plot:
        plt.figure(figsize=(10,10))
        plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs)
        plt.plot(x[0, :], parabola(x[0, :], *args))

    args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args)

    if plot:
        plt.plot(x[0, :], parabola(x[0, :], *args))
    return args


def correct_curvature(image, factor=None, axis=1):
    if factor is None:
        return

    if axis == 1:
        image = image.T

    ydim, xdim = image.shape
    x = np.arange(xdim + 1)
    y = np.arange(ydim + 1)
    xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5)
    xxn = xx - factor[0] * yy - factor[1] * yy ** 2
    ret = np.histogramdd((xxn.flatten(), yy.flatten()),
                         bins=[x, y],
                         weights=image.flatten())[0]

    return ret if axis == 1 else ret.T


def get_spectrum(image, factor=None, axis=0,
                 pixel_range=None, energy_range=None, ):
    start, stop = (0, image.shape[axis - 1])
    if pixel_range is not None:
        start = max(pixel_range[0] or start, start)
        stop = min(pixel_range[1] or stop, stop)

    edge = image.sum(axis=axis)[start:stop]
    bins = np.arange(start, stop + 1)
    centers = (bins[1:] + bins[:-1]) * 0.5
    if factor is not None:
        centers, edge = calibrate(centers, edge,
                                  factor=factor,
                                  range_=energy_range)

    return centers, edge


# -----------------------------------------------------------------------------
# Energy calibration


def energy_calibration(channels, energies):
    return np.polyfit(channels, energies, deg=1)


def calibrate(x, y=None, factor=None, range_=None):
    if factor is not None:
        x = np.polyval(factor, x)

    if y is not None and range_ is not None:
        start = np.argmin(np.abs((x - range_[0])))
        stop = np.argmin(np.abs((x - range_[1])))
        # Calibrated energies have a different direction
        x, y = x[stop:start], y[stop:start]

    return x, y


# -----------------------------------------------------------------------------
# Gaussian-related functions


FWHM_COEFF = 2 * np.sqrt(2 * np.log(2))


def gaussian_fit(x_data, y_data, offset=0):
    """
    Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
    """

    x0 = np.average(x_data, weights=y_data)
    sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data))

    # Gaussian fit
    baseline = y_data.min()
    p_0 = (y_data.max(), x0 + offset, sx, baseline)
    try:
        p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000)
        return p_f
    except (RuntimeError, TypeError) as e:
        print(e)
        return None


def gauss1d(x, height, x0, sigma, offset):
    return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset


def to_fwhm(sigma):
    return abs(sigma * FWHM_COEFF)


# -----------------------------------------------------------------------------
# Centroid

THRESHOLD = 510  # pixel counts above which a hit candidate is assumed
CURVE_A = 2.19042931e-02  # curvature parameters as determined elsewhere
CURVE_B = -3.02191568e-07


def _esrf_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
    gs = 2
    base = image.mean()
    cp = np.argwhere(image[gs // 2: -gs // 2, gs // 2: -gs // 2] > threshold) + np.array([gs // 2, gs // 2])
    if len(cp) > 100000:
        raise RuntimeError('Threshold too low or acquisition time too long')

    res = []
    for cy, cx in cp:
        spot = image[cy - gs // 2: cy + gs // 2 + 1, cx - gs // 2: cx + gs // 2 + 1] - base
        spot[spot < 0] = 0
        if (spot > image[cy, cx]).sum() == 0:
            mx = np.average(np.arange(cx - gs // 2, cx + gs // 2 + 1), weights=spot.sum(axis=0))
            my = np.average(np.arange(cy - gs // 2, cy + gs // 2 + 1), weights=spot.sum(axis=1))
            my -= (curvature[0] + curvature[1] * mx) * mx
            res.append((my, mx))
    return res


def _new_centroid(image, threshold=THRESHOLD, std_threshold=3.5, curvature=(CURVE_A, CURVE_B)):
    """find the position of photons with sub-pixel precision

    A photon is supposed to have hit the detector if the intensity within a
    2-by-2 square exceeds a threshold. In this case the position of the photon
    is calculated as the center-of-mass in a 4-by-4 square.

    Return the list of x,y coordinate pairs, corrected by the curvature.
    """
    base = image.mean()
    corners = image[1:, 1:] + image[:-1, 1:] + image[1:, :-1] + image[:-1, :-1]
    if threshold is None:
        threshold = corners.mean() + std_threshold * corners.std()
    middle = corners[1:-1, 1:-1]
    candidates = (
            (middle > threshold)
            * (middle >= corners[:-2, 1:-1]) * (middle > corners[2:, 1:-1])
            * (middle >= corners[1:-1, :-2]) * (middle > corners[1:-1, 2:])
            * (middle >= corners[:-2, :-2]) * (middle > corners[2:, :-2])
            * (middle >= corners[:-2, 2:]) * (middle > corners[2:, 2:]))
    cp = np.argwhere(candidates)
    if len(cp) > 10000:
        raise RuntimeError(
            "too many peaks, threshold too low or acquisition time too high")

    res = []
    for cy, cx in cp:
        spot = image[cy: cy + 4, cx: cx + 4] - base
        mx = np.average(np.arange(cx, cx + 4), weights=spot.sum(axis=0))
        my = np.average(np.arange(cy, cy + 4), weights=spot.sum(axis=1))
        my -= (curvature[0] + curvature[1] * mx) * mx
        res.append((my, mx))
    return res


centroid = _new_centroid


def decentroid(res):
    res = np.array(res)
    ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
    for cy, cx in res:
        if cx > 0 and cy > 0:
            ret[int(cy), int(cx)] += 1
    return ret


# -----------------------------------------------------------------------------
# Integral

FACTOR = 3
RANGE = [300, 400]
BINS = abs(np.subtract(*RANGE)) * FACTOR


def parabola(x, a, b, c=0):
    return (a * x + b) * x + c


def integrate(image, factor=FACTOR, range=RANGE, curvature=(CURVE_A, CURVE_B), ):
    image = image - image.mean()
    x = np.arange(image.shape[1])[None, :]
    y = np.arange(image.shape[0])[:, None]
    ys = factor * (y - parabola(x, curvature[1], curvature[0]))
    ysf = np.floor(ys)
    rang = (factor * range[0], factor * range[1])
    bins = rang[1] - rang[0]
    lhy, lhx = np.histogram(ysf.ravel(), bins=bins, weights=((ys - ysf) * image).ravel(), range=rang)
    rhy, rhx = np.histogram((ysf + 1).ravel(), bins=bins, weights=(((ysf + 1) - ys) * image).ravel(), range=rang)
    lvy, lvx = np.histogram(ysf.ravel(), bins=bins, weights=(ys - ysf).ravel(), range=rang)
    rvy, rvx = np.histogram((ysf + 1).ravel(), bins=bins, weights=((ysf + 1) - ys).ravel(), range=rang)
    return (lhy + rhy) / (lvy + rvy)


# -----------------------------------------------------------------------------
# hRIXS class


class hRIXS:
    # run
    PROPOSAL = 2769

    # image range
    X_RANGE = np.s_[:]
    Y_RANGE = np.s_[:]

    # centroid
    THRESHOLD = None  # pixel counts above which a hit candidate is assumed
    STD_THRESHOLD = 3.5  # same as THRESHOLD, in standard deviations
    CURVE_A = CURVE_A  # curvature parameters as determined elsewhere
    CURVE_B = CURVE_B

    # integral
    FACTOR = FACTOR
    RANGE = RANGE
    BINS = abs(np.subtract(*RANGE)) * FACTOR

    METHOD = 'centroid'  # ['centroid', 'integral']
    
    # Dark image and mask treatment
    USE_DARK = False
    USE_DARK_MASK = False
    DARK_MASK_THRESHOLD = 100
    MASK_AVG_X = np.s_[1850:2000]
    MASK_AVG_Y = np.s_[500:1500]

    ENERGY_INTERCEPT = 0
    ENERGY_SLOPE = 1

    FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm']

    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key.upper(), value)

    def get_params(self, *params):
        if not params:
            params = ('proposal', 'x_range', 'y_range',
                      'threshold', 'curve_a', 'curve_b',
                      'factor', 'range', 'bins',
                      'method', 'fields')
        return {param: getattr(self, param.upper()) for param in params}

    def from_run(self, runNB, proposal=None, extra_fields=()):
        if proposal is None:
            proposal = self.PROPOSAL
        run, data = tb.load(proposal, runNB=runNB,
                            fields=self.FIELDS + list(extra_fields))

        return data

    
    def load_dark(self, runNB, proposal=None, use_dark=True, mask=True,
                  mask_threshold=None):
        #*************************************************************
        # Loads a dark image and assigns it to the hRIXS instance
        # In addition sets attributes whether or not 
        # - hot pixels are identified and masked out
        # - the dark image is to be used in background subtraction
        # In addition a threshold value for hot pixel mask generation
        # can be given.
        #*************************************************************
        if mask_threshold == None:
            mask_threshold = self.DARK_MASK_THRESHOLD
        #*************************************************************
        # If given a list of runs, iterate over them. 
        # Otherwise read one. Give an exception if neither is the case.
        #*************************************************************
        if type(runNB) == type([]):
            data_list = []
            for run in runNB:
                data_list.append(self.from_run(run, proposal))
            data = xr.concat(data_list, dim='trainId')
        elif type(runNB) == type(1):
            data = self.from_run(runNB, proposal)
        else:
            raise Exception('load_dark() expects a list of indeces or an integer.')      
        #*************************************************************
        # Store the dark image (mean over aqs.) in two formats
        #*************************************************************            
        self.dark_image = data['hRIXS_det'].mean(dim='trainId')
        self.dark_im_array = self.dark_image.to_numpy()
        #*************************************************************
        # Set a flag whether the dark image is to be used later
        #*************************************************************  
        if use_dark:
            self.USE_DARK = True
        #*************************************************************
        # If hot/dead pixel masking is requested, find the mask and
        # set a flag in the instance. Set the masked dark values to 
        # mean intensity.
        #*************************************************************  
        if mask:
            dark_avg = np.mean(self.dark_im_array[self.MASK_AVG_Y,
                                                              self.MASK_AVG_X], (0, 1))
            self.dark_mask = self.dark_im_array > dark_avg + mask_threshold
            self.dark_im_array_m = np.array(self.dark_im_array)
            self.dark_im_array_m[self.dark_mask] = dark_avg
            self.USE_DARK_MASK = True
        return
    

    def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs):
        data = self.from_run(runNB, proposal)
        image = data['hRIXS_det'].sum(dim='trainId') \
                .to_numpy()[self.X_RANGE, self.Y_RANGE].T
        if args is None:
            spec = (image - image[:10, :].mean()).mean(axis=1)
            mean = np.average(np.arange(len(spec)), weights=spec)
            args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3,
                    spec.max(), image.mean())
        args = find_curvature(image, args, plot=plot, **kwargs)
        self.CURVE_B, self.CURVE_A, *_ = args
        return self.CURVE_A, self.CURVE_B

    
    def centroid(self, data, bins=None, return_hits=False):
        #*************************************************************
        # Carry out hit finding on data and bin them in grid
        # Allows for 
        # - redifining the bins
        # - extraction of the hit positions from each image
        # The function will use dark images and hot pixel mask as
        # given by
        # self.USE_DARK
        # self.USE_DARK_MASK
        #*************************************************************
        #*************************************************************
        # If new bins are given, use them
        #*************************************************************
        if bins is None:
            bins = self.BINS
        #*************************************************************
        # Define empty arrays and matrix for the output data
        #*************************************************************
        hit_x = []
        hit_y = []
        hits = []
        ret = np.zeros((len(data["hRIXS_det"]), bins))
        #*************************************************************
        # Handle each Aq image separately
        #*************************************************************
        for image, r in zip(data["hRIXS_det"], ret):
            use_image = image.to_numpy()
            #*************************************************************
            # Treat background by optionally 
            # -subtracting dark image (self.USE_DARK)
            # -masking the hot pixels (self.USE_DARK_MASK)
            # Diffrent combinations of the two flags result
            # in different actions
            #*************************************************************
            if self.USE_DARK:
                #***************************************
                # subtract dark image
                #***************************************
                use_image = use_image - self.dark_im_array
                if self.USE_DARK_MASK:
                    #***************************************
                    # set masked pixels 0
                    #***************************************
                    use_image[self.dark_mask] = 0
            else:
                #***************************************
                # dont subtract dark image, but set hot
                # pixels to a dark baseline value
                #***************************************
                if self.USE_DARK_MASK:
                    use_image[self.dark_mask] = np.mean(use_image[self.MASK_AVG_Y,
                                                              self.MASK_AVG_X], (0, 1))
            #*************************************************************
            # Run centroiding on the preprocessed image
            #*************************************************************
            c = centroid(
                use_image[self.X_RANGE, self.Y_RANGE].T,
                threshold=self.THRESHOLD,
                std_threshold=self.STD_THRESHOLD,
                curvature=(self.CURVE_A, self.CURVE_B))
            if not len(c):
                continue
            rc = np.array(c)
            #*************************************************************
            # If hits have been requested, append the hit data of the 
            # image to the lists of hit lists
            #*************************************************************
            if return_hits:
                hit_x.append(rc[:, 0])
                hit_y.append(rc[:, 1])
                hits.append(rc)
            #*************************************************************
            # Assign the spectrum to the spectrum matrix ret. Iteration
            # variable r points to the proper location of ret.
            #*************************************************************
            hy, hx = np.histogram(
                rc[:, 0], bins=bins,
                range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
            r[:] = hy
        #*************************************************************
        # Setup and assing a linear energy grid
        #*************************************************************
        data = data.assign_coords(
            energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins)
            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
        #**********************************************
        # If hits were requested, assign them to data
        #**********************************************
        if return_hits:
            data = data.assign(hits=(("trainId"), hits),
                        xhits=(("trainId"), hit_x),
                        yhits=(("trainId"), hit_y))
        #**********************************************
        # Always assign the spectrum to data
        #**********************************************
        data = data.assign(spectrum=(("trainId", "energy"), ret))
        return data

    
    def align_readouts(self,data,method,start,stop):
        import scipy as sp
        from scipy import optimize 
        #********************************************
        # aligns spectra in a given data xarray
        # METHOD
        # -max_value
        # -autocorrelation
        # -gauss_fit
        # start and stop are values of data.energy
        # that define the range of these operations
        # RETURNS LINE CENTER POSITIONS 
        # (and in future perhaps shifted spectra)
        #********************************************
        searchinds = (data.energy >=start)*(data.energy <=stop)
        peak_posis = []
        #********************************************
        # Simple maximum alignment
        #********************************************
        if method.lower() == 'max_value':
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for spec in data.spectrum:
                x          = data.energy.to_numpy()[searchinds]
                y          = spec.to_numpy()[searchinds]
                maxipos    = np.argmax(y)
                peak_posis.append(x[maxipos])
        #********************************************
        # Alignment based on autocorrelation
        # this is a relative alignment method
        # where 1st readout defines the abs scale
        #********************************************
        elif method.lower() == 'autocorrelation':
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for ind,spec in enumerate(data.spectrum):
                if ind == 0:
                    x0       = data.energy.to_numpy()[searchinds]
                    y0       = spec.to_numpy()[searchinds]
                    maxipos0 = np.argmax(spec.to_numpy()[searchinds])
                    peak_posis.append(x0[maxipos0])
                else:        
                    x        = data.energy.to_numpy()[searchinds]
                    y        = spec.to_numpy()[searchinds]
                    corr_len = np.sum(searchinds)
                    corr     = sp.signal.correlate(y,y0, mode='full')
                    maxpos   = np.argmax(corr)
                    shift    = maxpos-corr_len
                    peak_posis.append(x[maxipos0+shift])
        elif method.lower() == 'gauss_fit':
            #********************************************
            # Define needed functions
            #********************************************
            def Gauss(grid,x0,sigma):
                # Returns a normalized bell curve
                # with center at x0 and sigma
                # on grid
                return 1.0/(sigma*np.sqrt(2.0*np.pi))*np.exp(-0.5*(grid-x0)**2/sigma**2)

            def Cost(p,grid,spec):
                return np.sum(np.square(p[0]*Gauss(grid,p[1],p[2])-spec))
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for spec in data.spectrum:
                x    = data.energy.to_numpy()[searchinds]
                y    = spec.to_numpy()[searchinds]
                #********************************************
                # Initial Guess and bounds
                #********************************************
                area = np.sum(y)
                mean = np.average(x,weights=y)
                std  = np.sqrt(np.average((x-mean)**2,weights=y/area))
                p0         = [area, mean, std]
                #********************************************
                # Bounds
                #********************************************
                bnds       = [[0,None],[start,stop],[0,2*(stop-start)]]
                #********************************************
                # Fit by minimizing least squares error
                #********************************************
                p          = optimize.minimize(Cost,p0,args=(x,y),bounds=bnds,method='L-BFGS-B',tol=1e-6,
                                            options={'disp':0,'maxiter':1000000})
                if p.success:
                    peak_posis.append(p.x[1])
                else:
                    plt.figure()
                    plt.plot(x,y,'.')
                    plt.plot(x,p.x[0]*Gauss(x,p.x[1],p.x[2]))
                    raise Exception('align_readouts(): can not fit a gaussian to the data.')
        else: 
            raise Exception('align_readouts() did not recognize the method.')
        return peak_posis    
    
    
    def integrate(self, data):
        bins = self.Y_RANGE.stop - self.Y_RANGE.start
        ret = np.zeros((len(data["hRIXS_det"]), bins - 20))
        for image, r in zip(data["hRIXS_det"], ret):
            if self.USE_DARK:
                image = image - self.dark_image
            r[:] = integrate(image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, factor=1,
                             range=(10, bins - 10),
                             curvature=(self.CURVE_A, self.CURVE_B))
        data = data.assign_coords(
            energy=np.arange(self.Y_RANGE.start + 10, self.Y_RANGE.stop - 10)
            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
        return data.assign(spectrum=(("trainId", "energy"), ret))

    aggregators = dict(
        hRIXS_det=lambda x, dim: x.sum(dim=dim),
        Delay=lambda x, dim: x.mean(dim=dim),
        hRIXS_delay=lambda x, dim: x.mean(dim=dim),
        hRIXS_norm=lambda x, dim: x.sum(dim=dim),
        spectrum=lambda x, dim: x.sum(dim=dim),
    )

    def aggregator(self, da, dim):
        agg = self.aggregators.get(da.name)
        if agg is None:
            return None
        return agg(da, dim=dim)

    def aggregate(self, ds, dim="trainId"):
        ret = ds.map(self.aggregator, dim=dim)
        ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
        return ret

    def normalize(self, data):
        return data.assign(normalized=data["spectrum"] / data["hRIXS_norm"])