Skip to content
Snippets Groups Projects
Commit af216b85 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Use GPU if available.

parent d0c9fa19
No related branches found
No related tags found
1 merge request!19Handle pedestal in PES and use GPU if available in BNN
Pipeline #120472 passed
......@@ -526,6 +526,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train
self.model.train()
if torch.cuda.is_available():
self.model = self.model.to('cuda')
for epoch in range(self.n_epochs):
meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
......@@ -535,6 +537,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
prefix="Epoch: [{}]".format(epoch))
for i, batch in enumerate(loader):
x_b, y_b, w_b = batch
if torch.cuda.is_available():
x_b = x_b.to('cuda')
y_b = y_b.to('cuda')
w_b = w_b.to('cuda')
y_b_pred = self.model(x_b)
nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
......@@ -558,6 +564,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model.prune()
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.to('cpu')
return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment