From 39ab37d34da237992c5ef2c89a94c65bee32d739 Mon Sep 17 00:00:00 2001
From: David Hammer <david.hammer@xfel.eu>
Date: Wed, 5 Apr 2023 21:23:29 +0200
Subject: [PATCH] Enable GPU acceleration of the BOZ correction determination

---
 ...is part I.a Correction determination.ipynb |  32 ++--
 src/toolbox_scs/routines/boz.py               | 137 +++++++++++++-----
 2 files changed, 122 insertions(+), 47 deletions(-)

diff --git a/doc/BOZ analysis part I.a Correction determination.ipynb b/doc/BOZ analysis part I.a Correction determination.ipynb
index 9eec20b..f36c2b5 100644
--- a/doc/BOZ analysis part I.a Correction determination.ipynb	
+++ b/doc/BOZ analysis part I.a Correction determination.ipynb	
@@ -172,6 +172,16 @@
     "params.dask_load_persistently()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dask.config.set(scheduler=\"single-threaded\")\n",
+    "params.use_gpu()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -212,7 +222,7 @@
     }
    ],
    "source": [
-    "pedestal = np.mean(dark)\n",
+    "pedestal = boz.ensure_on_host(np.mean(dark))\n",
     "pedestal"
    ]
   },
@@ -236,7 +246,8 @@
    ],
    "source": [
     "mean_th = (pedestal-12, pedestal+15)\n",
-    "f = boz.inspect_dark(params.arr_dark, mean_th=mean_th)\n",
+    "f = boz.inspect_dark(boz.ensure_on_host(params.arr_dark),\n",
+    "                     mean_th=mean_th)\n",
     "f.suptitle(f'p:{params.proposal} d:{params.darkrun}')\n",
     "fname = path + prefix + '-inspect_dark.png'\n",
     "f.savefig(fname, dpi=300) "
@@ -308,8 +319,8 @@
    "outputs": [],
    "source": [
     "data = boz.average_module(params.arr, dark=dark).compute()\n",
-    "pp = data.mean(axis=(1,2)) # pulseId resolved mean\n",
-    "dataM = data.mean(axis=0) # mean over pulseId"
+    "pp = boz.ensure_on_host(data.mean(axis=(1,2))) # pulseId resolved mean\n",
+    "dataM = boz.ensure_on_host(data.mean(axis=0)) # mean over pulseId"
    ]
   },
   {
@@ -396,9 +407,10 @@
     }
    ],
    "source": [
-    "h, f = boz.inspect_histogram(params.arr,\n",
-    "                             params.arr_dark,\n",
-    "                             mask=params.get_mask() #, extra_lines=True\n",
+    "h, f = boz.inspect_histogram(boz.ensure_on_host(params.arr),\n",
+    "                             boz.ensure_on_host(params.arr_dark),\n",
+    "                             mask=boz.ensure_on_host(params.get_mask())\n",
+    "                             #, extra_lines=True\n",
     "                         )\n",
     "f.suptitle(f'p:{params.proposal} r:{params.run} d:{params.darkrun}')\n",
     "f.savefig(path+prefix+'-histogram.png', dpi=300)"
@@ -495,7 +507,7 @@
     "res = boz.average_module(params.arr, dark=dark,\n",
     "                         ret='mean', mask=params.get_mask(), sat_roi=params.rois['sat'],\n",
     "                         sat_level=params.sat_level)\n",
-    "avg = res.compute().mean(axis=0)"
+    "avg = res.mean(axis=0).compute()"
    ]
   },
   {
@@ -642,7 +654,7 @@
    ],
    "source": [
     "ff = boz.compute_flat_field_correction(params.rois, params)\n",
-    "f = boz.inspect_plane_fitting(avg/ff, params.rois)\n",
+    "f = boz.inspect_plane_fitting(boz.ensure_on_host(avg)/ff, params.rois)\n",
     "f.savefig(path+prefix+'-inspect-withflatfield-refined.png', dpi=300)"
    ]
   },
@@ -1505,7 +1517,7 @@
     }
    ],
    "source": [
-    "f = boz.inspect_Fnl(params.get_Fnl())"
+    "f = boz.inspect_Fnl(boz.ensure_on_host(params.get_Fnl()))"
    ]
   },
   {
diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py
index 33a936d..856751e 100644
--- a/src/toolbox_scs/routines/boz.py
+++ b/src/toolbox_scs/routines/boz.py
@@ -23,6 +23,12 @@ from extra_geom import DSSC_1MGeometry
 
 from toolbox_scs.routines.XAS import xas
 
+try:
+    import cupy as cp
+    _can_use_gpu = True
+except ModuleNotFoundError:
+    _can_use_gpu = False
+
 __all__ = [
     'parameters',
     'get_roi_pixel_pos',
@@ -79,6 +85,7 @@ class parameters():
         self.pixel_pos = _get_pixel_pos(self.module)
         self.gain = gain
 
+        self.mask = None
         self.mask_idx = None
         self.mean_th = (None, None)
         self.std_th = (None, None)
@@ -93,6 +100,8 @@ class parameters():
         self.ff_alpha = None
         self.ff_max_iter = None
 
+        self._using_gpu = False
+
         self.Fnl = None
         self.nl_alpha = None
         self.sat_level = None
@@ -115,6 +124,29 @@ class parameters():
         self.arr = self.arr.rechunk(('auto', -1, -1, -1))
         self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1))
 
+    def use_gpu(self):
+        assert _can_use_gpu, 'Failed to import cupy'
+        gpu_mem_gb = cp.cuda.Device().mem_info[1] / 2**30
+        if gpu_mem_gb < 30:
+            print(f'Warning: GPU memory ({gpu_mem_gb}GB) may be insufficient')
+        if self._using_gpu:
+            return
+        assert (
+            self.arr is not None and
+            self.arr_dark is not None
+        ), "Must load data before switching to GPU"
+        if self.mask is not None:
+            self.mask = cp.array(self.mask)
+        # moving full data to GPU
+        limit = 2**30
+        self.arr = da.array(
+            cp.array(self.arr.compute())
+        ).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
+        self.arr_dark = da.array(
+            cp.array(self.arr_dark.compute())
+        ).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
+        self._using_gpu = True
+
     def set_mask(self, arr):
         """Set mask of bad pixels.
 
@@ -133,6 +165,9 @@ class parameters():
                 mask[k[0], k[1]] = False
             self.mask = mask
 
+        if self._using_gpu:
+            self.mask = cp.array(self.mask)
+
     def get_mask(self):
         """Get the boolean array bad pixel of a DSSC module."""
         return self.mask
@@ -196,7 +231,10 @@ class parameters():
         if self.Fnl is None:
             return None
         else:
-            return np.array(self.Fnl)
+            if self._using_gpu:
+                return cp.array(self.Fnl)
+            else:
+                return np.array(self.Fnl)
 
     def save(self, path='./'):
         """Save the parameters as a JSON file.
@@ -300,6 +338,14 @@ class parameters():
 
         return f
 
+def ensure_on_host(arr):
+    # load data back from GPU - if it was on GPU
+    if hasattr(arr, "__cuda_array_interface__"):  # avoid importing CuPy
+        return arr.get()
+    elif isinstance(arr, (da.Array,)):
+        return arr.map_blocks(ensure_on_host)
+    return arr
+
 # Hexagonal pixels related function
 
 def _get_pixel_pos(module):
@@ -451,8 +497,8 @@ def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
     fig: matplotlib figure
     """
     # compute mean and std
-    dark_mean = arr.mean(axis=(0, 1)).compute()
-    dark_std = arr.std(axis=(0, 1)).compute()
+    dark_mean = ensure_on_host(arr.mean(axis=(0, 1)).compute())
+    dark_std = ensure_on_host(arr.std(axis=(0, 1)).compute())
 
     fig = plt.figure(figsize=(7, 2.7))
     gs = fig.add_gridspec(2, 4)
@@ -542,13 +588,13 @@ def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
 
     f = plt.figure(figsize=(6, 3))
     ax = plt.gca()
-    h = histogram_module(arr, mask=mask)
+    h = ensure_on_host(histogram_module(arr, mask=mask))
     Sum_h = np.sum(h)
     ax.plot(np.arange(2**9), h/Sum_h, marker='o',
             ms=3, markerfacecolor='none', lw=1)
 
     if arr_dark is not None:
-        hd = histogram_module(arr_dark, mask=mask)
+        hd = ensure_on_host(histogram_module(arr_dark, mask=mask))
         Sum_hd = np.sum(hd)
         ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
                 ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
@@ -805,7 +851,8 @@ def compute_flat_field_correction(rois, params, plot=False):
 
     if plot:
         f, ax = plt.subplots(1, 1, figsize=(6, 2))
-        img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r')
+        img = ax.pcolormesh(
+            np.flipud(flat_field[:, :256]), cmap='Greys_r')
         f.colorbar(img, ax=[ax], label='amplitude')
         ax.set_xlabel('px')
         ax.set_ylabel('px')
@@ -1086,9 +1133,9 @@ def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
     """
     params.set_flat_field(p)
     ff = compute_flat_field_correction(rois, params)
-    
-    data = process(np.arange(2**9), arr_dark, arr, tid, rois, mask, ff,
-        sat_level)
+
+    data = process(None, arr_dark, arr, tid, rois, mask, ff,
+        sat_level, params._using_gpu)
 
     # drop saturated shots
     d = data.where(data['sat_sat'] == False, drop=True)
