"""
Beam splitting Off-axis Zone plate analysis routines.

Copyright (2021) SCS Team.
"""

import time
import datetime
import json

import numpy as np
import xarray as xr
import dask.array as da
from scipy.optimize import minimize

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm

from extra_data import open_run
from extra_geom import DSSC_1MGeometry

__all__ = [
    'load_dssc_module',
    'inspect_dark',
    'average_module',
    'plane_fitting',
    'compute_flat_field_correction',
    'nl_crit',
    'inspect_correction',
    'find_rois',
    'nl_domain'
]


class parameters():
    """Parameters contains all input parameters for the BOZ corrections.

    This is used in beam splitting off-axis zone plate spectrocopy analysis as
    well as the during the determination of correction parameters themselves to
    ensure they can be reproduced.

    Inputs
    ------
    proposal: int, proposal number
    darkrun: int, run number for the dark run
    run: int, run number for the data run
    module: int, DSSC module number
    gain: float, number of ph per bin
    """
    def __init__(self, proposal, darkrun, run, module, gain):
        self.proposal = proposal
        self.darkrun = darkrun
        self.run = run
        self.module = module
        self.gain = gain
        self.mask_idx = None
        self.mean_th = (None, None)
        self.std_th = (None, None)
        self.rois = None
        self.rois_th = None
        self.flat_field = None
        self.flat_field_prod_th = (5.0, np.PINF)
        self.flat_field_ratio_th = (np.NINF, 1.2)
        self.plane_guess_fit = None
        self.Fnl = None
        self.alpha = None
        self.sat_level = None
        self.max_iter = None


        # temporary data
        self.arr_dark = None
        self.tid_dark = None
        self.arr = None
        self.tid = None

    def dask_load_persistently(self):
        """Load dask data array in memory."""
        self.arr_dark, self.tid_dark = load_dssc_module(self.proposal,
            self.darkrun, self.module, drop_intra_darks=True, persist=True)
        self.arr, self.tid = load_dssc_module(self.proposal, self.run,
            self.module, drop_intra_darks=True, persist=True)

        # make sure to rechunk the arrays
        self.arr = self.arr.rechunk((100, -1, -1, -1))
        self.arr_dark = self.arr_dark.rechunk((100, -1, -1, -1))

    def set_mask(self, arr):
        """Set mask of bad pixels.

        Inputs
        ------
        arr: either a boolean array of a DSSC module image or a list of bad
            pixel indices
        """
        if type(arr) is not list:
            self.mask_idx = np.argwhere(arr == False).tolist()
            self.mask = arr
        else:
            self.mask_idx = arr
            mask = np.ones((128, 512), dtype=bool)
            for k in self.mask_idx:
                mask[k[0], k[1]] = False
            self.mask = mask

    def get_mask(self):
        """Get the boolean array bad pixel of a DSSC module."""
        return self.mask

    def get_mask_idx(self):
        """Get the list of bad pixel indices."""
        return self.mask_idx

    def set_flat_field(self, plane,
            prod_th=None, ratio_th=None):
        """Set the flat field plane definition."""
        if type(plane) is not list:
            self.flat_field = plane.tolist()
        else:
            self.flat_field = plane
        if prod_th is not None:
            self.flat_field_prod_th = prod_th
        if ratio_th is not None:
            self.flat_field_ratio_th = ratio_th

    def get_flat_field(self):
        """Get the flat field plane definition."""
        if self.flat_field is None:
            return None
        else:
            return np.array(self.flat_field)

    def set_Fnl(self, Fnl):
        """Set the non-linear correction function."""
        if isinstance(Fnl, list):
            self.Fnl = Fnl
        else:
            self.Fnl = Fnl.tolist()

    def get_Fnl(self):
        """Get the non-linear correction function."""
        if self.Fnl is None:
            return None
        else:
            return np.array(self.Fnl)

    def save(self, path='./'):
        """Save the parameters as a JSON file.

        Inputs
        ------
        path: str, where to save the file, default to './'
        """
        v = {}
        v['proposal'] = self.proposal
        v['darkrun'] = self.darkrun
        v['run'] = self.run
        v['module'] = self.module
        v['gain'] = self.gain

        v['mask'] = self.mask_idx
        v['mean_th'] = self.mean_th
        v['std_th'] = self.std_th

        v['rois'] = self.rois
        v['rois_th'] = self.rois_th

        v['flat_field'] = self.flat_field
        v['flat_field_prod_th'] = self.flat_field_prod_th
        v['flat_field_ratio_th'] = self.flat_field_ratio_th
        v['plane_guess_fit'] = self.plane_guess_fit

        v['Fnl'] = self.Fnl
        v['alpha'] = self.alpha
        v['sat_level'] = self.sat_level
        v['max_iter'] = self.max_iter

        fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json'

        with open(path + fname, 'w') as f:
            json.dump(v, f)
            print(path + fname)

    @classmethod
    def load(cls, fname):
        """Load parameters from a JSON file.

        Inputs
        ------
        fname: string, name a the JSON file to load
        """
        with open(fname, 'r') as f:
            v = json.load(f)
        c = cls(v['proposal'], v['darkrun'], v['run'], v['module'], v['gain'])

        c.mean_th = v['mean_th']
        c.std_th = v['std_th']
        c.set_mask(v['mask'])

        c.rois = v['rois']
        c.rois_th = v['rois_th']

        c.set_flat_field(v['flat_field'], v['flat_field_prod_th'], v['flat_field_ratio_th'])
        c.plane_guess_fit = v['plane_guess_fit']

        c.set_Fnl(v['Fnl'])
        c.alpha = v['alpha']
        c.sat_level = v['sat_level']
        c.max_iter = v['max_iter']

        return c

    def __str__(self):
        f = f'proposal:{self.proposal} darkrun:{self.darkrun} run:{self.run}'
        f += f' module:{self.module} gain:{self.gain} ph/bin\n'

        if self.mask_idx is not None:
            f += f'mean threshold:{self.mean_th} std threshold:{self.std_th}\n'
            f += f'mask:(#{len(self.mask_idx)}) {self.mask_idx}\n'
        else:
            f += 'mask:None\n'

        f += f'rois threshold: {self.rois_th}\n'
        f += f'rois: {self.rois}\n'

        f += f'flat field p: {self.flat_field} prod:{self.flat_field_prod_th} ratio:{self.flat_field_ratio_th}\n'
        f += f'plane guess fit: {self.plane_guess_fit}\n'

        if self.Fnl is not None:
            f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
            f += f'alpha:{self.alpha}, sat. level:{self.sat_level}, '
            f += f' max. iter.:{self.max_iter}'
        else:
            f += 'Fnl: None'

        return f

