From 73c7d0d584c0997e65addc8abd3885a4d794fc7f Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Tue, 20 Aug 2024 18:44:08 +0200
Subject: [PATCH] Add MPI version of image averaging (powder-sum)

---
 nb/powdersum.slurm                |  20 ++
 src/geomtools/powder/powdersum.py | 331 +++++++++++++++++++++---------
 2 files changed, 255 insertions(+), 96 deletions(-)
 create mode 100644 nb/powdersum.slurm

diff --git a/nb/powdersum.slurm b/nb/powdersum.slurm
new file mode 100644
index 0000000..553f6e2
--- /dev/null
+++ b/nb/powdersum.slurm
@@ -0,0 +1,20 @@
+#!/bin/sh
+
+#SBATCH --output=powder-sum-%j.out
+#SBATCH --job-name=powder-sum
+
+source /etc/profile.d/modules.sh
+module purge
+module load exfel openmpi-no-python
+
+NODES=${SLURM_JOB_NUM_NODES}
+
+source ~/envs/dev/bin/activate
+
+export OMPI_MCA_btl=^smcuda
+export OMPI_MCA_rcache=^gpusm,rgpusm
+export OMPI_MCA_accelerator=^cuda
+
+MPIOPTS="--map-by NODES -x LD_LIBRARY_PATH -x PATH --bind-to none"
+
+mpirun -np $NODES ${MPIOPTS} geomtools-powder-sum $@
diff --git a/src/geomtools/powder/powdersum.py b/src/geomtools/powder/powdersum.py
index bea57e6..1bee17b 100644
--- a/src/geomtools/powder/powdersum.py
+++ b/src/geomtools/powder/powdersum.py
@@ -13,6 +13,12 @@ from extra_data.read_machinery import find_proposal
 from . import shmemarray as shmem
 from .misc import agipd_pixel_area, jungfrau_pixel_area
 
