diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index 262defdcbdfd1b0b3297527babf5b45900030c77..1783c0b81dd2bd37726a3503db44a3c3d8a3da34 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -28,6 +28,7 @@ from cal_tools.agipdutils import (
     melt_snowy_pixels,
 )
 from cal_tools.enums import AgipdGainMode, BadPixels, SnowResolution
+from cal_tools.litpx_counter import LitPixelCounter
 from logging import warning
 
 
@@ -1542,71 +1543,6 @@ class AgipdCorrections:
             counter.process(slice(first, last))
 
 
-class LitPixelCounter:
-    channel = "litpx"
-    output_fields = [
-        "cellId", "pulseId", "trainId", "litPixels", "goodPixels"]
-
-    def __init__(self, data, threshold=0.8):
-        self.data = data.copy()
-        for name in ["data", "mask", "cellId", "pulseId", "trainId"]:
-            assert name in data
-
-        self.image = data["data"]
-        self.mask = data["mask"]
-
-        self.threshold = threshold
-        self.max_images = data["data"].shape[0]
-        self.num_images = self.max_images
-
-        self.num_good_px = sharedmem.full(self.max_images, 0, int)
-        self.num_lit_px = sharedmem.full(self.max_images, 0, int)
-
-        self.data["litPixels"] = self.num_lit_px
-        self.data["goodPixels"] = self.num_good_px
-
-    def set_num_images(self, num_images):
-        self.num_images = num_images
-
-    def process(self, chunk):
-        ix = range(*chunk.indices(self.num_images))
-        for i in ix:
-            mask = self.mask[i] == 0
-            self.num_lit_px[i] = np.sum(
-                self.image[i] > self.threshold, initial=0, where=mask)
-            self.num_good_px[i] = np.sum(mask)
-
-    def create_schema(self, source, file_trains=None, count=None):
-        if file_trains is None:
-            file_trains = source.file["INDEX/trainId"][:]
-
-        if count is None:
-            tid = self.data["trainId"][:self.num_images]
-            trains, count = np.unique(tid, return_counts=True)
-            count = count[np.in1d(trains, file_trains)]
-
-        if len(file_trains) != len(count):
-            raise ValueError(
-                "The length of data count does not match the number of trains")
-        if np.sum(count) != self.num_images:
-            raise ValueError(
-                "The sum of data count does not match "
-                "the total number of data entries")
-
-        source.create_index(**{self.channel: count})
-        for key in self.output_fields:
-            source.create_dataset(
-                f"{self.channel}/{key}",
-                shape=(self.num_images,),
-                dtype=self.data[key].dtype
-            )
-
-    def write(self, source):
-        channel = source[self.channel]
-        for key in self.output_fields:
-            channel[key][:] = self.data[key][:self.num_images]
-
-
 def validate_selected_pulses(
     max_pulses: List[int], max_cells: int
 ) -> List[int]:
diff --git a/src/cal_tools/litpx_counter.py b/src/cal_tools/litpx_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..660ddbf72eaea1a42dc083925d09e08e0188a650
--- /dev/null
+++ b/src/cal_tools/litpx_counter.py
@@ -0,0 +1,85 @@
+import numpy as np
+import sharedmem
+
+
+class AnalysisAddon:
+    channel = "data"
+    output_fields = [
+        "cellId", "pulseId", "trainId"]
+    required_data = [
+        "cellId", "pulseId", "trainId"]
+
+    def __init__(self, data):
+        required_data = set(self.required_data) | {"pulseId"}
+        for name in required_data:
+            if name not in data:
+                raise ValueError(f"The field '{name}' is missed in 'data'")
+
+        self.data = data.copy()
+        self.max_images = data["pulseId"].shape[0]
+        self.num_images = self.max_images
+
+    def set_num_images(self, num_images):
+        self.num_images = num_images
+
+    def process(self, chunk):
+        raise NotImplementedError
+
+    def create_schema(self, source, file_trains=None, count=None):
+        if file_trains is None:
+            file_trains = source.file["INDEX/trainId"][:]
+
+        if count is None:
+            tid = self.data["trainId"][:self.num_images]
+            trains, count = np.unique(tid, return_counts=True)
+            count = count[np.in1d(trains, file_trains)]
+
+        if len(file_trains) != len(count):
+            raise ValueError(
+                "The length of data count does not match the number of trains")
+        if np.sum(count) != self.num_images:
+            raise ValueError(
+                "The sum of data count does not match "
+                "the total number of data entries")
+
+        source.create_index(**{self.channel: count})
+        for key in self.output_fields:
+            source.create_dataset(
+                f"{self.channel}/{key}",
+                shape=(self.num_images,),
+                dtype=self.data[key].dtype
+            )
+
+    def write(self, source):
+        channel = source[self.channel]
+        for key in self.output_fields:
+            channel[key][:] = self.data[key][:self.num_images]
+
+
+class LitPixelCounter(AnalysisAddon):
+    channel = "litpx"
+    output_fields = [
+        "cellId", "pulseId", "trainId", "litPixels", "goodPixels"]
+    required_data = [
+        "data", "mask", "cellId", "pulseId", "trainId"]
+
+    def __init__(self, data, threshold=0.8):
+        super().__init__(data)
+
+        self.image = data["data"]
+        self.mask = data["mask"]
+
+        self.threshold = threshold
+        self.num_good_px = sharedmem.full(self.max_images, 0, int)
+        self.num_lit_px = sharedmem.full(self.max_images, 0, int)
+
+        self.data["litPixels"] = self.num_lit_px
+        self.data["goodPixels"] = self.num_good_px
+
+    def process(self, chunk):
+        ix = range(*chunk.indices(self.num_images))
+        for i in ix:
+            mask = self.mask[i] == 0
+            self.num_lit_px[i] = np.sum(
+                self.image[i] > self.threshold, initial=0, where=mask)
+            self.num_good_px[i] = np.sum(mask)