{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gain Characterization #\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = \"/gpfs/exfel/exp/SPB/202030/p900138/scratch/karnem/r0203_r0204_v01/\" # the folder to read histograms from, required\n",
    "out_folder = \"\"  # the folder to output to, required\n",
    "hist_file_template = \"hists_m{:02d}_sum.h5\" # the template to use to access histograms\n",
    "modules = [10] # modules to correct, set to -1 for all, range allowed\n",
    "\n",
    "raw_folder = \"/gpfs/exfel/exp/MID/202030/p900137/raw\" # Path to raw image data used to create histograms\n",
    "proc_folder = \"\" # Path to corrected image data used to create histograms\n",
    "\n",
    "run = 449 # of the run of image data used to create histograms\n",
    "\n",
    "karabo_id = \"MID_DET_AGIPD1M-1\" # karabo karabo_id\n",
    "karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators\n",
    "ctrl_source_template = '{}/MDL/FPGA_COMP' # path to control information\n",
    "karabo_id_control = \"MID_IRU_AGIPD1M1\" # karabo-id for control device\n",
    "karabo_da_control = 'AGIPD1MCTRL00' # karabo DA for control infromation\n",
    "\n",
    "use_dir_creation_date = True # use the creation data of the input dir for database queries\n",
    "cal_db_interface = \"tcp://max-exfl016:8015#8045\" # the database interface to use\n",
    "cal_db_timeout = 30000 # in milli seconds\n",
    "local_output = True # output constants locally\n",
    "db_output = False # output constants to database\n",
    "\n",
    "# Fit parameters\n",
    "peak_range = [-30, 30, 35, 70, 95, 135, 145, 220] # where to look for the peaks, [a0, b0, a1, b1, ...] exactly 8 elements\n",
    "peak_width_range = [0, 30, 0, 35, 0, 40, 0, 45] # fit limits on the peak widths, [a0, b0, a1, b1, ...] exactly 8 elements\n",
    "peak_norm_range = [0.0, -1, 0, -1, 0, -1, 0, -1] #  \n",
    "\n",
    "# Bad-pixel thresholds (gain evaluation error). Contribute to BadPixel bit \"Gain_Evaluation_Error\"\n",
    "peak_lim = [-30, 30] # Limit of position of noise peak\n",
    "d0_lim = [10, 80] # hard limits for distance between noise and first peak\n",
    "peak_width_lim = [0.9, 1.55, 0.95, 1.65] # hard limits on the peak widths for first and second peak, in units of the noise peak. 4 parameters.\n",
    "chi2_lim = [0, 3.0] # Hard limit on chi2/nDOF value\n",
    "\n",
    "intensity_lim = 15 # Threshold on standard deviation of a histogram in ADU. Contribute to BadPixel bit \"No_Entry\"\n",
    "gain_lim = [0.8, 1.2] # Threshold on gain in relative number. Contribute to BadPixel bit \"Gain_deviation\"\n",
    "\n",
    "cell_range = [1, 3] # range of cell to be considered, [0,0] for all\n",
    "pixel_range = [0, 0, 32, 32] # range of pixels x1,y1,x2,y2 to consider [0,0,512,128] for all\n",
    "max_bins = 0 # Maximum number of bins to consider, 0 for all bins\n",
    "batch_size = [1, 8, 8] # batch size: [cell,x,y]\n",
    "fit_range = [0, 0] # range of a histogram considered for fitting in ADU. Dynamically evaluated in case [0,0]\n",
    "n_peaks_fit = 4 # Number of gaussian peaks to fit including noise peak\n",
    "fix_peaks = False # Fix distance between photon peaks\n",
    "do_minos = False # This is additional feature of minuit to evaluate errors. \n",
    "sigma_limit = 0. # If >0, repeat fit keeping only bins within mu +- sigma_limit*sigma\n",
    "\n",
    "# Detector conditions\n",
    "# NOTE: The below parameters are needed for the summary notebook when running through xfel-calibrate.\n",
    "mem_cells = -1  # number of memory cells used, negative values for auto-detection. \n",
    "bias_voltage = 300  # Bias voltage. \n",
    "acq_rate = 0.  # the detector acquisition rate, use 0 to try to auto-determine.\n",
    "gain_setting = -1  # the gain setting, negative values for auto-detection.\n",
    "photon_energy = 8.05  # photon energy in keV.\n",
    "integration_time = -1  # integration time, negative values for auto-detection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "import traceback\n",
    "import warnings\n",
    "from multiprocessing import Pool\n",
    "\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import sharedmem\n",
    "import XFELDetAna.xfelpyanatools as xana\n",
    "from cal_tools.agipdutils_ff import (\n",
    "    BadPixelsFF,\n",
    "    any_in,\n",
    "    fit_n_peaks,\n",
    "    gaussian,\n",
    "    gaussian_sum,\n",
    "    get_mask,\n",
    "    get_starting_parameters,\n",
    "    set_par_limits,\n",
    ")\n",
    "from cal_tools.ana_tools import get_range, save_dict_to_hdf5\n",
    "from iminuit import Minuit\n",
    "from XFELDetAna.plotting.heatmap import heatmapPlot\n",
    "from XFELDetAna.plotting.simpleplot import simplePlot\n",
    "\n",
    "# %load_ext autotime\n",
    "%matplotlib inline\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "peak_range = np.reshape(peak_range,(4,2))\n",
    "peak_width_range = np.reshape(peak_width_range,(4,2))\n",
    "peak_width_lim = np.reshape(peak_width_lim,(2,2))\n",
    "peak_norm_range = [None if x == -1 else x for x in peak_norm_range]\n",
    "peak_norm_range = np.reshape(peak_norm_range,(4,2))\n",
    "module = modules[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def idx_gen(batch_start, batch_size):\n",
    "    \"\"\"\n",
    "    This generator iterate across pixels and memory cells starting\n",
    "    from batch_start until batch_start+batch_size\n",
    "    \"\"\"\n",
    "    for c_idx in range(batch_start[0], batch_start[0]+batch_size[0]):\n",
    "        for x_idx in range(batch_start[1], batch_start[1]+batch_size[1]):\n",
    "            for y_idx in range(batch_start[2], batch_start[2]+batch_size[2]):\n",
    "                yield(c_idx, x_idx, y_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_pixels_x = pixel_range[2]-pixel_range[0]\n",
    "n_pixels_y = pixel_range[3]-pixel_range[1]\n",
    "\n",
    "hist_data = {}\n",
    "with h5py.File(f\"{in_folder}/{hist_file_template.format(module)}\", 'r') as hf:\n",
    "    hist_data['cellId'] = np.array(hf['cellId'][()])\n",
    "    hist_data['hRange'] = np.array(hf['hRange'][()])\n",
    "    hist_data['nBins'] = np.array(hf['nBins'][()])\n",
    "    \n",
    "    if cell_range == [0,0]:\n",
    "        cell_range[1] = hist_data['cellId'].shape[0]\n",
    "        \n",
    "    if max_bins == 0:\n",
    "        max_bins = hist_data['nBins']\n",
    "    \n",
    "    hist_data['cellId'] = hist_data['cellId'][cell_range[0]:cell_range[1]]\n",
    "    hist_data['hist'] = np.array(hf['hist'][cell_range[0]:cell_range[1], :max_bins, :])\n",
    "\n",
    "n_cells = cell_range[1]-cell_range[0]\n",
    "hist_data['hist'] = hist_data['hist'].reshape(n_cells, max_bins, 512, 128)\n",
    "hist_data['hist'] = hist_data['hist'][:,:,pixel_range[0]:pixel_range[2],pixel_range[1]:pixel_range[3]]\n",
    "\n",
    "print(f'Data shape {hist_data[\"hist\"].shape}')\n",
    "    \n",
    "bin_edges = np.linspace(hist_data['hRange'][0], hist_data['hRange'][1], int(hist_data['nBins']+1))\n",
    "x = (bin_edges[1:] + bin_edges[:-1])[:max_bins] * 0.5\n",
    "   \n",
    "\n",
    "batches = []\n",
    "for c_idx in range(0, n_cells, batch_size[0]):\n",
    "    for x_idx in range(0, n_pixels_x, batch_size[1]):\n",
    "        for y_idx in range(0, n_pixels_y, batch_size[2]):\n",
    "            batches.append([c_idx,x_idx,y_idx])\n",
    "        \n",
    "print(f'Number of batches {len(batches)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def fit_batch(batch_start):\n",
    "    current_result = {}\n",
    "    prev = None\n",
    "    for c_idx, x_idx, y_idx in idx_gen(batch_start, batch_size):\n",
    "        try:\n",
    "            y = hist_data['hist'][c_idx, :, x_idx, y_idx]\n",
    "\n",
    "            if prev is None:\n",
    "                prev, _ = get_starting_parameters(x, y, peak_range, n_peaks=n_peaks_fit)\n",
    "\n",
    "            if fit_range == [0, 0]:\n",
    "                frange = (prev[f'g0mean']-2*prev[f'g0sigma'],\n",
    "                          prev[f'g{n_peaks_fit-1}mean'] + prev[f'g{n_peaks_fit-1}sigma'])\n",
    "            else:\n",
    "                frange = fit_range\n",
    "\n",
    "            set_par_limits(prev, peak_range, peak_norm_range,\n",
    "                           peak_width_range, n_peaks_fit)\n",
    "            minuit = fit_n_peaks(x, y, prev, frange,\n",
    "                                 do_minos=do_minos, n_peaks=n_peaks_fit,\n",
    "                                 fix_d01=fix_peaks, sigma_limit=sigma_limit,)\n",
    "\n",
    "            ndof = np.rint(frange[1]-frange[0])-len(minuit.args) ## FIXME: this line is wrong if fix_peaks is True\n",
    "            current_result['chi2_ndof'] = minuit.fval/ndof\n",
    "            res = minuit.fitarg\n",
    "            if fix_peaks : ## set g2 and g3 mean correctly\n",
    "                for i in range(2,n_peaks_fit):\n",
    "                    d = res[f'g1mean'] - res[f'g0mean']\n",
    "                    res[f'g{i}mean'] = res[f'g0mean'] + d*i\n",
    "            current_result.update(res)\n",
    "            current_result.update(minuit.get_fmin())\n",
    "\n",
    "            fit_result['chi2_ndof'][c_idx, x_idx, y_idx] = current_result['chi2_ndof']\n",
    "\n",
    "            for key in res.keys():\n",
    "                if key in fit_result:\n",
    "                    fit_result[key][c_idx, x_idx, y_idx] = res[key]\n",
    "\n",
    "            fit_result['mask'][c_idx, x_idx, y_idx] = get_mask(current_result,\n",
    "                                                                    peak_lim,\n",
    "                                                                    d0_lim, chi2_lim,\n",
    "                                                                    peak_width_lim)\n",
    "        except Exception as e:\n",
    "            fit_result['mask'][c_idx, x_idx,\n",
    "                                    y_idx] = BadPixelsFF.FIT_FAILED.value\n",
    "            print(c_idx, x_idx, y_idx, e, traceback.format_exc())\n",
    "\n",
    "        if fit_result['mask'][c_idx, x_idx, y_idx] == 0:\n",
    "            prev = res\n",
    "        else:\n",
    "            prev = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Single fit ##\n",
    "\n",
    "Left plot shows starting parameters for fitting. Right plot shows result of the fit. Errors are evaluated with minos."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist = hist_data['hist'][1,:,1, 1]\n",
    "prev, shapes = get_starting_parameters(x, hist, peak_range, n_peaks=n_peaks_fit)\n",
    "\n",
    "if fit_range == [0, 0]:\n",
    "    frange = (prev[f'g0mean']-2*prev[f'g0sigma'],\n",
    "              prev[f'g3mean'] + prev[f'g3sigma'])\n",
    "else:\n",
    "    frange = fit_range\n",
    "\n",
    "set_par_limits(prev, peak_range, peak_norm_range,\n",
    "               peak_width_range, n_peaks=n_peaks_fit)\n",
    "minuit = fit_n_peaks(x, hist, prev, frange,\n",
    "                     do_minos=True, n_peaks=n_peaks_fit,\n",
    "                     fix_d01=fix_peaks,\n",
    "                     sigma_limit=sigma_limit,\n",
    "                    )\n",
    "print (minuit.get_fmin())\n",
    "minuit.print_matrix()\n",
    "print(minuit.get_param_states())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = minuit.fitarg\n",
    "if fix_peaks :\n",
    "    for i in range(2,n_peaks_fit):\n",
    "        d = res[f'g1mean'] - res[f'g0mean']\n",
    "        res[f'g{i}mean'] = res[f'g0mean'] + d*i\n",
    "err = minuit.errors\n",
    "p = minuit.args\n",
    "ya = np.arange(0,1e4)\n",
    "y = gaussian_sum(x,n_peaks_fit, *p)\n",
    "peak_colors = ['g', 'y', 'b', 'orange']\n",
    "\n",
    "peak_hist = hist.copy()\n",
    "d=[]\n",
    "if sigma_limit > 0 :\n",
    "    sel2 = (np.abs(x - res['g0mean']) < sigma_limit*res['g0sigma']) | \\\n",
    "           (np.abs(x - res['g1mean']) < sigma_limit*res['g1sigma']) | \\\n",
    "           (np.abs(x - res['g2mean']) < sigma_limit*res['g2sigma']) | \\\n",
    "           (np.abs(x - res['g3mean']) < sigma_limit*res['g3sigma'])\n",
    "    peak_hist[~sel2] = 0\n",
    "    valley_hist = hist.copy()\n",
    "    valley_hist[sel2] = 0\n",
    "    d.append({'x': x,\n",
    "              'y': valley_hist.astype(np.float64),\n",
    "              'y_err': np.sqrt(valley_hist),\n",
    "              'drawstyle': 'bars',\n",
    "              'errorstyle': 'bars',\n",
    "              'transparency': '95%',\n",
    "              'errorcoarsing': 3,\n",
    "              'label': f'X-ray Data)'\n",
    "             })\n",
    "    htitle = f'X-ray Data, (μ±{sigma_limit:0.1f}σ)'\n",
    "else :\n",
    "    htitle = 'X-ray Data'\n",
    "\n",
    "d.append({'x': x,\n",
    "          'y': peak_hist.astype(np.float64),\n",
    "          'y_err': np.sqrt(peak_hist),\n",
    "          'drawstyle': 'bars',\n",
    "          'errorstyle': 'bars',\n",
    "          'errorcoarsing': 3,\n",
    "          'label': htitle,\n",
    "         }\n",
    "        )\n",
    "d.append({'x': x,\n",
    "          'y': y,\n",
    "          'y2': (hist-y)/np.sqrt(hist),\n",
    "          'drawstyle':'line',\n",
    "          'drawstyle2': 'steps-mid',\n",
    "          'label': 'Fit'\n",
    "         }\n",
    "        )\n",
    "\n",
    "for i in range(n_peaks_fit):\n",
    "    d.append({'x': x,\n",
    "             'y': gaussian(x, res[f'g{i}n'], res[f'g{i}mean'], res[f'g{i}sigma']),\n",
    "             'drawstyle':'line',\n",
    "             'color': peak_colors[i],\n",
    "             })\n",
    "    d.append({'x': np.full_like(ya, res[f'g{i}mean']),\n",
    "              'y': ya,\n",
    "              'drawstyle': 'line',\n",
    "              'linestyle': 'dashed',\n",
    "              'color': peak_colors[i],\n",
    "              'label': f'peak {i} = {res[f\"g{i}mean\"]:0.1f} $ \\pm $ {err[f\"g{i}mean\"]:0.2f} ADU' })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax1, ax2) = plt.subplots(1, 2)\n",
    "fig.set_size_inches(16, 7)\n",
    "for i, shape in enumerate(shapes):\n",
    "    idx = shape[3]\n",
    "    ax1.errorbar(\n",
    "        x[idx], hist[idx],\n",
    "        np.sqrt(hist[idx]),\n",
    "        marker='+', ls='',\n",
    "    )\n",
    "    yg = gaussian(x[idx], *shape[:3])\n",
    "    l = f'Peak {i}: {shape[1]:0.1f} $ \\pm $ {shape[2]:0.2f} ADU'\n",
    "    ax1.plot(x[idx], yg, label=l)\n",
    "ax1.grid(True)\n",
    "ax1.set_xlabel(\"Signal [ADU]\")\n",
    "ax1.set_ylabel(\"Counts\")\n",
    "ax1.legend(ncol=2)\n",
    "\n",
    "_ = xana.simplePlot(\n",
    "    d,\n",
    "    use_axis=ax2,\n",
    "    x_label='Signal [ADU]',\n",
    "    y_label='Counts',\n",
    "    secondpanel=True, y_log=False,\n",
    "    x_range=(frange[0], frange[1]),\n",
    "    y_range=(1., np.max(hist)*1.6),\n",
    "    legend='top-left-frame-ncol2',\n",
    ")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All fits ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Allocate memory for fit results\n",
    "fit_result = {}\n",
    "keys = list(minuit.fitarg.keys())\n",
    "keys = [x for x in keys if 'limit_' not in x and 'fix_' not in x]\n",
    "keys += ['chi2_ndof', 'mask', 'gain']\n",
    "for key in keys:\n",
    "    dtype = 'f4'\n",
    "    if key == 'mask':\n",
    "        dtype = 'i4'\n",
    "    fit_result[key] = sharedmem.empty([n_cells, n_pixels_x, n_pixels_y], dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform fitting\n",
    "with Pool() as pool:\n",
    "    const_out = pool.map(fit_batch, batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate bad pixels\n",
    "fit_result['gain'] = (fit_result['g1mean'] - fit_result['g0mean'])/photon_energy\n",
    "\n",
    "# Calculate histogram width and evaluate cut\n",
    "h_sums = np.sum(hist_data['hist'], axis=1)\n",
    "hist_norm = hist_data['hist'] / h_sums[:, None, :, :]\n",
    "hist_mean = np.sum(hist_norm[:, :max_bins, ...] *\n",
    "                   x[None, :, None, None], axis=1)\n",
    "hist_sqr = (x[None, :, None, None] - hist_mean[:, None, ...])**2\n",
    "hist_std = np.sqrt(np.sum(hist_norm[:, :max_bins, ...] * hist_sqr, axis=1))\n",
    "\n",
    "fit_result['mask'][hist_std<intensity_lim] |= BadPixelsFF.NO_ENTRY.value\n",
    "\n",
    "# Bad pixel on gain deviation\n",
    "gains = np.copy(fit_result['gain'])\n",
    "gains[fit_result['mask']>0] = np.nan\n",
    "gain_mean = np.nanmean(gains, axis=(1,2))\n",
    "\n",
    "fit_result['mask'][fit_result['gain'] > gain_mean[:,None,None]*gain_lim[1] ] |=  BadPixelsFF.GAIN_DEVIATION.value\n",
    "fit_result['mask'][fit_result['gain'] < gain_mean[:,None,None]*gain_lim[0] ] |=  BadPixelsFF.GAIN_DEVIATION.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save fit results\n",
    "os.makedirs(out_folder, exist_ok=True)\n",
    "out_name = f'{out_folder}/fits_m{module:02d}.h5'\n",
    "print(f'Save to file: {out_name}')\n",
    "save_dict_to_hdf5({'data': fit_result}, out_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary across cells ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [\n",
    "    \"Noise peak [ADU]\",\n",
    "    \"First photon peak [ADU]\",\n",
    "    f\"gain [ADU/keV] $\\gamma$={photon_energy} [keV]\",\n",
    "    \"$\\chi^2$/nDOF\",\n",
    "    \"Fraction of bad pixels\",\n",
    "]\n",
    "\n",
    "for i, key in enumerate(['g0mean', 'g1mean', 'gain', 'chi2_ndof', 'mask']):\n",
    "    fig = plt.figure(figsize=(20,5))\n",
    "    ax = fig.add_subplot(121)\n",
    "    data = fit_result[key]\n",
    "    if key == 'mask':\n",
    "        data = data > 0\n",
    "        vmin, vmax = [0, 1]\n",
    "    else:\n",
    "        vmin, vmax = get_range(data, 5)\n",
    "    _ = heatmapPlot(\n",
    "        np.mean(data, axis=0).T,\n",
    "        add_panels=False, cmap='viridis', use_axis=ax,\n",
    "        vmin=vmin, vmax=vmax, lut_label=labels[i]\n",
    "    )\n",
    "\n",
    "    if key != 'mask':\n",
    "        vmin, vmax = get_range(data, 7)\n",
    "        ax = fig.add_subplot(122)\n",
    "        _ = xana.histPlot(\n",
    "            ax, data.flatten(),\n",
    "            bins=45,range=[vmin, vmax],\n",
    "            log=True,color='red',histtype='stepfilled'\n",
    "        )\n",
    "        ax.set_xlabel(labels[i])\n",
    "        ax.set_ylabel(\"Counts\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## histograms of fit parameters ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "a = ax.hist(hist_std.flatten(), bins=100, range=(0,100) )\n",
    "ax.plot([intensity_lim, intensity_lim], [0, np.nanmax(a[0])], linewidth=1.5, color='red' ) \n",
    "ax.set_xlabel('Histogram width [ADU]', fontsize=14)\n",
    "ax.set_ylabel('Number of histograms', fontsize=14)\n",
    "ax.set_title(f'{hist_std[hist_std<intensity_lim].shape[0]} histograms below threshold in {intensity_lim} ADU',\n",
    "              fontsize=14, fontweight='bold')\n",
    "ax.grid()\n",
    "ax.set_yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_par_distr(par):\n",
    "    fig = plt.figure(figsize=(16, 5))\n",
    "    sel = fit_result['mask'] == 0\n",
    "    \n",
    "    for i in range(n_peaks_fit) :\n",
    "        data=fit_result[f\"g{i}{par}\"]\n",
    "        plt_range=(-1,50)\n",
    "        if par =='mean':\n",
    "            plt_range=[peak_range[i][0] ,peak_range[i][1]]\n",
    "            \n",
    "        num_bins = int(plt_range[1] - plt_range[0])\n",
    "        ax = fig.add_subplot(1,n_peaks_fit,i+1)\n",
    "        _ = xana.histPlot(ax,data.flatten(), \n",
    "                          bins= num_bins,range=plt_range,\n",
    "                          log=True,color='red',\n",
    "                          label='all fits',)\n",
    "\n",
    "        a = ax.hist(data[sel].flatten(), \n",
    "                    bins=num_bins, range=plt_range,\n",
    "                    log=True,color='g',\n",
    "                    label='good fits only',\n",
    "                   )\n",
    "        ax.set_xlabel(f\"g{i} {par} [ADU]\")\n",
    "        ax.legend()\n",
    "        \n",
    "plot_par_distr('mean')\n",
    "plot_par_distr('sigma')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel = fit_result['mask'] == 0\n",
    "\n",
    "dsets = {'d01 [ADU]':fit_result[f\"g1mean\"]-fit_result[f\"g0mean\"],\n",
    "         'gain [ADU/keV]':fit_result[f\"gain\"],\n",
    "         'gain relative to module mean':fit_result[f\"gain\"]/np.nanmean(gain_mean),\n",
    "        }\n",
    "fig = plt.figure(figsize=(16,5))\n",
    "for i, (par, data) in enumerate(dsets.items()):\n",
    "    ax = fig.add_subplot(1, 3, i+1)\n",
    "    plt_range=get_range(data, 10)\n",
    "    num_bins = 100\n",
    "    _ = xana.histPlot(ax,data.flatten(), \n",
    "                      bins= num_bins,range=plt_range,\n",
    "                      log=True,color='red',\n",
    "                      label='all fits',)\n",
    "\n",
    "    a = ax.hist(data[sel].flatten(), \n",
    "                bins=num_bins, range=plt_range,\n",
    "                log=True,color='g',\n",
    "                label='good fits only',\n",
    "               )\n",
    "    ax.set_xlabel(f\"{par}\")\n",
    "    ax.legend()\n",
    "    if 'd01' in par :\n",
    "        ax.axvline(d0_lim[0])\n",
    "        ax.axvline(d0_lim[1])\n",
    "    if 'rel' in par :\n",
    "        ax.axvline(gain_lim[0])\n",
    "        ax.axvline(gain_lim[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary across pixels ##\n",
    "\n",
    "Mean and median values are calculated across all pixels for each memory cell. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_error_band(key, x, ax):\n",
    "    \n",
    "    cdata = np.copy(fit_result[key])\n",
    "    cdata[fit_result['mask']>0] = np.nan\n",
    "    \n",
    "    mean = np.nanmean(cdata, axis=(1,2))\n",
    "    median = np.nanmedian(cdata, axis=(1,2))\n",
    "    std = np.nanstd(cdata, axis=(1,2))\n",
    "    mad = np.nanmedian(np.abs(cdata - median[:,None,None]), axis=(1,2))\n",
    "\n",
    "    ax.plot(x, mean, 'k', color='#3F7F4C', label=\" mean value \")\n",
    "    ax.plot(x, median, 'o', color='red', label=\" median value \")\n",
    "    ax.fill_between(x, mean-std, mean+std,\n",
    "                     alpha=0.6, edgecolor='#3F7F4C', facecolor='#7EFF99',\n",
    "                     linewidth=1, linestyle='dashdot', antialiased=True,\n",
    "                     label=\" mean value $ \\pm $ std \")\n",
    "\n",
    "    ax.fill_between(x, median-mad, median+mad,\n",
    "                     alpha=0.3, edgecolor='red', facecolor='red',\n",
    "                     linewidth=1, linestyle='dashdot', antialiased=True,\n",
    "                     label=\" median value $ \\pm $ mad \")\n",
    "    \n",
    "    if f'error_{key}' in fit_result:\n",
    "        cerr = np.copy(fit_result[f'error_{key}'])\n",
    "        cerr[fit_result['mask']>0] = np.nan\n",
    "        \n",
    "        meanerr = np.nanmean(cerr, axis=(1,2))\n",
    "        ax.fill_between(x, mean-meanerr, mean+meanerr,\n",
    "                 alpha=0.6, edgecolor='#089FFF', facecolor='#089FFF',\n",
    "                 linewidth=1, linestyle='dashdot', antialiased=True,\n",
    "                 label=\" mean fit error \")\n",
    "    \n",
    "\n",
    "x = np.linspace(*cell_range, n_cells)\n",
    "\n",
    "for i, key in enumerate(['g0mean', 'g1mean', 'gain', 'chi2_ndof']):\n",
    "\n",
    "    fig = plt.figure(figsize=(10, 5))\n",
    "    ax = fig.add_subplot(111)\n",
    "    plot_error_band(key, x, ax)\n",
    "\n",
    "    ax.set_xlabel('Memory Cell ID', fontsize=14)\n",
    "    ax.set_ylabel(labels[i], fontsize=14)\n",
    "    ax.grid()\n",
    "    ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cut flow ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "fig.set_size_inches(10, 5)\n",
    "\n",
    "n_bars = 8\n",
    "x = np.arange(n_bars)\n",
    "width = 0.3\n",
    "\n",
    "msk = fit_result['mask']\n",
    "n_fits = np.prod(msk.shape)\n",
    "y = [any_in(msk, BadPixelsFF.FIT_FAILED.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value | BadPixelsFF.GAIN_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value | BadPixelsFF.GAIN_THRESHOLD.value |\n",
    "           BadPixelsFF.NOISE_PEAK_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value | BadPixelsFF.GAIN_THRESHOLD.value |\n",
    "           BadPixelsFF.NOISE_PEAK_THRESHOLD.value | BadPixelsFF.PEAK_WIDTH_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value | BadPixelsFF.GAIN_THRESHOLD.value |\n",
    "           BadPixelsFF.NOISE_PEAK_THRESHOLD.value | BadPixelsFF.PEAK_WIDTH_THRESHOLD.value\n",
    "           | BadPixelsFF.NO_ENTRY.value),\n",
    "     any_in(msk, BadPixelsFF.FIT_FAILED.value | BadPixelsFF.ACCURATE_COVAR.value |\n",
    "           BadPixelsFF.CHI2_THRESHOLD.value | BadPixelsFF.GAIN_THRESHOLD.value |\n",
    "           BadPixelsFF.NOISE_PEAK_THRESHOLD.value | BadPixelsFF.PEAK_WIDTH_THRESHOLD.value\n",
    "           | BadPixelsFF.NO_ENTRY.value| BadPixelsFF.GAIN_DEVIATION.value)\n",
    "    ]\n",
    "\n",
    "y2 = [any_in(msk, BadPixelsFF.FIT_FAILED.value),\n",
    "     any_in(msk, BadPixelsFF.ACCURATE_COVAR.value),\n",
    "     any_in(msk, BadPixelsFF.CHI2_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.GAIN_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.NOISE_PEAK_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.PEAK_WIDTH_THRESHOLD.value),\n",
    "     any_in(msk, BadPixelsFF.NO_ENTRY.value),\n",
    "     any_in(msk, BadPixelsFF.GAIN_DEVIATION.value)\n",
    "    ]\n",
    "\n",
    "y = (1 - np.sum(y, axis=(1,2,3))/n_fits)*100\n",
    "y2 = (1 - np.sum(y2, axis=(1,2,3))/n_fits)*100\n",
    "\n",
    "labels = ['Fit failes',\n",
    "         'Accurate covar',\n",
    "         'Chi2/nDOF',\n",
    "         'Gain',\n",
    "         'Noise peak',\n",
    "         'Peak width',\n",
    "         'No Entry',\n",
    "         'Gain deviation']\n",
    "\n",
    "ax.bar(x, y2, width, label='Only this cut')\n",
    "ax.bar(x, y, width, label='Cut flow')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, rotation=90)\n",
    "ax.set_ylim(y[5]-0.5, 100)\n",
    "ax.grid(True)\n",
    "ax.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}