diff --git a/src/geomtools/powder/image_agg.py b/src/geomtools/powder/image_agg.py
new file mode 100644
index 0000000000000000000000000000000000000000..d31e474aff42eac5c3715eefc74ffa7ba7c4a2bf
--- /dev/null
+++ b/src/geomtools/powder/image_agg.py
@@ -0,0 +1,245 @@
+import multiprocessing as mp
+import time
+
+import h5py
+import numpy as np
+
+from . import shmemarray as shmem
+
+
+class ImageAgg:
+    def __init__(self, nproc, detector_id, shape, modules,
+                 adu_per_unit=1, rounding_threshold=None, px_area=None):
+        self.adu_per_unit = adu_per_unit
+        self.modules = modules
+        self.num_modules = len(modules)
+        self.detector_id = detector_id
+        if rounding_threshold is None:
+            self.rounding = False
+            self.rounding_shift = 0.0
+        else:
+            self.rounding = True
+            self.rounding_shift = rounding_threshold - 0.5
+        self.shape = shape
+        self.nproc = nproc
+
+        arr_shape = (nproc,) + tuple(shape)
+        if px_area is None:
+            px_area = np.ones(shape, np.float32)
+
+        self.px_area = px_area
+
+        self.part_mean = shmem.empty(arr_shape, float)
+        self.part_deviation = shmem.empty(arr_shape, float)
+        self.part_count = shmem.empty(arr_shape, int)
+        self.part_num_frames = shmem.empty(nproc, int)
+
+    def reset(self, module_ix, det_source, mask=None):
+        self.det_source = det_source
+        self.modi = module_ix
+
+        self.r = None
+        self.count = None
+        self.mean = None
+        self.deviation = None
+        self.num_frames = None
+
+    def _compute_worker_part(self, args):
+        i, dc = args
+        mean = self.part_mean[i]
+        mean[:] = 0
+        deviation = self.part_deviation[i]
+        deviation[:] = 0
+        count = self.part_count[i]
+        count[:] = 0
+        self.part_num_frames[i] = np.sum(
+            dc[self.det_source].data_counts(index_group="image",
+                                            labelled=False)
+        )
+        inp_mask = self.mask[self.modi - self.mod0][None, ...]
+
+        for tid, data in dc.trains():
+            a = data[self.det_source]["image.data"] / self.adu_per_unit
+            msk = (data[self.det_source]["image.mask"] == 0) & inp_mask
+
+            a[~msk] = 0.0
+            if self.rounding:
+                msk &= a > -self.rounding_shift - 0.5
+                np.round(a - self.rounding_shift, out=a)
+                np.clip(a, 0.0, None, out=a)
+
+            mean += np.sum(a, 0)
+            deviation += np.sum(a * a, 0)
+            count += np.sum(msk.astype(int), 0)
+        return i
+
+    def compute(self, dc_img, source_pattern):
+        pool = mp.Pool(self.nproc)
+        for modi, modno in self.iter_modules():
+            tm0 = time.perf_counter()
+            print(f"[{self._rank:2d}] {modno:3d}: ", end="")
+
+            source_id = source_pattern.format(
+                detector_id=self.detector_id, modno=modno)
+            dc = dc_img.select(
+                [(source_id, "image.data"), (source_id, "image.mask")],
+                require_all=True
+            )
+            self.reset(modi, source_id)
+            result_iter = pool.imap_unordered(
+                self._compute_worker_part,
+                enumerate(dc.split_trains(self.nproc))
+            )
+
+            for r in result_iter:
+                pass
+
+            self.finish()
+            self.write_module()
+
+            tm1 = time.perf_counter()
+            print(f"{tm1 - tm0:.1f} s")
+
+        pool.terminate()
+        pool.join()
+
+    def finish(self):
+        self.r = type("Moments", (), {})
+        self.count = np.sum(self.part_count, 0)
+        nz = self.count > 0
+
+        self.mean = np.zeros(self.shape, np.float32)
+        self.mean = np.divide(
+            np.sum(self.part_mean, 0), self.count,
+            out=self.mean, where=nz
+        )
+        self.mean /= self.px_area
+
+        self.deviation = np.zeros(self.shape, np.float32)
+        self.deviation = np.divide(
+            np.sum(self.part_deviation, 0), self.count,
+            out=self.deviation, where=nz
+        )
+        self.deviation = np.sqrt(self.deviation - self.mean * self.mean)
+        self.deviation /= self.px_area
+
+        self.num_frames = np.sum(self.part_num_frames)
+
+    def _create_h5datasets(self, h5f):
+        s = (self.num_modules,) + self.shape
+        self._ds_mean = h5f.create_dataset(
+            "powderSum/image/mean", s, dtype="f4")
+        self._ds_std = h5f.create_dataset(
+            "powderSum/image/std", s, dtype="f4")
+        self._ds_count = h5f.create_dataset(
+            "powderSum/image/count", s, dtype="i8")
+        self._ds_nfrm = h5f.create_dataset(
+            "powderSum/image/numFrames", self.num_modules, dtype="i8")
+        self._ds_mask = h5f.create_dataset(
+            "powderSum/image/mask", s, dtype="u1")
+
+    def _create_output_arrays(self):
+        n = self.modN - self.mod0
+        s = (n,) + self.shape
+        self._ds_mean = np.zeros(s, dtype=np.float32)
+        self._ds_std = np.zeros(s, dtype=np.float32)
+        self._ds_count = np.zeros(s, dtype=int)
+        self._ds_nfrm = np.zeros(n, dtype=int)
+        self._ds_mask = np.zeros(s, dtype=np.uint8)
+
+    def prepare(self, output_fn, mask=None, comm=None, conditions={}):
+        if comm is None:
+            size = 1
+            rank = 0
+        else:
+            size = comm.Get_size()
+            rank = comm.Get_rank()
+
+        self.mod0 = rank * self.num_modules // size
+        self.modN = (rank + 1) * self.num_modules // size
+
+        self._rank = rank
+        self._size = size
+        self._comm = comm
+
+        n = self.modN - self.mod0
+        if mask is None:
+            s = (n,) + self.shape
+            self.mask = np.ones(s, bool)
+        else:
+            self.mask = mask[self.mod0:self.modN] == 0
+
+        if rank == 0:
+            h5f = h5py.File(output_fn, "w")
+            # modules
+            h5f["powderSum/image/modules"] = self.modules
+            # conditions
+            conditions_grp = h5f.create_group("powderSum/conditions")
+            conditions_grp["detectorId"] = self.detector_id.encode("ascii")
+            for key, value in conditions.items():
+                conditions_grp[key] = value
+        else:
+            h5f = None
+
+        self._h5f = h5f
+
+        # average image
+        if size > 1:
+            self._create_output_arrays()
+        else:
+            self._create_h5datasets(h5f)
+
+    def iter_modules(self):
+        return enumerate(self.modules[self.mod0:self.modN], start=self.mod0)
+
+    def write_module(self):
+        i = self.modi - self.mod0
+        self._ds_mean[i] = self.mean
+        self._ds_std[i] = self.deviation
+        self._ds_count[i] = self.count
+        self._ds_nfrm[i] = self.num_frames
+        self._ds_mask[i] = ~self.mask[i] | (self.count == 0)
+
+    def _send(self):
+        self._comm.Gatherv(self._ds_mean, None, root=0)
+        self._comm.Gatherv(self._ds_std, None, root=0)
+        self._comm.Gatherv(self._ds_count, None, root=0)
+        self._comm.Gatherv(self._ds_nfrm, None, root=0)
+        self._comm.Gatherv(self._ds_mask, None, root=0)
+
+    def _recv_and_write(self):
+        s = (self.num_modules,) + self.shape
+        buf = np.zeros(s, dtype=np.float32)
+        self._comm.Gatherv(self._ds_mean, buf, root=0)
+        self._h5f["powderSum/image/mean"] = buf
+        self._comm.Gatherv(self._ds_std, buf, root=0)
+        self._h5f["powderSum/image/std"] = buf
+        buf = np.zeros(s, dtype=int)
+        self._comm.Gatherv(self._ds_count, buf, root=0)
+        self._h5f["powderSum/image/count"] = buf
+        buf = np.zeros(self.num_modules, dtype=int)
+        self._comm.Gatherv(self._ds_nfrm, buf, root=0)
+        self._h5f["powderSum/image/numFrames"] = buf
+        buf = np.zeros(s, dtype=np.uint8)
+        self._comm.Gatherv(self._ds_mask, buf, root=0)
+        self._h5f["powderSum/image/mask"] = buf
+
+    def flush(self):
+        if self._size <= 1:
+            return
+        if self._rank == 0:
+            self._recv_and_write()
+            self._h5f.close()
+        else:
+            self._send()
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["_ds_mean"] = None
+        state["_ds_std"] = None
+        state["_ds_count"] = None
+        state["_ds_nfrm"] = None
+        state["_ds_mask"] = None
+        state["_h5f"] = None
+        state["_comm"] = None
+        return state
diff --git a/src/geomtools/powder/powdersum.py b/src/geomtools/powder/powdersum.py
index b40c502506547f50556c103792f1e87a331f5862..82186aed994369cb69925b927eb4346b75c059c3 100644
--- a/src/geomtools/powder/powdersum.py
+++ b/src/geomtools/powder/powdersum.py
@@ -1,16 +1,15 @@
-import multiprocessing as mp
 import os
 import re
 import time
 from argparse import ArgumentParser
 
 import h5py
