import numpy as np

class azimuthal_integrator(object):
    def __init__(self, imageshape, center, polar_range, dr=2):
        '''
        Create a reusable integrator for repeated azimuthal integration of similar
        images. Calculates array indices for a given parameter set that allows
        fast recalculation.
        
        Parameters
        ==========
        imageshape : tuple of ints
            The shape of the images to be integrated over.
            
        center : tuple of ints
            center coordinates in pixels
        
        polar_range : tuple of ints
            start and stop polar angle (in degrees) to restrict integration to wedges
        
        dr : int, default 2
            radial width of the integration slices. Takes non-square DSSC pixels into account.
        
        Returns
        =======
        ai : azimuthal_integrator instance
            Instance can directly be called with image data:
            > az_intensity = ai(image)
            radial distances and the polar mask are accessible as attributes:
            > ai.distance
            > ai.polar_mask
        '''
        self.shape = imageshape
        cx, cy = center
        sx, sy = imageshape
        xcoord, ycoord = np.ogrid[:sx, :sy]
        xcoord -= cx
        ycoord -= cy

        # distance from center, hexagonal pixel shape taken into account
        dist_array = np.hypot(xcoord * 204 / 236, ycoord)

        # array of polar angles
        tmin, tmax = np.deg2rad(np.sort(polar_range)) % np.pi
        polar_array = np.arctan2(xcoord, ycoord)
        polar_array = np.mod(polar_array, np.pi)
        self.polar_mask = (polar_array > tmin) * (polar_array < tmax)

        self.maxdist = min(sx  - cx, sy  - cy)

        ix, iy = np.indices(dimensions=(sx, sy))
        self.index_array = np.ravel_multi_index((ix, iy), (sx, sy))

        self.distance = np.array([])
        self.flat_indices = []
        for dist in range(dr, self.maxdist, dr):
            ring_mask = self.polar_mask * (dist_array >= (dist - dr)) * (dist_array < dist)
            self.flat_indices.append(self.index_array[ring_mask])
            self.distance = np.append(self.distance, dist)
    
    def __call__(self, image):
        assert self.shape == image.shape, 'image shape does not match'
        image_flat = image.flatten()
        return np.array([np.nansum(image_flat[indices]) for indices in self.flat_indices])