Skip to content
Snippets Groups Projects
Commit 00d50003 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Added a hyperprior on sigma.

parent e5449614
No related branches found
No related tags found
1 merge request!11Includes input energy parameter in the model and adds non-linearities
This commit is part of merge request !11. Comments created here will be created in the context of that merge request.
...@@ -2,6 +2,7 @@ from sklearn.base import BaseEstimator, RegressorMixin ...@@ -2,6 +2,7 @@ from sklearn.base import BaseEstimator, RegressorMixin
from typing import Any, Dict, Optional, Union, Tuple from typing import Any, Dict, Optional, Union, Tuple
import numpy as np import numpy as np
from scipy.special import gamma
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -20,9 +21,9 @@ class BNN(nn.Module): ...@@ -20,9 +21,9 @@ class BNN(nn.Module):
between the prediction and the true value. The standard deviation of the Gaussian is left as a between the prediction and the true value. The standard deviation of the Gaussian is left as a
parameter to be fit: sigma. parameter to be fit: sigma.
""" """
def __init__(self, input_dimension: int=1, output_dimension: int=1): def __init__(self, input_dimension: int=1, output_dimension: int=1, sigma: float=0.1, fit_sigma: bool=True):
super(BNN, self).__init__() super(BNN, self).__init__()
hidden_dimension = 100 hidden_dimension = 400
self.model = nn.Sequential( self.model = nn.Sequential(
bnn.BayesLinear(prior_mu=0.0, bnn.BayesLinear(prior_mu=0.0,
prior_sigma=0.1, prior_sigma=0.1,
...@@ -34,7 +35,7 @@ class BNN(nn.Module): ...@@ -34,7 +35,7 @@ class BNN(nn.Module):
in_features=hidden_dimension, in_features=hidden_dimension,
out_features=output_dimension) out_features=output_dimension)
) )
self.log_sigma2 = nn.Parameter(torch.ones(1), requires_grad=True) self.log_sigma2 = nn.Parameter(torch.ones(1)*np.log(sigma**2), requires_grad=fit_sigma)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
...@@ -49,11 +50,17 @@ class BNN(nn.Module): ...@@ -49,11 +50,17 @@ class BNN(nn.Module):
n_output = target.shape[1] n_output = target.shape[1]
error = w*(prediction - target) error = w*(prediction - target)
squared_error = error**2 squared_error = error**2
#return 0.5*squared_error.mean()
sigma2 = torch.exp(self.log_sigma2)[0] sigma2 = torch.exp(self.log_sigma2)[0]
norm_error = 0.5*squared_error/sigma2 norm_error = 0.5*squared_error/sigma2
norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output
return norm_error.sum(dim=1).mean(dim=0) + norm_term L = norm_error.sum(dim=1).mean(dim=0) + norm_term
# hyperprior for sigma to avoid large or too small sigma
# with a standardized input, this hyperprior forces sigma to be
# on avg. 1 and it is broad enough to allow for different sigma
alpha = 2.0
beta = 0.15
hl = -alpha*np.log(beta) + (alpha + 1)*self.log_sigma2 + beta/sigma2 + gamma(alpha)
return L + hl
def aleatoric_uncertainty(self) -> torch.Tensor: def aleatoric_uncertainty(self) -> torch.Tensor:
""" """
...@@ -126,7 +133,7 @@ class BNNModel(RegressorMixin, BaseEstimator): ...@@ -126,7 +133,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train # train
self.model.train() self.model.train()
epochs = 500 epochs = 200
for epoch in range(epochs): for epoch in range(epochs):
losses = list() losses = list()
nlls = list() nlls = list()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment