From 6c2e51cb02024bd21bd77bc249bd3e2acb3c137d Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Tue, 18 Apr 2023 18:06:35 +0200 Subject: [PATCH] Deal with both BNN and classical model. --- pes_to_spec/model.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 39b6c4b..214ce31 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -551,6 +551,7 @@ class Model(TransformerMixin, BaseEstimator): self.fit_model = BNNModel() else: self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) + self.bnn = bnn self.kde_xgm = None self.mu_xgm = np.nan @@ -912,7 +913,7 @@ class Model(TransformerMixin, BaseEstimator): joblib.dump([self.x_select, self.x_model, self.y_model, - self.fit_model.state_dict(), + self.fit_model.state_dict() if self.bnn else self.fit_model, self.channel_pca, #self.channel_fit_model DataHolder(dict( @@ -925,6 +926,7 @@ class Model(TransformerMixin, BaseEstimator): resolution=self.resolution, transfer_function=self.transfer_function, impulse_response=self.impulse_response, + bnn=self.bnn, ) ), self.ood, @@ -948,15 +950,8 @@ class Model(TransformerMixin, BaseEstimator): ood, kde_xgm, ) = joblib.load(filename) + obj = Model() - obj.x_select = x_select - obj.x_model = x_model - obj.y_model = y_model - obj.fit_model = BNNModel(state_dict=fit_model) - obj.channel_pca = channel_pca - #obj.channel_fit_model = channel_fit_model - obj.ood = ood - obj.kde_xgm = kde_xgm extra = extra.get_data() obj.mu_xgm = extra["mu_xgm"] @@ -968,5 +963,19 @@ class Model(TransformerMixin, BaseEstimator): obj.resolution = extra["resolution"] obj.transfer_function = extra["transfer_function"] obj.impulse_response = extra["impulse_response"] + obj.bnn = extra["bnn"] + + obj.x_select = x_select + obj.x_model = x_model + obj.y_model = y_model + if obj.bnn: + obj.fit_model = BNNModel(state_dict=fit_model) + else: + obj.fit_model = fit_model + obj.channel_pca = channel_pca + #obj.channel_fit_model = channel_fit_model + obj.ood = ood + obj.kde_xgm = kde_xgm + return obj -- GitLab