From f740e1371473981ba2c04a8f1a82478ebd6faded Mon Sep 17 00:00:00 2001
From: ahmedk <karim.ahmed@xfel.eu>
Date: Tue, 27 Aug 2024 08:48:25 +0200
Subject: [PATCH] fix: put back np.histogram function. plus some refactors and
 fixes

---
 ...ngfrau_Create_Fit_Spectra_Histos_NBC.ipynb | 101 ++++++++++--------
 src/cal_tools/jungfrau/jungfrau_ff.py         |  84 ++++++++-------
 src/cal_tools/jungfrau/jungfraualgs.pyx       |  27 -----
 tests/test_jungfrau_ff.py                     |  17 +--
 4 files changed, 111 insertions(+), 118 deletions(-)
 delete mode 100644 src/cal_tools/jungfrau/jungfraualgs.pyx

diff --git a/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb
index 14d3be4db..65826a299 100644
--- a/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb
+++ b/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb
@@ -93,11 +93,11 @@
     "import pasha as psh\n",
     "from extra_data import RunDirectory\n",
     "from h5py import File as h5file\n",
+    "from tqdm import tqdm\n",
     "\n",
     "import cal_tools.restful_config as rest_cfg\n",
     "from cal_tools.calcat_interface import JUNGFRAU_CalibrationData\n",
     "from cal_tools.jungfrau import jungfrau_ff\n",
-    "from cal_tools.jungfrau.jungfraualgs import fill_histogram\n",
     "from cal_tools.jungfrau.jungfraulib import JungfrauCtrl\n",
     "from cal_tools.step_timing import StepTimer\n",
     "from cal_tools.tools import calcat_creation_time"
@@ -389,13 +389,13 @@
     "        offset_map = const_data[da][\"Offset10Hz\"]\n",
     "        run_folder = in_folder / f'r{run:04d}'\n",
     "        ## Offset subtraction & Histogram filling\n",
-    "        ### performs offset subtraction andchunk_trains fills the histogram\n",
+    "        ### performs offset subtraction and chunk_trains fills the histogram\n",
     "        ### looping on individual files because single photon runs can be very large\n",
-    "        for dc_chunk in RunDirectory(\n",
-    "            run_folder, include=f\"*{da}*\"\n",
-    "        ).split_trains(trains_per_part=chunked_trains):\n",
+    "        chunks_list = list(RunDirectory(\n",
+    "            run_folder, include=f\"*{da}*\").split_trains(trains_per_part=chunked_trains))\n",
+    "        print(f\"Processing raw data and filling histogram in {len(chunks_list)} chunks\")\n",
     "\n",
-    "            trains = dc_chunk.get_array(det_src, 'data.trainId')\n",
+    "        for dc_chunk in chunks_list:\n",
     "            memcells = dc_chunk.get_array(\n",
     "                det_src,\n",
     "                'data.memoryCell',\n",
@@ -404,31 +404,28 @@
     "            adc = dc_chunk.get_array(det_src, 'data.adc', extra_dims=extra_dims).astype(np.float32)\n",
     "            gain = dc_chunk.get_array(det_src, 'data.gain', extra_dims=extra_dims)\n",
     "\n",
-    "            # gain = gain.where(gain < 2, other = 2).astype(np.int16)\n",
     "            step_timer.start()\n",
-    "\n",
+    "            # gain = gain.where(gain < 2, other = 2).astype(np.int16)\n",
     "            # Allocate shared arrays for corrected data. Used in `correct_train()`\n",
     "            data_corr = context.alloc(shape=adc.shape, dtype=np.float32)\n",
     "            context.map(correct_train, adc)\n",
     "            step_timer.done_step(\"correct_train\")\n",
+    "\n",
     "            step_timer.start()\n",
     "            chunks = jungfrau_ff.chunk_multi(data_corr, block_size)\n",
+    "            ch_inp = [(c, h_bins) for c in chunks]\n",
+    "            with multiprocessing.Pool(processes=min(n_cpus, len(chunks))) as pool:\n",
+    "                results = pool.starmap(jungfrau_ff.fill_histogram, ch_inp)\n",
     "\n",
-    "            with multiprocessing.Pool() as pool:\n",
+    "            for i, (h_chk, e_chk) in enumerate(results):\n",
+    "                if edges[da] is None:\n",
+    "                    edges[da] = e_chk\n",
     "\n",
