diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 242a507481f322d87182a70239ae5df2d8bd51fc..b91d173e16a1e0f987c365f67385723202de8a41 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -526,6 +526,8 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() + if torch.cuda.is_available(): + self.model = self.model.to('cuda') for epoch in range(self.n_epochs): meter = {k: AverageMeter(k, ':6.3f') for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} @@ -535,6 +537,10 @@ class BNNModel(RegressorMixin, BaseEstimator): prefix="Epoch: [{}]".format(epoch)) for i, batch in enumerate(loader): x_b, y_b, w_b = batch + if torch.cuda.is_available(): + x_b = x_b.to('cuda') + y_b = y_b.to('cuda') + w_b = w_b.to('cuda') y_b_pred = self.model(x_b) nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b) @@ -558,6 +564,8 @@ class BNNModel(RegressorMixin, BaseEstimator): self.model.prune() self.model.eval() + if torch.cuda.is_available(): + self.model = self.model.to('cpu') return self