{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Jungfrau Dark Summary\n",
    "\n",
    "Author: European XFEL Detector Department, Version: 1.0\n",
    "\n",
    "Summary for process dark constants and a comparison with previously injected constants with the same conditions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_folder = \"/gpfs/exfel/data/scratch/ahmedk/test/jungfrau_assembeled_dark\"  # path to output to, required\n",
    "metadata_folder = \"\"  # Directory containing calibration_metadata.yml when run by xfel-calibrate.\n",
    "\n",
    "# Parameters used to access raw data.\n",
    "karabo_da = []  # list of data aggregators, which corresponds to different JF modules. This is only needed for the detectors of one module.\n",
    "karabo_id = \"FXE_XAD_JF1M\"  # detector identifier.\n",
    "\n",
    "# Parameters to be used for injecting dark calibration constants.\n",
    "local_output = True  # Boolean indicating that local constants were stored in the out_folder\n",
    "\n",
    "# Skip the whole notebook if local_output is false in the preceding notebooks.\n",
    "if not local_output:\n",
    "    print('No local constants saved. Skipping summary plots')\n",
    "    import sys\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "from pathlib import Path\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import h5py\n",
    "import matplotlib\n",
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import yaml\n",
    "from IPython.display import Markdown, display\n",
    "\n",
    "matplotlib.use(\"agg\")\n",
    "%matplotlib inline\n",
    "\n",
    "import tabulate\n",
    "from IPython.display import Latex, Markdown, display\n",
    "from XFELDetAna.plotting.simpleplot import simplePlot\n",
    "\n",
    "from cal_tools.enums import BadPixels\n",
    "from cal_tools.plotting import init_jungfrau_geom, show_processed_modules_jungfrau\n",
    "from cal_tools.tools import CalibrationMetadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare paths and load previous constants' metadata.\n",
    "out_folder = Path(out_folder)\n",
    "metadata = CalibrationMetadata(metadata_folder or out_folder)\n",
    "mod_mapping = metadata.setdefault(\"modules-mapping\", {})\n",
    "dark_constants = [\"Offset\", \"Noise\", \"BadPixelsDark\"]\n",
    "\n",
    "prev_const_metadata = {}\n",
    "for fn in Path(metadata_folder or out_folder).glob(\"module_metadata_*.yml\"):\n",
    "    with fn.open(\"r\") as fd:\n",
    "        fdict = yaml.safe_load(fd)\n",
    "    module = fdict[\"module\"]\n",
    "    mod_mapping[module] = fdict[\"pdu\"]\n",
    "    prev_const_metadata[module] = fdict[\"old-constants\"]\n",
    "\n",
    "metadata.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expected_modules, geom = init_jungfrau_geom(\n",
    "    karabo_id=karabo_id, karabo_da=karabo_da)\n",
    "nmods = len(expected_modules)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preparing newly injected and previous constants from produced local folder in out_folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_gain = False  # constant is adaptive by default.\n",
    "# Get the constant shape from one of the local constants.\n",
    "# This is one way to realize the number of memory cells.\n",
    "with h5py.File(list(out_folder.glob(\"const_Offset_*\"))[0], 'r') as f:\n",
    "    const_shape = f[\"data\"][()].shape\n",
    "    # Get fixed gain value to decide offset vmin, vmax\n",
    "    # for later constant map plots.\n",
    "    gain_mode = \"condition/Gain mode/value\"\n",
    "    if gain_mode in f:\n",
    "        fixed_gain = f[gain_mode][()]\n",
    "        \n",
    "\n",
    "initial_stacked_constants = np.full(((nmods,)+const_shape), np.nan)\n",
    "curr_constants = { c: initial_stacked_constants.copy() for c in dark_constants}\n",
    "prev_constants = { c: initial_stacked_constants.copy() for c in dark_constants}\n",
    "\n",
    "exculded_constants = []  # constants excluded from comparison plots.\n",
    "\n",
    "# Loop over modules\n",
    "for cname in dark_constants:\n",
    "    excluded_modules = []  # modules with no previous constants.\n",
    "    for i, mod in enumerate(sorted(expected_modules)):\n",
    "        # Loop over expected dark constants in out_folder.\n",
    "        # Some constants can be missing in out_folder.\n",
    "        pdu = mod_mapping[mod]\n",
    "    \n",
    "        # first load new constant\n",
    "        fpath = out_folder / f\"const_{cname}_{pdu}.h5\"\n",
    "        with h5py.File(fpath, 'r') as f:\n",
    "            curr_constants[cname][i, ...] = f[\"data\"][()]\n",
    "\n",
    "        # Load previous constants.\n",
    "        old_mod_mdata = prev_const_metadata[mod]\n",
    "\n",
    "        if cname in old_mod_mdata:  # a module can be missing from detector dark processing.\n",
    "            filepath = old_mod_mdata[cname][\"filepath\"]\n",
    "            h5path = old_mod_mdata[cname][\"h5path\"]\n",
    "            if not filepath or not h5path:\n",
    "                excluded_modules.append(mod)\n",
    "                prev_constants[cname][i, ...].fill(np.nan)\n",
    "            else:\n",
    "                with h5py.File(filepath, \"r\") as fd:\n",
    "                    prev_constants[cname][i, ...] = fd[f\"{h5path}/data\"][()]\n",
    "\n",
    "    if excluded_modules:\n",
    "        print(f\"Previous {cname} constants for {excluded_modules} are not available.\\n.\")\n",
    "    # Exclude constants from comparison plots, if the corresponding\n",
    "    # previous constants are not available for all modules.\n",
    "    if len(excluded_modules) == nmods:\n",
    "        exculded_constants.append(cname)\n",
    "        print(f\"No comparison plots for {cname}.\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('## Processed modules'))\n",
    "\n",
    "processed_modules = list(mod_mapping.keys())\n",
    "processed_pdus = list(mod_mapping.values())\n",
    "show_processed_modules_jungfrau(\n",
    "    jungfrau_geom=geom,\n",
    "    constants=curr_constants,\n",
    "    processed_modules=processed_modules,\n",
    "    expected_modules=expected_modules,\n",
    "    display_module_names=processed_pdus,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gainstages = 3\n",
    "gain_names = [\"High Gain\", \"Medium Gain\", \"Low Gain\" ]\n",
    "\n",
    "gs = gridspec.GridSpec(2, 4)\n",
    "\n",
    "axes = {\n",
    "    \"map\": {\n",
    "        \"gs\": gs[0, 1:3],\n",
    "        \"shrink\": 0.7,\n",
    "        \"pad\": 0.05,\n",
    "        \"label\": \"ADCu\",\n",
    "        \"title\": \"{}\",\n",
    "        \"location\": \"right\",\n",
    "        },\n",
    "    \"diff\": {\n",
    "        \"gs\": gs[1, :2],\n",
    "        \"shrink\": 0.7,\n",
    "        \"pad\": 0.02,\n",
    "        \"label\": \"ADCu\",\n",
    "        \"location\": \"left\",\n",
    "        \"title\": \"Difference with previous {}\",\n",
    "        },\n",
    "    \"diff_frac\": {\n",
    "        \"gs\": gs[1, 2:],\n",
    "        \"shrink\": 0.7,\n",
    "        \"pad\": 0.02,\n",
    "        \"label\": \"%\",\n",
    "        \"location\": \"right\",\n",
    "        \"title\": \"Difference with previous {} %\",\n",
    "        },\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def badpx(constant_name):\n",
    "    return \"bad\" in constant_name.lower()\n",
    "\n",
    "\n",
    "def bp_entry(bp):\n",
    "    return [f\"{bp.name:<30s}\", f\"{bp.value:032b}\", f\"{int(bp.value)}\"]\n",
    "\n",
    "\n",
    "badpixels = [\n",
    "        BadPixels.OFFSET_OUT_OF_THRESHOLD,\n",
    "        BadPixels.NOISE_OUT_OF_THRESHOLD,\n",
    "        BadPixels.OFFSET_NOISE_EVAL_ERROR,\n",
    "        BadPixels.NO_DARK_DATA,\n",
    "        BadPixels.WRONG_GAIN_VALUE,\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary figures across pixels and memory cells.\n",
    "\n",
    "The following plots give an overview of calibration constants averaged across pixels and memory cells. A bad pixel mask is applied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cname, const in curr_constants.items():\n",
    "    if badpx(cname):\n",
    "        table = [bp_entry(bp) for bp in badpixels]\n",
    "        display(Markdown(\"\"\"**The bad pixel** mask is encoded as a bit mask.\"\"\"))\n",
    "        display(Latex(\n",
    "            tabulate.tabulate(\n",
    "                table,\n",
    "                tablefmt='latex',\n",
    "                headers=[\"Name\", \"bit value\", \"integer value\"]\n",
    "            )))\n",
    "\n",
    "    # Prepare the stacked mean of constant,\n",
    "    # the difference with the previous constant\n",
    "    # and the fraction of that difference.\n",
    "\n",
    "    mean_const = np.nanmean(const, axis=3)\n",
    "    mean_diff = np.abs(\n",
    "        np.nanmean(const, axis=3) - np.nanmean(\n",
    "            prev_constants[cname],\n",
    "            axis=3)\n",
    "        )\n",
    "    mean_frac = np.divide(\n",
    "        mean_diff,\n",
    "        mean_const,\n",
    "        out=np.zeros_like(mean_const),\n",
    "        where=(mean_const != 0)\n",
    "    ) * 100\n",
    "\n",
    "    for gain in range(gainstages):\n",
    "        data_to_plot = {\n",
    "            f'map': mean_const[..., gain],\n",
    "            f'diff': mean_diff[..., gain],\n",
    "            f'diff_frac': mean_frac[..., gain],\n",
    "            }\n",
    "\n",
    "        # Plotting constant overall modules.\n",
    "        display(Markdown(f'### {cname} - {gain_names[gain]}'))\n",
    "        if nmods > 1:\n",
    "            fig = plt.figure(figsize=(20, 20))\n",
    "        else:\n",
    "            fig = plt.figure(figsize=(20, 10))\n",
    "        \n",
    "        for axname, axv in axes.items():\n",
    "\n",
    "            # Avoid difference plots if previous constants\n",
    "            # are missing for the detector.\n",
    "            if cname in exculded_constants and axname != \"map\":\n",
    "                break\n",
    "            ax = fig.add_subplot(axv[\"gs\"])\n",
    "\n",
    "            if badpx(cname):\n",
    "                vmin, vmax = (0, sorted([bp.value for bp in badpixels])[-2])\n",
    "            else:\n",
    "                vmin, vmax = np.percentile(data_to_plot[axname], [5, 95])\n",
    "\n",
    "            geom.plot_data(\n",
    "                data_to_plot[axname],\n",
    "                vmin=vmin,\n",
    "                vmax=vmax,\n",
    "                ax=ax, \n",
    "                colorbar={\n",
    "                    \"shrink\": axv[\"shrink\"],\n",
    "                    \"pad\": axv[\"pad\"],\n",
    "                    \"location\": axv[\"location\"],\n",
    "                },\n",
    "            )\n",
    "\n",
    "            colorbar = ax.images[0].colorbar\n",
    "            colorbar.set_label(axv[\"label\"], fontsize=15)\n",
    "            colorbar.ax.tick_params(labelsize=15)\n",
    "            ax.tick_params(labelsize=1)\n",
    "            ax.set_title(axv[\"title\"].format(\n",
    "                f\"{cname} {gain_names[gain]}\"), fontsize=15)\n",
    "            \n",
    "            if axname == \"map\":\n",
    "                ax.set_xlabel('Columns', fontsize=15)\n",
    "                ax.set_ylabel('Rows', fontsize=15)\n",
    "                ax.tick_params(labelsize=15)\n",
    "            else:\n",
    "                ax.tick_params(labelsize=0)\n",
    "                # Remove axes labels for comparison plots.\n",
    "                ax.set_xlabel('')\n",
    "                ax.set_ylabel('')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if curr_constants[\"Offset\"].shape[-2] > 1:\n",
    "    display(Markdown(\"## Summary across pixels per memory cells\"))\n",
    "\n",
    "    # Plot mean and std of memcells for each module, gain, and constant\n",
    "    # across trains.\n",
    "    for const_name, const in curr_constants.items():\n",
    "        display(Markdown(f'### {const_name}'))\n",
    "\n",
    "        for gain in range(gainstages):\n",
    "            data = np.copy(const[..., gain])\n",
    "            \n",
    "            if const_name == 'BadPixelsDark':\n",
    "                data[data > 0] = 1.0\n",
    "                datamean = np.nanmean(data, axis=(1, 2))\n",
    "                datamean[datamean == 1.0] = np.nan\n",
    "\n",
    "                fig = plt.figure(\n",
    "                    figsize=(15, 6),\n",
    "                    tight_layout={'pad': 0.2, 'w_pad': 1.3, 'h_pad': 1.3})\n",
    "                label = 'Fraction of bad pixels'\n",
    "                ax = fig.add_subplot(1, 1, 1)\n",
    "\n",
    "            else:\n",
    "                datamean = np.nanmean(data, axis=(1, 2))\n",
    "                fig = plt.figure(\n",
    "                    figsize=(15, 6),\n",
    "                    tight_layout={'pad': 0.2, 'w_pad': 1.3, 'h_pad': 1.3})\n",
    "                label = f'{const_name} value [ADU], good pixels only'\n",
    "                ax = fig.add_subplot(1, 2, 1)\n",
    "\n",
    "            d = []\n",
    "            for i, mod in enumerate(datamean):\n",
    "                d.append({\n",
    "                    'x': np.arange(mod.shape[0]),\n",
    "                    'y': mod,\n",
    "                    'drawstyle': 'steps-pre',\n",
    "                    'label': processed_modules[i],\n",
    "                    })\n",
    "\n",
    "            simplePlot(\n",
    "                d, figsize=(10, 10), xrange=(-12, 510),\n",
    "                x_label='Memory Cell ID',\n",
    "                y_label=label,\n",
    "                use_axis=ax,\n",
    "                title=f'{gain_names[gain]}',\n",
    "                title_position=[0.5, 1.18],\n",
    "                legend='outside-top-ncol6-frame',\n",
    "                legend_size='18%',\n",
    "                legend_pad=0.00,\n",
    "                )\n",
    "\n",
    "            # Extra Sigma plot for Offset and Noise constants.\n",
    "            if const_name != 'BadPixelsDark':\n",
    "                ax = fig.add_subplot(1, 2, 2)\n",
    "                label = f'$\\sigma$ {const_name} [ADU], good pixels only'\n",
    "                d = []\n",
    "                for i, mod in enumerate(np.nanstd(data, axis=(1, 2))):\n",
    "                    d.append({\n",
    "                        'x': np.arange(mod.shape[0]),\n",
    "                        'y': mod,\n",
    "                        'drawstyle': 'steps-pre',\n",
    "                        'label': processed_modules[i],\n",
    "                        })\n",
    "\n",
    "                simplePlot(\n",
    "                    d, figsize=(10, 10), xrange=(-12, 510),\n",
    "                    x_label='Memory Cell ID',\n",
    "                    y_label=label,\n",
    "                    use_axis=ax,\n",
    "                    title=f'{gain_names[gain]} $\\sigma$',\n",
    "                    title_position=[0.5, 1.18],\n",
    "                    legend='outside-top-ncol6-frame',\n",
    "                    legend_size='18%',\n",
    "                    legend_pad=0.00,\n",
    "                    )\n",
    "            plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.11 ('.cal2_venv')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "ebdaec9fd6e243fab93e119377baafbbbd6671bf32db5f77705286047fa40d99"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}