from functools import lru_cache

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq
from scipy.optimize import curve_fit
from scipy.signal import fftconvolve
import xarray as xr
import toolbox_scs as tb


__all__ = [
    'find_curvature',
    'correct_curvature',
    'get_spectrum',
    'energy_calibration',
    'calibrate',
    'gaussian_fit',
    'to_fwhm'
]


# -----------------------------------------------------------------------------
# Curvature

def find_curvature(image, frangex=None, frangey=None,
                   deg=2, axis=1, offset=100):
    # Resolve arguments
    x_range = (0, image.shape[1])
    if frangex is not None:
        x_range = (max(frangex[0], x_range[0]), min(frangex[1], x_range[1]))
    y_range = (0, image.shape[0])
    if frangex is not None:
        y_range = (max(frangey[0], y_range[0]), min(frangey[1], y_range[1]))

    axis_range = y_range if axis == 1 else x_range
    axis_dim = image.shape[axis - 1]

    # Get kernel
    integral = image[slice(*y_range), slice(*x_range)].mean(axis=axis)
    roi = np.ones([axis_range[1] - axis_range[0], axis_dim])
    ref = roi * integral[:, np.newaxis]

    # Get sliced image
    slice_ = [slice(None), slice(None)]
    slice_[axis - 1] = slice(max(axis_range[0] - offset, 0),
                             min(axis_range[1] + offset, axis_dim))
    sliced = image[tuple(slice_)]
    if axis == 0:
        sliced = sliced.T

    # Get curvature factor from cross correlation
    crosscorr = fftconvolve(sliced,
                            ref[::-1, :],
                            axes=0, )
    shifts = np.argmax(crosscorr, axis=0)
    curv = np.polyfit(np.arange(axis_dim), shifts, deg=deg)
    return curv[:-1][::-1]


def find_curvature(img, args, plot=False, **kwargs):
    def parabola(x, a, b, c, s=0, h=0, o=0):
        return (a*x + b)*x + c
    def gauss(y, x, a, b, c, s, h, o=0):
        return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
    x = np.arange(img.shape[1])[None, :]
    y = np.arange(img.shape[0])[:, None]

    if plot:
        plt.figure(figsize=(10,10))
        plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs)
        plt.plot(x[0, :], parabola(x[0, :], *args))

    args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args)

    if plot:
        plt.plot(x[0, :], parabola(x[0, :], *args))
    return args


def correct_curvature(image, factor=None, axis=1):
    if factor is None:
        return

    if axis == 1:
        image = image.T

    ydim, xdim = image.shape
    x = np.arange(xdim + 1)
    y = np.arange(ydim + 1)
    xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5)
    xxn = xx - factor[0] * yy - factor[1] * yy ** 2
    ret = np.histogramdd((xxn.flatten(), yy.flatten()),
                         bins=[x, y],
                         weights=image.flatten())[0]

    return ret if axis == 1 else ret.T


def get_spectrum(image, factor=None, axis=0,
                 pixel_range=None, energy_range=None, ):
    start, stop = (0, image.shape[axis - 1])
    if pixel_range is not None:
        start = max(pixel_range[0] or start, start)
        stop = min(pixel_range[1] or stop, stop)

    edge = image.sum(axis=axis)[start:stop]
    bins = np.arange(start, stop + 1)
    centers = (bins[1:] + bins[:-1]) * 0.5
    if factor is not None:
        centers, edge = calibrate(centers, edge,
                                  factor=factor,
                                  range_=energy_range)

    return centers, edge


# -----------------------------------------------------------------------------
# Energy calibration


def energy_calibration(channels, energies):
    return np.polyfit(channels, energies, deg=1)


def calibrate(x, y=None, factor=None, range_=None):
    if factor is not None:
        x = np.polyval(factor, x)

    if y is not None and range_ is not None:
        start = np.argmin(np.abs((x - range_[0])))
        stop = np.argmin(np.abs((x - range_[1])))
        # Calibrated energies have a different direction
        x, y = x[stop:start], y[stop:start]

    return x, y


# -----------------------------------------------------------------------------
# Gaussian-related functions


FWHM_COEFF = 2 * np.sqrt(2 * np.log(2))


