From 39ab37d34da237992c5ef2c89a94c65bee32d739 Mon Sep 17 00:00:00 2001 From: David Hammer <david.hammer@xfel.eu> Date: Wed, 5 Apr 2023 21:23:29 +0200 Subject: [PATCH] Enable GPU acceleration of the BOZ correction determination --- ...is part I.a Correction determination.ipynb | 32 ++-- src/toolbox_scs/routines/boz.py | 137 +++++++++++++----- 2 files changed, 122 insertions(+), 47 deletions(-) diff --git a/doc/BOZ analysis part I.a Correction determination.ipynb b/doc/BOZ analysis part I.a Correction determination.ipynb index 9eec20b..f36c2b5 100644 --- a/doc/BOZ analysis part I.a Correction determination.ipynb +++ b/doc/BOZ analysis part I.a Correction determination.ipynb @@ -172,6 +172,16 @@ "params.dask_load_persistently()" ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "dask.config.set(scheduler=\"single-threaded\")\n", + "params.use_gpu()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -212,7 +222,7 @@ } ], "source": [ - "pedestal = np.mean(dark)\n", + "pedestal = boz.ensure_on_host(np.mean(dark))\n", "pedestal" ] }, @@ -236,7 +246,8 @@ ], "source": [ "mean_th = (pedestal-12, pedestal+15)\n", - "f = boz.inspect_dark(params.arr_dark, mean_th=mean_th)\n", + "f = boz.inspect_dark(boz.ensure_on_host(params.arr_dark),\n", + " mean_th=mean_th)\n", "f.suptitle(f'p:{params.proposal} d:{params.darkrun}')\n", "fname = path + prefix + '-inspect_dark.png'\n", "f.savefig(fname, dpi=300) " @@ -308,8 +319,8 @@ "outputs": [], "source": [ "data = boz.average_module(params.arr, dark=dark).compute()\n", - "pp = data.mean(axis=(1,2)) # pulseId resolved mean\n", - "dataM = data.mean(axis=0) # mean over pulseId" + "pp = boz.ensure_on_host(data.mean(axis=(1,2))) # pulseId resolved mean\n", + "dataM = boz.ensure_on_host(data.mean(axis=0)) # mean over pulseId" ] }, { @@ -396,9 +407,10 @@ } ], "source": [ - "h, f = boz.inspect_histogram(params.arr,\n", - " params.arr_dark,\n", - " mask=params.get_mask() #, extra_lines=True\n", + "h, f = boz.inspect_histogram(boz.ensure_on_host(params.arr),\n", + " boz.ensure_on_host(params.arr_dark),\n", + " mask=boz.ensure_on_host(params.get_mask())\n", + " #, extra_lines=True\n", " )\n", "f.suptitle(f'p:{params.proposal} r:{params.run} d:{params.darkrun}')\n", "f.savefig(path+prefix+'-histogram.png', dpi=300)" @@ -495,7 +507,7 @@ "res = boz.average_module(params.arr, dark=dark,\n", " ret='mean', mask=params.get_mask(), sat_roi=params.rois['sat'],\n", " sat_level=params.sat_level)\n", - "avg = res.compute().mean(axis=0)" + "avg = res.mean(axis=0).compute()" ] }, { @@ -642,7 +654,7 @@ ], "source": [ "ff = boz.compute_flat_field_correction(params.rois, params)\n", - "f = boz.inspect_plane_fitting(avg/ff, params.rois)\n", + "f = boz.inspect_plane_fitting(boz.ensure_on_host(avg)/ff, params.rois)\n", "f.savefig(path+prefix+'-inspect-withflatfield-refined.png', dpi=300)" ] }, @@ -1505,7 +1517,7 @@ } ], "source": [ - "f = boz.inspect_Fnl(params.get_Fnl())" + "f = boz.inspect_Fnl(boz.ensure_on_host(params.get_Fnl()))" ] }, { diff --git a/src/toolbox_scs/routines/boz.py b/src/toolbox_scs/routines/boz.py index 33a936d..856751e 100644 --- a/src/toolbox_scs/routines/boz.py +++ b/src/toolbox_scs/routines/boz.py @@ -23,6 +23,12 @@ from extra_geom import DSSC_1MGeometry from toolbox_scs.routines.XAS import xas +try: + import cupy as cp + _can_use_gpu = True +except ModuleNotFoundError: + _can_use_gpu = False + __all__ = [ 'parameters', 'get_roi_pixel_pos', @@ -79,6 +85,7 @@ class parameters(): self.pixel_pos = _get_pixel_pos(self.module) self.gain = gain + self.mask = None self.mask_idx = None self.mean_th = (None, None) self.std_th = (None, None) @@ -93,6 +100,8 @@ class parameters(): self.ff_alpha = None self.ff_max_iter = None + self._using_gpu = False + self.Fnl = None self.nl_alpha = None self.sat_level = None @@ -115,6 +124,29 @@ class parameters(): self.arr = self.arr.rechunk(('auto', -1, -1, -1)) self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1)) + def use_gpu(self): + assert _can_use_gpu, 'Failed to import cupy' + gpu_mem_gb = cp.cuda.Device().mem_info[1] / 2**30 + if gpu_mem_gb < 30: + print(f'Warning: GPU memory ({gpu_mem_gb}GB) may be insufficient') + if self._using_gpu: + return + assert ( + self.arr is not None and + self.arr_dark is not None + ), "Must load data before switching to GPU" + if self.mask is not None: + self.mask = cp.array(self.mask) + # moving full data to GPU + limit = 2**30 + self.arr = da.array( + cp.array(self.arr.compute()) + ).rechunk(('auto', -1, -1, -1), block_size_limit=limit) + self.arr_dark = da.array( + cp.array(self.arr_dark.compute()) + ).rechunk(('auto', -1, -1, -1), block_size_limit=limit) + self._using_gpu = True + def set_mask(self, arr): """Set mask of bad pixels. @@ -133,6 +165,9 @@ class parameters(): mask[k[0], k[1]] = False self.mask = mask + if self._using_gpu: + self.mask = cp.array(self.mask) + def get_mask(self): """Get the boolean array bad pixel of a DSSC module.""" return self.mask @@ -196,7 +231,10 @@ class parameters(): if self.Fnl is None: return None else: - return np.array(self.Fnl) + if self._using_gpu: + return cp.array(self.Fnl) + else: + return np.array(self.Fnl) def save(self, path='./'): """Save the parameters as a JSON file. @@ -300,6 +338,14 @@ class parameters(): return f +def ensure_on_host(arr): + # load data back from GPU - if it was on GPU + if hasattr(arr, "__cuda_array_interface__"): # avoid importing CuPy + return arr.get() + elif isinstance(arr, (da.Array,)): + return arr.map_blocks(ensure_on_host) + return arr + # Hexagonal pixels related function def _get_pixel_pos(module): @@ -451,8 +497,8 @@ def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)): fig: matplotlib figure """ # compute mean and std - dark_mean = arr.mean(axis=(0, 1)).compute() - dark_std = arr.std(axis=(0, 1)).compute() + dark_mean = ensure_on_host(arr.mean(axis=(0, 1)).compute()) + dark_std = ensure_on_host(arr.std(axis=(0, 1)).compute()) fig = plt.figure(figsize=(7, 2.7)) gs = fig.add_gridspec(2, 4) @@ -542,13 +588,13 @@ def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False): f = plt.figure(figsize=(6, 3)) ax = plt.gca() - h = histogram_module(arr, mask=mask) + h = ensure_on_host(histogram_module(arr, mask=mask)) Sum_h = np.sum(h) ax.plot(np.arange(2**9), h/Sum_h, marker='o', ms=3, markerfacecolor='none', lw=1) if arr_dark is not None: - hd = histogram_module(arr_dark, mask=mask) + hd = ensure_on_host(histogram_module(arr_dark, mask=mask)) Sum_hd = np.sum(hd) ax.plot(np.arange(2**9), hd/Sum_hd, marker='o', ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5) @@ -805,7 +851,8 @@ def compute_flat_field_correction(rois, params, plot=False): if plot: f, ax = plt.subplots(1, 1, figsize=(6, 2)) - img = ax.pcolormesh(np.flipud(flat_field[:, :256]), cmap='Greys_r') + img = ax.pcolormesh( + np.flipud(flat_field[:, :256]), cmap='Greys_r') f.colorbar(img, ax=[ax], label='amplitude') ax.set_xlabel('px') ax.set_ylabel('px') @@ -1086,9 +1133,9 @@ def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois, """ params.set_flat_field(p) ff = compute_flat_field_correction(rois, params) - - data = process(np.arange(2**9), arr_dark, arr, tid, rois, mask, ff, - sat_level) + + data = process(None, arr_dark, arr, tid, rois, mask, ff, + sat_level, params._using_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) @@ -1215,7 +1262,7 @@ def nl_lut(domain, dy): def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, - sat_level=511): + sat_level=511, use_gpu=False): """Criteria for the non linear correction. Inputs @@ -1239,8 +1286,8 @@ def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, of the deviation from the ideal detector response. """ Fmodel = nl_lut(domain, p) - data = process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, - sat_level) + data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark, + arr, tid, rois, mask, flat_field, sat_level, use_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) @@ -1290,7 +1337,7 @@ def nl_fit(params, domain): ff = compute_flat_field_correction(params.rois, params) fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid, - fitrois, params.get_mask(), ff, params.sat_level) + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) def fit_callback(x): if not hasattr(fit_callback, "counter"): @@ -1469,13 +1516,15 @@ def inspect_correction(params, gain=None): if Fnl is None: Fnl = np.arange(2**9) + xp = np if not params._using_gpu else cp # compute all levels of correction - data = process(np.arange(2**9), params.arr_dark, params.arr, params.tid, - fitrois, params.get_mask(), np.ones_like(ff), params.sat_level) - data_ff = process(np.arange(2**9), params.arr_dark, params.arr, params.tid, - fitrois, params.get_mask(), ff, params.sat_level) - data_ff_nl = process(Fnl, params.arr_dark, params.arr, - params.tid, fitrois, params.get_mask(), ff, params.sat_level) + data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level, + params._using_gpu) + data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) + data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, + fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) # for conversion to nb of photons if gain is None: @@ -1632,7 +1681,7 @@ def average_module(arr, dark=None, ret='mean', """ # F_INL if F_INL is not None: - narr = arr.map_blocks(lambda x: F_INL[x]) + narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype) else: narr = arr @@ -1640,11 +1689,25 @@ def average_module(arr, dark=None, ret='mean', narr = narr*mask if sat_roi is not None: - temp = (da.logical_not(da.any( - narr[:, :, sat_roi['yl']:sat_roi['yh'], - sat_roi['xl']:sat_roi['xh']] >= sat_level, - axis=[2, 3], keepdims=True))) - not_sat = da.repeat(da.repeat(temp, 128, axis=2), 512, axis=3) + not_sat = da.repeat( + da.repeat( + da.all( + narr[ + :, + :, + sat_roi["yl"] : sat_roi["yh"], + sat_roi["xl"] : sat_roi["xh"], + ] + < sat_level, + axis=[2, 3], + keepdims=True, + ), + 128, + axis=2, + ), + 512, + axis=3, + ) if dark is not None: narr = narr - dark @@ -1661,7 +1724,7 @@ def average_module(arr, dark=None, ret='mean', def process_module(arr, tid, dark, rois, mask=None, sat_level=511, - flat_field=None, F_INL=None): + flat_field=None, F_INL=None, use_gpu=False): """Process one module and extract roi intensity. Inputs @@ -1682,7 +1745,7 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, """ # F_INL if F_INL is not None: - narr = arr.map_blocks(lambda x: F_INL[x]) + narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype) else: narr = arr @@ -1711,6 +1774,9 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, # temp[::2, :, :] = flat_field[:, :] # flat_field = temp + if use_gpu and flat_field is not None: + flat_field = cp.asarray(flat_field) + # compute dark corrected ROI values v = {} for n in rois.keys(): @@ -1732,9 +1798,9 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, dims = ['trainId', 'pulseId'] r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} for n in rois.keys(): - res[n + '_sat'] = xr.DataArray(r_sat[n][:, :], + res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]), coords=r_coords, dims=dims) - res[n] = xr.DataArray(v[n], coords=r_coords, dims=dims) + res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims) for n in rois.keys(): roi = rois[n] @@ -1744,7 +1810,8 @@ def process_module(arr, tid, dark, rois, mask=None, sat_level=511, return res -def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): +def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511, + use_gpu=False): """Process dark and run data with corrections. Inputs @@ -1762,15 +1829,11 @@ def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511): roi extracted intensities """ # dark process - res = average_module(arr_dark, F_INL=Fmodel) - dark = res.compute() + dark = average_module(arr_dark, F_INL=Fmodel).compute() # data process - proc = process_module(arr, tid, dark, rois, mask, sat_level=sat_level, - flat_field=flat_field, F_INL=Fmodel) - data = proc.compute() - - return data + return process_module(arr, tid, dark, rois, mask, sat_level=sat_level, + flat_field=flat_field, F_INL=Fmodel, use_gpu=use_gpu).compute() def inspect_saturation(data, gain, Nbins=200): """Plot roi integrated histogram of the data with saturation -- GitLab