-import numpy as np
+import numpy as np  # noqa: F401
 import psutil
 from extra_data import RunDirectory, open_run
 from extra_data.read_machinery import find_proposal
 
-from . import shmemarray as shmem
+from .image_agg import ImageAgg
 from .misc import agipd_pixel_area, jungfrau_pixel_area
 
 try:
@@ -45,244 +44,6 @@ PX_AREA = {
 }
 
 
-class ImageAgg:
-    def __init__(self, nproc, detector_id, shape, modules,
-                 adu_per_unit=1, rounding_threshold=None, px_area=None):
-        self.adu_per_unit = adu_per_unit
-        self.modules = modules
-        self.num_modules = len(modules)
-        self.detector_id = detector_id
-        if rounding_threshold is None:
-            self.rounding = False
-            self.rounding_shift = 0.0
-        else:
-            self.rounding = True
-            self.rounding_shift = rounding_threshold - 0.5
-        self.shape = shape
-        self.nproc = nproc
-
-        arr_shape = (nproc,) + tuple(shape)
-        if px_area is None:
-            px_area = np.ones(shape, np.float32)
-
-        self.px_area = px_area
-
-        self.part_mean = shmem.empty(arr_shape, float)
-        self.part_deviation = shmem.empty(arr_shape, float)
-        self.part_count = shmem.empty(arr_shape, int)
-        self.part_num_frames = shmem.empty(nproc, int)
-
-    def reset(self, module_ix, det_source, mask=None):
-        self.det_source = det_source
-        self.modi = module_ix
-
-        self.r = None
-        self.count = None
-        self.mean = None
-        self.deviation = None
-        self.num_frames = None
-
-    def _compute_worker_part(self, args):
-        i, dc = args
-        mean = self.part_mean[i]
-        mean[:] = 0
-        deviation = self.part_deviation[i]
-        deviation[:] = 0
-        count = self.part_count[i]
-        count[:] = 0
-        self.part_num_frames[i] = np.sum(
-            dc[self.det_source].data_counts(index_group="image",
-                                            labelled=False)
-        )
-        inp_mask = self.mask[self.modi - self.mod0][None, ...]
-
-        for tid, data in dc.trains():
-            a = data[self.det_source]["image.data"] / self.adu_per_unit
-            msk = (data[self.det_source]["image.mask"] == 0) & inp_mask
-
-            a[~msk] = 0.0
-            if self.rounding:
-                msk &= a > -self.rounding_shift - 0.5
-                np.round(a - self.rounding_shift, out=a)
-                np.clip(a, 0.0, None, out=a)
-
-            mean += np.sum(a, 0)
-            deviation += np.sum(a * a, 0)
-            count += np.sum(msk.astype(int), 0)
-        return i
-
-    def compute(self, dc_img, source_pattern):
-        pool = mp.Pool(self.nproc)
-        for modi, modno in self.iter_modules():
-            tm0 = time.perf_counter()
-            print(f"[{self._rank:2d}] {modno:3d}: ", end="")
-
-            source_id = source_pattern.format(
-                detector_id=self.detector_id, modno=modno)
-            dc = dc_img.select(
-                [(source_id, "image.data"), (source_id, "image.mask")],
-                require_all=True
-            )
-            self.reset(modi, source_id)
-            result_iter = pool.imap_unordered(
-                self._compute_worker_part,
-                enumerate(dc.split_trains(self.nproc))
-            )
-
-            for r in result_iter:
-                pass
-
-            self.finish()
-            self.write_module()
-
-            tm1 = time.perf_counter()
-            print(f"{tm1 - tm0:.1f} s")
-
-        pool.terminate()
-        pool.join()
-
-    def finish(self):
-        self.r = type("Moments", (), {})
-        self.count = np.sum(self.part_count, 0)
-        nz = self.count > 0
-
-        self.mean = np.zeros(self.shape, np.float32)
-        self.mean = np.divide(
-            np.sum(self.part_mean, 0), self.count,
-            out=self.mean, where=nz
-        )
-        self.mean /= self.px_area
-
-        self.deviation = np.zeros(self.shape, np.float32)
-        self.deviation = np.divide(
-            np.sum(self.part_deviation, 0), self.count,
-            out=self.deviation, where=nz
-        )
-        self.deviation = np.sqrt(self.deviation - self.mean * self.mean)
-        self.deviation /= self.px_area
-
-        self.num_frames = np.sum(self.part_num_frames)
-
-    def _create_h5datasets(self, h5f):
-        s = (self.num_modules,) + self.shape
-        self._ds_mean = h5f.create_dataset(
-            "powderSum/image/mean", s, dtype="f4")
-        self._ds_std = h5f.create_dataset(
-            "powderSum/image/std", s, dtype="f4")
-        self._ds_count = h5f.create_dataset(
-            "powderSum/image/count", s, dtype="i8")
-        self._ds_nfrm = h5f.create_dataset(
-            "powderSum/image/numFrames", self.num_modules, dtype="i8")
-        self._ds_mask = h5f.create_dataset(
-            "powderSum/image/mask", s, dtype="u1")
-
-    def _create_output_arrays(self):
-        n = self.modN - self.mod0
-        s = (n,) + self.shape
-        self._ds_mean = np.zeros(s, dtype=np.float32)
-        self._ds_std = np.zeros(s, dtype=np.float32)
-        self._ds_count = np.zeros(s, dtype=int)
-        self._ds_nfrm = np.zeros(n, dtype=int)
-        self._ds_mask = np.zeros(s, dtype=np.uint8)
-
-    def prepare(self, output_fn, mask=None, comm=None, conditions={}):
-        if comm is None:
-            size = 1
-            rank = 0
-        else:
-            size = comm.Get_size()
-            rank = comm.Get_rank()
-
-        self.mod0 = rank * self.num_modules // size
-        self.modN = (rank + 1) * self.num_modules // size
-
-        self._rank = rank
-        self._size = size
-        self._comm = comm
-
-        n = self.modN - self.mod0
-        if mask is None:
-            s = (n,) + self.shape
-            self.mask = np.ones(s, bool)
-        else:
-            self.mask = mask[self.mod0:self.modN] == 0
-
-        if rank == 0:
-            h5f = h5py.File(output_fn, "w")
-            # modules
-            h5f["powderSum/image/modules"] = self.modules
-            # conditions
-            conditions_grp = h5f.create_group("powderSum/conditions")
-            conditions_grp["detectorId"] = self.detector_id.encode("ascii")
-            for key, value in conditions.items():
-                conditions_grp[key] = value
-        else:
-            h5f = None
-
-        self._h5f = h5f
-
-        # average image
-        if size > 1:
-            self._create_output_arrays()
-        else:
-            self._create_h5datasets(h5f)
-
-    def iter_modules(self):
-        return enumerate(self.modules[self.mod0:self.modN], start=self.mod0)
-
-    def write_module(self):
-        i = self.modi - self.mod0
-        self._ds_mean[i] = self.mean
-        self._ds_std[i] = self.deviation
-        self._ds_count[i] = self.count
-        self._ds_nfrm[i] = self.num_frames
-        self._ds_mask[i] = ~self.mask[i] | (self.count == 0)
-
-    def _send(self):
-        self._comm.Gatherv(self._ds_mean, None, root=0)
-        self._comm.Gatherv(self._ds_std, None, root=0)
-        self._comm.Gatherv(self._ds_count, None, root=0)
-        self._comm.Gatherv(self._ds_nfrm, None, root=0)
-        self._comm.Gatherv(self._ds_mask, None, root=0)
-
-    def _recv_and_write(self):
-        s = (self.num_modules,) + self.shape
-        buf = np.zeros(s, dtype=np.float32)
-        self._comm.Gatherv(self._ds_mean, buf, root=0)
-        self._h5f["powderSum/image/mean"] = buf
-        self._comm.Gatherv(self._ds_std, buf, root=0)
-        self._h5f["powderSum/image/std"] = buf
-        buf = np.zeros(s, dtype=int)
-        self._comm.Gatherv(self._ds_count, buf, root=0)
-        self._h5f["powderSum/image/count"] = buf
-        buf = np.zeros(self.num_modules, dtype=int)
-        self._comm.Gatherv(self._ds_nfrm, buf, root=0)
-        self._h5f["powderSum/image/numFrames"] = buf
-        buf = np.zeros(s, dtype=np.uint8)
-        self._comm.Gatherv(self._ds_mask, buf, root=0)
-        self._h5f["powderSum/image/mask"] = buf
-
-    def flush(self):
-        if self._size <= 1:
-            return
-        if self._rank == 0:
-            self._recv_and_write()
-            self._h5f.close()
-        else:
-            self._send()
-
-    def __getstate__(self):
-        state = self.__dict__.copy()
-        state["_ds_mean"] = None
-        state["_ds_std"] = None
-        state["_ds_count"] = None
-        state["_ds_nfrm"] = None
-        state["_ds_mask"] = None
-        state["_h5f"] = None
-        state["_comm"] = None
-        return state
-
-
 def main(argv=None):
     tm0 = time.perf_counter()