From b6ff3c1b394ba8bfc69d8c125727f064073d0ae6 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Fri, 8 Oct 2021 13:49:17 +0200
Subject: [PATCH] Optimize loading order for AGIPD data to exit early with
 train filtering

---
 src/cal_tools/agipdlib.py | 35 ++++++++++++++++++++++-------------
 1 file changed, 22 insertions(+), 13 deletions(-)

diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index 2c86dc39c..6308245f1 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -329,11 +329,16 @@ class AgipdCorrections:
         data_dict['moduleIdx'][0] = module_idx
         try:
             f = h5py.File(file_name, "r")
-            group = f[agipd_base]["image"]
 
             (_, first_index, last_index,
              _, valid_indices) = self.get_valid_image_idx(idx_base, f)
 
+            if len(valid_indices) == 0:
+                # If there's not a single valid index, exit early.
+                data_dict['nImg'][0] = 0
+                return 0
+
+            group = f[agipd_base]['image']
             allcells = np.squeeze(group['cellId'])
             allpulses = np.squeeze(group['pulseId'])
 
@@ -344,7 +349,7 @@ class AgipdCorrections:
 
             if firange is None:
                 # gen_valid_range() returns None if there are no cells
-                # to correct.
+                # to correct, exit early.
                 data_dict['nImg'][0] = 0
                 return 0
 
@@ -769,26 +774,30 @@ class AgipdCorrections:
     ):
         """Return the indices of valid data"""
         if raw_format_version == 2:
+            idxtrains = np.squeeze(infile['/INDEX/trainId'])
+
+            # Check against train ID filter list, if any
+            if self.train_ids is not None:
+                valid = np.in1d(idxtrains, self.train_ids)
+
+                if not valid.any():
+                    # Shortcut to avoid any further loading.
+                    return valid, 0, 0, idxtrains, np.zeros(0, dtype=np.int32)
+            else:
+                valid = np.ones_like(idxtrains, dtype=bool)
+
+            # Load count and offsets and filter for non-emtpy trains.
             count = np.squeeze(infile[idx_base + "image/count"])
             first = np.squeeze(infile[idx_base + "image/first"])
-            if np.count_nonzero(count != 0) == 0:
-                raise IOError("File has no valid counts")
-            valid = count != 0
-            idxtrains = np.squeeze(infile["/INDEX/trainId"])
+            valid &= count != 0
 
-            # Train indices are of type=f32
             # Validate that train indices values fall
             # between medianTrain +- 1e4
-            medianTrain = np.nanmedian(idxtrains)
+            medianTrain = np.median(idxtrains)
             lowok = (idxtrains > medianTrain - 1e4)
             highok = (idxtrains < medianTrain + 1e4)
             valid &= lowok & highok
 
-            # Filter down to train IDs selected externally, e.g. via
-            # PPU devices.
-            if self.train_ids is not None:
-                valid &= np.in1d(idxtrains, self.train_ids)
-
             if not valid.any():
                 # Shortcut if no valid trains are left.
                 return valid, 0, 0, idxtrains, np.zeros(0, dtype=np.int32)
-- 
GitLab