From af216b8594fe7f12c3ace15e73c4fe8c6a824c1e Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 16 Oct 2023 14:29:51 +0200
Subject: [PATCH] Use GPU if available.

---
 pes_to_spec/bnn.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 242a507..b91d173 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -526,6 +526,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
         # train
         self.model.train()
+        if torch.cuda.is_available():
+            self.model = self.model.to('cuda')
         for epoch in range(self.n_epochs):
             meter = {k: AverageMeter(k, ':6.3f')
                     for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
@@ -535,6 +537,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
                             prefix="Epoch: [{}]".format(epoch))
             for i, batch in enumerate(loader):
                 x_b, y_b, w_b = batch
+                if torch.cuda.is_available():
+                    x_b = x_b.to('cuda')
+                    y_b = y_b.to('cuda')
+                    w_b = w_b.to('cuda')
                 y_b_pred = self.model(x_b)
 
                 nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
@@ -558,6 +564,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
             self.model.prune()
 
         self.model.eval()
+        if torch.cuda.is_available():
+            self.model = self.model.to('cpu')
 
         return self
 
-- 
GitLab