From 73e78edea20e59f604ab7f82688115d99a8448d7 Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Thu, 17 Oct 2024 10:33:22 +0200
Subject: [PATCH] Fix according refactoring, add photon rounding and extra
 masks

---
 .../correction_addons/autocorrelation.py      | 110 +++++++++++++-----
 1 file changed, 79 insertions(+), 31 deletions(-)

diff --git a/src/calng/correction_addons/autocorrelation.py b/src/calng/correction_addons/autocorrelation.py
index 9cfbf7be..e9330a53 100644
--- a/src/calng/correction_addons/autocorrelation.py
+++ b/src/calng/correction_addons/autocorrelation.py
@@ -1,9 +1,12 @@
+import h5py
 import numpy as np
+from calngUtils import shmem_utils
+from karabo.bound import (
+    DOUBLE_ELEMENT, INT32_ELEMENT, NDARRAY_ELEMENT, STRING_ELEMENT,
+    VECTOR_STRING_ELEMENT)
 
-from karabo.bound import NDARRAY_ELEMENT, DOUBLE_ELEMENT, STRING_ELEMENT
-
+from .. import base_kernel_runner
 from .base_addon import BaseCorrectionAddon
-from .. import shmem_utils
 
 
 def block(b):
@@ -77,6 +80,11 @@ class Autocorrelation(BaseCorrectionAddon):
     @staticmethod
     def extend_device_schema(schema, prefix):
         (
+            INT32_ELEMENT(schema)
+            .key(f"{prefix}.moduleNumber")
+            .readOnly().initialValue(-1)
+            .commit(),
+
             STRING_ELEMENT(schema)
             .key(f"{prefix}.boundary")
             .tags("managed")
@@ -93,6 +101,22 @@ class Autocorrelation(BaseCorrectionAddon):
             .defaultValue(9.3)
             .reconfigurable()
             .commit(),
+
+            DOUBLE_ELEMENT(schema)
+            .key(f"{prefix}.roundingThreshold")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue(0.7)
+            .reconfigurable()
+            .commit(),
+
+            VECTOR_STRING_ELEMENT(schema)
+            .key(f"{prefix}.maskPaths")
+            .tags("managed")
+            .assignmentOptional()
+            .defaultValue([])
+            .reconfigurable()
+            .commit(),
         )
 
     @staticmethod
@@ -104,35 +128,51 @@ class Autocorrelation(BaseCorrectionAddon):
             .commit(),
         )
 
-    def __init__(self, config):
+    def __init__(self, device, prefix, config):
+        super().__init__(device, prefix, config)
+        self._shape = None
+        self._shmem_buffer = None
+
         global cupy
         import cupy
 
+        global fft
+        kernel_type = base_kernel_runner.KernelRunnerTypes[
+            device.unsafe_get("kernelType")
+        ]
+        if kernel_type is base_kernel_runner.KernelRunnerTypes.CPU:
+            import scipy.fft as fft
+        else:
+            import cupyx.scipy.fft as fft
+
         self._intensity_per_photon = config["intenstityPerPhoton"]
         self._autocorr = AUTOCORR_FUN[config["boundary"]]
+        self._rounding_threshold = config["roundingThreshold"]
 
-        self._shmem_buffer = None
+        da = device["constantParameters.karaboDa"]
+        self._modno = int(da[-2:])
+        device.set(f"{prefix}.moduleNumber", self._modno)
+
+        self._load_mask(config["maskPaths"])
 
     def __del__(self):
         del self._shmem_buffer
         super().__del__()
 
-    def _initialization(self):
-        global fft
-        if self._device.kernel_runner._gpu_based:
-            import cupyx.scipy.fft as fft
-        else:
-            import scipy.fft as fft
-
-        self._update_buffer()
-
-    def post_correction(self, processed_data, cell_table, pulse_table, output_hash):
-        nf, nx, ny = processed_data.shape
-        if nf != self._nframe:
-            self._update_buffer()
+    def post_correction(self, train_id, processed_data, cell_table, pulse_table, output_hash):
+        shape = processed_data.shape
+        if shape != self._shape:
+            self._update_buffer(shape)
 
-        data = np.around(processed_data / self._intensity_per_photon)
-        data[np.isnan(data) | (data < 0.0)] = 0.0
+        np.around(
+            processed_data / self._intensity_per_photon - (self._rounding_threshold - 0.5),
+            out=processed_data
+        )
+        processed_data[processed_data < 0.0] = 0.0
+        if self._mask is not None:
+            processed_data[:, self._mask] = np.nan
+        data = processed_data.copy()
+        data[np.isnan(data)] = 0.0
 
         autocorr = self._autocorr(data)
 
@@ -150,23 +190,18 @@ class Autocorrelation(BaseCorrectionAddon):
         else:
             output_hash.set("image.autocorr", buffer_array)
 
-        delete_paths = output_hash.get("deleteThese", default=[])
-        output_hash.set("deleteThese",
-                        delete_paths + ["image.autocorr"])
-
-    # def post_reshape(self, reshaped_data, cell_table, pulse_table, output_hash):
-    #     pass
-
     def reconfigure(self, changed_config):
         if changed_config.has("intenstityPerPhoton"):
             self._intensity_per_photon = changed_config["intenstityPerPhoton"]
         if changed_config.has("boundary"):
             self._autocorr = AUTOCORR_FUN[changed_config["boundary"]]
+        if changed_config.has("roundingThreshold"):
+            self._rounding_threshold = changed_config["roundingThreshold"]
+        if changed_config.has("maskPaths"):
+            self._load_mask(changed_config["maskPaths"])
 
-    def _update_buffer(self):
-        nf, _, nx, ny = self._device.input_data_shape
-        self._nframe = nf
-        shape = (nf, nx, ny)
+    def _update_buffer(self, shape):
+        self._shape = shape
 
         if self._shmem_buffer is None:
             shmem_buffer_name = self._device.getInstanceId() + ":dataOutput/autocorrelation"
@@ -181,3 +216,16 @@ class Autocorrelation(BaseCorrectionAddon):
                 self._shmem_buffer.cuda_pin()
         else:
             self._shmem_buffer.change_shape(shape)
+
+    def _load_mask(self, paths):
+        self._mask = None
+        for fn in paths:
+            try:
+                with h5py.File(fn, 'r') as f:
+                    mask = f["entry_1/data_1/mask"][self._modno] != 0
+                if self._mask is None:
+                    self._mask = mask
+                else:
+                    self._mask = self._mask | mask
+            except Exception as e:
+                self._device.logger.warn(f"Aurocorrelation addon: {e}")
-- 
GitLab