From d59052c40b8ff68cdce1a3ae9ea058a7099abdf1 Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Thu, 23 May 2024 14:10:53 +0200
Subject: [PATCH] fix: no need to use keep_data_dims function and use available
 functionalities in extra

---
 ...Jungfrau_Gain_Correct_and_Verify_NBC.ipynb | 83 +++++++------------
 1 file changed, 28 insertions(+), 55 deletions(-)

diff --git a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
index 668a7a165..059e49b4b 100644
--- a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
+++ b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
@@ -100,7 +100,6 @@
     "from cal_tools.step_timing import StepTimer\n",
     "from cal_tools.tools import (\n",
     "    calcat_creation_time,\n",
-    "    keep_data_dims,\n",
     "    map_seq_files,\n",
     "    write_constants_fragment,\n",
     ")\n",
@@ -780,51 +779,45 @@
     "step_timer.start()\n",
     "first_seq = 0 if sequences == [-1] else sequences[0]\n",
     "\n",
-    "corrected_files = [\n",
+    "seq_corrected_files = [\n",
     "    out_folder / f for f in fnmatch.filter(corrected_files, f\"*{run}*S{first_seq:05d}*\")\n",
     "]\n",
-    "with DataCollection.from_paths(corrected_files) as corr_dc:\n",
+    "\n",
+    "# TODO: replace with CALCAT value.\n",
+    "if \"1M\" in karabo_id:\n",
+    "    nmods = 2\n",
+    "elif \"4M\" in karabo_id:\n",
+    "    nmods = 8\n",
+    "else:  # 500K\n",
+    "    nmods = 1\n",
+    "\n",
+    "with DataCollection.from_paths(seq_corrected_files) as corr_dc:\n",
     "    # Reading CORR data for plotting.\n",
     "    jf_corr = components.JUNGFRAU(\n",
     "        corr_dc,\n",
     "        detector_name=karabo_id,\n",
+    "        n_modules=nmods,\n",
     "    ).select_trains(np.s_[:plot_trains])\n",
     "    tid, jf_corr_data = next(iter(jf_corr.trains(require_all=True)))\n",
-    "det_mod_start = jf_corr._modnos_start_at  # TODO: replace with CALCAT value.\n",
-    "available_modules_indices = list(jf_corr.modno_to_source.keys())\n",
+    "\n",
     "# Shape = [modules, trains, cells, x, y]\n",
-    "corrected = jf_corr.get_array(\"data.adc\")[:, :, cell_idx_preview, ...].values\n",
-    "corrected_train = keep_data_dims(  # loose the train axis.\n",
-    "    jf_corr_data[\"data.adc\"][:, cell_idx_preview, ...].values,\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
+    "corrected = jf_corr[\"data.adc\"].ndarray(module_gaps=True)[:, :, cell_idx_preview, ...]\n",
+    "corrected_train = jf_corr_data[\"data.adc\"][:, cell_idx_preview, ...]  # loose the train axis.\n",
     "\n",
-    "mask = jf_corr.get_array(\"data.mask\")[:, :, cell_idx_preview, ...].values\n",
-    "mask_train = keep_data_dims(\n",
-    "    jf_corr_data[\"data.mask\"][:, cell_idx_preview, ...].values,\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
+    "mask = jf_corr[\"data.mask\"].ndarray(module_gaps=True)[:, :, cell_idx_preview, ...]\n",
+    "mask_train = jf_corr_data[\"data.mask\"][:, cell_idx_preview, ...]\n",
     "\n",
     "with RunDirectory(f\"{in_folder}/r{run:04d}/\", f\"*S{first_seq:05d}*\", _use_voview=False) as raw_dc:\n",
     "    # Reading RAW data for plotting.\n",
-    "    jf_raw = components.JUNGFRAU(raw_dc, detector_name=karabo_id).select_trains(\n",
-    "            np.s_[:plot_trains]\n",
-    "    )\n",
+    "    jf_raw = components.JUNGFRAU(\n",
+    "        raw_dc, detector_name=karabo_id, n_modules=nmods\n",
+    "        ).select_trains(np.s_[:plot_trains])\n",
     "\n",
