diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index b696dba65c23de698e0740f721bb639f4062597e..bf6187769394a3d9e3d00184c26cd0b41ff58413 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -520,19 +520,20 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): y[channel] = np.stack(y[channel], axis=1) # roll data in the first pulse to check best shift - shifts = np.arange(-3, 3+1) - chi2 = {ch: np.zeros_like(shifts, dtype=np.float32) - for ch in y.keys()} - for i, s in enumerate(shifts): - meanX = {ch: np.mean(np.roll(y[ch][:,0,:], s, axis=-1), axis=0, keepdims=True) - for ch in y.keys()} - for ch in y.keys(): - chi2[ch] = np.sum(((meanX[ch] - self.mean[ch][:,0,:])/self.std[ch][:,0,:])**2) - shift = {ch: shifts[np.argmin(chi2[ch])] - for ch in y.keys()} - print("Shifts", shift) - y = {ch: np.roll(y[ch], shift[ch], axis=-1) - for ch in y.keys()} + if len(self.mean) > 0: + shifts = np.arange(-3, 3+1) + chi2 = {ch: np.zeros_like(shifts, dtype=np.float32) + for ch in y.keys()} + for i, s in enumerate(shifts): + meanX = {ch: np.mean(np.roll(y[ch][:,:1,:], s, axis=-1), axis=(0, 1), keepdims=True) + for ch in y.keys()} + for ch in y.keys(): + chi2[ch][i] = np.sum(((meanX[ch] - self.mean[ch])/self.std[ch])**2) + shift = {ch: shifts[np.argmin(chi2[ch])] + for ch in y.keys()} + print("Shifts", shift) + y = {ch: np.roll(y[ch], shift[ch], axis=-1) + for ch in y.keys()} if not keep_dictionary_structure: selected = list(y.values())