-    "                partial_fill = partial(fill_histogram, h_bins)\n",
-    "                r_maps = pool.map(partial_fill, chunks)\n",
+    "                n_blocks_col = int(h_spectra[da].shape[-1]/block_size[1])\n",
+    "                irow = int(np.floor(i/n_blocks_col)) * block_size[0]\n",
+    "                icol = i%n_blocks_col * block_size[1]\n",
+    "                h_spectra[da][..., irow:irow+block_size[0], icol:icol+block_size[1]] += h_chk\n",
     "\n",
-    "                for i, r in enumerate(r_maps):\n",
-    "                    h_chk, e_chk = r\n",
-    "                    if edges[da] is None:\n",
-    "                        edges[da] = np.array(e_chk)\n",
-    "\n",
-    "                    n_blocks_col = int(h_spectra[da].shape[-1]/block_size[1])\n",
-    "                    irow = int(np.floor(i/n_blocks_col)) * block_size[0]\n",
-    "                    icol = i%n_blocks_col * block_size[1]\n",
-    "\n",
-    "                    h_spectra[da][..., irow:irow+block_size[0], icol:icol+block_size[1]] += h_chk\n",
     "            step_timer.done_step(\"Histogram created\")"
    ]
   },
@@ -458,7 +455,7 @@
     "for da in karabo_da:\n",
     "    # transpose h_spectra for the following cells.\n",
     "    h_spectra[da] = h_spectra[da]\n",
-    "    fout_h_path = f'{out_folder}/R{runs[0]:04d}_{proposal.upper()}_Gain_Spectra_{da}_Histo.h5'\n",
+    "    fout_h_path = out_folder/ f\"R{runs[0]:04d}_{proposal.upper()}_Gain_Spectra_{da}_Histo.h5\"\n",
     "    hists = h_spectra[da]\n",
     "    with h5file(fout_h_path, 'w') as fout_h:\n",
     "        print(f\"Saving histograms at {fout_h_path}.\")\n",
@@ -475,6 +472,39 @@
     "        fout_h.attrs[\"creation_time\"] = str(creation_time)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_and_fill_map(r_maps, shape, dtype=np.float32):\n",
+    "\n",
+    "    g0_map = np.zeros(shape, dtype=dtype)\n",
+    "    sigma_map = np.zeros(shape, dtype=dtype)\n",
+    "    chi2ndf_map = np.zeros(shape, dtype=dtype)\n",
+    "    alpha_map = np.zeros(shape, dtype=dtype)\n",
+    "    n_blocks_col = int(shape[-1] / block_size[1])\n",
+    "\n",
+    "    for i, (g0_chk, sigma_chk, chi2ndf_chk, alpha_chk) in tqdm(enumerate(r_maps)):\n",
+    "\n",
+    "        irow = int(np.floor(i / n_blocks_col)) * block_size[0]\n",
+    "        icol = i % n_blocks_col * block_size[1]\n",
+    "\n",
+    "        slice_obj = (\n",
+    "            ...,\n",
+    "            slice(irow, irow + block_size[0]),\n",
+    "            slice(icol, icol + block_size[1])\n",
+    "        )\n",
+    "\n",
+    "        g0_map[slice_obj] = g0_chk\n",
+    "        sigma_map[slice_obj] = sigma_chk\n",
+    "        chi2ndf_map[slice_obj] = chi2ndf_chk\n",
+    "        alpha_map[slice_obj] = alpha_chk\n",
+    "\n",
+    "    return g0_map, sigma_map, chi2ndf_map, alpha_map"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -511,41 +541,26 @@
     "        const_data[da][\"Noise10Hz\"],\n",
     "    )\n",
     "    print(\"starting spectra fit\")\n",
+    "\n",
     "    step_timer.start()\n",
     "    with multiprocessing.Pool() as pool:\n",
     "        r_maps = pool.map(partial_fit, chunks)\n",
     "    step_timer.done_step(\"r_maps calculation\")\n",
     "\n",
     "    step_timer.start()\n",
