From e54496144367f3ce9013a80b7a913f9f06857c87 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 17 Apr 2023 18:41:49 +0200 Subject: [PATCH] ADapted regularization level. --- pes_to_spec/bnn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 1b6072e..bb66467 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -22,14 +22,14 @@ class BNN(nn.Module): """ def __init__(self, input_dimension: int=1, output_dimension: int=1): super(BNN, self).__init__() - hidden_dimension = 500 + hidden_dimension = 100 self.model = nn.Sequential( - bnn.BayesLinear(prior_mu=0, + bnn.BayesLinear(prior_mu=0.0, prior_sigma=0.1, in_features=input_dimension, out_features=hidden_dimension), nn.ReLU(), - bnn.BayesLinear(prior_mu=0, + bnn.BayesLinear(prior_mu=0.0, prior_sigma=0.1, in_features=hidden_dimension, out_features=output_dimension) @@ -126,7 +126,7 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() - epochs = 200 + epochs = 500 for epoch in range(epochs): losses = list() nlls = list() -- GitLab