From 8f0ff39983af774498738879c4a5f666c018d63d Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Sat, 7 Oct 2023 09:31:37 +0200 Subject: [PATCH] HOTFIX: Backwards compatibility for autocorrelation. --- pes_to_spec/model.py | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 57cf076..44e2ba6 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -681,6 +681,15 @@ class Model(TransformerMixin, BaseEstimator): self.impulse_response = None self.auto_corr = None + self.extra_options = ["mu_xgm", "sigma_xgm", + "wiener_filter_ft", "wiener_filter", + "wiener_energy_ft", "wiener_energy", + "resolution", + "transfer_function", "impulse_response", + "auto_corr", + "model_type", + "n_obs"] + def n_pars(self) -> float: """Get number of parameters.""" if self.model_type in ("bnn", "bnn_rvm"): @@ -1038,27 +1047,15 @@ class Model(TransformerMixin, BaseEstimator): Args: filename: File name where to save this. """ + extra = {k: getattr(self, k) + for k in self.extra_options} joblib.dump([self.x_select, self.x_model, self.y_model, self.fit_model.state_dict() if self.model_type in ("bnn", "bnn_rvm") else self.fit_model, self.channel_pca, #self.channel_fit_model - DataHolder(dict( - mu_xgm=self.mu_xgm, - sigma_xgm=self.sigma_xgm, - wiener_filter_ft=self.wiener_filter_ft, - wiener_filter=self.wiener_filter, - wiener_energy=self.wiener_energy, - wiener_energy_ft=self.wiener_energy_ft, - resolution=self.resolution, - transfer_function=self.transfer_function, - impulse_response=self.impulse_response, - auto_corr=self.auto_corr, - model_type=self.model_type, - n_obs=self.n_obs, - ) - ), + DataHolder(extra), self.ood, self.kde_xgm, ], filename, compress='zlib') @@ -1084,18 +1081,11 @@ class Model(TransformerMixin, BaseEstimator): obj = Model() extra = extra.get_data() - obj.mu_xgm = extra["mu_xgm"] - obj.sigma_xgm = extra["sigma_xgm"] - obj.wiener_filter_ft = extra["wiener_filter_ft"] - obj.wiener_filter = extra["wiener_filter"] - obj.wiener_energy_ft = extra["wiener_energy_ft"] - obj.wiener_energy = extra["wiener_energy"] - obj.resolution = extra["resolution"] - obj.transfer_function = extra["transfer_function"] - obj.impulse_response = extra["impulse_response"] - obj.auto_corr = extra["auto_corr"] - obj.model_type = extra["model_type"] - obj.n_obs = extra["n_obs"] + for k in obj.extra_options: + if k not in extra: + setattr(obj, k, None) + else: + setattr(obj, k, extra[k]) obj.x_select = x_select obj.x_model = x_model -- GitLab