-    "    g0_map = np.zeros((memory_cells, sensor_size[0], sensor_size[1]), dtype=np.float32)\n",
-    "    sigma_map = np.zeros((memory_cells, sensor_size[0], sensor_size[1]), dtype=np.float32)\n",
-    "    chi2ndf_map = np.zeros((memory_cells, sensor_size[0], sensor_size[1]), dtype=np.float32)\n",
-    "    alpha_map = np.zeros((memory_cells, sensor_size[0], sensor_size[1]), dtype=np.float32)\n",
-    "    for i, r in enumerate(r_maps):\n",
-    "        g0_chk, sigma_chk, chi2ndf_chk, alpha_chk = r\n",
-    "\n",
-    "        n_blocks_col = int(g0_map.shape[-1] / block_size[1])\n",
-    "        irow = int(np.floor(i / n_blocks_col)) * block_size[0]\n",
-    "        icol = i % n_blocks_col * block_size[1]\n",
-    "\n",
-    "        g0_map[..., irow : irow + block_size[0], icol : icol + block_size[1]] = g0_chk\n",
-    "        sigma_map[..., irow : irow + block_size[0], icol : icol + block_size[1]] = sigma_chk\n",
-    "        chi2ndf_map[\n",
-    "            ..., irow : irow + block_size[0], icol : icol + block_size[1]\n",
-    "        ] = chi2ndf_chk\n",
-    "        alpha_map[..., irow : irow + block_size[0], icol : icol + block_size[1]] = alpha_chk\n",
+    "    map_shape = (memory_cells, sensor_size[0], sensor_size[1])\n",
+    "    g0_map, sigma_map, chi2ndf_map, alpha_map = create_and_fill_map(r_maps, map_shape)\n",
     "    step_timer.done_step(\"Finished fitting\")\n",
-    "    fout_path = f\"{out_folder}/{fout_temp.format(da)}\"\n",
     "\n",
     "    step_timer.start()\n",
+    "    fout_path = out_folder / fout_temp.format(da)\n",
     "    with h5file(fout_path, \"w\") as fout:\n",
     "        dset_chi2 = fout.create_dataset(\"chi2map\", data=np.transpose(chi2ndf_map))\n",
     "        dset_gmap_fit = fout.create_dataset(\"gainMap_fit\", data=np.transpose(g0_map))\n",
     "        dset_std = fout.create_dataset(\"sigmamap\", data=np.transpose(sigma_map))\n",
     "        dset_alpha = fout.create_dataset(\"alphamap\", data=np.transpose(alpha_map))\n",
-    "        dset_noi = fout.create_dataset(\n",
-    "            \"noise_map\",\n",
-    "            data=const_data[da][\"Noise10Hz\"])\n",
+    "        dset_noi = fout.create_dataset(\"noise_map\", data=const_data[da][\"Noise10Hz\"])\n",
+    "\n",
     "        fout.attrs[\"memory_cells\"] = memory_cells\n",
     "        fout.attrs[\"integration_time\"] = integration_time\n",
     "        fout.attrs[\"bias_voltage\"] = bias_voltage\n",
diff --git a/src/cal_tools/jungfrau/jungfrau_ff.py b/src/cal_tools/jungfrau/jungfrau_ff.py
index 101052345..e43b1272d 100644
--- a/src/cal_tools/jungfrau/jungfrau_ff.py
+++ b/src/cal_tools/jungfrau/jungfrau_ff.py
@@ -56,38 +56,41 @@ def chunk_multi(data, block_size):
     return chunks
 
 
-def subtract_offset(offset_map, _inp):
+def fill_histogram(imgs, h_bins):
     """
-    Perform offset subtraction on raw data.
-    Args:
-        offset_map (xarray, float): map with offset constants,
-            with shape (3, memory_cells, row, col).
-        _inp (list): input data as:
-            * _inp[0]: raw images, with shape (trains, memory_cells, row, col).
-            * _inp[1]: gain bit map, with shape (trains, memory_cells, row, col).
-    Return: offset subtracted images.
-    """
-    imgs = _inp[0]
-    gbit = _inp[1]
+    Fills an histogram with shape
+    (n_bins-1, memory_cells, n_rows, n_cols)
+    from input images.
 
-    n_cells = imgs.shape[1]
+    Args:
+        h_bins (list, float): the binning of the x-axis
+        imgs (np.ndarray): image data to histogram
+            (trains, memory_cells, row, col).
 
