From 9da93ac8fc6c32ece8fe1ed15aa77341a69549be Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 22 May 2024 18:47:33 +0200
Subject: [PATCH] Fixed roll.

---
 pes_to_spec/model.py | 27 ++++++++++++++-------------
 1 file changed, 14 insertions(+), 13 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index b696dba..bf61877 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -520,19 +520,20 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
                 y[channel] = np.stack(y[channel], axis=1)
 
         # roll data in the first pulse to check best shift
-        shifts = np.arange(-3, 3+1)
-        chi2 = {ch: np.zeros_like(shifts, dtype=np.float32)
-                for ch in y.keys()}
-        for i, s in enumerate(shifts):
-            meanX = {ch: np.mean(np.roll(y[ch][:,0,:], s, axis=-1), axis=0, keepdims=True)
-                     for ch in y.keys()}
-            for ch in y.keys():
-                chi2[ch] = np.sum(((meanX[ch] - self.mean[ch][:,0,:])/self.std[ch][:,0,:])**2)
-        shift = {ch: shifts[np.argmin(chi2[ch])]
-                for ch in y.keys()}
-        print("Shifts", shift)
-        y = {ch: np.roll(y[ch], shift[ch], axis=-1)
-             for ch in y.keys()}
+        if len(self.mean) > 0:
+            shifts = np.arange(-3, 3+1)
+            chi2 = {ch: np.zeros_like(shifts, dtype=np.float32)
+                    for ch in y.keys()}
+            for i, s in enumerate(shifts):
+                meanX = {ch: np.mean(np.roll(y[ch][:,:1,:], s, axis=-1), axis=(0, 1), keepdims=True)
+                         for ch in y.keys()}
+                for ch in y.keys():
+                    chi2[ch][i] = np.sum(((meanX[ch] - self.mean[ch])/self.std[ch])**2)
+            shift = {ch: shifts[np.argmin(chi2[ch])]
+                    for ch in y.keys()}
+            print("Shifts", shift)
+            y = {ch: np.roll(y[ch], shift[ch], axis=-1)
+                 for ch in y.keys()}
 
         if not keep_dictionary_structure:
             selected = list(y.values())
-- 
GitLab