diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 39b6c4b7a7267ef4cddb0dc3462845f6b14da377..214ce31125dad632724c6ef8e71ae1d48247b78e 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -551,6 +551,7 @@ class Model(TransformerMixin, BaseEstimator): self.fit_model = BNNModel() else: self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) + self.bnn = bnn self.kde_xgm = None self.mu_xgm = np.nan @@ -912,7 +913,7 @@ class Model(TransformerMixin, BaseEstimator): joblib.dump([self.x_select, self.x_model, self.y_model, - self.fit_model.state_dict(), + self.fit_model.state_dict() if self.bnn else self.fit_model, self.channel_pca, #self.channel_fit_model DataHolder(dict( @@ -925,6 +926,7 @@ class Model(TransformerMixin, BaseEstimator): resolution=self.resolution, transfer_function=self.transfer_function, impulse_response=self.impulse_response, + bnn=self.bnn, ) ), self.ood, @@ -948,15 +950,8 @@ class Model(TransformerMixin, BaseEstimator): ood, kde_xgm, ) = joblib.load(filename) + obj = Model() - obj.x_select = x_select - obj.x_model = x_model - obj.y_model = y_model - obj.fit_model = BNNModel(state_dict=fit_model) - obj.channel_pca = channel_pca - #obj.channel_fit_model = channel_fit_model - obj.ood = ood - obj.kde_xgm = kde_xgm extra = extra.get_data() obj.mu_xgm = extra["mu_xgm"] @@ -968,5 +963,19 @@ class Model(TransformerMixin, BaseEstimator): obj.resolution = extra["resolution"] obj.transfer_function = extra["transfer_function"] obj.impulse_response = extra["impulse_response"] + obj.bnn = extra["bnn"] + + obj.x_select = x_select + obj.x_model = x_model + obj.y_model = y_model + if obj.bnn: + obj.fit_model = BNNModel(state_dict=fit_model) + else: + obj.fit_model = fit_model + obj.channel_pca = channel_pca + #obj.channel_fit_model = channel_fit_model + obj.ood = ood + obj.kde_xgm = kde_xgm + return obj