def _get_pixel_pos():
    """Compute the pixel position on hexagonal lattice of DSSC module 15"""
    # module pixel position
    dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
    g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)

    # keeping only module 15 pixel X,Y position
    return g.get_pixel_positions()[15][:, :, :2]


def _plane_flat_field(p, roi):
    """Compute the p plane over the given roi.

    Given the plane parameters p, compute the plane over the roi
    size.

    Parameters
    ----------
    p: a vector of a, b, c, d plane parameter with the
       plane given by ax+ by + cz + d = 0
    roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']

    Returns
    -------
    the plane field given by p evaluated on the roi
    extend.
    """
    a, b, c, d = p

    # DSSC pixel position on hexagonal lattice
    pixel_pos = _get_pixel_pos()
    pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]

    Z = -(a*pos[:, :, 0] + b*pos[:, :, 1] + d)/c

    return Z


def compute_flat_field_correction(rois, p, plot=False):
    """Compute the plane field correction on beam rois.

    Inputs
    ------
    rois: dictionnary of beam rois['n', '0', 'p']
    p: plane vector
    plot: boolean, True by default, diagnostic plot

    Returns
    -------
    numpy 2D array of the flat field correction evaluated over one DSSC ladder
    (2 sensors)
    """
    flat_field = np.ones((128, 512))

    r = 'n'
    flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
                       _plane_flat_field(p[:4], rois[r])
    r = 'p'
    flat_field[rois[r]['yl']:rois[r]['yh'], rois[r]['xl']:rois[r]['xh']] = \
                       _plane_flat_field(p[4:], rois[r])

    if plot:
        f, ax = plt.subplots(1, 1, figsize=(6, 2))
        img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
        f.colorbar(img, ax=[ax], label='amplitude')
        ax.set_xlabel('px')
        ax.set_ylabel('px')
        ax.set_aspect('equal')

    return flat_field


def nl_domain(N, low, high):
    """Create the input domain where the non-linear correction defined.

    Inputs
    ------
    N: integer, number of control points or intervals
    low: input values below or equal to low will not be corrected
    high: input values higher or equal to high will not be corrected

    Returns
    -------
    array of 2**9 integer values with N segments
    """
    x = np.arange(2**9)
    vx = x.copy()
    eps = 1e-5
    vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
    vx[x <= low] = 0
    vx[x >= high] = 0

    return vx


def nl_lut(domain, dy):
    """Compute the non-linear correction.

    Inputs
    ------
    domain: input domain where dy is defined. For zero no correction is
        defined. For non-zero value x, dy[x] is applied.
    dy: a vector of deviation from linearity on control point homogeneously
        dispersed over 9 bits.

    Returns
    -------
    F_INL: default None, non linear correction function given as a
           lookup table with 9 bits integer input
    """
    x = np.arange(2**9)
    ndy = np.insert(dy, 0, 0)  # add zero to dy

    f = x + ndy[domain]

    return f


def find_rois(data_mean, threshold):
    """Find rois from 3 beams configuration.

    Inputs
    ------
    data_mean: dark corrected average image
    threshold: threshold value to find beams

    Returns
    -------
    rois: dictionnary of rois
    """
    # compute vertical and horizontal projection
    pX = data_mean.mean(axis=0)
    pX = pX[:256]  # half the ladder since there is a gap in the middle
    pY = data_mean.mean(axis=1)

    # along X
    lowX = int(np.argmax(pX[:64] > threshold))  # 1st occurrence returned
    highX = int(np.argmax(pX[192:] <= threshold) + 192)  # 1st occ. returned

    leftX = int(np.argmin(pX[64:128]) + 64)
    rightX = int(np.argmin(pX[128:192]) + 128)

    # along Y
    lowY = int(np.argmax(pY[:64] > threshold))  # 1st occurrence returned
    highY = int(np.argmax(pY[64:] < threshold) + 64)  # 1st occ. returned

    # define rois
    rois = {}
    # baseline correction rois
    for k in [0, 1, 2, 3]:
        rois[f'b{k}'] = {'xl': k*64, 'xh': (k+1)*64, 'yl': 0, 'yh': lowY}
    for k in [8, 9, 10, 11]:
        rois[f'b{k}'] = {'xl': (k-8)*64, 'xh': (k+1-8)*64,
                         'yl': highY, 'yh': 128}

    # beam roi
    rois['n'] = {'xl': lowX, 'xh': leftX, 'yl': lowY, 'yh': highY}
    rois['0'] = {'xl': leftX, 'xh': rightX, 'yl': lowY, 'yh': highY}
    rois['p'] = {'xl': rightX, 'xh': highX, 'yl': lowY, 'yh': highY}

    # saturation roi
    rois['sat'] = {'xl': lowX, 'xh': highX, 'yl': lowY, 'yh': highY}

    # ASICs splitted beam roi
    rois['0X'] = {'xl': lowX, 'xh': 1*64, 'yl': lowY, 'yh': 64}
    rois['1X1'] = {'xl': 64, 'xh': leftX, 'yl': lowY, 'yh': 64}

    rois['1X2'] = {'xl': leftX, 'xh': 2*64, 'yl': lowY, 'yh': 64}
    rois['2X1'] = {'xl': 2*64, 'xh': rightX, 'yl': lowY, 'yh': 64}

    rois['2X2'] = {'xl': rightX, 'xh': 3*64, 'yl': lowY, 'yh': 64}
    rois['3X'] = {'xl': 3*64, 'xh': highX, 'yl': lowY, 'yh': 64}

    rois['8X'] = {'xl': lowX, 'xh': 1*64, 'yl': 64, 'yh': highY}
    rois['9X1'] = {'xl': 64, 'xh': leftX, 'yl': 64, 'yh': highY}

    rois['9X2'] = {'xl': leftX, 'xh': 2*64, 'yl': 64, 'yh': highY}
    rois['10X1'] = {'xl': 2*64, 'xh': rightX, 'yl': 64, 'yh': highY}

    rois['10X2'] = {'xl': rightX, 'xh': 3*64, 'yl': 64, 'yh': highY}
    rois['11X'] = {'xl': 3*64, 'xh': highX, 'yl': 64, 'yh': highY}

    return rois


def find_rois_from_params(params):
    """Find rois from 3 beams configuration.

    Inputs
    ------
    params: parameters

    Returns
    -------
    rois: dictionnary of rois
    """
    assert params.arr_dark is not None, "Data not loaded"
    dark = average_module(params.arr_dark).compute()

    assert params.arr is not None, "Data not loaded"
    data = average_module(params.arr, dark=dark).compute()
    data_mean = data.mean(axis=0)  # mean over pulseId
    threshold = params.rois_th

    return find_rois(data_mean, threshold)


def inspect_rois(data_mean, rois, threshold=None, allrois=False):
    """Find rois from 3 beams configuration from mean module image.

    Inputs
    ------
    data_mean: mean module image
    threshold: float, default None, threshold value used to detect beams
        boundaries
    allrois: boolean, default False, plot all rois defined in rois or only the
        main ones (['n', '0', 'p'])

    Returns
    -------
    matplotlib figure
    """
    # compute vertical and horizontal projection
    pX = data_mean.mean(axis=0)
    pX = pX[:256]  # half the ladder since there is a gap in the middle
    pY = data_mean.mean(axis=1)

    # Set up the axes with gridspec
    fig = plt.figure(figsize=(5, 3))
    grid = plt.GridSpec(2, 2,  width_ratios=(1, 4), height_ratios=(2, 1),
                        # left=0.1, right=0.9, bottom=0.1, top=0.9,
                        wspace=0.05, hspace=0.05)
    main_ax = fig.add_subplot(grid[0, 1])
    y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax)
    x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax)

    # scatter points on the main axes
    Xs = np.arange(len(pX))
    Ys = np.arange(len(pY))
    main_ax.pcolormesh(Xs, Ys, np.flipud(data_mean[:, :256]),
                       cmap='Greys_r',
                       vmin=0,
                       vmax=np.percentile(data_mean[:, :256], 99))
    main_ax.set_aspect('equal')

    from matplotlib.patches import Rectangle
    roi = rois['n']
    main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
        roi['xh'] - roi['xl'],
        roi['yh'] - roi['yl'],
        alpha=0.3, color='b'))
    roi = rois['0']
    main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
        roi['xh'] - roi['xl'],
        roi['yh'] - roi['yl'],
        alpha=0.3, color='g'))
    roi = rois['p']
    main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
        roi['xh'] - roi['xl'],
        roi['yh'] - roi['yl'],
        alpha=0.3, color='r'))

    x.plot(Xs, pX)
    x.invert_yaxis()
    if threshold is not None:
        x.axhline(threshold, c='k', alpha=.5)
    x.axvline(rois['n']['xl'], c='r', alpha=.5)
    x.axvline(rois['0']['xl'], c='r', alpha=.6)

    x.axvline(rois['p']['xl'], c='r', alpha=.7)
    x.axvline(rois['p']['xh'], c='r', alpha=.8)

    y.plot(pY, np.arange(len(pY)-1, -1, -1))
    y.invert_xaxis()
    if threshold is not None:
        y.axvline(threshold, c='k', alpha=.5)
    y.axhline(127-rois['p']['yl'], c='r', alpha=.5)
    y.axhline(127-rois['p']['yh'], c='r', alpha=.6)

    return fig


def histogram_module(arr, mask=None):
    """Compute a histogram of the 9 bits raw pixel values over a module.

    Inputs
    ------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    mask: optional bad pixel mask

    Returns
    -------
    histogram
    """
    if mask is not None:
        w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
                arr.shape[1], axis=1), arr.shape[0], axis=0)
        w = w.rechunk((100, -1, -1, -1))
        return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
    else:
        return da.bincount(arr.ravel(), minlength=512).compute()


def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
    """Compute and plot a histogram of the 9 bits raw pixel values.

    Inputs
    ------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
    mask: optional bad pixel mask
    extra_lines: boolean, default False, plot extra lines at period values

    Returns
    -------
    (h, hd): histogram of arr, arr_dark
    figure
    """
    from matplotlib.ticker import MultipleLocator

    f = plt.figure(figsize=(6, 3))
    ax = plt.gca()
    h = histogram_module(arr, mask=mask)
    Sum_h = np.sum(h)
    ax.plot(np.arange(2**9), h/Sum_h, marker='o',
            ms=3, markerfacecolor='none', lw=1)

    if arr_dark is not None:
        hd = histogram_module(arr_dark, mask=mask)
        Sum_hd = np.sum(hd)
        ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
                ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
    else:
        hd = None

    if extra_lines:
        for k in range(50, 271):
            if not (k - 2) % 8:
                ax.axvline(k, c='k', alpha=0.5, ls='--')
            if not (k - 3) % 16:
                ax.axvline(k, c='g', alpha=0.3, ls='--')
            if not (k - 7) % 32:
                ax.axvline(k, c='r', alpha=0.3, ls='--')

    ax.axvline(271, c='C1', alpha=0.5, ls='--')

    ax.set_xlim([0, 2**9-1])
    ax.set_yscale('log')
    ax.xaxis.set_minor_locator(MultipleLocator(10))
    ax.set_xlabel('DSSC pixel value')
    ax.set_ylabel('count frequency')

    return (h, hd), f


def load_dssc_module(proposalNB, runNB, moduleNB=15,
                     subset=slice(None), drop_intra_darks=True, persist=False):
    """Load single module dssc data as dask array.

    Inputs
    ------
    proposalNB: proposal number
    runNB: run number
    moduleNB: default 15, module number
    subset: default slice(None), subset of trains to load
    drop_intra_darks: boolean, default True, remove intra darks from the data
    persist: default False, load all data persistently in memory

    Returns
    -------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    tid: array of train id number
    """
    run = open_run(proposal=proposalNB, run=runNB)

    # DSSC
    source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
    key = 'image.data'

    arr = run[source, key][subset].dask_array()
    # fix 256 value becoming spuriously 0 instead
    arr[arr == 0] = 256

    ppt = run[source, key][subset].data_counts()
    # ignore train with no pulses, can happen in burst mode acquisition
    ppt = ppt[ppt > 0]
    tid = ppt.index.to_numpy()

    ppt = np.unique(ppt)
    assert ppt.shape[0] == 1, "number of pulses changed during the run"
    ppt = ppt[0]

    # reshape in trainId, pulseId, 2d-image
    arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])

    # drop intra darks
    if drop_intra_darks:
        arr = arr[:, ::2, :, :]

    # load data in memory
    if persist:
        arr = arr.persist()

    return arr, tid


def average_module(arr, dark=None, ret='mean',
                   mask=None, sat_roi=None, sat_level=300, F_INL=None):
    """Compute the average or std over a module.

    Inputs
    ------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    dark: default None, dark to be substracted
    ret: string, either 'mean' to compute the mean or 'std' to compute the
         standard deviation
    mask: default None, mask of bad pixels to ignore
    sat_roi: roi over which to check for pixel with values larger than
             sat_level to drop the image from the average or std
    sat_level: int, minimum pixel value for a pixel to be considered saturated
    F_INL: default None, non linear correction function given as a
           lookup table with 9 bits integer input

    Returns
    -------
    average or standard deviation image
    """
    # F_INL
    if F_INL is not None:
        narr = arr.map_blocks(lambda x: F_INL[x])
    else:
        narr = arr

    if mask is not None:
        narr = narr*mask

    if sat_roi is not None:
        temp = (da.logical_not(da.any(
                narr[:, :, sat_roi['yl']:sat_roi['yh'],
                           sat_roi['xl']:sat_roi['xh']] >= sat_level,
                               axis=[2, 3], keepdims=True)))
        not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)

    if dark is not None:
        narr = narr - dark

    if ret == 'mean':
        if sat_roi is not None:
            return da.average(narr, axis=0, weights=not_sat)
        else:
            return narr.mean(axis=0)
    elif ret == 'std':
        return narr.std(axis=0)
    else:
        raise ValueError(f'ret={ret} not supported')


def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
    """Add a colobar on a new axes so it match the plot size.

    Inputs
    ------
    im: image plotted
    ax: axes on which the image was plotted
    loc: string, default 'right', location of the colorbar
    size: string, default '5%', proportion of the colobar with respect to the
          plotted image
    pad: float, default 0.05, pad width between plot and colorbar
    """
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes(loc, size=size, pad=pad)
    cbar = fig.colorbar(im, cax=cax)

    return cbar


def bad_pixel_map(params):
    """Compute the bad pixels map.

    Inputs
    ------
    params: parameters

    Returns
    -------
    bad pixel map
    """
    assert params.arr_dark is not None, "Data not loaded"

    # compute mean and std
    dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
    dark_std = params.arr_dark.std(axis=(0, 1)).compute()

    mask = np.ones_like(dark_mean)
    if params.mean_th[0] is not None:
        mask *= dark_mean >= params.mean_th[0]
    if params.mean_th[1] is not None:
        mask *= dark_mean <= params.mean_th[1]
    if params.std_th[0] is not None:
        mask *= dark_std >= params.std_th[0]
    if params.std_th[1] is not None:
        mask *= dark_std >= params.std_th[1]

    print(f'# bad pixel: {int(128*512-mask.sum())}')

    return mask.astype(bool)


def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
    """Inspect dark run data and plot diagnostic.

    Inputs
    ------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    mean_th: tuple of threshold (low, high), default (None, None), to compute
            a mask of good pixels for which the mean dark value lie inside this
            range
    std_th: tuple of threshold (low, high), default (None, None), to compute a
            mask of bad pixels for which the dark std value lie inside this
            range

    Returns
    -------
    fig: matplotlib figure
    """
    # compute mean and std
    dark_mean = arr.mean(axis=(0, 1)).compute()
    dark_std = arr.std(axis=(0, 1)).compute()

    fig = plt.figure(figsize=(7, 2.7))
    gs = fig.add_gridspec(2, 4)
    ax1 = fig.add_subplot(gs[0, 1:])
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax11 = fig.add_subplot(gs[0, 0])

    ax2 = fig.add_subplot(gs[1, 1:])
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax22 = fig.add_subplot(gs[1, 0])

    vmin = np.percentile(dark_mean.flatten(), 2)
    vmax = np.percentile(dark_mean.flatten(), 98)
    im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
    ax1.invert_yaxis()
    ax1.set_aspect('equal')
    cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
    cbar1.ax.set_ylabel('dark mean')

    ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
            range=(vmin/2, vmax*2))
    if mean_th[0] is not None:
        ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
    if mean_th[1] is not None:
        ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
    ax11.set_yscale('log')

    vmin = np.percentile(dark_std.flatten(), 2)
    vmax = np.percentile(dark_std.flatten(), 98)
    im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
    ax2.invert_yaxis()
    ax2.set_aspect('equal')
    cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
    cbar2.ax.set_ylabel('dark std')

    ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
    if std_th[0] is not None:
        ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
    if std_th[1] is not None:
        ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
    ax22.set_yscale('log')

    return fig


def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None):
    """Extract beams roi from average image and compute the ratio.

    Inputs
    ------
    avg: module average image with no saturated shots for the flat field
         determination
    rois: dictionnary or ROIs
    prod_th, ratio_th: tuple of floats for low and high threshold on
        product and ratio
    vmin: imshow vmin level, default None will use 5 percentile value
    vmax: imshow vmax level, default None will use 99.8 percentile value

    Returns
    -------
    fig: matplotlib figure plotted
    domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order
    """
    if vmin is None:
        vmin = np.percentile(avg, 5)
    if vmax is None:
        vmax = np.percentile(avg, 99.8)

    fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9))

    img_rois = {}
    centers = {}

    for k, r in enumerate(['n', '0', 'p']):
        roi = rois[r]
        centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
                      (roi['xl'] + roi['xh'])//2])

    d = '0'
    roi = rois[d]
    for k, r in enumerate(['n', '0', 'p']):
        img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
        roi['yl']:roi['yh'], roi['xl']:roi['xh']]
        im = axs[0, k].imshow(img_rois[r],
                              vmin=vmin,
                              vmax=vmax)

    n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th)

    prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
    for k, r in enumerate(['n', '0', 'p']):
        v = img_rois[r]*img_rois['0']
        if prod_vmin is None:
            prod_vmin = np.percentile(v, .5)
            prod_vmax = np.percentile(v, 20) # we look for low intensity region
        im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
        axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2))

        v = img_rois[r]/img_rois['0']
        if ratio_vmin is None:
            ratio_vmin = np.percentile(v, 5)
            ratio_vmax = np.percentile(v, 99.8)
        im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r')
        axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))

    cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
    cbar.ax.set_xlabel('data mean')

    cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
    cbar.ax.set_xlabel('product')

    cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal")
    cbar.ax.set_xlabel('ratio')

    # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')

    domain = (n_m, p_m)

    return fig, domain

def inspect_plane_fitting(avg, rois, domain, vmin=None, vmax=None):
    """Extract beams roi from average image and compute the ratio.

    Inputs
    ------
    avg: module average image with no saturated shots for the flat field
         determination
    rois: dictionnary of rois
    vmin: imshow vmin level, default None will use 5 percentile value
    vmax: imshow vmax level, default None will use 99.8 percentile value

    Returns
    -------
    fig: matplotlib figure plotted
    """
    if vmin is None:
        vmin = np.percentile(avg, 5)
    if vmax is None:
        vmax = np.percentile(avg, 99.8)

    fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6))

    img_rois = {}
    centers = {}

    for k, r in enumerate(['n', '0', 'p']):
        roi = rois[r]
        centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
                      (roi['xl'] + roi['xh'])//2])

    d = '0'
    roi = rois[d]
    for k, r in enumerate(['n', '0', 'p']):
        img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
        roi['yl']:roi['yh'], roi['xl']:roi['xh']]
        im = axs[0, k].imshow(img_rois[r],
                              vmin=vmin,
                              vmax=vmax)

    for k, r in enumerate(['n', '0', 'p']):
        v = img_rois[r]/img_rois['0']
        im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')

    n_m, p_m = domain
    axs[1, 0].contour(n_m)
    axs[1, 2].contour(p_m)

    cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
    cbar.ax.set_xlabel('data mean')

    cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
    cbar.ax.set_xlabel('ratio')

    # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')

    return fig


