From 8aad680915d2ad8dd837f372c7d0fe9987ca034f Mon Sep 17 00:00:00 2001
From: Karim Ahmed <karim.ahmed@xfel.eu>
Date: Fri, 1 Apr 2022 12:43:28 +0200
Subject: [PATCH] [AGIPD][CORRECT] Workaround for fixing plots after correcting
 one cellId

---
 .../AGIPD/AGIPD_Correct_and_Verify.ipynb      | 101 +++++++++++-------
 1 file changed, 60 insertions(+), 41 deletions(-)

diff --git a/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb b/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
index 932269330..a9471792d 100644
--- a/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
+++ b/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
@@ -814,7 +814,23 @@
     "    else:\n",
     "        tid, data = next(iter(run_data.select(f'{detector_id}/DET/*', source).trains(require_all=True)))\n",
     "\n",
-    "    return tid, stack_detector_data(train=data, data=source, fillvalue=fillvalue, modules=modules)"
+    "    # TODO: remove and use the keep_dims version after updating Extra-data.\n",
+    "    # Avoid using default axis with sources of an expected scalar value per train.\n",
+    "    if len(range(*cell_sel.crange)) == 1 and source in ['image.blShift', 'image.cellId', 'image.pulseId']:\n",
+    "        axis = 0\n",
+    "    else:\n",
+    "        axis = -3\n",
+    "\n",
+    "    stacked_data = stack_detector_data(\n",
+    "        train=data, data=source, fillvalue=fillvalue, modules=modules, axis=axis)\n",
+    "    # Add cellId dimension when correcting one cellId only.\n",
+    "    if (\n",
+    "        len(range(*cell_sel.crange)) == 1 and\n",
+    "        data_folder != run_folder  # avoid adding pulse dims for raw data.\n",
+    "    ):\n",
+    "        stacked_data = stacked_data[np.newaxis, ...]\n",
+    "\n",
+    "    return tid, stacked_data"
    ]
   },
   {
@@ -931,6 +947,10 @@
    "source": [
     "pulse_range = [np.min(pulseId[pulseId>=0]), np.max(pulseId[pulseId>=0])]\n",
     "\n",
+    "# Modify pulse_range, if only one pulse is selected.\n",
+    "if pulse_range[0] == pulse_range[1]:\n",
+    "    pulse_range = [0, pulse_range[1]+int(acq_rate)]\n",
+    "\n",
     "mean_data = np.nanmean(corrected, axis=(2, 3))\n",
     "hist, bins_x, bins_y = calgs.histogram2d(mean_data.flatten().astype(np.float32),\n",
     "                                      pulseId.flatten().astype(np.float32),\n",
@@ -993,8 +1013,13 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "display(Markdown('### Raw preview ###\\n'))\n",
-    "display(Markdown(f'Mean over images of the RAW data\\n'))"
+    "if cell_id_preview not in cellId[:, 0]:\n",
+    "    print(f\"WARNING: The selected cell_id_preview value {cell_id_preview} is not available in the corrected data.\")\n",
+    "    cell_id_preview = cellId[:, 0][0]\n",
+    "    cell_idx_preview = 0\n",
+    "    print(f\"Previewing the first available cellId: {cell_id_preview}.\")\n",
+    "else:\n",
+    "    cell_idx_preview = np.where(cellId[:, 0] == cell_id_preview)[0][0] "
    ]
   },
   {
@@ -1003,11 +1028,17 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "fig = plt.figure(figsize=(20, 10))\n",
-    "ax = fig.add_subplot(111)\n",
-    "data = np.mean(raw[:, 0, ...], axis=0)\n",
-    "vmin, vmax = get_range(data, 5)\n",
-    "ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
+    "display(Markdown('### Raw preview ###\\n'))\n",
+    "if cellId.shape[0] != 1:\n",
+    "    display(Markdown(f'Mean over images of the RAW data\\n'))\n",
+    "    fig = plt.figure(figsize=(20, 10))\n",
+    "    ax = fig.add_subplot(111)\n",
+    "    data = np.mean(raw[slice(*cell_sel.crange), 0, ...], axis=0)\n",
+    "    vmin, vmax = get_range(data, 5)\n",
+    "    ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)\n",
+    "else:\n",
+    "    print(\"Skipping mean RAW preview for single memory cell, \"\n",
+    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
    ]
   },
   {
@@ -1016,11 +1047,11 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "display(Markdown(f'Single shot of the RAW data from cell {np.max(cellId[cell_id_preview])} \\n'))\n",
+    "display(Markdown(f'Single shot of the RAW data from cell {cell_id_preview} \\n'))\n",
     "fig = plt.figure(figsize=(20, 10))\n",
     "ax = fig.add_subplot(111)\n",
-    "vmin, vmax = get_range(raw[cell_id_preview, 0, ...], 5)\n",
-    "ax = geom.plot_data_fast(raw[cell_id_preview, 0, ...], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
+    "vmin, vmax = get_range(raw[cell_idx_preview, 0, ...], 5)\n",
+    "ax = geom.plot_data_fast(raw[cell_idx_preview, 0, ...], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
    ]
   },
   {
@@ -1030,7 +1061,17 @@
    "outputs": [],
    "source": [
     "display(Markdown('### Corrected preview ###\\n'))\n",
-    "display(Markdown(f'A single shot image from cell {np.max(cellId[cell_id_preview])} \\n'))"
+    "if cellId.shape[0] != 1:\n",
+    "    display(Markdown('### Mean CORRECTED Preview ###\\n'))\n",
+    "    display(Markdown(f'A mean across train: {tid}\\n'))\n",
+    "    fig = plt.figure(figsize=(20, 10))\n",
+    "    ax = fig.add_subplot(111)\n",
+    "    data = np.mean(corrected, axis=0)\n",
+    "    vmin, vmax = get_range(data, 7)\n",
+    "    ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=-50, vmax=vmax)\n",
+    "else:\n",
+    "    print(\"Skipping mean CORRECTED preview for single memory cell, \"\n",
+    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
    ]
   },
   {
@@ -1039,11 +1080,12 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "display(Markdown(f'A single shot of the CORRECTED image from cell {cell_id_preview} \\n'))\n",
     "fig = plt.figure(figsize=(20, 10))\n",
     "ax = fig.add_subplot(111)\n",
-    "vmin, vmax = get_range(corrected[cell_id_preview], 7, -50)\n",
+    "vmin, vmax = get_range(corrected[cell_idx_preview], 7, -50)\n",
     "vmin = - 50\n",
-    "ax = geom.plot_data_fast(corrected[cell_id_preview], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
+    "ax = geom.plot_data_fast(corrected[cell_idx_preview], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
    ]
   },
   {
@@ -1054,9 +1096,9 @@
    "source": [
     "fig = plt.figure(figsize=(20, 10))\n",
     "ax = fig.add_subplot(111)\n",
-    "vmin, vmax = get_range(corrected[cell_id_preview], 5, -50)\n",
+    "vmin, vmax = get_range(corrected[cell_idx_preview], 5, -50)\n",
     "nbins = np.int((vmax + 50) / 2)\n",
-    "h = ax.hist(corrected[cell_id_preview].flatten(),\n",
+    "h = ax.hist(corrected[cell_idx_preview].flatten(),\n",
     "            bins=nbins, range=(-50, vmax),\n",
     "            histtype='stepfilled', log=True)\n",
     "plt.xlabel('[ADU]')\n",
@@ -1064,29 +1106,6 @@
     "ax.grid()"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "display(Markdown('### Mean CORRECTED Preview ###\\n'))\n",
-    "display(Markdown(f'A mean across one train\\n'))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "fig = plt.figure(figsize=(20, 10))\n",
-    "ax = fig.add_subplot(111)\n",
-    "data = np.mean(corrected, axis=0)\n",
-    "vmin, vmax = get_range(data, 7)\n",
-    "ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=-50, vmax=vmax)"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -1164,7 +1183,7 @@
    "outputs": [],
    "source": [
     "display(Markdown(f'### Single Shot Bad Pixels ### \\n'))\n",
-    "display(Markdown(f'A single shot bad pixel map from cell {np.max(cellId[cell_id_preview])} \\n'))"
+    "display(Markdown(f'A single shot bad pixel map from cell {cell_id_preview} \\n'))"
    ]
   },
   {
@@ -1175,7 +1194,7 @@
    "source": [
     "fig = plt.figure(figsize=(20, 10))\n",
     "ax = fig.add_subplot(111)\n",
-    "geom.plot_data_fast(np.log2(mask[cell_id_preview]), ax=ax, vmin=0, vmax=32, cmap=\"jet\")"
+    "geom.plot_data_fast(np.log2(mask[cell_idx_preview]), ax=ax, vmin=0, vmax=32, cmap=\"jet\")"
    ]
   },
   {
-- 
GitLab