diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 564beaba948b346a78248a1c627905886a1cf048..d45e477cfc0e7a8af9f8e3d72bee72cc0adfd6de 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -625,9 +625,8 @@ class Model(TransformerMixin, BaseEstimator): selection_model = self.x_model['select'] low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True) for channel in self.get_channels(): - pca_model = self.channel_pca_model[channel].named_steps["pca"] - low_pca = pca_model.transform(low_res) - low_pca_rec = pca_model.inverse_transform(low_pca) + low_pca = self.channel_pca_model[channel].named_steps["pca"].fit_transform(low_res[channel]) + low_pca_rec = self.channel_pca_model[channel].named_steps["pca"].inverse_transform(low_pca) low_pca_unc = np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True) self.channel_pca_model[channel]['unc'].set_uncertainty(low_pca_unc) @@ -728,8 +727,9 @@ class Model(TransformerMixin, BaseEstimator): """ joblib.dump([self.x_model, self.y_model, - self.fit_model], - filename, compress='zlib') + self.fit_model, + self.channel_pca_model + ], filename, compress='zlib') @staticmethod def load(filename: str) -> Model: @@ -741,10 +741,11 @@ class Model(TransformerMixin, BaseEstimator): Returns: A new model object. """ - x_model, y_model, fit_model = joblib.load(filename) + x_model, y_model, fit_model, channel_pca_model = joblib.load(filename) obj = Model() obj.x_model = x_model obj.y_model = y_model obj.fit_model = fit_model + obj.channel_pca_model = channel_pca_model return obj