diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index aaf850988631569f696274aea52fd1fca3625f28..cf057e6dae9e6d618edd5ed909732c25cdb42957 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -63,7 +63,7 @@ class BNN(nn.Module): """ def __init__(self, input_dimension: int=1, output_dimension: int=1): super(BNN, self).__init__() - hidden_dimension = 100 + hidden_dimension = 30 # controls the aleatoric uncertainty self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True) # controls the weight hyperprior @@ -82,8 +82,8 @@ class BNN(nn.Module): # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization. # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior. # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit. - self.alpha_lambda = 0.1 - self.beta_lambda = 0.1 + self.alpha_lambda = 0.001 + self.beta_lambda = 0.001 # Hyperprior choice on the likelihood noise level: # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different @@ -92,8 +92,8 @@ class BNN(nn.Module): # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative # This seems to lead to a larger training time, though. # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range - self.alpha_sigma = 0.1 - self.beta_sigma = 0.1 + self.alpha_sigma = 0.001 + self.beta_sigma = 0.001 self.model = nn.Sequential( bnn.BayesLinear(prior_mu=0.0, @@ -201,10 +201,10 @@ class BNNModel(RegressorMixin, BaseEstimator): self.model = BNN(X.shape[1], y.shape[1]) # prepare data loader - B = 100 + B = 200 loader = DataLoader(ds, batch_size=B, - num_workers=5, + num_workers=32, shuffle=True, #pin_memory=True, drop_last=True,