Skip to content
Snippets Groups Projects

Includes input energy parameter in the model and adds non-linearities

Merged Danilo Enoque Ferreira de Lima requested to merge with_energy into main
1 file
+ 13
6
Compare changes
  • Side-by-side
  • Inline
+ 13
6
@@ -2,6 +2,7 @@ from sklearn.base import BaseEstimator, RegressorMixin
from typing import Any, Dict, Optional, Union, Tuple
import numpy as np
from scipy.special import gamma
import torch
import torch.nn as nn
@@ -20,9 +21,9 @@ class BNN(nn.Module):
between the prediction and the true value. The standard deviation of the Gaussian is left as a
parameter to be fit: sigma.
"""
def __init__(self, input_dimension: int=1, output_dimension: int=1):
def __init__(self, input_dimension: int=1, output_dimension: int=1, sigma: float=0.1, fit_sigma: bool=True):
super(BNN, self).__init__()
hidden_dimension = 100
hidden_dimension = 400
self.model = nn.Sequential(
bnn.BayesLinear(prior_mu=0.0,
prior_sigma=0.1,
@@ -34,7 +35,7 @@ class BNN(nn.Module):
in_features=hidden_dimension,
out_features=output_dimension)
)
self.log_sigma2 = nn.Parameter(torch.ones(1), requires_grad=True)
self.log_sigma2 = nn.Parameter(torch.ones(1)*np.log(sigma**2), requires_grad=fit_sigma)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
@@ -49,11 +50,17 @@ class BNN(nn.Module):
n_output = target.shape[1]
error = w*(prediction - target)
squared_error = error**2
#return 0.5*squared_error.mean()
sigma2 = torch.exp(self.log_sigma2)[0]
norm_error = 0.5*squared_error/sigma2
norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output
return norm_error.sum(dim=1).mean(dim=0) + norm_term
L = norm_error.sum(dim=1).mean(dim=0) + norm_term
# hyperprior for sigma to avoid large or too small sigma
# with a standardized input, this hyperprior forces sigma to be
# on avg. 1 and it is broad enough to allow for different sigma
alpha = 2.0
beta = 0.15
hl = -alpha*np.log(beta) + (alpha + 1)*self.log_sigma2 + beta/sigma2 + gamma(alpha)
return L + hl
def aleatoric_uncertainty(self) -> torch.Tensor:
"""
@@ -126,7 +133,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train
self.model.train()
epochs = 500
epochs = 200
for epoch in range(epochs):
losses = list()
nlls = list()
Loading