From 8f0ff39983af774498738879c4a5f666c018d63d Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Sat, 7 Oct 2023 09:31:37 +0200
Subject: [PATCH] HOTFIX: Backwards compatibility for autocorrelation.

---
 pes_to_spec/model.py | 44 +++++++++++++++++---------------------------
 1 file changed, 17 insertions(+), 27 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 57cf076..44e2ba6 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -681,6 +681,15 @@ class Model(TransformerMixin, BaseEstimator):
         self.impulse_response = None
         self.auto_corr = None
 
+        self.extra_options = ["mu_xgm", "sigma_xgm",
+                              "wiener_filter_ft", "wiener_filter",
+                              "wiener_energy_ft", "wiener_energy",
+                              "resolution",
+                              "transfer_function", "impulse_response",
+                              "auto_corr",
+                              "model_type",
+                              "n_obs"]
+
     def n_pars(self) -> float:
         """Get number of parameters."""
         if self.model_type in ("bnn", "bnn_rvm"):
@@ -1038,27 +1047,15 @@ class Model(TransformerMixin, BaseEstimator):
         Args:
           filename: File name where to save this.
         """
+        extra = {k: getattr(self, k)
+                for k in self.extra_options}
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
                      self.fit_model.state_dict() if self.model_type in ("bnn", "bnn_rvm") else self.fit_model,
                      self.channel_pca,
                      #self.channel_fit_model
-                     DataHolder(dict(
-                                     mu_xgm=self.mu_xgm,
-                                     sigma_xgm=self.sigma_xgm,
-                                     wiener_filter_ft=self.wiener_filter_ft,
-                                     wiener_filter=self.wiener_filter,
-                                     wiener_energy=self.wiener_energy,
-                                     wiener_energy_ft=self.wiener_energy_ft,
-                                     resolution=self.resolution,
-                                     transfer_function=self.transfer_function,
-                                     impulse_response=self.impulse_response,
-                                     auto_corr=self.auto_corr,
-                                     model_type=self.model_type,
-                                     n_obs=self.n_obs,
-                                    )
-                               ),
+                     DataHolder(extra),
                      self.ood,
                      self.kde_xgm,
                      ], filename, compress='zlib')
@@ -1084,18 +1081,11 @@ class Model(TransformerMixin, BaseEstimator):
         obj = Model()
 
         extra = extra.get_data()
-        obj.mu_xgm = extra["mu_xgm"]
-        obj.sigma_xgm = extra["sigma_xgm"]
-        obj.wiener_filter_ft = extra["wiener_filter_ft"]
-        obj.wiener_filter = extra["wiener_filter"]
-        obj.wiener_energy_ft = extra["wiener_energy_ft"]
-        obj.wiener_energy = extra["wiener_energy"]
-        obj.resolution = extra["resolution"]
-        obj.transfer_function = extra["transfer_function"]
-        obj.impulse_response = extra["impulse_response"]
-        obj.auto_corr = extra["auto_corr"]
-        obj.model_type = extra["model_type"]
-        obj.n_obs = extra["n_obs"]
+        for k in obj.extra_options:
+            if k not in extra:
+                setattr(obj, k, None)
+            else:
+                setattr(obj, k, extra[k])
 
         obj.x_select = x_select
         obj.x_model = x_model
-- 
GitLab