def plane_fitting_domain(avg, rois, prod_th, ratio_th):
    """Extract beams roi, compute their ratio and the domain.

    Inputs
    ------
    avg: module average image with no saturated shots for the flat field
         determination
    rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
          as the reference beam in the middle
    prod_th: float tuple, low and hight threshold level to determine the plane
        fitting domain on the product image of the orders
    ratio_th: float tuple, low and high threshold level to determine the plane
        fitting domain on the ratio image of the orders

    Returns
    -------
    n: img ratio 'n'/'0'
    n_m: mask where the the product 'n'*'0' is higher than 5 indicting that the
         img ratio 'n'/'0' is defined
    p: img ratio 'p'/'0'
    p_m: mask where the the product 'p'*'0' is higher than 5 indicting that the
         img ratio 'p'/'0' is defined
    """
    centers = {}

    for k, r in enumerate(['n', '0', 'p']):
        centers[r] = np.array([(rois[r]['yl'] + rois[r]['yh'])//2,
                      (rois[r]['xl'] + rois[r]['xh'])//2])

    k = 'n'
    num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
    d = '0'
    denom = np.roll(avg, tuple(centers[k] - centers[d]))[
        rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
    n = num/denom
    prod = num*denom
    n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
            (n > ratio_th[0]) * (n < ratio_th[1]))
    n_m[~np.isfinite(n)] = 0
    n[~np.isfinite(n)] = 0

    k = 'p'
    num = avg[rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
    d = '0'
    denom = np.roll(avg, tuple(centers[k] - centers[d]))[
        rois[k]['yl']:rois[k]['yh'], rois[k]['xl']:rois[k]['xh']]
    p = num/denom
    prod = num*denom
    p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
            (p > ratio_th[0]) * (p < ratio_th[1]))
    p_m[~np.isfinite(p)] = 0
    p[~np.isfinite(p)] = 0

    return n, n_m, p, p_m


def plane_fitting(params):
    """Fit the plane flat field normalization.

    Inputs
    ------
    params: parameters

    Returns
    -------
    res: the minimization result. The fitted vector res.x = [a, b, c, d]
        defines the plane as a*x + b*y + c*z + d = 0
    """
    assert params.arr_dark is not None, "Data not loaded"
    dark = average_module(params.arr_dark).compute()
    assert params.arr is not None, "Data not loaded"
    data = average_module(params.arr, dark=dark,
        ret='mean', mask=params.mask, sat_roi=params.rois['sat'],
        sat_level=params.sat_level).compute()
    data_mean = data.mean(axis=0)  # mean over pulseId

    n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois,
        params.flat_field_prod_th, params.flat_field_ratio_th)

    def _crit(x):
        """Fitting criteria for the plane field normalization.

        Inputs
        ------
        x: 2 vector [a, b, c, d] concatenated defining the plane as
                a*x + b*y + c*z + d = 0
        """

        a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x

        num_n = a_n**2 + b_n**2 + c_n**2

        roi = params.rois['n']
        pixel_pos = _get_pixel_pos()
        # DSSC pixel position on hexagonal lattice
        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
        d0_2 = np.sum(n_m*(a_n*pos[:, :, 0] + b_n*pos[:, :, 1]
            + c_n*n + d_n)**2)/num_n

        num_p = a_p**2 + b_p**2 + c_p**2

        roi = params.rois['p']
        # DSSC pixel position on hexagonal lattice
        pos = pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], :]
        d2_2 = np.sum(p_m*(a_p*pos[:, :, 0] + b_p*pos[:, :, 1]
            + c_p*p + d_p)**2)/num_p

        return 1e3*(d2_2 + d0_2)

    if params.plane_guess_fit is None:
        p_guess_fit = [-20, 0.0, 1.5, -0.5, -20, 0, 1.5, -0.5 ]
    else:
        p_guess_fit = params.plane_guess_fit

    res = minimize(_crit, p_guess_fit)

    return res


def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
                   flat_field=None, F_INL=None):
    """Process one module and extract roi intensity.

    Inputs
    ------
    arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
    tid: array of train id number
    dark: pulse resolved dark image to remove
    rois: dictionnary of rois
    mask: default None, mask of ignored pixels
    sat_level: integer, default 511, at which level pixel begin to saturate
    flat_field: default None, flat field correction
    F_INL: default None, non linear correction function given as a
           lookup table with 9 bits integer input

    Returns
    -------
    dataset of extracted pulse and train resolved roi intensities.
    """
    # F_INL
    if F_INL is not None:
        narr = arr.map_blocks(lambda x: F_INL[x])
    else:
        narr = arr

    # apply mask
    if mask is not None:
        narr = narr*mask

    # crop rois
    r = {}
    rd = {}
    for n in rois.keys():
        r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
                    rois[n]['xl']:rois[n]['xh']]
        rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
                     rois[n]['xl']:rois[n]['xh']]

    # find saturated shots
    r_sat = {}
    for n in rois.keys():
        r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))

    # TODO: flat field should not be applied on intra darks
    # # change flat field dimension to match data
    # if flat_field is not None:
    #     temp = np.ones_like(dark)
    #     temp[::2, :, :] = flat_field[:, :]
    #    flat_field = temp

    # compute dark corrected ROI values
    v = {}
    for n in rois.keys():

        r[n] = r[n] - rd[n]

        if flat_field is not None:
            # TODO:  flat field should not be applied on intra darks
            # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
            #                 rois[n]['xl']:rois[n]['xh']]
            ff = flat_field[rois[n]['yl']:rois[n]['yh'],
                            rois[n]['xl']:rois[n]['xh']]
            r[n] = r[n]/ff

        v[n] = r[n].sum(axis=(2, 3))

    res = xr.Dataset()

    dims = ['trainId', 'pulseId']
    r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
    for n in rois.keys():
        res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
                                       coords=r_coords, dims=dims)
        res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)

    for n in rois.keys():
        roi = rois[n]
        res[n + '_area'] = xr.DataArray(np.array([
            (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))

    return res


def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
    """Process dark and run data with corrections.

    Inputs
    ------
    Fmodel: correction lookup table
    arr_dark: dark data
    arr: data
    rois: ['n', '0', 'p', 'sat'] rois
    mask: mask of good pixels
    flat_field: zone plate flat field correction
    sat_level: integer, default 511, at which level pixel begin to saturate

    Returns
    -------
    roi extracted intensities
    """
    # dark process
    res = average_module(arr_dark, F_INL=Fmodel)
    dark = res.compute()

    # data process
    proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
                          flat_field=flat_field, F_INL=Fmodel)
    data = proc.compute()

    return data


def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
    sat_level=511):
    """Criteria for the non linear correction.

    Inputs
    ------
    p: vector of dy non linear correction
    domain: domain over which the non linear correction is defined
    alpha: float, coefficient scaling the cost of the correction function
        in the criterion
    arr_dark: dark data
    arr: data
    tid: train id of arr data
    rois: ['n', '0', 'p', 'sat'] rois
    mask: mask fo good pixels
    flat_field: zone plate flat field correction
    sat_level: integer, default 511, at which level pixel begin to saturate

    Returns
    -------
    (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
    error squared from a transmission of 1.0 and err2 is the sum of the square
    of the deviation from the ideal detector response.
    """
    Fmodel = nl_lut(domain, p)
    data = process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field,
        sat_level)

    # drop saturated shots
    d = data.where(data['sat_sat'] == False, drop=True)

    v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(),
            methods=['weighted'])
    err_1 = 1e8*v_1['weighted']['s']**2

    v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(),
            methods=['weighted'])
    err_2 = 1e8*v_2['weighted']['s']**2

    err_a = np.sum((Fmodel-np.arange(2**9))**2)

    return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a


