diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 242a507481f322d87182a70239ae5df2d8bd51fc..b91d173e16a1e0f987c365f67385723202de8a41 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -526,6 +526,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
         # train
         self.model.train()
+        if torch.cuda.is_available():
+            self.model = self.model.to('cuda')
         for epoch in range(self.n_epochs):
             meter = {k: AverageMeter(k, ':6.3f')
                     for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
@@ -535,6 +537,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
                             prefix="Epoch: [{}]".format(epoch))
             for i, batch in enumerate(loader):
                 x_b, y_b, w_b = batch
+                if torch.cuda.is_available():
+                    x_b = x_b.to('cuda')
+                    y_b = y_b.to('cuda')
+                    w_b = w_b.to('cuda')
                 y_b_pred = self.model(x_b)
 
                 nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
@@ -558,6 +564,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
             self.model.prune()
 
         self.model.eval()
+        if torch.cuda.is_available():
+            self.model = self.model.to('cpu')
 
         return self