From a9552016028dcbcbaca0b86abb75e7e308af15fd Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Wed, 11 Oct 2023 08:41:59 +0200 Subject: [PATCH] Pad arrays with zeros when in multi-bunch mode. --- pes_to_spec/model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 98f7328..8291089 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -309,9 +309,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): if self.delta_tof is not None: first = max(0, self.tof_start - self.delta_tof) last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof) - y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]], axis=1) + y = {channel: [item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]] for channel, item in X.items() if channel in self.channels} + # pad it with zeros, if we reach the edge of the array + for channel in y.keys(): + y[channel] = [np.pad(y[channel][j], (0, 2*delta_tof - len(y[channel][j]))) + for j in range(len(y[channel]))] + y[channel] = np.stack(y[channel], axis=1) if not keep_dictionary_structure: selected = list(y.values()) if pulse_energy is not None: @@ -883,6 +888,7 @@ class Model(TransformerMixin, BaseEstimator): n_components = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0] if len(n_components) > 0: n_components = n_components[0] + n_components = max(600, n_components) print(f"Using {n_components} comp. for PES PCA.") self.x_model.set_params(pca__n_components=n_components) @@ -896,6 +902,7 @@ class Model(TransformerMixin, BaseEstimator): n_components_hr = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0] if len(n_components_hr) > 0: n_components_hr = n_components_hr[0] + n_components_hr = max(20, n_components_hr) print(f"Using {n_components_hr} comp. for grating spec. PCA.") self.y_model.set_params(pca__n_components=n_components_hr) -- GitLab