def nl_fit(params, domain):
    """Fit non linearities correction function.

    Inputs
    ------
    params: parameters
    domain: array of index

    Returns
    -------
    res: scipy minimize result. res.x is the optimized parameters

    firres: iteration index arrays of criteria results for
        [alpha=0, alpha, alpha=1]
    """
    # load data
    assert params.arr is not None, "Data not loaded"
    assert params.arr_dark is not None, "Data not loaded"

    # we only need few rois
    fitrois = {}
    for k in ['n', '0', 'p', 'sat']:
        fitrois[k] = params.rois[k]

    # p0
    N = np.unique(domain).shape[0] - 1
    p0 = np.array([0]*N)

    # flat flat_field
    ff = compute_flat_field_correction(params.rois, params.get_flat_field())

    fixed_p = (domain, params.alpha, params.arr_dark, params.arr, params.tid,
        fitrois, params.get_mask(), ff, params.sat_level)

    def fit_callback(x):
        if not hasattr(fit_callback, "counter"):
            fit_callback.counter = 0  # it doesn't exist yet, so initialize it
            fit_callback.start = time.monotonic()
            fit_callback.res = []

        now = time.monotonic()
        time_delta = datetime.timedelta(seconds=now-fit_callback.start)
        fit_callback.counter += 1

        temp = list(fixed_p)
        Jalpha = nl_crit(x, *temp)
        temp[1] = 0
        J0 = nl_crit(x, *temp)
        temp[1] = 1
        J1 = nl_crit(x, *temp)
        fit_callback.res.append([J0, Jalpha, J1])
        print(f'{fit_callback.counter-1}: {time_delta} '
                f'({J0}, {Jalpha}, {J1}), {x}')

        return False

    fit_callback(p0)
    res = minimize(nl_crit, p0, fixed_p,
        options={'disp': True, 'maxiter': params.max_iter},
        callback=fit_callback)

    return res, fit_callback.res


def inspect_nl_fit(res_fit):
    """Plot the progress of the fit.

    Inputs
    ------
    res_fit:

    Returns
    -------
    matplotlib figure
    """
    r = np.array(res_fit)
    f = plt.figure(figsize=(6, 4))
    ax = f.gca()
    ax2 = plt.twinx()
    ax.plot(1.0/np.sqrt(1e-8*r[:, 0]), c='C0')
    ax2.plot(r[:, 2], c='C1', ls='-.')
    ax.set_xlabel('# iteration')
    ax.set_ylabel('SNR', color='C0')
    ax2.set_ylabel('correction cost', color='C1')
    ax.set_yscale('log')
    ax2.set_yscale('log')

    return f


def snr(sig, ref, methods=None, verbose=False):
    """ Compute mean, std and SNR from transmitted signal sig and I0 signal ref.

    Inputs
    ------
    sig: 1D signal samples
    ref: 1D reference samples
    methods: None by default or list of strings to select which methods to use.
        Possible values are 'direct', 'weighted', 'diff'. In case of None, all
        methods will be calculated.
    verbose: booleand, if True prints calculated values

    Returns
    -------
    dictionnary of [methods][value] where value is 'mu' for mean and 's' for
    standard deviation.

    """
    if methods is None:
        methods = ['direct', 'weighted', 'diff']

    w = ref
    x = sig/ref

    mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref)

    w = w[mask]
    sig = sig[mask]
    ref = ref[mask]
    x = x[mask]

    res = {}

    # direct mean and std
    if 'direct' in methods:
        mu = np.mean(x)
        s = np.std(x)
        if verbose:
            print(f'mu: {mu}, s: {s}, snr: {mu/s}')

        res['direct'] = {'mu': mu, 's':s}

    # weighted mean and std
    if 'weighted' in methods:
        wmu = np.sum(sig)/np.sum(ref)
        v1 = np.sum(w)
        v2 = np.sum(w**2)
        ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1))

        if verbose:
            print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}')

        res['weighted'] = {'mu': wmu, 's':ws}

    # noise from diff
    if 'diff' in methods:
        dmu = np.mean(x)
        ds = np.std(np.diff(x))/np.sqrt(2)
        if verbose:
            print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}')

        res['diff'] = {'mu': dmu, 's':ds}

    return res


