Skip to content
Snippets Groups Projects

Includes input energy parameter in the model and adds non-linearities

Merged Danilo Enoque Ferreira de Lima requested to merge with_energy into main
1 file
+ 18
9
Compare changes
  • Side-by-side
  • Inline
+ 18
9
@@ -551,6 +551,7 @@ class Model(TransformerMixin, BaseEstimator):
self.fit_model = BNNModel()
else:
self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
self.bnn = bnn
self.kde_xgm = None
self.mu_xgm = np.nan
@@ -912,7 +913,7 @@ class Model(TransformerMixin, BaseEstimator):
joblib.dump([self.x_select,
self.x_model,
self.y_model,
self.fit_model.state_dict(),
self.fit_model.state_dict() if self.bnn else self.fit_model,
self.channel_pca,
#self.channel_fit_model
DataHolder(dict(
@@ -925,6 +926,7 @@ class Model(TransformerMixin, BaseEstimator):
resolution=self.resolution,
transfer_function=self.transfer_function,
impulse_response=self.impulse_response,
bnn=self.bnn,
)
),
self.ood,
@@ -948,15 +950,8 @@ class Model(TransformerMixin, BaseEstimator):
ood,
kde_xgm,
) = joblib.load(filename)
obj = Model()
obj.x_select = x_select
obj.x_model = x_model
obj.y_model = y_model
obj.fit_model = BNNModel(state_dict=fit_model)
obj.channel_pca = channel_pca
#obj.channel_fit_model = channel_fit_model
obj.ood = ood
obj.kde_xgm = kde_xgm
extra = extra.get_data()
obj.mu_xgm = extra["mu_xgm"]
@@ -968,5 +963,19 @@ class Model(TransformerMixin, BaseEstimator):
obj.resolution = extra["resolution"]
obj.transfer_function = extra["transfer_function"]
obj.impulse_response = extra["impulse_response"]
obj.bnn = extra["bnn"]
obj.x_select = x_select
obj.x_model = x_model
obj.y_model = y_model
if obj.bnn:
obj.fit_model = BNNModel(state_dict=fit_model)
else:
obj.fit_model = fit_model
obj.channel_pca = channel_pca
#obj.channel_fit_model = channel_fit_model
obj.ood = ood
obj.kde_xgm = kde_xgm
return obj
Loading