From 7f8bfdce205e64cbb16e40aa51d6f6d226b29ff6 Mon Sep 17 00:00:00 2001
From: Nuno Duarte <duarten@max-exfl001.desy.de>
Date: Wed, 23 Nov 2022 17:10:06 +0100
Subject: [PATCH] use pasha to speed up data correction steps

---
 .../Characterize_FlatFields_ePix100_NBC.ipynb | 375 ++++++++++--------
 1 file changed, 208 insertions(+), 167 deletions(-)

diff --git a/notebooks/ePix100/Characterize_FlatFields_ePix100_NBC.ipynb b/notebooks/ePix100/Characterize_FlatFields_ePix100_NBC.ipynb
index 893e4cf56..589d940b8 100644
--- a/notebooks/ePix100/Characterize_FlatFields_ePix100_NBC.ipynb
+++ b/notebooks/ePix100/Characterize_FlatFields_ePix100_NBC.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "7cf33cf7",
+   "id": "1c7e1c06",
    "metadata": {},
    "source": [
     "#  ePix100 Flat Field Characterization\n",
@@ -15,19 +15,19 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "c96be80c",
+   "id": "632e811d",
    "metadata": {},
    "outputs": [],
    "source": [
     "in_folder = '/gpfs/exfel/exp/MID/202231/p900310/raw' # input folder, required\n",
-    "out_folder = '' # output folder, required\n",
+    "out_folder = '/gpfs/exfel/data/scratch/duarten/outputs' # output folder, required\n",
     "metadata_folder = ''  # Directory containing calibration_metadata.yml when run by xfel-calibrate\n",
     "sequences = [-1] # sequences to process, set to -1 for all, range allowed\n",
-    "run = 33 # which run to read data from, required\n",
+    "run = 29 # which run to read data from, required\n",
     "\n",
     "# Parameters for accessing the raw data.\n",
-    "karabo_id = \"MID_EXP_EPIX-1\" # karabo karabo_id\n",
-    "karabo_da = \"EPIX01\"  # data aggregators\n",
+    "karabo_id = \"MID_EXP_EPIX-2\" # karabo karabo_id\n",
+    "karabo_da = \"EPIX02\"  # data aggregators\n",
     "receiver_template = \"RECEIVER\" # detector receiver template for accessing raw data files\n",
     "path_template = 'RAW-R{:04d}-{}-S{{:05d}}.h5' # the template to use to access data\n",
     "instrument_source_template = '{}/DET/{}:daqOutput' # instrument detector data source in h5files\n",
@@ -74,7 +74,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "6e89ae94",
+   "id": "a7e217c5",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -85,6 +85,7 @@
     "import matplotlib.pyplot as plt\n",
     "from matplotlib.colors import LogNorm\n",
     "import numpy as np\n",
+    "import pasha as psh\n",
     "from extra_data import RunDirectory, H5File\n",
     "from pathlib import Path\n",
     "from prettytable import PrettyTable\n",
@@ -112,7 +113,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "b410e39c",
+   "id": "04dcdabd",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -123,12 +124,14 @@
     "prettyPlotting = True\n",
     "\n",
     "profiler = xprof.Profiler()\n",
-    "profiler.disable()"
+    "profiler.disable()\n",
+    "\n",
+    "step_timer = StepTimer()"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "7f2f9f6c",
+   "id": "1dca330a",
    "metadata": {},
    "source": [
     "## Load Data"
@@ -137,7 +140,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "f93b80e6",
+   "id": "04be011d",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -164,7 +167,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ccf801b4",
+   "id": "41591fdc",
    "metadata": {
     "slideshow": {
      "slide_type": "-"
@@ -201,15 +204,15 @@
     "    raise IndexError(\"No sequence files available for the selected sequences.\")\n",
     "\n",
     "# Trains to be processed\n",
-    "trains = np.ndarray(0,dtype=int)\n",
-    "for seq in seq_files:\n",
-    "    seq_str = str(seq) \n",
-    "    n = int(seq_str[seq_str.rfind('-S0')+len('-S0'):seq_str.rfind('.h5')])\n",
-    "    seq_size = H5File(seq).get_data_counts(*pixels_src).size # last sequence might be smaller than seq0_size\n",
-    "    t = np.arange(n*seq0_size,n*seq0_size+seq_size)\n",
-    "    trains = np.append(trains,t)\n",
-    "    \n",
-    "trains = trains[:n_trains]\n",
+    "n_trains = run_dir.get_data_counts(*pixels_src).shape[0]\n",
+    "dshape = run_dir.select(*pixels_src)[pixels_src].shape\n",
+    "\n",
+    "if n_trains != dshape[0]:\n",
+    "    print(f\"Warning: {n_trains - dshape[0]} trains with empty data.\")\n",
+    "    n_trains = dshape[0]\n",
+    "\n",
+    "trains = np.arange(0,n_trains)\n",
+    "data_dc = run_dir.select(*pixels_src,require_all=True).select_trains(trains)\n",
     "\n",
     "print(f\"Reading from: \")\n",
     "[print(f'\\t{seq}') for seq in seq_files]\n",
@@ -221,7 +224,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "aa4b9002",
+   "id": "ad572c1e",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -263,7 +266,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "400b0e3d",
+   "id": "2f68865b",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -279,7 +282,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "7985697d",
+   "id": "50247fff",
    "metadata": {
     "tags": []
    },
@@ -290,7 +293,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "15bb48da",
+   "id": "dcc980a6",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -318,7 +321,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "0be1ba68",
+   "id": "83c6ee4a",
    "metadata": {},
    "source": [
     "## Instantiate calculators"
@@ -327,7 +330,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "d1913cc2",
+   "id": "d2a20082",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -376,7 +379,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "4681991f",
+   "id": "d2a61263",
    "metadata": {},
    "source": [
     "## Correct data"
@@ -385,68 +388,86 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "85f34676",
+   "id": "6298f8d7",
    "metadata": {},
    "outputs": [],
    "source": [
-    "step_timer.start()\n",
-    "bins = np.arange(ADU_range[0],ADU_range[1],1)\n",
+    "bin_min = -50\n",
+    "bin_max = 800\n",
+    "bin_width = 1\n",
     "\n",
-    "hist_O=hist_CM=hist_CS=0\n",
-    "timer_O=timer_CM=timer_CS=0\n",
-    "data_singles = np.empty(data.swapaxes(0,-1).shape,dtype=int)\n",
-    "\n",
-    "chunk_size = 100 # Data is processed by chunks to avoid memory overload\n",
-    "chunk = 0\n",
-    "while chunk < data.shape[0]-1:\n",
-    "    \n",
-    "    prev_chunk = chunk\n",
-    "    chunk+=chunk_size\n",
-    "    if chunk > data.shape[0]: # last chunk may have different size\n",
-    "        chunk = data.shape[0]-1\n",
+    "bins = np.arange(bin_min,bin_max,bin_width)\n",
+    "hist = {'O': 0,'CM': 0,'CS': 0, 'S': 0}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "be0e8b03",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def correct_train(worker_id, index, train_id, dc):\n",
     "\n",
-    "    d = data[prev_chunk:chunk]\n",
+    "    d = dc[pixels_src[0]][pixels_src[1]].astype(np.float32)\n",
     "\n",
     "    # Offset correction\n",
-    "    t = time.time()\n",
     "    d -= const_data['Offset'].squeeze()\n",
-    "    hist_O += np.histogram(d.flatten(),bins=bins)[0]\n",
-    "    timer_O += time.time()-t\n",
-    "\n",
+    "    hist['O'] += np.histogram(d.flatten(),bins=bins)[0]\n",
+    "     \n",
     "    # Common Mode correction\n",
-    "    t = time.time()\n",
     "    d = d.swapaxes(0,-1)\n",
     "    d = cmCorrection_block.correct(d)\n",
     "    d = cmCorrection_col.correct(d)\n",
     "    d = cmCorrection_row.correct(d)\n",
-    "    hist_CM += np.histogram(d.flatten(),bins=bins)[0]\n",
-    "    timer_CM += time.time()-t\n",
+    "    d = d.swapaxes(0,-1)\n",
+    "    hist['CM'] += np.histogram(d.flatten(),bins=bins)[0]\n",
     "    \n",
     "    # Charge Sharing correction\n",
-    "    t = time.time()\n",
+    "    d = d.swapaxes(0,-1)\n",
     "    d, patterns = patternClassifier.classify(d)\n",
-    "    data_singles[:,:,prev_chunk:chunk],fs = patternSelector.select(d,patterns)\n",
-    "    hist_CS += np.histogram(d[d>0].flatten(),bins=bins)[0]\n",
-    "    timer_CS += time.time()-t\n",
-    "    data[prev_chunk:chunk] = d.swapaxes(0,-1)\n",
+    "    sing,fs = patternSelector.select(d,patterns)\n",
+    "    d = d.swapaxes(0,-1)\n",
+    "    hist['CS'] += np.histogram(d[d>0].flatten(),bins=bins)[0]\n",
+    "    hist['S'] += np.histogram(sing[sing>0].flatten(),bins=bins)[0]\n",
+    "    \n",
+    "    data_corr[index+prev_chunk] = d\n",
+    "    data_singles[index+prev_chunk] = sing.swapaxes(0,-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a80abef3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "step_timer.start()\n",
     "\n",
-    "    print(f'Corrected trains: {chunk} ({int(chunk/data.shape[0]*100)}%)',end='\\r')\n",
+    "chunk_size = 1000\n",
     "\n",
-    "data_singles = data_singles.swapaxes(0,-1)\n",
-    "hist_S = np.histogram(data_singles[data_singles>0].flatten(),bins=bins)[0] # histogram of single events\n",
+    "psh.set_default_context('threads', num_workers=35) # num_workers=35 was found to be optimal\n",
+    "data_corr = psh.alloc(shape=dshape, dtype=np.float32)\n",
+    "data_singles = psh.alloc(shape=dshape, dtype=int)\n",
     "\n",
-    "print('',end='\\r')\n",
-    "print(f'\\nCorrections applied:')\n",
-    "print(f'  Offset: {int(timer_O)} s')\n",
-    "print(f'  Common Mode: {int(timer_CM)} s')\n",
-    "print(f'  Charge Sharing: {int(timer_CS)} s')\n",
+    "chunk = 0\n",
+    "while chunk < dshape[0]-1:\n",
+    "    \n",
+    "    prev_chunk = chunk\n",
+    "    chunk+=chunk_size\n",
+    "    if chunk > dshape[0]: # last chunk may have different size\n",
+    "        chunk = dshape[0]-1\n",
+    "        \n",
+    "    psh.map(correct_train, data_dc.select_trains(np.arange(prev_chunk,chunk)))\n",
+    "        \n",
+    "    print(f'Corrected trains: {chunk} ({round(chunk/dshape[0]*100)}%)',end='\\r')\n",
     "\n",
     "step_timer.done_step('Corrected data. Elapsed Time')"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "1e7cb152",
+   "id": "53e766d3",
    "metadata": {},
    "source": [
     "## Plot histograms"
@@ -455,33 +476,29 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "14648d1a",
+   "id": "d40f9942",
    "metadata": {},
    "outputs": [],
    "source": [
-    "step_timer.start()\n",
-    "\n",
     "bins_c = bins[:-1]+np.diff(bins)[0]/2 # center of bins\n",
     "\n",
     "plt.figure(figsize=(12,8))\n",
-    "plt.plot(bins_c,hist_O, label='Offset corrected')\n",
-    "plt.plot(bins_c,hist_CM, label='Common Mode corrected')\n",
-    "plt.plot(bins_c,hist_CS, label='Charge Sharing corrected')\n",
-    "plt.plot(bins_c,hist_S, label='Singles')\n",
+    "plt.plot(bins_c,hist['O'], label='Offset corrected')\n",
+    "plt.plot(bins_c,hist['CM'], label='Common Mode corrected')\n",
+    "plt.plot(bins_c,hist['CS'], label='Charge Sharing corrected')\n",
+    "plt.plot(bins_c,hist['S'], label='Singles')\n",
     "plt.xlim(ADU_range)\n",
     "plt.yscale('log')\n",
     "plt.xlabel('ADU',fontsize=12)\n",
     "plt.title(f'{karabo_id} | {proposal} - r{run}', fontsize=14)\n",
     "plt.legend(fontsize=12);\n",
-    "plt.grid(ls=':')\n",
-    "\n",
-    "step_timer.done_step('Calculated histograms. Elapsed Time')"
+    "plt.grid(ls=':')"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "996ec108",
+   "id": "1fcb9f4c",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -511,7 +528,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "e5beb142",
+   "id": "4034fc52",
    "metadata": {},
    "source": [
     "## Flat-Field Statistics"
@@ -520,7 +537,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "f6dd62a0",
+   "id": "2a8322be",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -530,19 +547,19 @@
     "    return A*np.exp(-(x-mu)**2/(2.*sigma**2))\n",
     "\n",
     "# rough initial estimate of fit parameters\n",
-    "fit_estimates = [np.max(hist_S),           # amplitude\n",
-    "                 bins[np.argmax(hist_S)],  # centroid\n",
-    "                 10]                       # sigma"
+    "fit_estimates = [np.max(hist['S']),           # amplitude\n",
+    "                 bins[np.argmax(hist['S'])],  # centroid\n",
+    "                 10]                          # sigma"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e23abd09",
+   "id": "870bcfb3",
    "metadata": {},
    "outputs": [],
    "source": [
-    "coeff, var_matrix = curve_fit(gauss, bins_c, hist_S, p0=fit_estimates)\n",
+    "coeff, var_matrix = curve_fit(gauss, bins_c, hist['S'], p0=fit_estimates)\n",
     "singles_mu = coeff[1]\n",
     "singles_sig = abs(coeff[2])\n",
     "ROI = np.round([singles_mu-N_sigma_interval*singles_sig, # region of interest to find first photopeak per pixel\n",
@@ -550,9 +567,9 @@
     "y_fit = gauss(bins_c, *coeff)\n",
     "\n",
     "plt.figure(figsize=(9,6))\n",
-    "plt.plot(bins_c,hist_S, 'k' , label = 'singles')\n",
-    "plt.plot(bins_c,y_fit, 'g--', label = 'gauss fit') \n",
-    "plt.ylim(1,max(hist_S)*1.5);\n",
+    "plt.plot(bins_c,hist['S'],'k',label = 'singles')\n",
+    "plt.plot(bins_c,y_fit,'g--',label = 'gauss fit') \n",
+    "plt.ylim(1,max(hist['S'])*1.5);\n",
     "plt.xlim(ADU_range)\n",
     "plt.vlines(coeff[1],0,plt.gca().get_ylim()[1],color='g',ls=':')\n",
     "\n",
@@ -578,7 +595,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ba77bc13",
+   "id": "6d822905",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -597,15 +614,14 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "888adca9",
+   "id": "a9ee83ef",
    "metadata": {
     "scrolled": false
    },
    "outputs": [],
    "source": [
-    "# Ignore bins that have less than 1% of counts compared to the bin with more counts\n",
     "mask_bins = np.unique(singles_per_pixel,return_counts=True)[1] > np.max(np.unique(singles_per_pixel,return_counts=True)[1])*.01\n",
-    "last_bin = np.max(np.unique(singles_per_pixel)[mask_bins])\n",
+    "last_bin = np.max(np.unique(singles_per_pixel)[mask_bins]) # xlim on bin that has less than 1% of max counts\n",
     "\n",
     "# Plot singles distribution\n",
     "fig = xana.heatmapPlot(\n",
@@ -633,7 +649,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "e8a591cf",
+   "id": "94762322",
    "metadata": {},
    "source": [
     "## Plot random sample pixels "
@@ -642,7 +658,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e02c7217",
+   "id": "a09d5627",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -692,7 +708,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "7ed693b6",
+   "id": "e3117832",
    "metadata": {},
    "source": [
     "## Fit single photon peaks per pixel"
@@ -701,10 +717,8 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "4ec68114",
-   "metadata": {
-    "scrolled": false
-   },
+   "id": "e6d1134f",
+   "metadata": {},
    "outputs": [],
    "source": [
     "step_timer.start()\n",
@@ -734,7 +748,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "aaedd476",
+   "id": "31465fa6",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -768,7 +782,6 @@
     "plt.ylim((1, coeff[0]*1.2))\n",
     "plt.show()\n",
     "\n",
-    "\n",
     "print('--------------------')\n",
     "print('Fit parameters:')\n",
     "print(f'  centroid = {np.round(coeff[1],3)}')\n",
@@ -778,7 +791,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "97a62e5a",
+   "id": "58329334",
    "metadata": {},
    "source": [
     "## Flat-Field Bad Pixels"
@@ -787,7 +800,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ba92087e",
+   "id": "4391a940",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -834,7 +847,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "f4efeefe",
+   "id": "50a4ceb9",
    "metadata": {},
    "source": [
     "## Relative Gain Map"
@@ -843,7 +856,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "24bdbef2",
+   "id": "a036c8a1",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -867,7 +880,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "6485378b",
+   "id": "cb0f0679",
    "metadata": {},
    "source": [
     "## Absolute Gain Conversion Constant"
@@ -876,14 +889,14 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ee55bccc",
+   "id": "c71b22d2",
    "metadata": {},
    "outputs": [],
    "source": [
     "step_timer.start()\n",
     "\n",
     "# Correct data with calculated gain map\n",
-    "data_gain_corrected = data*rel_gain_map\n",
+    "data_gain_corrected = data_corr*rel_gain_map\n",
     "\n",
     "h,ADU = np.histogram(data_gain_corrected.flatten(),\n",
     "                     bins=np.arange(BP_fit_threshold[0],BP_fit_threshold[1]).astype(int))\n",
@@ -902,7 +915,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ef97ebf2",
+   "id": "0bde8503",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -938,7 +951,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "cf95ebad",
+   "id": "81662158",
    "metadata": {},
    "source": [
     "## Gain Map Validation\n",
@@ -951,7 +964,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "13a0408f",
+   "id": "186ebc07",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -984,11 +997,15 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "16383412",
+   "id": "243de317",
    "metadata": {},
    "outputs": [],
    "source": [
     "if db_gain_map is not None:\n",
+    "    \n",
+    "    # Calculate gain conversion constant of DB gain map\n",
+    "    gain_conv_const_db = 1/np.median(db_gain_map[const_data['BadPixelsDark'].squeeze()>0])\n",
+    "    \n",
     "    # Correlate new and DB gain maps\n",
     "    plt.figure(figsize=(7,7))\n",
     "\n",
@@ -1011,36 +1028,59 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "a2c1e6cb",
+   "id": "10599c7f",
    "metadata": {},
    "outputs": [],
    "source": [
+    "def correct_validation_train(worker_id, index, train_id, dc):\n",
+    "\n",
+    "    d = dc[pixels_src[0]][pixels_src[1]].astype(np.float32)\n",
+    "\n",
+    "    # Offset correction\n",
+    "    d -= const_data['Offset'].squeeze()\n",
+    "\n",
+    "    # Common Mode correction\n",
+    "    d = d.swapaxes(0,-1)\n",
+    "    d = cmCorrection_block.correct(d)\n",
+    "    d = cmCorrection_col.correct(d)\n",
+    "    d = cmCorrection_row.correct(d)\n",
+    "    d = d.swapaxes(0,-1)\n",
+    "\n",
+    "    # Relative Gain correction\n",
+    "    d_new_map = d*rel_gain_map\n",
+    "    if db_gain_map is not None:\n",
+    "        d_db_map  = d*db_gain_map*gain_conv_const_db\n",
+    "\n",
+    "    # Charge Sharing correction\n",
+    "    d, patterns = patternClassifier.classify(d.swapaxes(0,-1))\n",
+    "    FF_data[index] = d.swapaxes(0,-1) # no gain correction\n",
+    "    \n",
+    "    d_new_map, patterns = patternClassifier.classify(d_new_map.swapaxes(0,-1))\n",
+    "    FF_data_new_map[index] = d_new_map.swapaxes(0,-1) # gain correction with new gain map\n",
+    "    \n",
+    "    if db_gain_map is not None:\n",
+    "        d_db_map, patterns = patternClassifier.classify(d_db_map.swapaxes(0,-1))\n",
+    "        FF_data_db_map[index] = d_db_map.swapaxes(0,-1) # gain correction with DB gain map"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4d02f5ff",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Correct validation trains\n",
     "step_timer.start()\n",
     "\n",
     "N_validation_trains = 1000\n",
-    "data_dc = run_dir.select(*pixels_src, require_all=True).select_trains(trains[:N_validation_trains])\n",
-    "FF_data = data_dc[pixels_src].ndarray().astype(np.float16)\n",
-    "\n",
-    "# Offset_correction\n",
-    "FF_data -= const_data['Offset'].squeeze()\n",
-    "\n",
-    "# Common Mode correction\n",
-    "FF_data = FF_data.swapaxes(0,-1)\n",
-    "FF_data = cmCorrection_block.correct(FF_data)\n",
-    "FF_data = cmCorrection_col.correct(FF_data)\n",
-    "FF_data = cmCorrection_row.correct(FF_data)\n",
-    "FF_data = FF_data.swapaxes(0,-1)\n",
-    "\n",
-    "# Relative Gain & Charge Sharing correction\n",
-    "FF_data_new_map = FF_data*abs_gain_map*gain_conv_const\n",
-    "FF_data_new_map,pat = patternClassifier.classify(FF_data_new_map.swapaxes(0,-1))\n",
+    "\n",
+    "FF_data = psh.alloc(shape=(N_validation_trains,dshape[1],dshape[2]), dtype=np.float32)\n",
+    "FF_data_new_map = psh.alloc(shape=(N_validation_trains,dshape[1],dshape[2]), dtype=np.float32)\n",
     "if db_gain_map is not None:\n",
-    "    gain_conv_const_db = 1/np.median(db_gain_map[const_data['BadPixelsDark'].squeeze()>0])\n",
-    "    FF_data_db_map = FF_data*db_gain_map*gain_conv_const_db\n",
-    "    FF_data_db_map,pat = patternClassifier.classify(FF_data_db_map.swapaxes(0,-1))\n",
-    "    \n",
-    "# Charge Sharing correction (without gain correction)    \n",
-    "FF_data,pat = patternClassifier.classify(FF_data.swapaxes(0,-1))\n",
+    "    FF_data_db_map = psh.alloc(shape=(N_validation_trains,dshape[1],dshape[2]), dtype=np.float32)\n",
+    "\n",
+    "psh.map(correct_validation_train, data_dc.select_trains(trains[:N_validation_trains]))\n",
     "\n",
     "step_timer.done_step('Corrected evaluation data. Elapsed Time')"
    ]
@@ -1048,10 +1088,11 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "bc812883",
+   "id": "45c26d8e",
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Calculate histograms\n",
     "bins_FF = np.arange(-50,800)\n",
     "FF_hist_CS = np.histogram(FF_data[FF_data>0].flatten(),bins=bins_FF)[0]\n",
     "\n",
@@ -1082,13 +1123,13 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "73fe5ba3",
+   "id": "0c424437",
    "metadata": {
     "scrolled": false
    },
    "outputs": [],
    "source": [
-    "N_peaks = 6\n",
+    "N_peaks = 4\n",
     "sigma_tol = 2 # sigma tolerance to show in gauss fit\n",
     "\n",
     "# Ignore split events below primary energy threshold\n",
@@ -1156,7 +1197,7 @@
     "            plt.yscale('log')    \n",
     "            plt.xlabel('keV',fontsize=12)\n",
     "            plt.xlim(left=0)\n",
-    "            plt.ylim(1,ylim_top)\n",
+    "            plt.ylim(.1,ylim_top)\n",
     "\n",
     "            # Remove repeated entries from legend\n",
     "            handles, labels = plt.gca().get_legend_handles_labels()\n",
@@ -1185,7 +1226,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "efade482",
+   "id": "26e07868",
    "metadata": {},
    "source": [
     "## Linearity Analysis"
@@ -1194,7 +1235,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "d0098fd3",
+   "id": "462dbee2",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1203,16 +1244,16 @@
     "plt.subplot(1,2,1)\n",
     "plt.plot(peaks,[peak_energy*p for p in peaks], '-', c='k', label='expected')\n",
     "plt.plot(peaks,[peak_energy*p for p in peaks]+np.array(abs_dev[:N_peaks]), 'o', c='b')\n",
-    "new_fit= np.polyfit(peaks,[peak_energy*p for p in peaks]+np.array(abs_dev[:N_peaks]),1)\n",
+    "fit_coeffs= np.polyfit(peaks,[peak_energy*p for p in peaks]+np.array(abs_dev[:N_peaks]),1)\n",
     "\n",
-    "plt.plot(peaks,new_fit[0]*peaks+new_fit[1], '--', c='b', label='New gain map')\n",
+    "plt.plot(peaks,fit_coeffs[0]*peaks+fit_coeffs[1], '--', c='b', label='New gain map')\n",
     "str_theo  = f'$a_1$={peak_energy :,.4f}, $a_0$=0'\n",
-    "str_new = f'$a_1$={new_fit[0]:,.4f}, $a_0$={new_fit[1]:,.4f}'\n",
+    "str_new = f'$a_1$={fit_coeffs[0]:,.4f}, $a_0$={fit_coeffs[1]:,.4f}'\n",
     "plt.annotate(s=str_theo,xy=(.36,.94),xycoords='axes fraction',fontsize=11,bbox=dict(facecolor='k',alpha=.2,pad=1))\n",
     "plt.annotate(s=str_new ,xy=(.36,.88),xycoords='axes fraction',fontsize=11,bbox=dict(facecolor='b',alpha=.2,pad=1))\n",
     "\n",
-    "xx = np.arange(0,101)\n",
-    "y_fit_new = new_fit[0]*xx+new_fit[1] # extrapolation for 100 photons\n",
+    "xx = np.arange(1,100,.1) # in photons\n",
+    "y_fit_new = fit_coeffs[0]*xx+fit_coeffs[1] # extrapolation for 100 photons\n",
     "\n",
     "plt.xticks(peaks)\n",
     "plt.title(f'Linearity ({karabo_id} | {proposal} - r{run})')\n",
@@ -1222,11 +1263,11 @@
     "plt.grid(ls=':')\n",
     "\n",
     "plt.subplot(1,2,2)\n",
-    "dev_new = abs(y_fit_new-(peak_energy*xx))/(peak_energy*xx)*100\n",
-    "plt.plot(xx,dev_new,c='b', label='New gain map')\n",
+    "dev_new = (y_fit_new-(peak_energy*xx))/(peak_energy*xx)*100\n",
+    "plt.plot(xx*peak_energy,dev_new,c='b', label='New gain map')\n",
     "plt.xscale('log')\n",
-    "plt.xlim(1,100)\n",
-    "plt.xlabel('# Photons')\n",
+    "plt.xlim(right=100)\n",
+    "plt.xlabel('Energy (keV)')\n",
     "plt.ylabel('Linearity Deviation (%)')\n",
     "plt.title(f'Linearity extrapolation ({karabo_id} | {proposal} - r{run})')\n",
     "plt.grid(ls=':',which='both')\n",
@@ -1243,8 +1284,8 @@
     "    plt.annotate(s=str_db  ,xy=(.36,.82),xycoords='axes fraction',fontsize=11,bbox=dict(facecolor='r',alpha=.2,pad=1))\n",
     "\n",
     "    plt.subplot(1,2,2)\n",
-    "    dev_db = abs(y_fit_db-(peak_energy*xx))/(peak_energy*xx)*100\n",
-    "    plt.plot(xx,dev_db,c='r', label='DB gain map')\n",
+    "    dev_db = (y_fit_db-(peak_energy*xx))/(peak_energy*xx)*100\n",
+    "    plt.plot(xx*peak_energy,dev_db,c='r', label='DB gain map')\n",
     "\n",
     "plt.subplot(1,2,1)\n",
     "leg = plt.legend(fontsize=12)\n",
@@ -1254,7 +1295,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "39554980",
+   "id": "fbf45dac",
    "metadata": {},
    "source": [
     "## Energy Resolution Analysis"
@@ -1263,14 +1304,13 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "cc661463",
+   "id": "6d19c0b6",
    "metadata": {},
    "outputs": [],
    "source": [
     "def power_function(x,*p):\n",
     "    a,b,c = p\n",
     "    return a*x**b + c\n",
-    "\n",
     "# rough initial estimate of fit parameters\n",
     "fit_estimates = [20,-.5,0]"
    ]
@@ -1278,36 +1318,37 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "f29ff449",
+   "id": "0bff8c08",
    "metadata": {},
    "outputs": [],
    "source": [
     "# Linearity of the visualized peaks\n",
     "plt.figure(figsize=(8,6))\n",
-    "plt.plot(peaks,E_res[:N_peaks], 'o', c='b', label='New gain map')\n",
-    "\n",
-    "xx = np.arange(0,10,.1)\n",
-    "coeff,param = curve_fit(power_function,peaks,E_res[:N_peaks],p0=fit_estimates)\n",
-    "new_fit = power_function(xx,*coeff)\n",
-    "\n",
-    "plt.plot(xx,new_fit, '--', c='b')\n",
     "\n",
+    "xx = np.arange(0,50,.1)\n",
     "if db_gain_map is not None:\n",
-    "    plt.plot(peaks,E_res[N_peaks:], 'o', c='r', label='DB gain map')\n",
+    "    plt.plot(peaks*peak_energy,E_res[N_peaks:], 'o', c='r', label='DB gain map')\n",
+    "    coeff,param = curve_fit(power_function,peaks*peak_energy,E_res[N_peaks:],p0=fit_estimates)\n",
+    "    power_fit = power_function(xx,*coeff)\n",
+    "    plt.plot(xx,power_fit, '--', c='r')\n",
+    "\n",
+    "plt.plot(peaks*peak_energy,E_res[:N_peaks], 'o', c='b', label='New gain map')\n",
+    "coeff,param = curve_fit(power_function,peaks*peak_energy,E_res[:N_peaks],p0=fit_estimates)\n",
+    "power_fit = power_function(xx,*coeff)\n",
+    "plt.plot(xx,power_fit, '--', c='b')\n",
     "\n",
     "plt.title(f'Energy Resolution ({karabo_id} | {proposal} - r{run})')\n",
-    "plt.xlabel('# Photons')\n",
+    "plt.xlabel('Energy (keV)')\n",
     "plt.ylabel('Energy Resolution (%)')\n",
     "plt.legend(fontsize=12)\n",
-    "plt.xticks(xx*10)\n",
-    "plt.xlim(1,10)\n",
+    "plt.xlim(1,np.ceil(xx[-1]))\n",
     "plt.ylim(0,30)\n",
     "plt.grid(ls=':')"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "43336fd9",
+   "id": "a0d24089",
    "metadata": {},
    "source": [
     "## Calibration Constants DB\n",
@@ -1317,7 +1358,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "78f37290",
+   "id": "1756684b",
    "metadata": {},
    "outputs": [],
    "source": [
-- 
GitLab