-    for cell in range(n_cells):
-        this_cell_gbit = gbit.sel(cell=cell)
+    Returns: histogram bin counts, bin edges.
+    """
+    if not isinstance(imgs, np.ndarray):
+        raise TypeError("Expected imgs numpy ndarray type.")
 
-        this_cell_off = offset_map[:, cell]
+    if imgs.ndim < 4:
+        raise ValueError("Expected 4D imgs array.")
 
-        _o = np.choose(
-            this_cell_gbit, (
-                np.expand_dims(this_cell_off[0], axis=0),
-                np.expand_dims(this_cell_off[1], axis=0),
-                np.expand_dims(this_cell_off[2], axis=0)
-            )
-        )
+    n_cells, n_rows, n_cols = imgs.shape[1:]
 
-        imgs.loc[dict(cell=cell)] -= _o
+    h_chk = np.zeros(
+        (len(h_bins)-1, n_cells, n_rows, n_cols),
+        dtype=np.int32)
 
-    return imgs
+    e_chk = None
+    for cell in range(n_cells):
+        for row in range(n_rows):
+            for col in range(n_cols):
+                this_pix = imgs[:, cell, row, col]
+                _h, _e = np.histogram(this_pix, bins=h_bins)
+                h_chk[..., cell, row, col] += _h
+                if e_chk is None:
+                    e_chk = np.array(_e)
+    return h_chk, e_chk
 
 
 # peak finder to find the first photon peak and/or the pedestal peak
@@ -391,23 +394,25 @@ def rebin_histo(h, x, r):
     return h_out, x_out
 
 
-def fit_histogram(x, _fit_func, n_sigma, rebin, ratio, noise, _inp):
+def fit_histogram(x, _fit_func, n_sigma, rebin, ratio, noise, histo):
     """
     wrap around function for fitting of histogram
     
