From 9da93ac8fc6c32ece8fe1ed15aa77341a69549be Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Wed, 22 May 2024 18:47:33 +0200 Subject: [PATCH] Fixed roll. --- pes_to_spec/model.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index b696dba..bf61877 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -520,19 +520,20 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): y[channel] = np.stack(y[channel], axis=1) # roll data in the first pulse to check best shift - shifts = np.arange(-3, 3+1) - chi2 = {ch: np.zeros_like(shifts, dtype=np.float32) - for ch in y.keys()} - for i, s in enumerate(shifts): - meanX = {ch: np.mean(np.roll(y[ch][:,0,:], s, axis=-1), axis=0, keepdims=True) - for ch in y.keys()} - for ch in y.keys(): - chi2[ch] = np.sum(((meanX[ch] - self.mean[ch][:,0,:])/self.std[ch][:,0,:])**2) - shift = {ch: shifts[np.argmin(chi2[ch])] - for ch in y.keys()} - print("Shifts", shift) - y = {ch: np.roll(y[ch], shift[ch], axis=-1) - for ch in y.keys()} + if len(self.mean) > 0: + shifts = np.arange(-3, 3+1) + chi2 = {ch: np.zeros_like(shifts, dtype=np.float32) + for ch in y.keys()} + for i, s in enumerate(shifts): + meanX = {ch: np.mean(np.roll(y[ch][:,:1,:], s, axis=-1), axis=(0, 1), keepdims=True) + for ch in y.keys()} + for ch in y.keys(): + chi2[ch][i] = np.sum(((meanX[ch] - self.mean[ch])/self.std[ch])**2) + shift = {ch: shifts[np.argmin(chi2[ch])] + for ch in y.keys()} + print("Shifts", shift) + y = {ch: np.roll(y[ch], shift[ch], axis=-1) + for ch in y.keys()} if not keep_dictionary_structure: selected = list(y.values()) -- GitLab