Skip to content
Snippets Groups Projects

BNN optimization.

Merged Danilo Enoque Ferreira de Lima requested to merge bnn_opt into main
1 file
+ 7
7
Compare changes
  • Side-by-side
  • Inline
+ 19
20
@@ -63,9 +63,9 @@ class BNN(nn.Module):
"""
def __init__(self, input_dimension: int=1, output_dimension: int=1):
super(BNN, self).__init__()
hidden_dimension = 100
hidden_dimension = 50
# controls the aleatoric uncertainty
self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True)
# controls the weight hyperprior
self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
@@ -82,8 +82,8 @@ class BNN(nn.Module):
# and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization.
# An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior.
# Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit.
self.alpha_lambda = 0.1
self.beta_lambda = 0.1
self.alpha_lambda = 0.001
self.beta_lambda = 0.001
# Hyperprior choice on the likelihood noise level:
# The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different
@@ -92,8 +92,8 @@ class BNN(nn.Module):
# Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative
# This seems to lead to a larger training time, though.
# Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range
self.alpha_sigma = 0.1
self.beta_sigma = 0.1
self.alpha_sigma = 0.001
self.beta_sigma = 0.001
self.model = nn.Sequential(
bnn.BayesLinear(prior_mu=0.0,
@@ -123,13 +123,12 @@ class BNN(nn.Module):
"""
Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
"""
n_output = target.shape[1]
error = w*(prediction - target)
squared_error = error**2
sigma2 = torch.exp(-self.log_isigma2)[0]
sigma2 = torch.exp(-self.log_isigma2)
norm_error = 0.5*squared_error/sigma2
norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output
return norm_error.sum(dim=1).mean(dim=0) + norm_term
norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2)
return (norm_error + norm_term).sum(dim=1).mean(dim=0)
def neg_log_hyperprior(self) -> torch.Tensor:
"""
@@ -138,18 +137,18 @@ class BNN(nn.Module):
# hyperprior for sigma to avoid large or too small sigma
# with a standardized input, this hyperprior forces sigma to be
# on avg. 1 and it is broad enough to allow for different sigma
isigma2 = torch.exp(self.log_ilambda2)[0]
isigma2 = torch.exp(self.log_isigma2)
neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
ilambda2 = torch.exp(self.log_ilambda2)[0]
ilambda2 = torch.exp(self.log_ilambda2)
neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
return neg_log_hyperprior_noise + neg_log_hyperprior_weights
return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum()
def aleatoric_uncertainty(self) -> torch.Tensor:
"""
Get the aleatoric component of the uncertainty.
"""
#return 0
return torch.exp(-0.5*self.log_isigma2[0])
return torch.exp(-0.5*self.log_isigma2)
def w_precision(self) -> torch.Tensor:
"""
@@ -201,10 +200,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model = BNN(X.shape[1], y.shape[1])
# prepare data loader
B = 100
B = 50
loader = DataLoader(ds,
batch_size=B,
num_workers=5,
num_workers=20,
shuffle=True,
#pin_memory=True,
drop_last=True,
@@ -223,7 +222,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train
self.model.train()
epochs = 1000
epochs = 500
for epoch in range(epochs):
meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
@@ -248,7 +247,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
meter['-log(lkl)'].update(nll.detach().cpu().item(), B)
meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
meter['sigma'].update(self.model.aleatoric_uncertainty().mean().detach().cpu().numpy(), B)
meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
progress.display(len(loader))
@@ -268,12 +267,12 @@ class BNNModel(RegressorMixin, BaseEstimator):
K = 10
y_pred = list()
for _ in range(K):
y_k = self.model(torch.from_numpy(X)).detach().numpy()
y_k = self.model(torch.from_numpy(X)).detach().cpu().numpy()
y_pred.append(y_k)
y_pred = np.stack(y_pred, axis=1)
y_mu = np.mean(y_pred, axis=1)
y_epi = np.std(y_pred, axis=1)
y_ale = self.model.aleatoric_uncertainty().detach().numpy()
y_ale = self.model.aleatoric_uncertainty().detach().cpu().numpy()
y_unc = (y_epi**2 + y_ale**2)**0.5
if not return_std:
return y_mu
Loading