def inspect_correction(params, gain=None):
    """Criteria for the non linear correction.

    Inputs
    ------
    params: parameters
    gain: float, default None, DSSC gain in ph/bin

    Returns
    -------
    matplotlib figure
    """
    # load data
    assert params.arr is not None, "Data not loaded"
    assert params.arr_dark is not None, "Data not loaded"

    # we only need few rois
    fitrois = {}
    for k in ['n', '0', 'p', 'sat']:
        fitrois[k] = params.rois[k]

    # flat flat_field
    plane_ff = params.get_flat_field()
    if plane_ff is None:
        plane_ff = [0.0, 0.0, 1.0, -1.0]
    ff = compute_flat_field_correction(params.rois, plane_ff)

    # non linearities
    Fnl = params.get_Fnl()
    if Fnl is None:
        Fnl = np.arange(2**9)

    # compute all levels of correction
    data = process(np.arange(2**9), params.arr_dark, params.arr, params.tid,
        fitrois, params.get_mask(), np.ones_like(ff), params.sat_level)
    data_ff = process(np.arange(2**9), params.arr_dark, params.arr, params.tid,
        fitrois, params.get_mask(), ff, params.sat_level)
    data_ff_nl = process(Fnl, params.arr_dark, params.arr,
        params.tid, fitrois, params.get_mask(), ff, params.sat_level)

    # for conversion to nb of photons
    if gain is None:
        g = 1
    else:
        g = gain

    scale = 1e-6

    f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True)

    # nbins = np.linspace(0.01, 1.0, 100)

    photon_scale = None

    for k, d in enumerate([data, data_ff, data_ff_nl]):
        for l, (n, r) in enumerate([('n', '0'), ('p', '0'), ('n', 'p')]):

            if photon_scale is None:
                lower = 0
                upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
                photon_scale = np.linspace(lower, upper, 150)

            good_d = d.where(d['sat_sat'] == False, drop=True)
            sat_d = d.where(d['sat_sat'], drop=True)

            snr_v = snr(good_d[n].values.flatten(),
                        good_d[r].values.flatten(), verbose=True)

            m = snr_v['direct']['mu']
            h, xedges, yedges, img = axs[l, k].hist2d(
                g*scale*good_d[r].values.flatten(),
                good_d[n].values.flatten()/good_d[r].values.flatten(),
                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                cmap='Blues',
                norm=LogNorm(vmin=0.2, vmax=200),
                # alpha=0.5 # make  the plot looks ugly with lots of white lines
                )
            h, xedges, yedges, img2 = axs[l, k].hist2d(
                g*scale*sat_d[r].values.flatten(),
                sat_d[n].values.flatten()/sat_d[r].values.flatten(),
                [photon_scale, np.linspace(0.95, 1.05, 150)*m],
                cmap='Reds',
                norm=LogNorm(vmin=0.2, vmax=200),
                # alpha=0.5 # make  the plot looks ugly with lots of white lines
                )

            v = snr_v['direct']['mu']/snr_v['direct']['s']
            axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}',
                            transform = axs[l, k].transAxes)
            v = snr_v['weighted']['mu']/snr_v['weighted']['s']
            axs[l, k].text(0.4, 0.05, f'wSNR: {v:.0f}',
                            transform = axs[l, k].transAxes)

            # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
            # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')

            axs[l, k].set_ylim([0.95*m, 1.05*m])

    for k in range(3):
        #for l in range(3):
        #    axs[l, k].set_ylim([0.95, 1.05])
        if gain:
            axs[2, k].set_xlabel('#ph (10$^6$)')
        else:
            axs[2, k].set_xlabel('ADU (10$^6$)')

    f.colorbar(img, ax=axs, label='counts')

    axs[0, 0].set_title('raw')
    axs[0, 1].set_title('flat field')
    axs[0, 2].set_title('non-linear')

    axs[0, 0].set_ylabel(r'-1$^\mathrm{st}$/0$^\mathrm{th}$ order')
    axs[1, 0].set_ylabel(r'1$^\mathrm{st}$/0$^\mathrm{th}$ order')
    axs[2, 0].set_ylabel(r'-1$^\mathrm{st}$/1$^\mathrm{th}$ order')

    return f


def inspect_Fnl(Fnl):
    """Plot the correction function Fnl.

    Inputs
    ------
    Fnl: non linear correction function lookup table

    Returns
    -------
    matplotlib figure
    """
    x = np.arange(2**9)
    f = plt.figure(figsize=(6, 4))

    plt.plot(x, Fnl - x)
    # plt.axvline(40, c='k', ls='--')
    # plt.axvline(280, c='k', ls='--')
    plt.xlabel('input value')
    plt.ylabel('output correction F(x)-x')
    plt.xlim([0, 511])

    return f

def inspect_saturation(data, gain, Nbins=200):
    """Plot roi integrated histogram of the data with saturation
    
    Inputs
    ------
        data: xarray of roi integrated DSSC data
        gain: nominal DSSC gain in ph/bin
        Nbins: number of bins for the histogram, by default 200
        
    Returns
    -------
        f: handle to the matplotlib figure
        h: xarray of the histogram data
    """
    d = data.where(data['sat_sat'] == False, drop=True)
    s = data.where(data['sat_sat'] == True, drop=True)
    
    # percentage of saturated shots
    N_nonsat = d['n'].count()
    N_all = data.dims['trainId'] * data.dims['pulseId']
    sat_percent = ((N_all - N_nonsat)/N_all).values*100.0
    
    # find the bin ranges
    sum_v = {}
    low = 0
    high = 0
    scale = 1e-6
    for k in ['n', '0', 'p']:
        v = data[k].values.ravel()*scale*gain
        sum_v[k] = np.nansum(v)
        v_low, v_high = np.nanmin(v), np.nanmax(v)
        if v_low < low:
            low = v_low
        if v_high > high:
            high = v_high

    # compute bins edges, center and width
    bins = np.linspace(low, high, Nbins+1)
    bins_c = 0.5*(bins[:-1] + bins[1:])
    w = bins[1] - bins[0]

    fig, ax = plt.subplots(figsize=(6,4))

    h = {}
    for kk, k in enumerate(['n', '0', 'p']):
        v_d = d[k].values.ravel()*scale*gain
        v_s = s[k].values.ravel()*scale*gain
    
        h[k+'_nosat'], bin_e = np.histogram(v_d, bins)
        h[k+'_sat'], bin_e = np.histogram(v_s, bins)
    
        # compute density normalization on all data
        norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat']))
    
        ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm,
                        facecolor=f"C{kk}", edgecolor='none', alpha=0.2)

        ax.plot(bins_c, h[k+'_nosat']/norm, label=k,
                c=f'C{kk}', alpha=0.4)

    ax.text(0.6, 0.9, f"saturation: {sat_percent:.2f}%",
             color='r', alpha=0.5, transform=plt.gca().transAxes)
    ax.legend()
    ax.set_xlabel(r'10$^6$ ph')
    ax.set_ylabel('density')
    
    # save data as xarray dataset
    dv = {}
    for k in h.keys():
        dv[k] = {"dims": "N", "data": h[k]}
    ds = {
        "coords": {"N": {"dims": "N", "data": bins_c,
                         "attrs": {"units": f"{scale:g} ph"}}},
        "attrs": {"saturation (%)": sat_percent},
        "dims": "N",
        "data_vars": dv}
    
    return fig, xr.Dataset.from_dict(ds)