From 1e64d429c46687db2f0fc5cfe37fe8f4edc07b76 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Thu, 29 Aug 2024 12:41:09 +0200
Subject: [PATCH] Feat: Update Fit Summary notebook with new plots

---
 ...ungfrau_gain_Spectra_Fit_Summary_NBC.ipynb | 295 ++++++++++++++----
 1 file changed, 238 insertions(+), 57 deletions(-)

diff --git a/notebooks/Jungfrau/Jungfrau_gain_Spectra_Fit_Summary_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_gain_Spectra_Fit_Summary_NBC.ipynb
index 0faa46cd4..3df1df4fe 100644
--- a/notebooks/Jungfrau/Jungfrau_gain_Spectra_Fit_Summary_NBC.ipynb
+++ b/notebooks/Jungfrau/Jungfrau_gain_Spectra_Fit_Summary_NBC.ipynb
@@ -38,22 +38,25 @@
    "outputs": [],
    "source": [
     "import math\n",
+    "import multiprocessing as mp\n",
     "import warnings\n",
+    "from IPython.display import Markdown, display\n",
+    "from logging import warning\n",
     "from pathlib import Path\n",
     "\n",
     "warnings.filterwarnings('ignore')\n",
     "\n",
-    "from h5py import File as h5file\n",
     "import matplotlib\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
+    "from h5py import File as h5file\n",
     "\n",
     "matplotlib.use(\"agg\")\n",
     "%matplotlib inline\n",
     "\n",
+    "from cal_tools.calcat_interface import CalCatApi\n",
     "from cal_tools.plotting import init_jungfrau_geom\n",
-    "from cal_tools.restful_config import calibration_client\n",
-    "from cal_tools.calcat_interface import CalCatApi"
+    "from cal_tools.restful_config import calibration_client"
    ]
   },
   {
@@ -76,7 +79,41 @@
     "calcat = CalCatApi(client=calcat_client)\n",
     "detector_id = calcat.detector(karabo_id)['id']\n",
     "da_mapping = calcat.physical_detector_units(detector_id, pdu_snapshot_at=creation_time)\n",
-    "da_to_pdu = {k: v[\"physical_name\"] for k, v in da_mapping.items()}"
+    "da_to_pdu = {k: v[\"physical_name\"] for k, v in da_mapping.items()}\n",
+    "run = runs[0]  # TODO: Update for multiple runs\n",
+    "proposal = list(filter(None, in_folder.strip('/').split('/')))[-2]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Histograms for all cells for each module"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def process_histogram(file_path, shared_dict):\n",
+    "    try:\n",
+    "        with h5file(file_path, 'r') as f:\n",
+    "            histos = f[\"histos\"][:]\n",
+    "            edges = f[\"edges\"][:]\n",
+    "    except Exception as e:\n",
+    "        warning(f\"Error while loading Histogram file {file_path}: {e}\")\n",
+    "\n",
+    "        shared_dict['bin_centers'] = None\n",
+    "        shared_dict['mean_histos'] = None\n",
+    "        return\n",
+    "\n",
+    "    bin_centers = (edges[1:] + edges[:-1]) / 2\n",
+    "    mean_histos = histos.mean(axis=(2, 3))  # Shape: (bins, cells)\n",
+    "    \n",
+    "    shared_dict['bin_centers'] = bin_centers\n",
+    "    shared_dict['mean_histos'] = mean_histos"
    ]
   },
   {
@@ -85,37 +122,176 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "proposal = list(filter(None, in_folder.strip('/').split('/')))[-2]\n",
-    "run = runs[0]  # TODO this will need to be fixed when I start implementing multiple runs.\n",
-    "stacked_constants = np.full(geom.expected_data_shape, np.nan)  # nmods, 512, 1024\n",
-    "\n",
-    "for i, da in enumerate (da_to_pdu.keys()):\n",
-    "   with h5file(\n",
-    "      Path(out_folder) / spectra_fit_temp.format(run, proposal.upper(), da, fit_func),\n",
-    "      'r'\n",
-    "   ) as f:\n",
-    "      # f[g0_fit_dataset] shape is 1024, 512, mem_cells\n",
-    "      stacked_constants[i] = np.moveaxis(\n",
-    "         np.mean(np.array(f[g0_fit_dataset]), axis=-1), 0, 1)\n",
-    "      \n",
-    "fig, ax = plt.subplots(figsize=(18, 10))\n",
-    "vmin, vmax = np.percentile(stacked_constants, [5, 95])\n",
-    "geom.plot_data_fast(\n",
-    "   stacked_constants,\n",
-    "   ax=ax,\n",
-    "   vmin=vmin,\n",
-    "   vmax=vmax,\n",
-    "   colorbar={'shrink': 1, 'pad': 0.01},\n",
-    ")\n",
-    "ax.set_title(f'{karabo_id} - One photon peak position', size=18)\n",
+    "ncols = min(4, nmods)\n",
+    "nrows = math.ceil(nmods / ncols)\n",
+    "\n",
+    "fig, axs = plt.subplots(\n",
+    "    nrows, ncols, figsize=((5*ncols)//1.8, (5*nrows + 1)//1.8), squeeze=False)\n",
+    "axs = axs.flatten()\n",
+    "\n",
+    "# Use a consistent color cycle for all subplots\n",
+    "custom_colors = [\n",
+    "    '#0000FF', '#FF0000', '#00FF00', '#FFFF00', '#FF00FF', '#FFA500', \n",
+    "    '#800080', '#008000', '#000080', '#FFC0CB', '#A52A2A', '#808080', \n",
+    "    '#FFD700', '#8B4513', '#FF4500', '#2E8B57'\n",
+    "]\n",
+    "\n",
+    "# Prepare shared memory and multiprocessing\n",
+    "manager = mp.Manager()\n",
+    "shared_dict_list = [manager.dict() for _ in range(nmods)]\n",
+    "\n",
+    "with mp.Pool(processes=min(mp.cpu_count(), nmods)) as pool:\n",
+    "    file_paths = [Path(out_folder) / histo_temp.format(run, proposal.upper(), da) for da in da_to_pdu.keys()]\n",
+    "    pool.starmap(process_histogram, zip(file_paths, shared_dict_list))\n",
+    "\n",
+    "for i, ((da, pdu), shared_dict) in enumerate(zip(da_to_pdu.items(), shared_dict_list)):\n",
+    "    ax = axs[i]\n",
+    "    bin_centers = shared_dict['bin_centers']\n",
+    "    mean_histos = shared_dict['mean_histos']\n",
+    "\n",
+    "    # Missing histogram data for module\n",
+    "    if bin_centers is None:\n",
+    "        continue\n",
+    "\n",
+    "    for m in range(mean_histos.shape[1]):  # Iterate over cells\n",
+    "        ax.semilogy(\n",
+    "            bin_centers, mean_histos[:, m],\n",
+    "            color=custom_colors[m],\n",
+    "            # Create only for the first subplot as legend is shared.\n",
+    "            label=f'Cell {m}' if i == 0 else \"\"\n",
+    "        )\n",
+    "\n",
+    "    ax.set_title(f\"{da} ({pdu})\")\n",
+    "    ax.set_xlabel(\"ADC value\")\n",
+    "    ax.set_ylabel(\"Counts\")\n",
+    "\n",
+    "# Hide unused subplots\n",
+    "for i in range(nmods, len(axs)):\n",
+    "    axs[i].set_visible(False)\n",
+    "\n",
+    "plt.tight_layout()\n",
+    "\n",
+    "# Add a single legend at the top of the figure\n",
+    "handles, labels = axs[0].get_legend_handles_labels()\n",
+    "fig.legend(\n",
+    "    handles, labels, loc='upper center', ncol=min(16, 8), \n",
+    "    bbox_to_anchor=(0.5, 1.05), fontsize='small')\n",
+    "plt.subplots_adjust(top=0.9)\n",
     "plt.show()"
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "## Histogram data for all cells for each module"
+    "display(Markdown(f\"## Display fitting results using {fit_func} model\"))\n",
+    "\n",
+    "def plot_stacked_heatmap(geom, stacked_constants, title):\n",
+    "    _, ax = plt.subplots(figsize=(8, 6))\n",
+    "    vmin, vmax = np.nanpercentile(stacked_constants, [5, 95])\n",
+    "    geom.plot_data_fast(\n",
+    "        stacked_constants,\n",
+    "        ax=ax,\n",
+    "        vmin=vmin,\n",
+    "        vmax=vmax,\n",
+    "        colorbar={'shrink': 1, 'pad': 0.01, 'label': title},\n",
+    "    )\n",
+    "    ax.set_title(title, size=12)\n",
+    "    plt.show()\n",
+    "\n",
+    "\n",
+    "def plot_cells_comparison(data, title):\n",
+    "    n_cells = data.shape[-1]\n",
+    "\n",
+    "    if n_cells == 1:  # no need to plot for single cell\n",
+    "        return\n",
+    "\n",
+    "    # For multiple cells, create a grid layout\n",
+    "    n_cols = min(4, n_cells)  # Max 4 columns\n",
+    "    n_rows = math.ceil(n_cells / n_cols)\n",
+    "\n",
+    "    fig, axes = plt.subplots(\n",
+    "        n_rows, n_cols, figsize=(5*n_cols//2, 5*n_rows//2))\n",
+    "    fig.suptitle(title, fontsize=12)\n",
+    "\n",
+    "    # Flatten axes array for easy indexing\n",
+    "    axes = axes.flatten() if n_cells > 1 else [axes]\n",
+    "    vmin, vmax = np.nanpercentile(data, [5, 95])\n",
+    "\n",
+    "    for i in range(n_cells):\n",
+    "        ax = axes[i]\n",
+    "        im = ax.imshow(\n",
+    "            data[..., i],\n",
+    "            cmap='viridis',\n",
+    "            vmin=vmin,\n",
+    "            vmax=vmax\n",
+    "        )\n",
+    "        ax.set_title(f'Cell {i}', size=10)\n",
+    "        plt.colorbar(im, ax=ax)\n",
+    "\n",
+    "    # Hide unused subplots\n",
+    "    for i in range(n_cells, len(axes)):\n",
+    "        axes[i].set_visible(False)\n",
+    "\n",
+    "    plt.tight_layout()\n",
+    "    plt.show()\n",
+    "\n",
+    "\n",
+    "def plot_histogram_of_values(data, title, bins=100):\n",
+    "    data = data.flatten()\n",
+    "\n",
+    "    # Count failed fittings. -1000 and -1 used for failed fitting.\n",
+    "    failed_percentage = (np.sum((data == -1000) | (data == -1)) / len(data)) * 100\n",
+    "    \n",
+    "    # Separate valid data and failed fittings\n",
+    "    valid_data = data[(data != -1000) & (data != -1)]\n",
+    "    plt.figure(figsize=(6, 3))\n",
+    "\n",
+    "    # Use Interquartile Range (IQR) for defining histogram range.\n",
+    "    q1, q3 = np.percentile(valid_data, [25, 75])\n",
+    "    iqr = q3 - q1\n",
+    "    lower = max(q1 - 1.5 * iqr, np.min(valid_data))\n",
+    "    upper = min(q3 + 1.5 * iqr, np.max(valid_data))\n",
+    "\n",
+    "    plt.hist(data, bins=bins, range=(lower, upper), edgecolor='black')\n",
+    "    plt.xlabel('Value')\n",
+    "    plt.ylabel('Count')\n",
+    "    plt.title(title, size=12)\n",
+    "\n",
+    "    mean = np.mean(valid_data)\n",
+    "    median = np.median(valid_data)\n",
+    "    plt.axvline(\n",
+    "        mean,\n",
+    "        color='r',\n",
+    "        linestyle='dashed',\n",
+    "        linewidth=2,\n",
+    "        label=f'Mean: {mean:.2f}'\n",
+    "    )\n",
+    "    plt.axvline(\n",
+    "        median,\n",
+    "        color='g',\n",
+    "        linestyle='dashed',\n",
+    "        linewidth=2,\n",
+    "        label=f'Median: {median:.2f}'\n",
+    "    )\n",
+    "\n",
+    "    # Add text box with failed fitting percentage.\n",
+    "    plt.text(\n",
+    "        0.05, 0.95,\n",
+    "        f\"Failed Fittings:\\n    {failed_percentage:.2f}%\",\n",
+    "        transform=plt.gca().transAxes, \n",
+    "        verticalalignment='top',\n",
+    "        bbox=dict(\n",
+    "            boxstyle='round',\n",
+    "            facecolor='white',\n",
+    "            alpha=0.8),\n",
+    "        fontsize=10)\n",
+    "\n",
+    "    plt.legend()\n",
+    "    plt.tight_layout()\n",
+    "    plt.show()"
    ]
   },
   {
@@ -124,34 +300,39 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "if nmods > 4:\n",
-    "    fixed_cols = 4\n",
-    "    row, col = math.ceil(nmods / fixed_cols), 4\n",
-    "else:\n",
-    "    row, col = 1, nmods\n",
-    "\n",
-    "fig, axs = plt.subplots(row, col, figsize=(20, 10))  # Adjust for spatial histograms\n",
-    "axs = axs.ravel()\n",
-    "for i, da in enumerate (da_to_pdu.keys()):\n",
-    "    with h5file(\n",
-    "        Path(out_folder) / histo_temp.format(run, proposal.upper(), da),\n",
-    "        'r'\n",
-    "    ) as f:\n",
-    "        histos = f[\"histos\"][:]\n",
-    "    row, col = divmod(i, 4)\n",
-    "    for m in range(histos.shape[1]):\n",
-    "        cell_hist = histos[m]\n",
-    "        axs[i].plot(cell_hist.ravel(), label=f'Cell {m}')\n",
-    "        axs[i].set_title(f\"{da} ({da_to_pdu[da]})\")\n",
-    "\n",
-    "for i, ax in enumerate(axs):\n",
-    "    if i > nmods-1:\n",
-    "        ax.set_visible(False)  # Hide unused subplots\n",
-    "# Create a legend for the whole figure\n",
-    "handles, labels = axs[0].get_legend_handles_labels()\n",
-    "fig.legend(handles, labels, loc=\"center right\" if nmods > 4 else \"upper right\")\n",
-    "plt.tight_layout(pad=3)\n",
-    "plt.show()"
+    "# Define datasets to plot\n",
+    "datasets = {\n",
+    "    \"gainMap_fit\": 'Single Photon Peak Position (ADU)',\n",
+    "    'sigmamap': 'Peak Width (σ)',\n",
+    "    'alphamap': 'Charge Sharing Probability (α)',  # 'Charge sharing parameter'\n",
+    "    'chi2map': 'Goodness of Fit (χ²/ndf)'  # 'Chi-square of fit'\n",
+    "}\n",
+    "\n",
+    "for dataset, title in datasets.items():\n",
+    "    stacked_constants = np.full(geom.expected_data_shape, np.nan)\n",
+    "    all_data = []\n",
+    "    av_modules = []\n",
+    "\n",
+    "    for i, da in enumerate(da_to_pdu.keys()):\n",
+    "        file_path = Path(out_folder) / spectra_fit_temp.format(run, proposal.upper(), da, fit_func)\n",
+    "        try:\n",
+    "            with h5file(file_path, 'r') as f:\n",
+    "                data = np.array(f[dataset])\n",
+    "                stacked_constants[i] = np.moveaxis(np.mean(data, axis=-1), 0, 1)\n",
+    "                all_data.append(data)\n",
+    "                av_modules.append(da)\n",
+    "        except Exception as e:\n",
+    "            warning(f\"Error while loading Fitting file {file_path}: {e}\")\n",
+    "            # Help to avoid plotting missing module.\n",
+    "\n",
+    "    # Plot stacked heatmap\n",
+    "    plot_stacked_heatmap(geom, stacked_constants, title)\n",
+    "\n",
+    "    plot_histogram_of_values(\n",
+    "        np.concatenate(all_data), f'Distribution of {title}')\n",
+    "\n",
+    "    # Plot cell comparison for the the last module module\n",
+    "    plot_cells_comparison(all_data[-1], f'{title} - Module {av_modules[-1]}({da_to_pdu[av_modules[-1]]})')"
    ]
   }
  ],
-- 
GitLab