@@ -1215,7 +1262,7 @@ def nl_lut(domain, dy):
 
 
 def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
-    sat_level=511):
+    sat_level=511, use_gpu=False):
     """Criteria for the non linear correction.
 
     Inputs
@@ -1239,8 +1286,8 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
     of the deviation from the ideal detector response.
     """
     Fmodel = nl_lut(domain, p)
-    data = process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field,
-        sat_level)
+    data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
+        arr, tid, rois, mask, flat_field, sat_level, use_gpu)
 
     # drop saturated shots
     d = data.where(data['sat_sat'] == False, drop=True)
@@ -1290,7 +1337,7 @@ def nl_fit(params, domain):
     ff = compute_flat_field_correction(params.rois, params)
 
     fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), ff, params.sat_level)
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
 
     def fit_callback(x):
         if not hasattr(fit_callback, "counter"):
@@ -1469,13 +1516,15 @@ def inspect_correction(params, gain=None):
     if Fnl is None:
         Fnl = np.arange(2**9)
 
+    xp = np if not params._using_gpu else cp
     # compute all levels of correction
-    data = process(np.arange(2**9), params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), np.ones_like(ff), params.sat_level)
-    data_ff = process(np.arange(2**9), params.arr_dark, params.arr, params.tid,
-        fitrois, params.get_mask(), ff, params.sat_level)
-    data_ff_nl = process(Fnl, params.arr_dark, params.arr,
-        params.tid, fitrois, params.get_mask(), ff, params.sat_level)
+    data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
+        params._using_gpu)
+    data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
+    data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
+        fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
 
     # for conversion to nb of photons
     if gain is None:
@@ -1632,7 +1681,7 @@ def average_module(arr, dark=None, ret='mean',
     """
     # F_INL
     if F_INL is not None:
-        narr = arr.map_blocks(lambda x: F_INL[x])
+        narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
     else:
         narr = arr
 
@@ -1640,11 +1689,25 @@ def average_module(arr, dark=None, ret='mean',
         narr = narr*mask
 
     if sat_roi is not None:
-        temp = (da.logical_not(da.any(
-                narr[:, :, sat_roi['yl']:sat_roi['yh'],
-                           sat_roi['xl']:sat_roi['xh']] >= sat_level,
-                               axis=[2, 3], keepdims=True)))
-        not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3)
+        not_sat = da.repeat(
+            da.repeat(
+                da.all(
+                    narr[
+                        :,
+                        :,
+                        sat_roi["yl"] : sat_roi["yh"],
+                        sat_roi["xl"] : sat_roi["xh"],
+                    ]
+                    < sat_level,
+                    axis=[2, 3],
+                    keepdims=True,
+                ),
+                128,
+                axis=2,
+            ),
+            512,
+            axis=3,
+        )
 
     if dark is not None:
         narr = narr - dark
@@ -1661,7 +1724,7 @@ def average_module(arr, dark=None, ret='mean',
 
 
 def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
-                   flat_field=None, F_INL=None):
+                   flat_field=None, F_INL=None, use_gpu=False):
     """Process one module and extract roi intensity.
 
     Inputs
@@ -1682,7 +1745,7 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
     """
     # F_INL
     if F_INL is not None:
-        narr = arr.map_blocks(lambda x: F_INL[x])
+        narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
     else:
         narr = arr
 
@@ -1711,6 +1774,9 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
     #     temp[::2, :, :] = flat_field[:, :]
     #    flat_field = temp
 
+    if use_gpu and flat_field is not None:
+        flat_field = cp.asarray(flat_field)
+
     # compute dark corrected ROI values
     v = {}
     for n in rois.keys():
@@ -1732,9 +1798,9 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
     dims = ['trainId', 'pulseId']
     r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
     for n in rois.keys():
-        res[n + '_sat'] = xr.DataArray(r_sat[n][:, :],
+        res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]),
                                        coords=r_coords, dims=dims)
-        res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims)
+        res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims)
 
     for n in rois.keys():
         roi = rois[n]
@@ -1744,7 +1810,8 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
     return res
 
 
-def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
+def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511,
+        use_gpu=False):
     """Process dark and run data with corrections.
 
     Inputs
@@ -1762,15 +1829,11 @@ def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511):
     roi extracted intensities
     """
     # dark process
-    res = average_module(arr_dark, F_INL=Fmodel)
-    dark = res.compute()
+    dark = average_module(arr_dark, F_INL=Fmodel).compute()
 
     # data process
-    proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
-                          flat_field=flat_field, F_INL=Fmodel)
-    data = proc.compute()
-
-    return data
+    return process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
+      flat_field=flat_field, F_INL=Fmodel, use_gpu=use_gpu).compute()
 
 def inspect_saturation(data, gain, Nbins=200):
     """Plot roi integrated histogram of the data with saturation
-- 
GitLab