Skip to content
Snippets Groups Projects

Includes input energy parameter in the model and adds non-linearities

Merged Danilo Enoque Ferreira de Lima requested to merge with_energy into main
2 files
+ 21
28
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 20
27
@@ -7,7 +7,7 @@ from scipy.special import gamma
import torch
import torch.nn as nn
import torchbnn as bnn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import TensorDataset, DataLoader
class AverageMeter(object):
"""Computes and stores the average and current value"""
@@ -41,7 +41,7 @@ class ProgressMeter(object):
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
@@ -151,23 +151,11 @@ class BNN(nn.Module):
#return 0
return torch.exp(-0.5*self.log_isigma2[0])
def l(self) -> torch.Tensor:
def w_precision(self) -> torch.Tensor:
"""
Get the weights std. dev.
Get the weights precision.
"""
return torch.exp(-0.5*self.log_ilambda2[0])
class BNNDataset(Dataset):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray):
self.x = x
self.y = y
self.w = w
assert len(x) == len(y) and len(x) == len(w)
def __len__(self) -> int:
"""How many samples do I have?"""
return len(self.x)
def __getitem__(self, idx):
return {"x": self.x[idx, :], "y": self.y[idx, :], "w": self.w[idx, :]}
return torch.exp(self.log_ilambda2[0])
class BNNModel(RegressorMixin, BaseEstimator):
"""
@@ -204,14 +192,23 @@ class BNNModel(RegressorMixin, BaseEstimator):
weights = np.ones(len(X), dtype=np.float32)
if len(weights.shape) == 1:
weights = weights[:, np.newaxis]
ds = BNNDataset(X, y, weights)
ds = TensorDataset(torch.from_numpy(X),
torch.from_numpy(y),
torch.from_numpy(weights))
# create model
self.model = BNN(X.shape[1], y.shape[1])
# prepare data loader
B = 5
loader = DataLoader(ds, batch_size=B)
loader = DataLoader(ds,
batch_size=B,
num_workers=5,
shuffle=True,
#pin_memory=True,
drop_last=True,
)
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
number_of_batches = len(ds)/float(B)
weight_prior = 1.0/float(number_of_batches)
@@ -226,18 +223,16 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train
self.model.train()
epochs = 100
epochs = 200
for epoch in range(epochs):
meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'lambda')}
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
progress = ProgressMeter(
len(loader),
meter.values(),
prefix="Epoch: [{}]".format(epoch))
for i, batch in enumerate(loader):
x_b = batch["x"]
y_b = batch["y"]
w_b = batch["w"]
x_b, y_b, w_b = batch
y_b_pred = self.model(x_b)
nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
@@ -254,10 +249,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
meter['lambda'].update(self.model.l().detach().cpu().item(), B)
meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
if i % 100 == 0:
progress.display(i)
progress.display(len(loader))
self.model.eval()
Loading