Skip to content
Snippets Groups Projects
Commit 8b150174 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Fixed sigma bug and allowed sigma to operate per class.

parent 43b757ef
No related branches found
No related tags found
1 merge request!12BNN optimization.
......@@ -63,9 +63,9 @@ class BNN(nn.Module):
"""
def __init__(self, input_dimension: int=1, output_dimension: int=1):
super(BNN, self).__init__()
hidden_dimension = 30
hidden_dimension = 50
# controls the aleatoric uncertainty
self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True)
# controls the weight hyperprior
self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
......@@ -123,13 +123,12 @@ class BNN(nn.Module):
"""
Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
"""
n_output = target.shape[1]
error = w*(prediction - target)
squared_error = error**2
sigma2 = torch.exp(-self.log_isigma2)[0]
sigma2 = torch.exp(-self.log_isigma2)
norm_error = 0.5*squared_error/sigma2
norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output
return norm_error.sum(dim=1).mean(dim=0) + norm_term
norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2)
return (norm_error + norm_term).sum(dim=1).mean(dim=0)
def neg_log_hyperprior(self) -> torch.Tensor:
"""
......@@ -138,18 +137,18 @@ class BNN(nn.Module):
# hyperprior for sigma to avoid large or too small sigma
# with a standardized input, this hyperprior forces sigma to be
# on avg. 1 and it is broad enough to allow for different sigma
isigma2 = torch.exp(self.log_ilambda2)[0]
isigma2 = torch.exp(self.log_isigma2)
neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
ilambda2 = torch.exp(self.log_ilambda2)[0]
ilambda2 = torch.exp(self.log_ilambda2)
neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
return neg_log_hyperprior_noise + neg_log_hyperprior_weights
return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum()
def aleatoric_uncertainty(self) -> torch.Tensor:
"""
Get the aleatoric component of the uncertainty.
"""
#return 0
return torch.exp(-0.5*self.log_isigma2[0])
return torch.exp(-0.5*self.log_isigma2)
def w_precision(self) -> torch.Tensor:
"""
......@@ -201,7 +200,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model = BNN(X.shape[1], y.shape[1])
# prepare data loader
B = 200
B = 100
loader = DataLoader(ds,
batch_size=B,
num_workers=32,
......@@ -248,7 +247,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
meter['-log(lkl)'].update(nll.detach().cpu().item(), B)
meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
meter['sigma'].update(self.model.aleatoric_uncertainty().mean().detach().cpu().numpy(), B)
meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
progress.display(len(loader))
......@@ -268,12 +267,12 @@ class BNNModel(RegressorMixin, BaseEstimator):
K = 10
y_pred = list()
for _ in range(K):
y_k = self.model(torch.from_numpy(X)).detach().numpy()
y_k = self.model(torch.from_numpy(X)).detach().cpu().numpy()
y_pred.append(y_k)
y_pred = np.stack(y_pred, axis=1)
y_mu = np.mean(y_pred, axis=1)
y_epi = np.std(y_pred, axis=1)
y_ale = self.model.aleatoric_uncertainty().detach().numpy()
y_ale = self.model.aleatoric_uncertainty().detach().cpu().numpy()
y_unc = (y_epi**2 + y_ale**2)**0.5
if not return_std:
return y_mu
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment