diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 98f732821f9ffe7e1a2c48240a8ab11452f26b0d..8291089ef3ff0b9ef42d38b945df470d107ee41a 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -309,9 +309,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): if self.delta_tof is not None: first = max(0, self.tof_start - self.delta_tof) last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof) - y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]], axis=1) + y = {channel: [item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]] for channel, item in X.items() if channel in self.channels} + # pad it with zeros, if we reach the edge of the array + for channel in y.keys(): + y[channel] = [np.pad(y[channel][j], (0, 2*delta_tof - len(y[channel][j]))) + for j in range(len(y[channel]))] + y[channel] = np.stack(y[channel], axis=1) if not keep_dictionary_structure: selected = list(y.values()) if pulse_energy is not None: @@ -883,6 +888,7 @@ class Model(TransformerMixin, BaseEstimator): n_components = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0] if len(n_components) > 0: n_components = n_components[0] + n_components = max(600, n_components) print(f"Using {n_components} comp. for PES PCA.") self.x_model.set_params(pca__n_components=n_components) @@ -896,6 +902,7 @@ class Model(TransformerMixin, BaseEstimator): n_components_hr = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0] if len(n_components_hr) > 0: n_components_hr = n_components_hr[0] + n_components_hr = max(20, n_components_hr) print(f"Using {n_components_hr} comp. for grating spec. PCA.") self.y_model.set_params(pca__n_components=n_components_hr)