+try:
+    from mpi4py import MPI
+    comm = MPI.COMM_WORLD
+except ImportError:
+    comm = None
+
 SASE = {
     "SA1": "SA1", "FXE": "SA1", "SPB": "SA1",
     "SA2": "SA1", "HED": "SA2", "MID": "SA2",
@@ -40,18 +46,12 @@ PX_AREA = {
 
 
 class ImageAgg:
-    def __init__(
-        self,
-        nproc,
-        det_source,
-        shape,
-        adu_per_unit=1,
-        rounding_threshold=None,
-        px_area=None,
-        mask=None,
-    ):
-        self.det_source = det_source
+    def __init__(self, nproc, detector_id, shape, modules,
+                 adu_per_unit=1, rounding_threshold=None, px_area=None):
         self.adu_per_unit = adu_per_unit
+        self.modules = modules
+        self.num_modules = len(modules)
+        self.detector_id = detector_id
         if rounding_threshold is None:
             self.rounding = False
             self.rounding_shift = 0.0
@@ -60,10 +60,7 @@ class ImageAgg:
             self.rounding_shift = rounding_threshold - 0.5
         self.shape = shape
         self.nproc = nproc
-        if mask is None:
-            mask = np.zeros(shape, int)
 
-        self.mask = mask == 0
         arr_shape = (nproc,) + tuple(shape)
         if px_area is None:
             px_area = np.ones(shape, np.float32)
@@ -75,7 +72,17 @@ class ImageAgg:
         self.part_count = shmem.empty(arr_shape, int)
         self.part_num_frames = shmem.empty(nproc, int)
 
-    def compute(self, args):
+    def reset(self, module_ix, det_source, mask=None):
+        self.det_source = det_source
+        self.modi = module_ix
+
+        self.r = None
+        self.count = None
+        self.mean = None
+        self.deviation = None
+        self.num_frames = None
+
+    def _compute_worker_part(self, args):
         i, dc = args
         mean = self.part_mean[i]
         mean[:] = 0
@@ -87,10 +94,11 @@ class ImageAgg:
             dc[self.det_source].data_counts(index_group="image",
                                             labelled=False)
         )
+        inp_mask = self.mask[self.modi - self.mod0][None, ...]
 
         for tid, data in dc.trains():
             a = data[self.det_source]["image.data"] / self.adu_per_unit
-            msk = (data[self.det_source]["image.mask"] == 0) & self.mask
+            msk = (data[self.det_source]["image.mask"] == 0) & inp_mask
 
             a[~msk] = 0.0
             if self.rounding:
@@ -103,6 +111,36 @@ class ImageAgg:
             count += np.sum(msk.astype(int), 0)
         return i
 
+    def compute(self, dc_img, source_pattern):
+        pool = mp.Pool(self.nproc)
+        for modi, modno in self.iter_modules():
+            tm0 = time.perf_counter()
+            print(f"[{self._rank:2d}] {modno:3d}: ", end="")
+
+            source_id = source_pattern.format(
+                detector_id=self.detector_id, modno=modno)
+            dc = dc_img.select(
+                [(source_id, "image.data"), (source_id, "image.mask")],
+                require_all=True
+            )
+            self.reset(modi, source_id)
+            result_iter = pool.imap_unordered(
+                self._compute_worker_part,
+                enumerate(dc.split_trains(self.nproc))
+            )
+
+            for r in result_iter:
+                pass
+
+            self.finish()
+            self.write_module()
+
+            tm1 = time.perf_counter()
+            print(f"{tm1 - tm0:.1f} s")
+
+        pool.terminate()
+        pool.join()
+
     def finish(self):
         self.r = type("Moments", (), {})
         self.count = np.sum(self.part_count, 0)
@@ -125,8 +163,129 @@ class ImageAgg:
 
         self.num_frames = np.sum(self.part_num_frames)
 
+    def _create_h5datasets(self, h5f):
+        s = (self.num_modules,) + self.shape
+        self._ds_mean = h5f.create_dataset(
+            "powderSum/image/mean", s, dtype="f4")
+        self._ds_std = h5f.create_dataset(
+            "powderSum/image/std", s, dtype="f4")
+        self._ds_count = h5f.create_dataset(
+            "powderSum/image/count", s, dtype="i8")
+        self._ds_nfrm = h5f.create_dataset(
+            "powderSum/image/numFrames", self.num_modules, dtype="i8")
+        self._ds_mask = h5f.create_dataset(
+            "powderSum/image/mask", s, dtype="u1")
+
+    def _create_output_arrays(self):
+        n = self.modN - self.mod0
+        s = (n,) + self.shape
+        self._ds_mean = np.zeros(s, dtype=np.float32)
+        self._ds_std = np.zeros(s, dtype=np.float32)
+        self._ds_count = np.zeros(s, dtype=int)
+        self._ds_nfrm = np.zeros(n, dtype=int)
+        self._ds_mask = np.zeros(s, dtype=np.uint8)
+
+    def prepare(self, output_fn, mask=None, comm=None, conditions={}):
+        if comm is None:
+            size = 1
+            rank = 0
+        else:
+            size = comm.Get_size()
+            rank = comm.Get_rank()
+
+        self.mod0 = rank * self.num_modules // size
+        self.modN = (rank + 1) * self.num_modules // size
+
+        self._rank = rank
+        self._size = size
+        self._comm = comm
+
+        n = self.modN - self.mod0
+        if mask is None:
+            s = (n,) + self.shape
+            self.mask = np.ones(s, bool)
+        else:
+            self.mask = mask[self.mod0:self.modN] == 0
+
+        if rank == 0:
+            h5f = h5py.File(output_fn, "w")
+            # modules
+            h5f["powderSum/image/modules"] = self.modules
+            # conditions
+            conditions_grp = h5f.create_group("powderSum/conditions")
+            conditions_grp["detectorId"] = self.detector_id.encode("ascii")
+            for key, value in conditions.items():
+                conditions_grp[key] = value
+        else:
+            h5f = None
+
+        self._h5f = h5f
+
+        # average image
+        if size > 1:
+            self._create_output_arrays()
+        else:
+            self._create_h5datasets(h5f)
+
+    def iter_modules(self):
+        return enumerate(self.modules[self.mod0:self.modN], start=self.mod0)
+
+    def write_module(self):
+        i = self.modi - self.mod0
+        self._ds_mean[i] = self.mean
+        self._ds_std[i] = self.deviation
+        self._ds_count[i] = self.count
+        self._ds_nfrm[i] = self.num_frames
+        self._ds_mask[i] = ~self.mask | (self.count == 0)
+
+    def _send(self):
+        self._comm.Gatherv(self._ds_mean, None, root=0)
+        self._comm.Gatherv(self._ds_std, None, root=0)
+        self._comm.Gatherv(self._ds_count, None, root=0)
+        self._comm.Gatherv(self._ds_nfrm, None, root=0)
+        self._comm.Gatherv(self._ds_mask, None, root=0)
+
+    def _recv_and_write(self):
+        s = (self.num_modules,) + self.shape
+        buf = np.zeros(s, dtype=np.float32)
+        self._comm.Gatherv(self._ds_mean, buf, root=0)
+        self._h5f["powderSum/image/mean"] = buf
+        self._comm.Gatherv(self._ds_std, buf, root=0)
+        self._h5f["powderSum/image/std"] = buf
+        buf = np.zeros(s, dtype=int)
+        self._comm.Gatherv(self._ds_count, buf, root=0)
+        self._h5f["powderSum/image/count"] = buf
+        buf = np.zeros(self.num_modules, dtype=int)
+        self._comm.Gatherv(self._ds_nfrm, buf, root=0)
+        self._h5f["powderSum/image/numFrames"] = buf
+        buf = np.zeros(s, dtype=np.uint8)
+        self._comm.Gatherv(self._ds_mask, buf, root=0)
+        self._h5f["powderSum/image/mask"] = buf
+
+    def flush(self):
+        if self._size <= 1:
+            return
+        if self._rank == 0:
+            self._recv_and_write()
+            self._h5f.close()
+        else:
+            self._send()
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["_ds_mean"] = None
+        state["_ds_std"] = None
+        state["_ds_count"] = None
+        state["_ds_nfrm"] = None
+        state["_ds_mask"] = None
+        state["_h5f"] = None
+        state["_comm"] = None
+        return state
+
 
 def main(argv=None):
+    tm0 = time.perf_counter()
+
     parser = ArgumentParser(
         description="The program sum up images over a given run")
     parser.add_argument("-n", "--num-proc", type=int,
@@ -170,13 +329,31 @@ def main(argv=None):
 
     propdir = find_proposal(f"p{args.proposal:06d}")
 
+    # --- MPI initialization begin
     if args.num_proc is None:
         num_proc = min(psutil.cpu_count(logical=False), 32)
 
+    if comm is not None:
+        mpi_size = comm.Get_size()
+        rank = comm.Get_rank()
+    else:
+        mpi_size = 1
+        rank = 0
+
+    if rank == 0:
+        mpi_print = print
+    else:
+        def mpi_print(*args, **kwargs):
+            pass
+
+    mpi_print("Num proc:", num_proc)
+    mpi_print("MPI size:", mpi_size)
+    # --- MPI initialization end
+
     detector_id = args.detector_id
     inst, _, _ = detector_id.partition("_")
-    print("Instrument:", inst)
 
+    mpi_print("Instrument:", inst)
     run = open_run(args.proposal, args.run, data="all")
 
     # photon energy and wave length
@@ -189,7 +366,7 @@ def main(argv=None):
         xgm_id = args.xgm_id
         if not xgm_id:
             xgm_id = XGM[SASE[inst]]
-        print("XGM:", xgm_id)
+        mpi_print("XGM:", xgm_id)
 
         # wave length
         lmd = run[xgm_id, "pulseEnergy.wavelengthUsed"].as_single_value()
@@ -206,16 +383,16 @@ def main(argv=None):
         else:
             z_motor_id, z_motor_key = ZMOTOR[detector_id]
 
-        print(f"Detector Z motor: {z_motor_id}.{z_motor_key}")
+        mpi_print(f"Detector Z motor: {z_motor_id}.{z_motor_key}")
 
         z_motor = run[z_motor_id]
         clen0 = MIN_CLEN.get(detector_id)
         clen = 1e-3 * z_motor.run_value(z_motor_key) + clen0 + args.spacing_len
 
-    print()
-    print(f"Photon energy: {photon_en:g} (keV)")
-    print(f"Wave length: {lmd:.4g} (nm)")
-    print(f"Camera length: {clen:.3f} (m)")
+    mpi_print()
+    mpi_print(f"Photon energy: {photon_en:g} (keV)")
+    mpi_print(f"Wave length: {lmd:.4g} (nm)")
+    mpi_print(f"Camera length: {clen:.3f} (m)")
 
     if args.images_dir:
         images_dir = args.images_dir
@@ -263,93 +440,55 @@ def main(argv=None):
         mask_fn = None
         mask0 = [None] * num_modules
 
-    print()
-    print("Detector source pattern:", det_source_pattern)
-    print("Num modules:", num_modules)
-    print("Modules:", modules)
-    print("Mask:", mask_fn)
-    print()
-    print("Num proc:", num_proc)
+    mpi_print()
+    mpi_print("Detector source pattern:", det_source_pattern)
+    mpi_print("Num modules:", num_modules)
+    mpi_print("Modules:", modules)
+    mpi_print("Mask:", mask_fn)
 
     if args.round_to_photons:
         adu_per_unit = photon_en
         rounding_threshold = args.rounding_threshold
-        print("Rounding threshold:", rounding_threshold)
     else:
         adu_per_unit = 1
         rounding_threshold = None
 
-    pw_aggs = []
-    for modi, modno in enumerate(modules):
-        tm0 = time.perf_counter()
-        print(f"{modno: 3d}: ", end="")
-
-        source_id = source_pattern.format(detector_id=detector_id, modno=modno)
-        dc = dc_img.select(
-            [(source_id, "image.data"), (source_id, "image.mask")],
-            require_all=True
-        )
-
-        shape = dc[source_id, "image.data"].shape[1:]
-        if args.layout is None:
-            px_area = None
-        else:
-            px_area = PX_AREA[args.layout]()
-
-        agg = ImageAgg(
-            num_proc,
-            source_id,
-            shape,
-            adu_per_unit,
-            rounding_threshold,
-            px_area=px_area,
-            mask=mask0[modi],
-        )
-
-        with mp.Pool(num_proc) as pool:
-            result_iter = pool.imap_unordered(
-                agg.compute, enumerate(dc.split_trains(num_proc))
-            )
-
-            for r in result_iter:
-                pass
-
-        agg.finish()
-        pw_aggs.append(agg)
+    mpi_print()
+    mpi_print("Rounding threshold:", rounding_threshold)
+
+    if args.layout is None:
+        px_area = None
+        for modno in modules:
+            source_id = source_pattern.format(
+                detector_id=detector_id, modno=modno)
+            if source_id in dc_img.all_sources:
+                break
+        shape = dc_img[source_id, "image.data"].shape[1:]
+    else:
+        px_area = PX_AREA[args.layout]()
+        shape = px_area.shape
 
-        tm1 = time.perf_counter()
-        print(f"{tm1 - tm0:.1f} s")
+    conditions = {
+        "cameraLen": clen,
+        "photonEnergy": photon_en,
+        "waveLength": lmd,
+        "run": args.run,
+        "proposal": args.proposal,
+    }
 
     output_fn = f"powder_sum_p{args.proposal:06d}_r{args.run:04d}.h5"
     if args.output_dir:
         output_fn = os.path.join(args.output_dir, output_fn)
-    with h5py.File(output_fn, "w") as f:
-        # conditions
-        f["powderSum/conditions/cameraLen"] = clen
-        f["powderSum/conditions/photonEnergy"] = photon_en
-        f["powderSum/conditions/waveLength"] = lmd
-        f["powderSum/conditions/detectorId"] = detector_id.encode("ascii")
-        f["powderSum/conditions/run"] = args.run
-        f["powderSum/conditions/proposal"] = args.proposal
 
-        # average image
-        f["powderSum/image/mean"] = np.stack(
-            [pw_aggs[i].mean for i in range(num_modules)]
-        )
-        f["powderSum/image/std"] = np.stack(
-            [pw_aggs[i].deviation for i in range(num_modules)]
-        )
-        f["powderSum/image/count"] = np.stack(
-            [pw_aggs[i].count for i in range(num_modules)]
-        )
-        f["powderSum/image/numFrames"] = np.stack(
-            [pw_aggs[i].num_frames for i in range(num_modules)]
-        )
-        f["powderSum/image/mask"] = np.stack(
-            [~pw_aggs[i].mask | (pw_aggs[i].count == 0)
-             for i in range(num_modules)]
-        ).astype(np.uint8)
-        f["powderSum/image/modules"] = modules
+    agg = ImageAgg(num_proc, detector_id, shape, modules, adu_per_unit,
+                   rounding_threshold, px_area=px_area)
+
+    agg.prepare(output_fn, mask=mask0, comm=comm, conditions=conditions)
+    agg.compute(dc_img, source_pattern)
+    agg.flush()
+
+    tm1 = time.perf_counter()
+    mpi_print(f"Walltime: {tm1 - tm0:.0f} s")
 
 
 if __name__ == "__main__":
-- 
GitLab