diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 1b6072e0e801a2657fa5c449c905c021b6364407..bb664678926aa5de703005c05acc549670fa1d0b 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -22,14 +22,14 @@ class BNN(nn.Module): """ def __init__(self, input_dimension: int=1, output_dimension: int=1): super(BNN, self).__init__() - hidden_dimension = 500 + hidden_dimension = 100 self.model = nn.Sequential( - bnn.BayesLinear(prior_mu=0, + bnn.BayesLinear(prior_mu=0.0, prior_sigma=0.1, in_features=input_dimension, out_features=hidden_dimension), nn.ReLU(), - bnn.BayesLinear(prior_mu=0, + bnn.BayesLinear(prior_mu=0.0, prior_sigma=0.1, in_features=hidden_dimension, out_features=output_dimension) @@ -126,7 +126,7 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() - epochs = 200 + epochs = 500 for epoch in range(epochs): losses = list() nlls = list()