-    "raw = jf_raw.get_array(\"data.adc\")[:, :, cell_idx_preview, ...].values\n",
-    "raw_train = (\n",
-    "    jf_raw.select_trains(by_id[[tid]])\n",
-    "    .get_array(\"data.adc\")[:, 0, cell_idx_preview, ...]\n",
-    "    .values\n",
-    ")\n",
+    "raw = jf_raw[\"data.adc\"].ndarray(module_gaps=True)[:, :, cell_idx_preview, ...]\n",
     "\n",
-    "gain = jf_raw.get_array(\"data.gain\")[:, :, cell_idx_preview, ...].values\n",
+    "gain = jf_raw[\"data.gain\"].ndarray(module_gaps=True)[:, :, cell_idx_preview, ...]\n",
     "gain_train_cells = (\n",
-    "    jf_raw.select_trains(by_id[[tid]]).get_array(\"data.gain\")[:, :, :, ...].values\n",
+    "    jf_raw.select_trains(by_id[[tid]], )[\"data.gain\"].ndarray(module_gaps=True)[:, :, :, ...]\n",
     ")\n",
     "step_timer.done_step(\"Prepared data for plotting\")"
    ]
@@ -845,13 +838,8 @@
     "print(f\"The per pixel mean of the first {raw.shape[1]} trains of the first sequence file\")\n",
     "\n",
     "fig, ax = plt.subplots(figsize=(18, 10))\n",
-    "raw_mean = keep_data_dims(\n",
-    "    np.mean(raw, axis=1),\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
-    "vmin, vmax = np.percentile(raw_mean, [5, 95])\n",
+    "raw_mean = np.nanmean(raw, axis=1)\n",
+    "vmin, vmax = np.nanpercentile(raw_mean, [5, 95])\n",
     "geom.plot_data_fast(\n",
     "    raw_mean,\n",
     "    ax=ax,\n",
@@ -879,12 +867,7 @@
     "print(f\"The per pixel mean of the first {corrected.shape[1]} trains of the first sequence file\")\n",
     "\n",
     "fig, ax = plt.subplots(figsize=(18, 10))\n",
-    "corrected_mean = keep_data_dims(\n",
-    "    np.nanmean(corrected, axis=1),\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
+    "corrected_mean = np.nanmean(corrected, axis=1)\n",
     "vmin, vmax = np.nanpercentile(corrected_mean, [5, 95])\n",
     "\n",
     "mean_plot_kwargs = dict(vmin=vmin, vmax=vmax)\n",
@@ -920,12 +903,7 @@
     "fig, ax = plt.subplots(figsize=(18, 10))\n",
     "corrected_masked = corrected.copy()\n",
     "corrected_masked[mask != 0] = np.nan\n",
-    "corrected_masked_mean = keep_data_dims(\n",
-    "    np.nanmean(corrected_masked, axis=1),\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
+    "corrected_masked_mean = np.nanmean(corrected_masked, axis=1)\n",
     "del corrected_masked\n",
     "\n",
     "if not strixel_sensor:\n",
@@ -1080,12 +1058,7 @@
     "display(Markdown((f\"#### The per pixel maximum of train {tid} of the GAIN data\")))\n",
     "\n",
     "fig, ax = plt.subplots(figsize=(18, 10))\n",
-    "gain_max = keep_data_dims(\n",
-    "    np.max(gain_train_cells, axis=(1, 2)),\n",
-    "    geom.expected_data_shape,\n",
-    "    available_modules_indices,\n",
-    "    det_mod_start\n",
-    ")\n",
+    "gain_max = np.max(gain_train_cells, axis=(1, 2))\n",
     "\n",
     "geom.plot_data_fast(\n",
     "    gain_max,\n",
-- 
GitLab