From 9d04d20e6c287a67e527a5b4371dd0040458799a Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Fri, 17 Mar 2023 15:23:06 +0100
Subject: [PATCH] (Ab)use entry point extras to specify detectors

TODO: add boilerplate to remaining detectors.  For now, AGIPD will have both
IntetgratedIntensityAddon and RandomFramesAddon whereas DSSC only has RandomFramesAddon.
---
 setup.py                                      |  4 ++--
 src/calng/base_correction.py                  | 14 +++++++-------
 src/calng/correction_addons/base_addon.py     |  1 -
 .../correction_addons/integrated_intensity.py |  1 -
 src/calng/corrections/AgipdCorrection.py      | 16 ++++++++--------
 src/calng/corrections/DsscCorrection.py       | 19 +++++++++++++------
 6 files changed, 30 insertions(+), 25 deletions(-)

diff --git a/setup.py b/setup.py
index 64e89df5..6a674d21 100644
--- a/setup.py
+++ b/setup.py
@@ -51,8 +51,8 @@ setup(name='calng',
               'RoiTool = calng.RoiTool:RoiTool',
           ],
 
-          'karabo.calng_correction_addon': [
-              'IntegratedIntensity = calng.correction_addons.integrated_intensity:IntegratedIntensityAddon',
+          'calng.correction_addon': [
+              'IntegratedIntensity = calng.correction_addons.integrated_intensity:IntegratedIntensityAddon [agipd]',
               'RandomFrames = calng.correction_addons.random_frames:RandomFramesAddon',
           ]
       },
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 45e1830a..9fe8e7c9 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -1428,13 +1428,13 @@ def add_preview_outputs(schema, channels):
 
 
 def add_addon_nodes(schema, device_class, prefix="addons"):
-    for addon in entry_points().get("karabo.calng_correction_addon", []):
-        addon_class = addon.load()
-        if (
-            "*" in addon_class._target_devices
-            or device_class.__name__ in addon_class._target_devices
-        ):
-            device_class._available_addons.append(addon_class)
+    det_name = device_class.__name__[:-len("Correction")].lower()
+    device_class._available_addons = [
+        addon.load()
+        for addon in entry_points().get("calng.correction_addon", [])
+        if not addon.extras
+        or det_name in (extra[0] for extra in addon.extras)
+    ]
 
     for addon_class in device_class._available_addons:
         (
diff --git a/src/calng/correction_addons/base_addon.py b/src/calng/correction_addons/base_addon.py
index 95e330ca..72800488 100644
--- a/src/calng/correction_addons/base_addon.py
+++ b/src/calng/correction_addons/base_addon.py
@@ -1,5 +1,4 @@
 class BaseCorrectionAddon:
-    _target_devices = {"*"}  # if not "*", then full correction device names
     _node_name = None  # subclass must set (usually name of addon minus "Addon" suffix)
 
     @staticmethod
diff --git a/src/calng/correction_addons/integrated_intensity.py b/src/calng/correction_addons/integrated_intensity.py
index 247bebd9..61b1ab5b 100644
--- a/src/calng/correction_addons/integrated_intensity.py
+++ b/src/calng/correction_addons/integrated_intensity.py
@@ -11,7 +11,6 @@ def maybe_get(a):
 
 
 class IntegratedIntensityAddon(BaseCorrectionAddon):
-    _target_devices = {"AgipdCorrection"}
     _node_name = "integratedIntensity"
 
     @staticmethod
diff --git a/src/calng/corrections/AgipdCorrection.py b/src/calng/corrections/AgipdCorrection.py
index 33faafed..00df2f29 100644
--- a/src/calng/corrections/AgipdCorrection.py
+++ b/src/calng/corrections/AgipdCorrection.py
@@ -553,7 +553,15 @@ class AgipdCorrection(base_correction.BaseCorrection):
 
         # this is not automatically done by superclass for complicated class reasons
         base_correction.add_preview_outputs(expected, AgipdCorrection._preview_outputs)
+        base_correction.add_correction_step_schema(
+            expected,
+            AgipdCorrection._managed_keys,
+            AgipdCorrection._correction_steps,
+        )
         base_correction.add_addon_nodes(expected, AgipdCorrection)
+        base_correction.add_bad_pixel_config_node(
+            expected, AgipdCorrection._managed_keys
+        )
         AgipdCalcatFriend.add_schema(expected, AgipdCorrection._managed_keys)
         (
             # support both CPU and GPU kernels
@@ -571,14 +579,6 @@ class AgipdCorrection(base_correction.BaseCorrection):
             .commit(),
         )
         AgipdCorrection._managed_keys.add("kernelType")
-        base_correction.add_correction_step_schema(
-            expected,
-            AgipdCorrection._managed_keys,
-            AgipdCorrection._correction_steps,
-        )
-        base_correction.add_bad_pixel_config_node(
-            expected, AgipdCorrection._managed_keys
-        )
 
         # turn off the force MG / HG steps by default
         for step in ("forceMgIfBelow", "forceHgIfBelow"):
diff --git a/src/calng/corrections/DsscCorrection.py b/src/calng/corrections/DsscCorrection.py
index 4cb71797..43e3ff4a 100644
--- a/src/calng/corrections/DsscCorrection.py
+++ b/src/calng/corrections/DsscCorrection.py
@@ -220,7 +220,15 @@ class DsscCorrection(base_correction.BaseCorrection):
             .setNewDefaultValue("pulse")
             .commit(),
         )
+        base_correction.add_preview_outputs(expected, DsscCorrection._preview_outputs)
+        base_correction.add_correction_step_schema(
+            expected,
+            DsscCorrection._managed_keys,
+            DsscCorrection._correction_steps,
+        )
+        base_correction.add_addon_nodes(expected, DsscCorrection)
         DsscCalcatFriend.add_schema(expected, DsscCorrection._managed_keys)
+
         (
             # support both CPU and GPU kernels
             STRING_ELEMENT(expected)
@@ -237,12 +245,7 @@ class DsscCorrection(base_correction.BaseCorrection):
             .commit(),
         )
         DsscCorrection._managed_keys.add("kernelType")
-        base_correction.add_preview_outputs(expected, DsscCorrection._preview_outputs)
-        base_correction.add_correction_step_schema(
-            expected,
-            DsscCorrection._managed_keys,
-            DsscCorrection._correction_steps,
-        )
+
         (
             VECTOR_STRING_ELEMENT(expected)
             .key("managedKeys")
@@ -302,6 +305,10 @@ class DsscCorrection(base_correction.BaseCorrection):
 
         buffer_handle, buffer_array = self._shmem_buffer.next_slot()
         self.kernel_runner.correct(self._correction_flag_enabled)
+        for _, addon in self._enabled_addons:
+            addon.post_correction(
+                self.kernel_runner.processed_data, cell_table, pulse_table, data_hash
+            )
         self.kernel_runner.reshape(
             output_order=self.unsafe_get("dataFormat.outputAxisOrder"),
             out=buffer_array,
-- 
GitLab