From 078673bc09b22a25427a58eba9580dd641be8ebf Mon Sep 17 00:00:00 2001 From: ahmedk <karim.ahmed@xfel.eu> Date: Fri, 16 Aug 2024 22:05:15 +0200 Subject: [PATCH] feat: Improve fill_histogram into a cython function with 13.83x speed up for burst mode - Original function Single Cell: Average time: 1.879472 seconds Min time: 1.874208 seconds Max time: 1.882767 seconds - Cython function Single Cell: Average time: 0.603006 seconds Min time: 0.602453 seconds Max time: 0.603634 seconds Speedup: 3.12x - Original function Burst Mode: Average time: 17.192991 seconds Min time: 17.174598 seconds Max time: 17.211236 seconds - Cython function Burst Mode: Average time: 1.242232 seconds Min time: 1.230240 seconds Max time: 1.258492 seconds Speedup: 13.84x --- ...ngfrau_Create_Fit_Spectra_Histos_NBC.ipynb | 5 +- setup.py | 5 ++ src/cal_tools/jungfrau/jungfrau_ff.py | 38 -------------- src/cal_tools/jungfrau/jungfraualgs.pyx | 27 ++++++++++ tests/test_jungfrau_ff.py | 50 ++++++++----------- 5 files changed, 56 insertions(+), 69 deletions(-) create mode 100644 src/cal_tools/jungfrau/jungfraualgs.pyx diff --git a/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb index 9cbc44689..14d3be4db 100644 --- a/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb +++ b/notebooks/Jungfrau/Jungfrau_Create_Fit_Spectra_Histos_NBC.ipynb @@ -97,6 +97,7 @@ "import cal_tools.restful_config as rest_cfg\n", "from cal_tools.calcat_interface import JUNGFRAU_CalibrationData\n", "from cal_tools.jungfrau import jungfrau_ff\n", + "from cal_tools.jungfrau.jungfraualgs import fill_histogram\n", "from cal_tools.jungfrau.jungfraulib import JungfrauCtrl\n", "from cal_tools.step_timing import StepTimer\n", "from cal_tools.tools import calcat_creation_time" @@ -310,7 +311,7 @@ "outputs": [], "source": [ "h_n_bins = int((h_range[1] - h_range[0])/h_bins_s)\n", - "h_bins = np.linspace(h_range[0], h_range[1], h_n_bins)\n", + "h_bins = np.linspace(h_range[0], h_range[1], h_n_bins, dtype=np.float32)\n", "\n", "h_spectra = dict()\n", "edges = dict()\n", @@ -415,7 +416,7 @@ "\n", " with multiprocessing.Pool() as pool:\n", "\n", - " partial_fill = partial(jungfrau_ff.fill_histogram, h_bins)\n", + " partial_fill = partial(fill_histogram, h_bins)\n", " r_maps = pool.map(partial_fill, chunks)\n", "\n", " for i, r in enumerate(r_maps):\n", diff --git a/setup.py b/setup.py index b8ce52d6d..705fef5e8 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,11 @@ ext_modules = [ ["src/cal_tools/gotthard2/gotthard2algs.pyx"], include_dirs=[numpy.get_include()], ), + Extension( + "cal_tools.jungfrau.jungfraualgs", + ["src/cal_tools/jungfrau/jungfraualgs.pyx"], + include_dirs=[numpy.get_include()], + ), ] diff --git a/src/cal_tools/jungfrau/jungfrau_ff.py b/src/cal_tools/jungfrau/jungfrau_ff.py index b442166ef..101052345 100644 --- a/src/cal_tools/jungfrau/jungfrau_ff.py +++ b/src/cal_tools/jungfrau/jungfrau_ff.py @@ -90,44 +90,6 @@ def subtract_offset(offset_map, _inp): return imgs -def fill_histogram(h_bins, imgs): - """ - Fills an histogram with shape - (n_bins-1, memory_cells, n_rows, n_cols) - from input images. - - Args: - h_bins (list, float): the binning of the x-axis - imgs (np.ndarray): image data to histogram - (trains, memory_cells, row, col). - - Returns: histogram bin counts, bin edges. - """ - if not isinstance(imgs, np.ndarray): - raise TypeError("Expected imgs numpy ndarray type.") - - if imgs.ndim < 4: - raise ValueError("Expected 4D imgs array.") - - n_cells = imgs.shape[1] - n_rows = imgs.shape[2] - n_cols = imgs.shape[3] - - h_chk = np.zeros((len(h_bins)-1, n_cells, n_rows, n_cols), dtype=np.int32) - - e_chk = None - for cell in range(n_cells): - for row in range(n_rows): - for col in range(n_cols): - this_cell = imgs[:, cell, ...] - this_pix = this_cell[:, row, col] - _h, _e = np.histogram(this_pix, bins=h_bins) - h_chk[..., cell, row, col] += _h - if e_chk is None: - e_chk = np.array(_e) - return h_chk, e_chk - - # peak finder to find the first photon peak and/or the pedestal peak # can/should be replaced with smthg more efficient diff --git a/src/cal_tools/jungfrau/jungfraualgs.pyx b/src/cal_tools/jungfrau/jungfraualgs.pyx new file mode 100644 index 000000000..65ffa1fc9 --- /dev/null +++ b/src/cal_tools/jungfrau/jungfraualgs.pyx @@ -0,0 +1,27 @@ +import numpy as np +cimport numpy as np +cimport cython + +@cython.boundscheck(False) +@cython.wraparound(False) +def fill_histogram(np.ndarray[np.float32_t, ndim=1] h_bins, np.ndarray[np.float32_t, ndim=4] imgs): + cdef int n_trains = imgs.shape[0] + cdef int n_cells = imgs.shape[1] + cdef int n_rows = imgs.shape[2] + cdef int n_cols = imgs.shape[3] + cdef int n_bins = len(h_bins) - 1 + + cdef np.ndarray[np.int32_t, ndim=4] h_out = np.zeros((n_bins, n_cells, n_rows, n_cols), dtype=np.int32) + cdef int train, cell, row, col, bin + cdef double value + + for cell in range(n_cells): + for row in range(n_rows): + for col in range(n_cols): + for train in range(n_trains): + value = imgs[train, cell, row, col] + for bin in range(n_bins): + if h_bins[bin] <= value < h_bins[bin + 1]: + h_out[bin, cell, row, col] += 1 + break + return h_out, h_bins \ No newline at end of file diff --git a/tests/test_jungfrau_ff.py b/tests/test_jungfrau_ff.py index 7f67b2325..9159086ca 100644 --- a/tests/test_jungfrau_ff.py +++ b/tests/test_jungfrau_ff.py @@ -5,10 +5,9 @@ import pytest from cal_tools.jungfrau.jungfrau_ff import ( chunk_multi, - fill_histogram, _peak_position, ) - +from cal_tools.jungfrau.jungfraualgs import fill_histogram def test_peak_detection_correctness(): x = np.array([10, 20, 30, 40, 50, 60, 70, 80]) @@ -45,16 +44,16 @@ def test_distance_constraint(): @pytest.fixture def sample_chunk(): - return np.random.randint(0, 1000, size=(100, 1, 256, 64)) # 100 trains, 1 cell, 256x64 chunk + return np.random.rand(100, 1, 256, 64).astype(np.float32) # 100 trains, 1 cell, 256x64 chunk @pytest.fixture def sample_chunk_16_cells(): - return np.random.randint(0, 1000, size=(100, 16, 256, 64)) # 100 trains, 16 cells, 256x64 chunk + return np.random.rand(100, 16, 256, 64).astype(np.float32) # 100 trains, 16 cells, 256x64 chunk def test_fill_histogram_basic(sample_chunk): - h_bins = np.linspace(0, 1000, 101) # 100 bins + h_bins = np.linspace(0, 1000, 101, dtype=np.float32) # 100 bins hist, edges = fill_histogram(h_bins, sample_chunk) assert hist.shape == (100, 1, 256, 64) @@ -64,44 +63,44 @@ def test_fill_histogram_basic(sample_chunk): def test_fill_histogram_16_cells(sample_chunk_16_cells): - h_bins = np.linspace(0, 1000, 101) - hist, edges = fill_histogram(h_bins, sample_chunk_16_cells) + h_bins = np.linspace(0, 1000, 101, dtype=np.float32) + hist, _ = fill_histogram(h_bins, sample_chunk_16_cells) assert hist.shape == (100, 16, 256, 64) assert np.sum(hist) == np.prod(sample_chunk_16_cells.shape) def test_fill_histogram_single_train(): - chunk = np.random.randint(0, 100, size=(1, 1, 256, 64)) - h_bins = np.linspace(0, 100, 11) - hist, edges = fill_histogram(h_bins, chunk) + chunk = np.random.rand(1, 1, 256, 64).astype(np.float32) + h_bins = np.linspace(0, 100, 11, dtype=np.float32) + hist, _ = fill_histogram(h_bins, chunk) assert hist.shape == (10, 1, 256, 64) assert np.sum(hist) == np.prod(chunk.shape) def test_fill_histogram_single_bin(): - chunk = np.ones((10, 1, 256, 64)) - h_bins = np.array([0, 2]) - hist, edges = fill_histogram(h_bins, chunk) + chunk = np.ones((10, 1, 256, 64), dtype=np.float32) + h_bins = np.array([0, 2], dtype=np.float32) + hist, _ = fill_histogram(h_bins, chunk) assert hist.shape == (1, 1, 256, 64) assert np.all(hist == 10) def test_fill_histogram_float_data(): - chunk = np.random.rand(50, 1, 256, 64) - h_bins = np.linspace(0, 1, 11) - hist, edges = fill_histogram(h_bins, chunk) + chunk = np.random.rand(50, 1, 256, 64).astype(np.float32) + h_bins = np.linspace(0, 1, 11, dtype=np.float32) + hist, _ = fill_histogram(h_bins, chunk) assert hist.shape == (10, 1, 256, 64) assert np.sum(hist) == np.prod(chunk.shape) def test_fill_histogram_out_of_range(): - chunk = np.random.randint(-10, 110, size=(100, 1, 256, 64)) - h_bins = np.linspace(0, 100, 11) - hist, edges = fill_histogram(h_bins, chunk) + chunk = np.random.rand(100, 1, 256, 64).astype(np.float32) + h_bins = np.linspace(0, 100, 11, dtype=np.float32) + hist, _ = fill_histogram(h_bins, chunk) assert hist.shape == (10, 1, 256, 64) assert np.sum(hist) <= np.prod(chunk.shape) @@ -113,21 +112,14 @@ def test_fill_histogram_out_of_range(): (200, 1, 128, 32) ]) def test_fill_histogram_various_shapes(shape): - chunk = np.random.randint(0, 1000, size=shape) - h_bins = np.linspace(0, 1000, 101) - hist, edges = fill_histogram(h_bins, chunk) + chunk = np.random.rand(*shape).astype(np.float32) + h_bins = np.linspace(0, 1000, 101, dtype=np.float32) + hist, _ = fill_histogram(h_bins, chunk) assert hist.shape == (100, *shape[1:]) assert np.sum(hist) == np.prod(shape) -def test_fill_histogram_input_validation(): - - # Test invalid chunk shape - with pytest.raises(ValueError): - fill_histogram(np.linspace(0, 100, 11), np.zeros((10, 256, 64))) - - @pytest.fixture def sample_data(): # 100 trains, 1 memory cell, 1024x512 image -- GitLab