From 095cb4ea69743668d0ee3c7d44066e087ec524aa Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 16 Oct 2023 16:52:34 +0200
Subject: [PATCH] Estimate pedestal per channel and reduce memory usage by
 freeing unused variables as soon as possible. Not important usually, but
 avoids OOM crashes when dealing with very long traces.

---
 pes_to_spec/model.py | 28 ++++++++++++++++++++--------
 1 file changed, 20 insertions(+), 8 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 911ddd4..8d53d57 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 from typing import Any, Dict, List, Optional, Union, Tuple, Literal
 
+import sys
 import joblib
 
 import numpy as np
@@ -282,6 +283,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         self.poly = poly
         self.mean = dict()
         self.std = dict()
+        self.pedestal = {ch: 0.0 for ch in self.channels}
 
     def transform(self, X: Dict[str, np.ndarray],
                   keep_dictionary_structure: bool=False,
@@ -309,7 +311,8 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         if self.delta_tof is not None:
             first = max(0, self.tof_start - self.delta_tof)
             last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
-            y = {channel: [item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]]
+            y = {channel: [item[:, (first + delta):(last + delta)] - self.pedestal[channel]
+                           for delta in pulse_spacing[channel]]
                  for channel, item in X.items()
                    if channel in self.channels}
             # pad it with zeros, if we reach the edge of the array
@@ -336,15 +339,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         Returns: The index.
         """
         # reduce on channel and on train ID
-        sum_low_res = - np.mean(sum(list(X.values())), axis=0)
-        zero_level = np.amin(sum_low_res)
+        sum_low_res = np.mean(sum([-(v - self.pedestal[ch]) for ch, v in X.items()]), axis=0)
         axis = np.arange(0.0, sum_low_res.shape[0], 1.0)
         #widths = np.arange(10, 50, step=5)
         #peak_idx = find_peaks_cwt(sum_low_res, widths)
         gaussian = np.exp(-0.5*(axis - sum_low_res.shape[0]//2)**2/20**2)
         gaussian /= np.sum(gaussian, axis=0, keepdims=True)
         # apply it to the data
-        smoothened = fftconvolve(sum_low_res - zero_level, gaussian, mode="same", axes=0)
+        smoothened = fftconvolve(sum_low_res, gaussian, mode="same", axes=0)
         peak_idx = [np.argmax(smoothened)]
         if len(peak_idx) < 1:
             raise PromptNotFoundError()
@@ -368,6 +370,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
 
         Returns: The object itself.
         """
+        # estimate pedestal
+        X_pedestal = {ch: np.mean(v[:10])
+                      for ch, v in X.items()}
         self.tof_start = self.estimate_prompt_peak(X)
         X_tr = self.transform(X, keep_dictionary_structure=True)
         self.mean = {ch: np.mean(X_tr[ch], axis=0, keepdims=True)
@@ -864,6 +869,11 @@ class Model(TransformerMixin, BaseEstimator):
         print("Finding region-of-interest")
         self.x_select.fit(low_res_data)
         low_res_select = self.x_select.transform(low_res_data, pulse_energy=pulse_energy)
+        low_res_selected_dict = self.x_select.transform(low_res_data, keep_dictionary_structure=True)
+
+        # after this we do not need low_res_data anymore
+        del low_res_data
+
         # keep the number of pulses
         B, P, _ = low_res_select.shape
         low_res_select = low_res_select.reshape((B*P, -1))
@@ -874,6 +884,8 @@ class Model(TransformerMixin, BaseEstimator):
         cov = EllipticEnvelope(random_state=0).fit(lr_sums)
         good_low_res = cov.predict(lr_sums)
         filter_lr = (good_low_res > 0)
+        del cov
+
         low_res_filter = low_res_select[filter_hr & filter_lr, :]
         high_res_filter = high_res_data[filter_hr & filter_lr, :]
         weights_filter = weights
@@ -890,6 +902,7 @@ class Model(TransformerMixin, BaseEstimator):
         if len(n_components) > 0:
             n_components = n_components[0]
         n_components = max(600, n_components)
+        del pca_test
 
         print(f"Using {n_components} comp. for PES PCA.")
         self.x_model.set_params(pca__n_components=n_components)
@@ -904,6 +917,7 @@ class Model(TransformerMixin, BaseEstimator):
         if len(n_components_hr) > 0:
             n_components_hr = n_components_hr[0]
         n_components_hr = max(20, n_components_hr)
+        del pca_test
 
         print(f"Using {n_components_hr} comp. for grating spec. PCA.")
         self.y_model.set_params(pca__n_components=n_components_hr)
@@ -1003,12 +1017,10 @@ class Model(TransformerMixin, BaseEstimator):
             return high_res.reshape((B, P, -1))
 
         # for consistency check per channel
-        selection_model = self.x_select
-        low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         for channel in self.get_channels():
-            B, P, _ = low_res_selected[channel].shape
+            B, P, _ = low_res_selected_dict[channel].shape
             print(f"Calculate PCA on {channel}")
-            low_pca = self.channel_pca[channel].fit_transform(low_res_selected[channel].reshape(B*P, -1))
+            low_pca = self.channel_pca[channel].fit_transform(low_res_selected_dict[channel].reshape(B*P, -1))
             self.ood[channel].fit(low_pca)
 
         print("End of fit.")
-- 
GitLab