From 6c2e51cb02024bd21bd77bc249bd3e2acb3c137d Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 18 Apr 2023 18:06:35 +0200
Subject: [PATCH] Deal with both BNN and classical model.

---
 pes_to_spec/model.py | 27 ++++++++++++++++++---------
 1 file changed, 18 insertions(+), 9 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 39b6c4b..214ce31 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -551,6 +551,7 @@ class Model(TransformerMixin, BaseEstimator):
             self.fit_model = BNNModel()
         else:
             self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
+        self.bnn = bnn
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
@@ -912,7 +913,7 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
-                     self.fit_model.state_dict(),
+                     self.fit_model.state_dict() if self.bnn else self.fit_model,
                      self.channel_pca,
                      #self.channel_fit_model
                      DataHolder(dict(
@@ -925,6 +926,7 @@ class Model(TransformerMixin, BaseEstimator):
                                      resolution=self.resolution,
                                      transfer_function=self.transfer_function,
                                      impulse_response=self.impulse_response,
+                                     bnn=self.bnn,
                                     )
                                ),
                      self.ood,
@@ -948,15 +950,8 @@ class Model(TransformerMixin, BaseEstimator):
          ood,
          kde_xgm,
         ) = joblib.load(filename)
+
         obj = Model()
-        obj.x_select = x_select
-        obj.x_model = x_model
-        obj.y_model = y_model
-        obj.fit_model = BNNModel(state_dict=fit_model)
-        obj.channel_pca = channel_pca
-        #obj.channel_fit_model = channel_fit_model
-        obj.ood = ood
-        obj.kde_xgm = kde_xgm
 
         extra = extra.get_data()
         obj.mu_xgm = extra["mu_xgm"]
@@ -968,5 +963,19 @@ class Model(TransformerMixin, BaseEstimator):
         obj.resolution = extra["resolution"]
         obj.transfer_function = extra["transfer_function"]
         obj.impulse_response = extra["impulse_response"]
+        obj.bnn = extra["bnn"]
+
+        obj.x_select = x_select
+        obj.x_model = x_model
+        obj.y_model = y_model
+        if obj.bnn:
+            obj.fit_model = BNNModel(state_dict=fit_model)
+        else:
+            obj.fit_model = fit_model
+        obj.channel_pca = channel_pca
+        #obj.channel_fit_model = channel_fit_model
+        obj.ood = ood
+        obj.kde_xgm = kde_xgm
+
         return obj
 
-- 
GitLab