-    arguments:
-    * x (list, float): - bin centers along x
-    * _fit_func (string): which function to use for fit
-         * CHARGE_SHARING: single peak with charge sharing tail
-         * CHARGE_SHARING_2: sum of two CHARGE_SHARING
-         * GAUSS: gaussian function
-    * n_sigma (int): to calculate threshold of the peak finder as thr = n_sigma * sigma0
-    * ratio (float): ratio parameter of the peak finder
-    * _input (list): contains the data to preform the fit
-         * _input[0]: histogram bin counts with shape (n_bins, memory_cells, row, col)
-         * _input[1]: noise map with shape (3, memory_cells, row, col)
-         
-    returns: map of peak values, map of peak variances, map of chi2/ndf, map of charge sharing parameter values
+    Args:
+        x (list, float): - bin centers along x
+        _fit_func (string): which function to use for fit
+            - CHARGE_SHARING: single peak with charge sharing tail
+            - CHARGE_SHARING_2: sum of two CHARGE_SHARING
+            - GAUSS: gaussian function
+        n_sigma (int): to calculate threshold of the peak finder as thr = n_sigma * sigma0
+        ratio (float): ratio parameter of the peak finder
+        histo (ndarray): histogram bin counts with shape (n_bins, memory_cells, row, col)
+
+    Returns:
+        - map of peak values
+        - map of peak variances
+        - map of chi2/ndf
+        - map of charge sharing parameter values
     """
     _funcs = dict(
         CHARGE_SHARING=fit_charge_sharing, 
@@ -416,7 +421,6 @@ def fit_histogram(x, _fit_func, n_sigma, rebin, ratio, noise, _inp):
     )
     fit_func = _funcs[_fit_func]
 
-    histo = _inp[0]
     n_cells, n_rows, n_cols = histo.shape[1:]
 
     sigma0 = 15.
diff --git a/src/cal_tools/jungfrau/jungfraualgs.pyx b/src/cal_tools/jungfrau/jungfraualgs.pyx
deleted file mode 100644
index 65ffa1fc9..000000000
--- a/src/cal_tools/jungfrau/jungfraualgs.pyx
+++ /dev/null
@@ -1,27 +0,0 @@
-import numpy as np
-cimport numpy as np
-cimport cython
-
-@cython.boundscheck(False)
-@cython.wraparound(False)
-def fill_histogram(np.ndarray[np.float32_t, ndim=1] h_bins, np.ndarray[np.float32_t, ndim=4] imgs):
-    cdef int n_trains = imgs.shape[0]
-    cdef int n_cells = imgs.shape[1]
-    cdef int n_rows = imgs.shape[2]
-    cdef int n_cols = imgs.shape[3]
-    cdef int n_bins = len(h_bins) - 1
-
-    cdef np.ndarray[np.int32_t, ndim=4] h_out = np.zeros((n_bins, n_cells, n_rows, n_cols), dtype=np.int32)
-    cdef int train, cell, row, col, bin
-    cdef double value
-
-    for cell in range(n_cells):
-        for row in range(n_rows):
-            for col in range(n_cols):
-                for train in range(n_trains):
-                    value = imgs[train, cell, row, col]
-                    for bin in range(n_bins):
-                        if h_bins[bin] <= value < h_bins[bin + 1]:
-                            h_out[bin, cell, row, col] += 1
-                            break
-    return h_out, h_bins
\ No newline at end of file
diff --git a/tests/test_jungfrau_ff.py b/tests/test_jungfrau_ff.py
index 9159086ca..b464e9e99 100644
--- a/tests/test_jungfrau_ff.py
+++ b/tests/test_jungfrau_ff.py
@@ -5,9 +5,10 @@ import pytest
 
 from cal_tools.jungfrau.jungfrau_ff import (
     chunk_multi,
+    fill_histogram,
     _peak_position,
 )
-from cal_tools.jungfrau.jungfraualgs import fill_histogram
+
 
 def test_peak_detection_correctness():
     x = np.array([10, 20, 30, 40, 50, 60, 70, 80])
@@ -54,7 +55,7 @@ def sample_chunk_16_cells():
 
 def test_fill_histogram_basic(sample_chunk):
     h_bins = np.linspace(0, 1000, 101, dtype=np.float32)  # 100 bins
-    hist, edges = fill_histogram(h_bins, sample_chunk)
+    hist, edges = fill_histogram(sample_chunk, h_bins)
     
     assert hist.shape == (100, 1, 256, 64)
     assert edges.shape == (101,)
@@ -64,7 +65,7 @@ def test_fill_histogram_basic(sample_chunk):
 
 def test_fill_histogram_16_cells(sample_chunk_16_cells):
     h_bins = np.linspace(0, 1000, 101, dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, sample_chunk_16_cells)
+    hist, _ = fill_histogram(sample_chunk_16_cells, h_bins)
     
     assert hist.shape == (100, 16, 256, 64)
     assert np.sum(hist) == np.prod(sample_chunk_16_cells.shape)
@@ -73,7 +74,7 @@ def test_fill_histogram_16_cells(sample_chunk_16_cells):
 def test_fill_histogram_single_train():
     chunk = np.random.rand(1, 1, 256, 64).astype(np.float32)
     h_bins = np.linspace(0, 100, 11, dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, chunk)
+    hist, _ = fill_histogram(chunk, h_bins)
     
     assert hist.shape == (10, 1, 256, 64)
     assert np.sum(hist) == np.prod(chunk.shape)
@@ -82,7 +83,7 @@ def test_fill_histogram_single_train():
 def test_fill_histogram_single_bin():
     chunk = np.ones((10, 1, 256, 64), dtype=np.float32)
     h_bins = np.array([0, 2], dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, chunk)
+    hist, _ = fill_histogram(chunk, h_bins)
     
     assert hist.shape == (1, 1, 256, 64)
     assert np.all(hist == 10)
@@ -91,7 +92,7 @@ def test_fill_histogram_single_bin():
 def test_fill_histogram_float_data():
     chunk = np.random.rand(50, 1, 256, 64).astype(np.float32)
     h_bins = np.linspace(0, 1, 11, dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, chunk)
+    hist, _ = fill_histogram(chunk, h_bins)
     
     assert hist.shape == (10, 1, 256, 64)
     assert np.sum(hist) == np.prod(chunk.shape)
@@ -100,7 +101,7 @@ def test_fill_histogram_float_data():
 def test_fill_histogram_out_of_range():
     chunk = np.random.rand(100, 1, 256, 64).astype(np.float32)
     h_bins = np.linspace(0, 100, 11, dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, chunk)
+    hist, _ = fill_histogram(chunk, h_bins)
     
     assert hist.shape == (10, 1, 256, 64)
     assert np.sum(hist) <= np.prod(chunk.shape)
@@ -114,7 +115,7 @@ def test_fill_histogram_out_of_range():
 def test_fill_histogram_various_shapes(shape):
     chunk = np.random.rand(*shape).astype(np.float32)
     h_bins = np.linspace(0, 1000, 101, dtype=np.float32)
-    hist, _ = fill_histogram(h_bins, chunk)
+    hist, _ = fill_histogram(chunk, h_bins)
     
     assert hist.shape == (100, *shape[1:])
     assert np.sum(hist) == np.prod(shape)
-- 
GitLab