diff --git a/nb/powdersum.slurm b/nb/powdersum.slurm new file mode 100644 index 0000000000000000000000000000000000000000..553f6e2cac1419ebe4072f9b387d67c13d9e89db --- /dev/null +++ b/nb/powdersum.slurm @@ -0,0 +1,20 @@ +#!/bin/sh + +#SBATCH --output=powder-sum-%j.out +#SBATCH --job-name=powder-sum + +source /etc/profile.d/modules.sh +module purge +module load exfel openmpi-no-python + +NODES=${SLURM_JOB_NUM_NODES} + +source ~/envs/dev/bin/activate + +export OMPI_MCA_btl=^smcuda +export OMPI_MCA_rcache=^gpusm,rgpusm +export OMPI_MCA_accelerator=^cuda + +MPIOPTS="--map-by NODES -x LD_LIBRARY_PATH -x PATH --bind-to none" + +mpirun -np $NODES ${MPIOPTS} geomtools-powder-sum $@ diff --git a/src/geomtools/powder/powdersum.py b/src/geomtools/powder/powdersum.py index bea57e684e1c5e7b1c1dbab255094cc2dba5b802..1bee17bcfe42ddb658390833c0605240954605c5 100644 --- a/src/geomtools/powder/powdersum.py +++ b/src/geomtools/powder/powdersum.py @@ -13,6 +13,12 @@ from extra_data.read_machinery import find_proposal from . import shmemarray as shmem from .misc import agipd_pixel_area, jungfrau_pixel_area +try: + from mpi4py import MPI + comm = MPI.COMM_WORLD +except ImportError: + comm = None + SASE = { "SA1": "SA1", "FXE": "SA1", "SPB": "SA1", "SA2": "SA1", "HED": "SA2", "MID": "SA2", @@ -40,18 +46,12 @@ PX_AREA = { class ImageAgg: - def __init__( - self, - nproc, - det_source, - shape, - adu_per_unit=1, - rounding_threshold=None, - px_area=None, - mask=None, - ): - self.det_source = det_source + def __init__(self, nproc, detector_id, shape, modules, + adu_per_unit=1, rounding_threshold=None, px_area=None): self.adu_per_unit = adu_per_unit + self.modules = modules + self.num_modules = len(modules) + self.detector_id = detector_id if rounding_threshold is None: self.rounding = False self.rounding_shift = 0.0 @@ -60,10 +60,7 @@ class ImageAgg: self.rounding_shift = rounding_threshold - 0.5 self.shape = shape self.nproc = nproc - if mask is None: - mask = np.zeros(shape, int) - self.mask = mask == 0 arr_shape = (nproc,) + tuple(shape) if px_area is None: px_area = np.ones(shape, np.float32) @@ -75,7 +72,17 @@ class ImageAgg: self.part_count = shmem.empty(arr_shape, int) self.part_num_frames = shmem.empty(nproc, int) - def compute(self, args): + def reset(self, module_ix, det_source, mask=None): + self.det_source = det_source + self.modi = module_ix + + self.r = None + self.count = None + self.mean = None + self.deviation = None + self.num_frames = None + + def _compute_worker_part(self, args): i, dc = args mean = self.part_mean[i] mean[:] = 0 @@ -87,10 +94,11 @@ class ImageAgg: dc[self.det_source].data_counts(index_group="image", labelled=False) ) + inp_mask = self.mask[self.modi - self.mod0][None, ...] for tid, data in dc.trains(): a = data[self.det_source]["image.data"] / self.adu_per_unit - msk = (data[self.det_source]["image.mask"] == 0) & self.mask + msk = (data[self.det_source]["image.mask"] == 0) & inp_mask a[~msk] = 0.0 if self.rounding: @@ -103,6 +111,36 @@ class ImageAgg: count += np.sum(msk.astype(int), 0) return i + def compute(self, dc_img, source_pattern): + pool = mp.Pool(self.nproc) + for modi, modno in self.iter_modules(): + tm0 = time.perf_counter() + print(f"[{self._rank:2d}] {modno:3d}: ", end="") + + source_id = source_pattern.format( + detector_id=self.detector_id, modno=modno) + dc = dc_img.select( + [(source_id, "image.data"), (source_id, "image.mask")], + require_all=True + ) + self.reset(modi, source_id) + result_iter = pool.imap_unordered( + self._compute_worker_part, + enumerate(dc.split_trains(self.nproc)) + ) + + for r in result_iter: + pass + + self.finish() + self.write_module() + + tm1 = time.perf_counter() + print(f"{tm1 - tm0:.1f} s") + + pool.terminate() + pool.join() + def finish(self): self.r = type("Moments", (), {}) self.count = np.sum(self.part_count, 0) @@ -125,8 +163,129 @@ class ImageAgg: self.num_frames = np.sum(self.part_num_frames) + def _create_h5datasets(self, h5f): + s = (self.num_modules,) + self.shape + self._ds_mean = h5f.create_dataset( + "powderSum/image/mean", s, dtype="f4") + self._ds_std = h5f.create_dataset( + "powderSum/image/std", s, dtype="f4") + self._ds_count = h5f.create_dataset( + "powderSum/image/count", s, dtype="i8") + self._ds_nfrm = h5f.create_dataset( + "powderSum/image/numFrames", self.num_modules, dtype="i8") + self._ds_mask = h5f.create_dataset( + "powderSum/image/mask", s, dtype="u1") + + def _create_output_arrays(self): + n = self.modN - self.mod0 + s = (n,) + self.shape + self._ds_mean = np.zeros(s, dtype=np.float32) + self._ds_std = np.zeros(s, dtype=np.float32) + self._ds_count = np.zeros(s, dtype=int) + self._ds_nfrm = np.zeros(n, dtype=int) + self._ds_mask = np.zeros(s, dtype=np.uint8) + + def prepare(self, output_fn, mask=None, comm=None, conditions={}): + if comm is None: + size = 1 + rank = 0 + else: + size = comm.Get_size() + rank = comm.Get_rank() + + self.mod0 = rank * self.num_modules // size + self.modN = (rank + 1) * self.num_modules // size + + self._rank = rank + self._size = size + self._comm = comm + + n = self.modN - self.mod0 + if mask is None: + s = (n,) + self.shape + self.mask = np.ones(s, bool) + else: + self.mask = mask[self.mod0:self.modN] == 0 + + if rank == 0: + h5f = h5py.File(output_fn, "w") + # modules + h5f["powderSum/image/modules"] = self.modules + # conditions + conditions_grp = h5f.create_group("powderSum/conditions") + conditions_grp["detectorId"] = self.detector_id.encode("ascii") + for key, value in conditions.items(): + conditions_grp[key] = value + else: + h5f = None + + self._h5f = h5f + + # average image + if size > 1: + self._create_output_arrays() + else: + self._create_h5datasets(h5f) + + def iter_modules(self): + return enumerate(self.modules[self.mod0:self.modN], start=self.mod0) + + def write_module(self): + i = self.modi - self.mod0 + self._ds_mean[i] = self.mean + self._ds_std[i] = self.deviation + self._ds_count[i] = self.count + self._ds_nfrm[i] = self.num_frames + self._ds_mask[i] = ~self.mask | (self.count == 0) + + def _send(self): + self._comm.Gatherv(self._ds_mean, None, root=0) + self._comm.Gatherv(self._ds_std, None, root=0) + self._comm.Gatherv(self._ds_count, None, root=0) + self._comm.Gatherv(self._ds_nfrm, None, root=0) + self._comm.Gatherv(self._ds_mask, None, root=0) + + def _recv_and_write(self): + s = (self.num_modules,) + self.shape + buf = np.zeros(s, dtype=np.float32) + self._comm.Gatherv(self._ds_mean, buf, root=0) + self._h5f["powderSum/image/mean"] = buf + self._comm.Gatherv(self._ds_std, buf, root=0) + self._h5f["powderSum/image/std"] = buf + buf = np.zeros(s, dtype=int) + self._comm.Gatherv(self._ds_count, buf, root=0) + self._h5f["powderSum/image/count"] = buf + buf = np.zeros(self.num_modules, dtype=int) + self._comm.Gatherv(self._ds_nfrm, buf, root=0) + self._h5f["powderSum/image/numFrames"] = buf + buf = np.zeros(s, dtype=np.uint8) + self._comm.Gatherv(self._ds_mask, buf, root=0) + self._h5f["powderSum/image/mask"] = buf + + def flush(self): + if self._size <= 1: + return + if self._rank == 0: + self._recv_and_write() + self._h5f.close() + else: + self._send() + + def __getstate__(self): + state = self.__dict__.copy() + state["_ds_mean"] = None + state["_ds_std"] = None + state["_ds_count"] = None + state["_ds_nfrm"] = None + state["_ds_mask"] = None + state["_h5f"] = None + state["_comm"] = None + return state + def main(argv=None): + tm0 = time.perf_counter() + parser = ArgumentParser( description="The program sum up images over a given run") parser.add_argument("-n", "--num-proc", type=int, @@ -170,13 +329,31 @@ def main(argv=None): propdir = find_proposal(f"p{args.proposal:06d}") + # --- MPI initialization begin if args.num_proc is None: num_proc = min(psutil.cpu_count(logical=False), 32) + if comm is not None: + mpi_size = comm.Get_size() + rank = comm.Get_rank() + else: + mpi_size = 1 + rank = 0 + + if rank == 0: + mpi_print = print + else: + def mpi_print(*args, **kwargs): + pass + + mpi_print("Num proc:", num_proc) + mpi_print("MPI size:", mpi_size) + # --- MPI initialization end + detector_id = args.detector_id inst, _, _ = detector_id.partition("_") - print("Instrument:", inst) + mpi_print("Instrument:", inst) run = open_run(args.proposal, args.run, data="all") # photon energy and wave length @@ -189,7 +366,7 @@ def main(argv=None): xgm_id = args.xgm_id if not xgm_id: xgm_id = XGM[SASE[inst]] - print("XGM:", xgm_id) + mpi_print("XGM:", xgm_id) # wave length lmd = run[xgm_id, "pulseEnergy.wavelengthUsed"].as_single_value() @@ -206,16 +383,16 @@ def main(argv=None): else: z_motor_id, z_motor_key = ZMOTOR[detector_id] - print(f"Detector Z motor: {z_motor_id}.{z_motor_key}") + mpi_print(f"Detector Z motor: {z_motor_id}.{z_motor_key}") z_motor = run[z_motor_id] clen0 = MIN_CLEN.get(detector_id) clen = 1e-3 * z_motor.run_value(z_motor_key) + clen0 + args.spacing_len - print() - print(f"Photon energy: {photon_en:g} (keV)") - print(f"Wave length: {lmd:.4g} (nm)") - print(f"Camera length: {clen:.3f} (m)") + mpi_print() + mpi_print(f"Photon energy: {photon_en:g} (keV)") + mpi_print(f"Wave length: {lmd:.4g} (nm)") + mpi_print(f"Camera length: {clen:.3f} (m)") if args.images_dir: images_dir = args.images_dir @@ -263,93 +440,55 @@ def main(argv=None): mask_fn = None mask0 = [None] * num_modules - print() - print("Detector source pattern:", det_source_pattern) - print("Num modules:", num_modules) - print("Modules:", modules) - print("Mask:", mask_fn) - print() - print("Num proc:", num_proc) + mpi_print() + mpi_print("Detector source pattern:", det_source_pattern) + mpi_print("Num modules:", num_modules) + mpi_print("Modules:", modules) + mpi_print("Mask:", mask_fn) if args.round_to_photons: adu_per_unit = photon_en rounding_threshold = args.rounding_threshold - print("Rounding threshold:", rounding_threshold) else: adu_per_unit = 1 rounding_threshold = None - pw_aggs = [] - for modi, modno in enumerate(modules): - tm0 = time.perf_counter() - print(f"{modno: 3d}: ", end="") - - source_id = source_pattern.format(detector_id=detector_id, modno=modno) - dc = dc_img.select( - [(source_id, "image.data"), (source_id, "image.mask")], - require_all=True - ) - - shape = dc[source_id, "image.data"].shape[1:] - if args.layout is None: - px_area = None - else: - px_area = PX_AREA[args.layout]() - - agg = ImageAgg( - num_proc, - source_id, - shape, - adu_per_unit, - rounding_threshold, - px_area=px_area, - mask=mask0[modi], - ) - - with mp.Pool(num_proc) as pool: - result_iter = pool.imap_unordered( - agg.compute, enumerate(dc.split_trains(num_proc)) - ) - - for r in result_iter: - pass - - agg.finish() - pw_aggs.append(agg) + mpi_print() + mpi_print("Rounding threshold:", rounding_threshold) + + if args.layout is None: + px_area = None + for modno in modules: + source_id = source_pattern.format( + detector_id=detector_id, modno=modno) + if source_id in dc_img.all_sources: + break + shape = dc_img[source_id, "image.data"].shape[1:] + else: + px_area = PX_AREA[args.layout]() + shape = px_area.shape - tm1 = time.perf_counter() - print(f"{tm1 - tm0:.1f} s") + conditions = { + "cameraLen": clen, + "photonEnergy": photon_en, + "waveLength": lmd, + "run": args.run, + "proposal": args.proposal, + } output_fn = f"powder_sum_p{args.proposal:06d}_r{args.run:04d}.h5" if args.output_dir: output_fn = os.path.join(args.output_dir, output_fn) - with h5py.File(output_fn, "w") as f: - # conditions - f["powderSum/conditions/cameraLen"] = clen - f["powderSum/conditions/photonEnergy"] = photon_en - f["powderSum/conditions/waveLength"] = lmd - f["powderSum/conditions/detectorId"] = detector_id.encode("ascii") - f["powderSum/conditions/run"] = args.run - f["powderSum/conditions/proposal"] = args.proposal - # average image - f["powderSum/image/mean"] = np.stack( - [pw_aggs[i].mean for i in range(num_modules)] - ) - f["powderSum/image/std"] = np.stack( - [pw_aggs[i].deviation for i in range(num_modules)] - ) - f["powderSum/image/count"] = np.stack( - [pw_aggs[i].count for i in range(num_modules)] - ) - f["powderSum/image/numFrames"] = np.stack( - [pw_aggs[i].num_frames for i in range(num_modules)] - ) - f["powderSum/image/mask"] = np.stack( - [~pw_aggs[i].mask | (pw_aggs[i].count == 0) - for i in range(num_modules)] - ).astype(np.uint8) - f["powderSum/image/modules"] = modules + agg = ImageAgg(num_proc, detector_id, shape, modules, adu_per_unit, + rounding_threshold, px_area=px_area) + + agg.prepare(output_fn, mask=mask0, comm=comm, conditions=conditions) + agg.compute(dc_img, source_pattern) + agg.flush() + + tm1 = time.perf_counter() + mpi_print(f"Walltime: {tm1 - tm0:.0f} s") if __name__ == "__main__":