def gaussian_fit(x_data, y_data, offset=0):
    """
    Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
    """

    x0 = np.average(x_data, weights=y_data)
    sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data))

    # Gaussian fit
    baseline = y_data.min()
    p_0 = (y_data.max(), x0 + offset, sx, baseline)
    try:
        p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000)
        return p_f
    except (RuntimeError, TypeError) as e:
        print(e)
        return None


def gauss1d(x, height, x0, sigma, offset):
    return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset


def to_fwhm(sigma):
    return abs(sigma * FWHM_COEFF)


# -----------------------------------------------------------------------------
# Centroid

THRESHOLD = 510  # pixel counts above which a hit candidate is assumed
CURVE_A = 2.19042931e-02  # curvature parameters as determined elsewhere
CURVE_B = -3.02191568e-07


def _esrf_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
    gs = 2
    base = image.mean()
    cp = np.argwhere(image[gs // 2: -gs // 2, gs // 2: -gs // 2] > threshold) + np.array([gs // 2, gs // 2])
    if len(cp) > 100000:
        raise RuntimeError('Threshold too low or acquisition time too long')

    res = []
    for cy, cx in cp:
        spot = image[cy - gs // 2: cy + gs // 2 + 1, cx - gs // 2: cx + gs // 2 + 1] - base
        spot[spot < 0] = 0
        if (spot > image[cy, cx]).sum() == 0:
            mx = np.average(np.arange(cx - gs // 2, cx + gs // 2 + 1), weights=spot.sum(axis=0))
            my = np.average(np.arange(cy - gs // 2, cy + gs // 2 + 1), weights=spot.sum(axis=1))
            my -= (curvature[0] + curvature[1] * mx) * mx
            res.append((my, mx))
    return res


def _new_centroid(image, threshold=THRESHOLD, std_threshold=3.5, curvature=(CURVE_A, CURVE_B)):
    """find the position of photons with sub-pixel precision

    A photon is supposed to have hit the detector if the intensity within a
    2-by-2 square exceeds a threshold. In this case the position of the photon
    is calculated as the center-of-mass in a 4-by-4 square.

    Return the list of x,y coordinate pairs, corrected by the curvature.
    """
    base = image.mean()
    corners = image[1:, 1:] + image[:-1, 1:] + image[1:, :-1] + image[:-1, :-1]
    if threshold is None:
        threshold = corners.mean() + std_threshold * corners.std()
    middle = corners[1:-1, 1:-1]
    candidates = (
            (middle > threshold)
            * (middle >= corners[:-2, 1:-1]) * (middle > corners[2:, 1:-1])
            * (middle >= corners[1:-1, :-2]) * (middle > corners[1:-1, 2:])
            * (middle >= corners[:-2, :-2]) * (middle > corners[2:, :-2])
            * (middle >= corners[:-2, 2:]) * (middle > corners[2:, 2:]))
    cp = np.argwhere(candidates)
    if len(cp) > 10000:
        raise RuntimeError(
            "too many peaks, threshold too low or acquisition time too high")

    res = []
    for cy, cx in cp:
        spot = image[cy: cy + 4, cx: cx + 4] - base
        mx = np.average(np.arange(cx, cx + 4), weights=spot.sum(axis=0))
        my = np.average(np.arange(cy, cy + 4), weights=spot.sum(axis=1))
        my -= (curvature[0] + curvature[1] * mx) * mx
        res.append((my, mx))
    return res


centroid = _new_centroid


def decentroid(res):
    res = np.array(res)
    ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
    for cy, cx in res:
        if cx > 0 and cy > 0:
            ret[int(cy), int(cx)] += 1
    return ret


# -----------------------------------------------------------------------------
# Integral

FACTOR = 3
RANGE = [300, 400]
BINS = abs(np.subtract(*RANGE)) * FACTOR


def parabola(x, a, b, c=0):
    return (a * x + b) * x + c


def integrate(image, factor=FACTOR, range=RANGE, curvature=(CURVE_A, CURVE_B), ):
    image = image - image.mean()
    x = np.arange(image.shape[1])[None, :]
    y = np.arange(image.shape[0])[:, None]
    ys = factor * (y - parabola(x, curvature[1], curvature[0]))
    ysf = np.floor(ys)
    rang = (factor * range[0], factor * range[1])
    bins = rang[1] - rang[0]
    lhy, lhx = np.histogram(ysf.ravel(), bins=bins, weights=((ys - ysf) * image).ravel(), range=rang)
    rhy, rhx = np.histogram((ysf + 1).ravel(), bins=bins, weights=(((ysf + 1) - ys) * image).ravel(), range=rang)
    lvy, lvx = np.histogram(ysf.ravel(), bins=bins, weights=(ys - ysf).ravel(), range=rang)
    rvy, rvx = np.histogram((ysf + 1).ravel(), bins=bins, weights=((ysf + 1) - ys).ravel(), range=rang)
    return (lhy + rhy) / (lvy + rvy)


# -----------------------------------------------------------------------------
# hRIXS class


class hRIXS:
    # run
    PROPOSAL = 2769

    # image range
    X_RANGE = np.s_[:]
    Y_RANGE = np.s_[:]

    # centroid
    THRESHOLD = None  # pixel counts above which a hit candidate is assumed
    STD_THRESHOLD = 3.5  # same as THRESHOLD, in standard deviations
    CURVE_A = CURVE_A  # curvature parameters as determined elsewhere
    CURVE_B = CURVE_B

    # integral
    FACTOR = FACTOR
    RANGE = RANGE
    BINS = abs(np.subtract(*RANGE)) * FACTOR

    METHOD = 'centroid'  # ['centroid', 'integral']
    
    # Dark image and mask treatment
    USE_DARK = False
    USE_DARK_MASK = False
    DARK_MASK_THRESHOLD = 100
    MASK_AVG_X = np.s_[1850:2000]
    MASK_AVG_Y = np.s_[500:1500]
    
    # Early energy calibration from runs 31-36 (to estimate mono drift)
    PIX2EV_POLY = [6.31196512e-02, 6.05502748e+02]
    # Width of fit window to calculate hv (estimated mono ± this)
    HV_FIT_DELTA = 30
    # Gaussian fit parameters for elastic line finding
    HV_FIT_SIGMA = 2

    ENERGY_INTERCEPT = 0
    ENERGY_SLOPE = 1

    FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm']

    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key.upper(), value)

    def get_params(self, *params):
        if not params:
            params = ('proposal', 'x_range', 'y_range',
                      'threshold', 'curve_a', 'curve_b',
                      'factor', 'range', 'bins',
                      'method', 'fields')
        return {param: getattr(self, param.upper()) for param in params}
    
    def pixel2energy(self, pixel):         
        # Calculates energy based on pixel position from PIX2EV_POLY (re-fit as needed).
        return np.polyval(self.PIX2EV_POLY, pixel)

    def energy2pixel(self, ev):
        # Calculates the expected pixel position of a given energy.
        # Only works (somewhat) predictably for 1st and 2nd order polynomials!
        if len(self.PIX2EV_POLY)>3:
            # The length of PIX2EV_POLY is order+1
            raise ValueError('Too high polynomial order!')
        epoly = np.poly1d(self.PIX2EV_POLY)
        pixel = (epoly-ev).roots[-1]
        if np.imag(pixel) == 0:
            return pixel
        else:
            raise ValueError('Complex root! (Out of fit bounds?)')

    def from_run(self, runNB, proposal=None, extra_fields=()):
        if proposal is None:
            proposal = self.PROPOSAL
        run, data = tb.load(proposal, runNB=runNB,
                            fields=self.FIELDS + list(extra_fields))

        return data

    
    def load_dark(self, runNB, proposal=None, use_dark=True, mask=True,
                  mask_threshold=None):
        #*************************************************************
        # Loads a dark image and assigns it to the hRIXS instance
        # In addition sets attributes whether or not 
        # - hot pixels are identified and masked out
        # - the dark image is to be used in background subtraction
        # In addition a threshold value for hot pixel mask generation
        # can be given.
        #*************************************************************
        if mask_threshold == None:
            mask_threshold = self.DARK_MASK_THRESHOLD
        #*************************************************************
        # If given a list of runs, iterate over them. 
        # Otherwise read one. Give an exception if neither is the case.
        #*************************************************************
        if type(runNB) == type([]):
            data_list = []
            for run in runNB:
                data_list.append(self.from_run(run, proposal))
            data = xr.concat(data_list, dim='trainId')
        elif type(runNB) == type(1):
            data = self.from_run(runNB, proposal)
        else:
            raise Exception('load_dark() expects a list of indeces or an integer.')      
        #*************************************************************
        # Store the dark image (mean over aqs.) in two formats
        #*************************************************************            
        self.dark_image = data['hRIXS_det'].mean(dim='trainId')
        self.dark_im_array = self.dark_image.to_numpy()
        #*************************************************************
        # Set a flag whether the dark image is to be used later
        #*************************************************************  
        if use_dark:
            self.USE_DARK = True
        #*************************************************************
        # If hot/dead pixel masking is requested, find the mask and
        # set a flag in the instance. Set the masked dark values to 
        # mean intensity.
        #*************************************************************  
        if mask:
            dark_avg = np.mean(self.dark_im_array[self.MASK_AVG_Y,
                                                              self.MASK_AVG_X], (0, 1))
            self.dark_mask = self.dark_im_array > dark_avg + mask_threshold
            self.dark_im_array_m = np.array(self.dark_im_array)
            self.dark_im_array_m[self.dark_mask] = dark_avg
            self.USE_DARK_MASK = True
        return
    

    def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs):
        data = self.from_run(runNB, proposal)
        image = data['hRIXS_det'].sum(dim='trainId') \
                .to_numpy()[self.X_RANGE, self.Y_RANGE].T
        if args is None:
            spec = (image - image[:10, :].mean()).mean(axis=1)
            mean = np.average(np.arange(len(spec)), weights=spec)
            args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3,
                    spec.max(), image.mean())
        args = find_curvature(image, args, plot=plot, **kwargs)
        self.CURVE_B, self.CURVE_A, *_ = args
        return self.CURVE_A, self.CURVE_B

    
    def centroid(self, data, bins=None, return_hits=False, fit_elastic=False,
                 hv=None, fit_delta=None):
        #*************************************************************
        # Carry out hit finding on data and bin them in grid
        # Allows for 
        # - redifining the bins
        # - extraction of the hit positions from each image
        # The function will use dark images and hot pixel mask as
        # given by
        # self.USE_DARK
        # self.USE_DARK_MASK
        #*************************************************************
        #*************************************************************
        # If new bins are given, use them
        #*************************************************************
        if bins is None:
            bins = self.BINS
        #*************************************************************
        # Define empty arrays and matrix for the output data
        #*************************************************************
        hit_x = []
        hit_y = []
        hits = []
        ret = np.zeros((len(data["hRIXS_det"]), bins))
        #*************************************************************
        # If the 'actual' photon energy is desired, we'll need to know
        # where to start looking for the elastic line
        #*************************************************************
        if fit_elastic:
            elastic_pixel = np.zeros(len(data["hRIXS_det"]))
            found_hv = np.zeros(len(data["hRIXS_det"]))
            if fit_delta == None:
                fit_delta = self.HV_FIT_DELTA
            if hv == None:
                # ?? We'll make a guess later
                hv = -1 * np.ones(len(data["hRIXS_det"]))
            else:
                # hv is given as one value or a train-vector
                if type(hv) == type([]):
                    if len(hv) != len(data["hRIXS_det"]):
                        raise ValueError("Size of hv vector doesn't match trainID")
                else:
                    hv = hv * np.ones(len(data["hRIXS_det"]))
                
        #*************************************************************
        # Handle each Aq image separately
        #*************************************************************
        for i, (image, r) in enumerate(zip(data["hRIXS_det"], ret)):
            use_image = image.to_numpy()
            #*************************************************************
            # Treat background by optionally 
            # -subtracting dark image (self.USE_DARK)
            # -masking the hot pixels (self.USE_DARK_MASK)
            # Diffrent combinations of the two flags result
            # in different actions
            #*************************************************************
            if self.USE_DARK:
                #***************************************
                # subtract dark image
                #***************************************
                use_image = use_image - self.dark_im_array
                if self.USE_DARK_MASK:
                    #***************************************
                    # set masked pixels 0
                    #***************************************
                    use_image[self.dark_mask] = 0
            else:
                #***************************************
                # dont subtract dark image, but set hot
                # pixels to a dark baseline value
                #***************************************
                if self.USE_DARK_MASK:
                    use_image[self.dark_mask] = np.mean(use_image[self.MASK_AVG_Y,
                                                              self.MASK_AVG_X], (0, 1))
            #*************************************************************
            # Run centroiding on the preprocessed image
            #*************************************************************
            c = centroid(
                use_image[self.X_RANGE, self.Y_RANGE].T,
                threshold=self.THRESHOLD,
                std_threshold=self.STD_THRESHOLD,
                curvature=(self.CURVE_A, self.CURVE_B))
            if not len(c):
                continue
            rc = np.array(c)
            #*************************************************************
            # If hits have been requested, append the hit data of the 
            # image to the lists of hit lists
            #*************************************************************
            if return_hits:
                hit_x.append(rc[:, 0])
                hit_y.append(rc[:, 1])
                hits.append(rc)
            #*************************************************************
            # Assign the spectrum to the spectrum matrix ret. Iteration
            # variable r points to the proper location of ret.
            #*************************************************************
            hy, hx = np.histogram(
                rc[:, 0], bins=bins,
                range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
            r[:] = hy
            #*************************************************************
            # Return the found elastic peak energy, if requested
            #*************************************************************
            if fit_elastic:
                if hv[i] == -1:
                    # Assume elastic line to be the strongest peak
                    gx = np.where(hy == max(hy))[0][0]
                else:
                    # Use given guess for where it is
                    gx = np.where(hx + self.Y_RANGE.start > self.energy2pixel(hv[i]))[0][0]
                popt, pcov = curve_fit(gauss1d, hx[gx-fit_delta:gx+fit_delta],
                                       hy[gx-fit_delta:gx+fit_delta],
                                      p0 = [max(hx[gx-fit_delta:gx+fit_delta]),
                                           hx[gx], self.HV_FIT_SIGMA, 0],
                                      bounds=([0, 0, 0, -np.Inf],[np.Inf, np.Inf, 100, np.Inf]))
                
                elastic_pixel[i] = popt[1]+ self.Y_RANGE.start
                found_hv[i] = self.pixel2energy(popt[1]+ self.Y_RANGE.start)
                                       
        #*************************************************************
        # Setup and assing a linear energy grid
        #*************************************************************
        data = data.assign_coords(
            energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins)
            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
        #**********************************************
        # If hits were requested, assign them to data
        #**********************************************
        if return_hits:
            data = data.assign(hits=(("trainId"), hits),
                        xhits=(("trainId"), hit_x),
                        yhits=(("trainId"), hit_y))
        #**********************************************
        # If hv was fitted, return it
        #**********************************************
        if fit_elastic:
            data = data.assign(fitted_hv = (("trainId"), found_hv),
                              elastic_pixel = (("trainId"), elastic_pixel))
        #**********************************************
        # Always assign the spectrum to data
        #**********************************************
        data = data.assign(spectrum=(("trainId", "energy"), ret))
        return data

    
    def align_readouts(self,data_list,start,stop,method='max_value',ImageGrouping=1,fit_tol=None):
        import xarray as xa
        import scipy as sp
        from scipy import optimize 
        #********************************************
        # aligns spectra in a given list of data from
        #
        #   data = h.from_run(runind,2776)
        #   OR
        #   data = h.centroid(data,return_hits=True)
        #
        # One can choose to handle 
        # Grouping = 'run'
        # Grouping = integer (1 means each image separately)
        # METHOD
        # -max_value
        # -centroid
        # -autocorrelation
        # -gauss_fit
        # start and stop are values of data.energy
        # that define the range of these operations
        # RETURNS: 
        # energy grid and shifted spectra
        #********************************************
        # Accept a list of data xarrays. I needed
        # to write it this way as xarray was hard
        #********************************************
        energy  = data_list[0].energy.to_numpy()
        if type(ImageGrouping)==type(1):
            #********************************************
            # Group data to bunches of a given number of
            # Images 
            #********************************************
            data    = xa.concat(data_list,dim='trainId')
            spectra = []
            xhits   = []
            Ngroups = data.spectrum.shape[0]//ImageGrouping
            # First the groups with full length
            for groupind in range(Ngroups):
                spectra.append(np.sum(data.spectrum[groupind*ImageGrouping:(groupind+1)*ImageGrouping].to_numpy(),axis=0))
                xhits.append(np.hstack(data.xhits[groupind*ImageGrouping:(groupind+1)*ImageGrouping].to_numpy()))
            # if the leftover tail is more than a half of full length, add it too
            if data.spectrum.shape[0]%ImageGrouping > ImageGrouping/2:
                spectra.append(np.sum(data.spectrum[Ngroups*ImageGrouping::].to_numpy(),axis=0))
                xhits.append(np.hstack(data.xhits[Ngroups*ImageGrouping::].to_numpy()))
        elif ImageGrouping.lower() == 'run' or ImageGrouping.lower() == 'runs':
            #********************************************
            # Group data by runs
            #********************************************
            energy  = data_list[0].energy.to_numpy()
            spectra = []
            xhits   = []
            for rundata in data_list:
                runspectrum = np.zeros(energy.shape)
                runxhits    = []
                for spec,hits in zip(rundata.spectrum.to_numpy(),rundata.xhits.to_numpy()):
                    runspectrum += spec
                    runxhits.append(hits)
                runxhits = np.hstack(runxhits)
                spectra.append(runspectrum)
                xhits.append(runxhits)
        else:
            raise Exception('align_readouts() needs a reasonable grouping argument')
        #********************************************
        #********************************************
        # PART 1: find peak positions
        #********************************************
        #********************************************
        searchinds = (energy >=start)*(energy <=stop)
        peak_posis = []
        #********************************************
        # Simple maximum alignment
        #********************************************
        if method.lower() == 'max_value':
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for ind,spec in enumerate(spectra):
                x          = energy[searchinds]
                y          = spec[searchinds]
                maxipos    = np.argmax(y)
                peak_posis.append(x[maxipos])
        #********************************************
        # Centroid
        #********************************************
        elif method.lower() == 'centroid':
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for ind,spec in enumerate(spectra):
                x          = energy[searchinds]
                y          = spec[searchinds]
                maxipos    = np.average(x,weights=y)
                peak_posis.append(maxipos)
        #********************************************
        # Alignment based on autocorrelation
        # this is a relative alignment method
        # where 1st readout defines the abs scale
        #********************************************
        elif method.lower() == 'autocorrelation':
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for ind,spec in enumerate(spectra):
                if ind == 0:
                    x0       = energy[searchinds]
                    y0       = spec[searchinds]
                    maxipos0 = np.argmax(spec[searchinds])
                    peak_posis.append(x0[maxipos0])
                else:        
                    x        = energy[searchinds]
                    y        = spec[searchinds]
                    corr_len = np.sum(searchinds)
                    corr     = sp.signal.correlate(y,y0, mode='full')
                    maxpos   = np.argmax(corr)
                    shift    = maxpos-corr_len
                    peak_posis.append(x[maxipos0+shift])
        #********************************************
        # Alignment based on Gaussian fitting
        #********************************************
        elif method.lower() == 'gauss_fit':
            if fit_tol == None:
                raise Exception('Gauss fit requires a tolerance value.')
            #********************************************
            # Define needed functions
            #********************************************
            def Gauss(grid,x0,sigma):
                # Returns a normalized bell curve
                # with center at x0 and std sigma
                # on grid
                return 1.0/(sigma*np.sqrt(2.0*np.pi))*np.exp(-0.5*(grid-x0)**2/sigma**2)

            def Cost(p,grid,spec):
                # Returns squared error of spec and a bell curve presented on grid
                return np.sum(np.square(p[0]*Gauss(grid,p[1],p[2])+p[3]-spec))
            #********************************************
            # Find the max for each of the spectra
            #********************************************
            for ind,spec in enumerate(spectra):
                x    = energy[searchinds]
                y    = spec[searchinds]
                #********************************************
                # Initial Guess and bounds
                #********************************************
                area = np.sum(y)
                mean = np.average(x,weights=y)
                std  = np.sqrt(np.average((x-mean)**2,weights=y))
                p0   = [area/2, mean, std, 0]
                #********************************************
                # Bounds
                #********************************************
                bnds       = [[0,None],[start,stop],[0,0.5*(stop-start)],[0,np.max(y)]]
                #********************************************
                # Fit by minimizing least squares error
                #********************************************
                p          = optimize.minimize(Cost,p0,args=(x,y),bounds=bnds,method='L-BFGS-B',tol=fit_tol,
                                            options={'disp':0,'maxiter':1000000})
                if p.success:
                    peak_posis.append(p.x[1])
                    #plt.figure()
                    #plt.plot(x,y,'.')
                    #plt.plot(x,p0[0]*Gauss(x,p0[1],p0[2])+p0[3])
                    #plt.plot(x,p.x[0]*Gauss(x,p.x[1],p.x[2])+p.x[3])
                else:
                    #********************************************
                    # If fitting fails plot why and quit
                    #********************************************
                    plt.figure()
                    plt.plot(x,y,'.')
                    plt.plot(x,p0[0]*Gauss(x,p0[1],p0[2])+p.x[3])
                    plt.plot(x,p.x[0]*Gauss(x,p.x[1],p.x[2])+p.x[3])
                    raise Exception('align_readouts(): can not fit a gaussian to the data.')
        else: 
            raise Exception('align_readouts() did not recognize the method.')
        #********************************************
        #********************************************
        # PART 2: shift aquisitions
        # For max_value and autocorrelation,
        # the shift is integer bins and no rebinning
        # is needed
        # FOR gaussian fit, rebinning the given
        # hit lists
        #********************************************
        #********************************************
        if method.lower() == 'max_value' or method.lower() == 'autocorrelation':
            #********************************************
            # Since peak_posis[0] matches with one grid
            # point, the shift in energy scale will be
            # an integer multiple of grid spacing. Thus
            # we can subtract and interpolation will 
            # not cause a problem.
            #********************************************
            energy_aligned=energy-peak_posis[0]
            ret = []
            for position,spec in zip(peak_posis,spectra):
                grid = energy_aligned
                x    = energy - position
                y    = spec
                ret.append(np.interp(grid,x,y))
            spectrum_aligned = np.array(ret)
        elif method.lower() == 'centroid' or method.lower() == 'gauss_fit':
            e0  = int(peak_posis[0])
            energy_aligned = np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, self.BINS)-e0
            ret = []
            for position,hits in zip(peak_posis,xhits):
                spectrum = np.histogram(hits-position,bins=self.BINS,
                                    range=(0-e0, self.Y_RANGE.stop - self.Y_RANGE.start-e0))[0]
                ret.append(spectrum)
            spectrum_aligned = np.array(ret)
        else:
            raise Exception('align_readouts() did not recognize the method.')
        
        return energy_aligned,spectrum_aligned
    
    
    def integrate(self, data):
        bins = self.Y_RANGE.stop - self.Y_RANGE.start
        ret = np.zeros((len(data["hRIXS_det"]), bins - 20))
        for image, r in zip(data["hRIXS_det"], ret):
            if self.USE_DARK:
                image = image - self.dark_image
            r[:] = integrate(image.to_numpy()[self.X_RANGE, self.Y_RANGE].T, factor=1,
                             range=(10, bins - 10),
                             curvature=(self.CURVE_A, self.CURVE_B))
        data = data.assign_coords(
            energy=np.arange(self.Y_RANGE.start + 10, self.Y_RANGE.stop - 10)
            * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
        return data.assign(spectrum=(("trainId", "energy"), ret))

    aggregators = dict(
        hRIXS_det=lambda x, dim: x.sum(dim=dim),
        Delay=lambda x, dim: x.mean(dim=dim),
        hRIXS_delay=lambda x, dim: x.mean(dim=dim),
        hRIXS_norm=lambda x, dim: x.sum(dim=dim),
        spectrum=lambda x, dim: x.sum(dim=dim),
    )

    def aggregator(self, da, dim):
        agg = self.aggregators.get(da.name)
        if agg is None:
            return None
        return agg(da, dim=dim)

    def aggregate(self, ds, dim="trainId"):
        ret = ds.map(self.aggregator, dim=dim)
        ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
        return ret

    def normalize(self, data):
        return data.assign(normalized=data["spectrum"] / data["hRIXS_norm"])