diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index 121c120c81d12c171c6c0be2679cc1a7ade8415f..2c86dc39c86a829c22e398903d66c12fbe8474f7 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -196,6 +196,7 @@ class AgipdCorrections:
         corr_bools: Optional[dict] = None,
         gain_mode: AgipdGainMode = AgipdGainMode.ADAPTIVE_GAIN,
         comp_threads=1,
+        train_ids: Optional[np.ndarray] = None
     ):
         """
         Initialize an AgipdCorrections Class
@@ -211,6 +212,7 @@ class AgipdCorrections:
         :param corr_bools: A dict with all of the correction booleans requested
                            or available
         :param comp_threads: Number of threads to use for compressing gain/mask
+        :param train_ids: train IDs to process, all if omitted.
 
         The following example shows a typical use case:
         .. code-block:: python
@@ -251,6 +253,7 @@ class AgipdCorrections:
         self.max_cells = max_cells
         self.gain_mode = gain_mode
         self.comp_threads = comp_threads
+        self.train_ids = np.array(train_ids) if train_ids is not None else None
 
         self.start, self.last, self.step = self._validate_selected_pulses(
             max_pulses, max_cells)
@@ -339,6 +342,12 @@ class AgipdCorrections:
                                            allpulses, valid_indices,
                                            apply_sel_pulses)
 
+            if firange is None:
+                # gen_valid_range() returns None if there are no cells
+                # to correct.
+                data_dict['nImg'][0] = 0
+                return 0
+
             n_img = firange.shape[0]
             data_dict['nImg'][0] = n_img
             if np.all(np.diff(firange) == 1):
@@ -775,6 +784,15 @@ class AgipdCorrections:
             highok = (idxtrains < medianTrain + 1e4)
             valid &= lowok & highok
 
+            # Filter down to train IDs selected externally, e.g. via
+            # PPU devices.
+            if self.train_ids is not None:
+                valid &= np.in1d(idxtrains, self.train_ids)
+
+            if not valid.any():
+                # Shortcut if no valid trains are left.
+                return valid, 0, 0, idxtrains, np.zeros(0, dtype=np.int32)
+
             # Last index = last valid train + max. number of memory cells
             last_index = int(first[valid][-1] + count[valid][-1])
             first_index = int(first[valid][0])