From a9552016028dcbcbaca0b86abb75e7e308af15fd Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 11 Oct 2023 08:41:59 +0200
Subject: [PATCH] Pad arrays with zeros when in multi-bunch mode.

---
 pes_to_spec/model.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 98f7328..8291089 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -309,9 +309,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         if self.delta_tof is not None:
             first = max(0, self.tof_start - self.delta_tof)
             last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
-            y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]], axis=1)
+            y = {channel: [item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]]
                  for channel, item in X.items()
                    if channel in self.channels}
+            # pad it with zeros, if we reach the edge of the array
+            for channel in y.keys():
+                y[channel] = [np.pad(y[channel][j], (0, 2*delta_tof - len(y[channel][j])))
+                              for j in range(len(y[channel]))]
+                y[channel] = np.stack(y[channel], axis=1)
         if not keep_dictionary_structure:
             selected = list(y.values())
             if pulse_energy is not None:
@@ -883,6 +888,7 @@ class Model(TransformerMixin, BaseEstimator):
         n_components = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0]
         if len(n_components) > 0:
             n_components = n_components[0]
+        n_components = max(600, n_components)
 
         print(f"Using {n_components} comp. for PES PCA.")
         self.x_model.set_params(pca__n_components=n_components)
@@ -896,6 +902,7 @@ class Model(TransformerMixin, BaseEstimator):
         n_components_hr = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0]
         if len(n_components_hr) > 0:
             n_components_hr = n_components_hr[0]
+        n_components_hr = max(20, n_components_hr)
 
         print(f"Using {n_components_hr} comp. for grating spec. PCA.")
         self.y_model.set_params(pca__n_components=n_components_hr)
-- 
GitLab