diff --git a/src/cal_tools/agipdalgs.pyx b/src/cal_tools/agipdalgs.pyx
index 954d63d7055cbccc176c885681cdce354519f672..649511df6ecdf15b3674590d27c29c6b9e56e34b 100644
--- a/src/cal_tools/agipdalgs.pyx
+++ b/src/cal_tools/agipdalgs.pyx
@@ -120,35 +120,18 @@ def histogram2d(cnp.ndarray[float, ndim=1] data_x, cnp.ndarray[float, ndim=1] da
     return ret, np.linspace(min_x, max_x, bins_x+1), np.linspace(min_y, max_y, bins_y+1)
 
 
-@boundscheck(False)
-@wraparound(False)
-def gain_choose(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[cnp.float32_t, ndim=4] choices):
-    """Specialised fast equivalent of np.choose(), to select data for a per-pixel gain stage"""
-    cdef int i, j, k
-    cdef cnp.uint8_t v
-    cdef cnp.ndarray[cnp.float32_t, ndim=3] out
-    out = np.zeros_like(a, dtype=np.float32)
-
-    assert (<object>choices).shape == (3,) + (<object>a).shape
-
-    with nogil:
-        for i in range(a.shape[0]):
-            for j in range(a.shape[1]):
-                for k in range(a.shape[2]):
-                    v = a[i, j, k]
-                    out[i, j, k] = choices[v, i, j, k]
-
-    return out
-
+ctypedef fused choices_t:
+    cnp.float32_t
+    cnp.int32_t
 
 @boundscheck(False)
 @wraparound(False)
-def gain_choose_int(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[cnp.int32_t, ndim=4] choices):
+def gain_choose(cnp.ndarray[cnp.uint8_t, ndim=3] a, cnp.ndarray[choices_t, ndim=4] choices):
     """Specialised fast equivalent of np.choose(), to select data for a per-pixel gain stage"""
     cdef int i, j, k
     cdef cnp.uint8_t v
-    cdef cnp.ndarray[cnp.int32_t, ndim=3] out
-    out = np.zeros_like(a, dtype=np.int32)
+    cdef cnp.ndarray[choices_t, ndim=3] out
+    out = np.zeros_like(a, dtype=(<object>choices).dtype)
 
     assert (<object>choices).shape == (3,) + (<object>a).shape
 
diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index fafc7fc8ec111df0428a6eaffe82847d199cf971..2e82ca223e30b4b836dec6f93a7a3ed3c5e35f4a 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -728,7 +728,7 @@ class AgipdCorrections:
         cellid = self.shared_dict[i_proc]['cellId'][first:last]
         # output is saved in sharedmem to pass for correct_agipd()
         # as this function takes about 3 seconds.
-        self.shared_dict[i_proc]["msk"][first:last] = calgs.gain_choose_int(
+        self.shared_dict[i_proc]["msk"][first:last] = calgs.gain_choose(
             gain, self.mask[module_idx][:, cellid]
         )
 
@@ -802,7 +802,7 @@ class AgipdCorrections:
         # if baseline correction was not requested
         # msk and rel_corr will still be empty shared_mem arrays
         if not any(self.blc_bools):
-            msk = calgs.gain_choose_int(gain, self.mask[module_idx][:, cellid])
+            msk = calgs.gain_choose(gain, self.mask[module_idx][:, cellid])
 
             # same for relative gain and then bad pixel mask
             if hasattr(self, "rel_gain"):