From 8b150174f3e4fd8e3ebd086242ac7aa1f3cfd608 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 26 Apr 2023 16:49:04 +0200
Subject: [PATCH] Fixed sigma bug and allowed sigma to operate per class.

---
 pes_to_spec/bnn.py | 27 +++++++++++++--------------
 1 file changed, 13 insertions(+), 14 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index cf057e6..aa9cf7d 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -63,9 +63,9 @@ class BNN(nn.Module):
     """
     def __init__(self, input_dimension: int=1, output_dimension: int=1):
         super(BNN, self).__init__()
-        hidden_dimension = 30
+        hidden_dimension = 50
         # controls the aleatoric uncertainty
-        self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+        self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True)
         # controls the weight hyperprior
         self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
 
@@ -123,13 +123,12 @@ class BNN(nn.Module):
         """
         Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
         """
-        n_output = target.shape[1]
         error = w*(prediction - target)
         squared_error = error**2
-        sigma2 = torch.exp(-self.log_isigma2)[0]
+        sigma2 = torch.exp(-self.log_isigma2)
         norm_error = 0.5*squared_error/sigma2
-        norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output
-        return norm_error.sum(dim=1).mean(dim=0) + norm_term
+        norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2)
+        return (norm_error + norm_term).sum(dim=1).mean(dim=0)
 
     def neg_log_hyperprior(self) -> torch.Tensor:
         """
@@ -138,18 +137,18 @@ class BNN(nn.Module):
         # hyperprior for sigma to avoid large or too small sigma
         # with a standardized input, this hyperprior forces sigma to be
         # on avg. 1 and it is broad enough to allow for different sigma
-        isigma2 = torch.exp(self.log_ilambda2)[0]
+        isigma2 = torch.exp(self.log_isigma2)
         neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
-        ilambda2 = torch.exp(self.log_ilambda2)[0]
+        ilambda2 = torch.exp(self.log_ilambda2)
         neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
-        return neg_log_hyperprior_noise + neg_log_hyperprior_weights
+        return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum()
 
     def aleatoric_uncertainty(self) -> torch.Tensor:
         """
             Get the aleatoric component of the uncertainty.
         """
         #return 0
-        return torch.exp(-0.5*self.log_isigma2[0])
+        return torch.exp(-0.5*self.log_isigma2)
 
     def w_precision(self) -> torch.Tensor:
         """
@@ -201,7 +200,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
         self.model = BNN(X.shape[1], y.shape[1])
 
         # prepare data loader
-        B = 200
+        B = 100
         loader = DataLoader(ds,
                             batch_size=B,
                             num_workers=32,
@@ -248,7 +247,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
                 meter['-log(lkl)'].update(nll.detach().cpu().item(), B)
                 meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
                 meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
-                meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
+                meter['sigma'].update(self.model.aleatoric_uncertainty().mean().detach().cpu().numpy(), B)
                 meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
 
             progress.display(len(loader))
@@ -268,12 +267,12 @@ class BNNModel(RegressorMixin, BaseEstimator):
         K = 10
         y_pred = list()
         for _ in range(K):
-            y_k = self.model(torch.from_numpy(X)).detach().numpy()
+            y_k = self.model(torch.from_numpy(X)).detach().cpu().numpy()
             y_pred.append(y_k)
         y_pred = np.stack(y_pred, axis=1)
         y_mu = np.mean(y_pred, axis=1)
         y_epi = np.std(y_pred, axis=1)
-        y_ale = self.model.aleatoric_uncertainty().detach().numpy()
+        y_ale = self.model.aleatoric_uncertainty().detach().cpu().numpy()
         y_unc = (y_epi**2 + y_ale**2)**0.5
         if not return_std:
             return y_mu
-- 
GitLab