diff --git a/src/cal_tools/agipdalgs.pyx b/src/cal_tools/agipdalgs.pyx index 954d63d7055cbccc176c885681cdce354519f672..649511df6ecdf15b3674590d27c29c6b9e56e34b 100644 --- a/src/cal_tools/agipdalgs.pyx +++ b/src/cal_tools/agipdalgs.pyx @@ -120,35 +120,18 @@ def histogram2d(cnp.ndarray[float, ndim=1] data_x, cnp.ndarray[float, ndim=1] da return ret, np.linspace(min_x, max_x, bins_x+1), np.linspace(min_y, max_y, bins_y+1) -@boundscheck(False) -@wraparound(False) -def gain_choose(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[cnp.float32_t, ndim=4] choices): - """Specialised fast equivalent of np.choose(), to select data for a per-pixel gain stage""" - cdef int i, j, k - cdef cnp.uint8_t v - cdef cnp.ndarray[cnp.float32_t, ndim=3] out - out = np.zeros_like(a, dtype=np.float32) - - assert (<object>choices).shape == (3,) + (<object>a).shape - - with nogil: - for i in range(a.shape[0]): - for j in range(a.shape[1]): - for k in range(a.shape[2]): - v = a[i, j, k] - out[i, j, k] = choices[v, i, j, k] - - return out - +ctypedef fused choices_t: + cnp.float32_t + cnp.int32_t @boundscheck(False) @wraparound(False) -def gain_choose_int(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[cnp.int32_t, ndim=4] choices): +def gain_choose(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[choices_t, ndim=4] choices): """Specialised fast equivalent of np.choose(), to select data for a per-pixel gain stage""" cdef int i, j, k cdef cnp.uint8_t v - cdef cnp.ndarray[cnp.int32_t, ndim=3] out - out = np.zeros_like(a, dtype=np.int32) + cdef cnp.ndarray[choices_t, ndim=3] out + out = np.zeros_like(a, dtype=(<object>choices).dtype) assert (<object>choices).shape == (3,) + (<object>a).shape diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py index fafc7fc8ec111df0428a6eaffe82847d199cf971..2e82ca223e30b4b836dec6f93a7a3ed3c5e35f4a 100644 --- a/src/cal_tools/agipdlib.py +++ b/src/cal_tools/agipdlib.py @@ -728,7 +728,7 @@ class AgipdCorrections: cellid = self.shared_dict[i_proc]['cellId'][first:last] # output is saved in sharedmem to pass for correct_agipd() # as this function takes about 3 seconds. - self.shared_dict[i_proc]["msk"][first:last] = calgs.gain_choose_int( + self.shared_dict[i_proc]["msk"][first:last] = calgs.gain_choose( gain, self.mask[module_idx][:, cellid] ) @@ -802,7 +802,7 @@ class AgipdCorrections: # if baseline correction was not requested # msk and rel_corr will still be empty shared_mem arrays if not any(self.blc_bools): - msk = calgs.gain_choose_int(gain, self.mask[module_idx][:, cellid]) + msk = calgs.gain_choose(gain, self.mask[module_idx][:, cellid]) # same for relative gain and then bad pixel mask if hasattr(self, "rel_gain"):