From 9bd703432e15b8a952cbc8420a9236705cbabadc Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Tue, 28 Nov 2023 14:39:08 +0000
Subject: [PATCH] Try to simplify & speed up AGIPD file reading code

---
 src/cal_tools/agipdlib.py | 41 +++++++++++++++++----------------------
 1 file changed, 18 insertions(+), 23 deletions(-)

diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index 4d8cb0e77..1b87470d3 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -12,7 +12,7 @@ import h5py
 import numpy as np
 import sharedmem
 from dateutil import parser
-from extra_data import DataCollection, H5File, RunDirectory, by_id, components
+from extra_data import DataCollection, H5File, RunDirectory, by_id
 
 from cal_tools import agipdalgs as calgs
 from cal_tools.agipdutils import (
@@ -682,50 +682,45 @@ class AgipdCorrections:
         valid_train_ids = im_dc.train_ids
         # Get a count of images in each train
         nimg_in_trains = im_dc[agipd_base, "image.trainId"].data_counts(False)
-        nimg_in_trains = nimg_in_trains.astype(int)
+        nimg_in_trains = nimg_in_trains.astype(np.int64)
 
         # store valid trains in shared memory
         n_valid_trains = len(valid_train_ids)
         data_dict["n_valid_trains"][0] = n_valid_trains
         data_dict["valid_trains"][:n_valid_trains] = valid_train_ids
 
-        if "AGIPD500K" in agipd_base:
-            agipd_comp = components.AGIPD500K(im_dc)
-        else:
-            agipd_comp = components.AGIPD1M(im_dc)
-
-        kw = {
-            "unstack_pulses": False,
-        }
-
         # get selection for the images in this file
         cm = (self.cell_sel.CM_NONE if apply_sel_pulses
               else self.cell_sel.CM_PRESEL)
 
-        cellid = np.squeeze(im_dc[agipd_base, "image.cellId"].ndarray())
+        agipd_src = im_dc[agipd_base]
+
+        cellid = agipd_src["image.cellId"].ndarray()[:, 0]
 
         img_selected, nimg_in_trains = self.cell_sel.get_cells_on_trains(
             np.array(valid_train_ids), nimg_in_trains, cellid, cm=cm)
-        data_dict["nimg_in_trains"][:n_valid_trains] = nimg_in_trains
 
-        frm_ix = np.flatnonzero(img_selected)
+        data_dict["nimg_in_trains"][:n_valid_trains] = nimg_in_trains
         data_dict["cm_presel"][0] = (cm == self.cell_sel.CM_PRESEL)
-        n_img = len(frm_ix)
+
+        n_img = img_selected.sum()
+        if img_selected.all():
+            # All frames selected - use slice to skip unnecessary copy
+            frm_ix = np.s_[:]
+        else:
+            frm_ix = np.flatnonzero(img_selected)
 
         # read raw data
-        # [n_modules, n_imgs, 2, x, y]
-        raw_data = agipd_comp.get_array("image.data", **kw)[0]
+        # [n_imgs, 2, x, y]
+        raw_data = agipd_src['image.data'].ndarray()
 
         # store in shmem only selected images
         data_dict['nImg'][0] = n_img
         data_dict['data'][:n_img] = raw_data[frm_ix, 0]
         data_dict['rawgain'][:n_img] = raw_data[frm_ix, 1]
-        data_dict['cellId'][:n_img] = agipd_comp.get_array(
-            "image.cellId", **kw)[0, frm_ix]
-        data_dict['pulseId'][:n_img] = agipd_comp.get_array(
-            "image.pulseId", **kw)[0, frm_ix]
-        data_dict['trainId'][:n_img] = agipd_comp.get_array(
-            "image.trainId", **kw)[0, frm_ix]
+        data_dict['cellId'][:n_img] = cellid[frm_ix]
+        data_dict['pulseId'][:n_img] = agipd_src['image.pulseId'].ndarray()[frm_ix, 0]
+        data_dict['trainId'][:n_img] = agipd_src['image.trainId'].ndarray()[frm_ix, 0]
 
         return n_img
 
-- 
GitLab