From 00d500033cd034dc2f78781bae705185f16d375a Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 18 Apr 2023 00:40:37 +0200
Subject: [PATCH] Added a hyperprior on sigma.

---
 pes_to_spec/bnn.py | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index bb66467..57f4dc5 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -2,6 +2,7 @@ from sklearn.base import BaseEstimator, RegressorMixin
 from typing import Any, Dict, Optional, Union, Tuple
 
 import numpy as np
+from scipy.special import gamma
 
 import torch
 import torch.nn as nn
@@ -20,9 +21,9 @@ class BNN(nn.Module):
         between the prediction and the true value. The standard deviation of the Gaussian is left as a
         parameter to be fit: sigma.
     """
-    def __init__(self, input_dimension: int=1, output_dimension: int=1):
+    def __init__(self, input_dimension: int=1, output_dimension: int=1, sigma: float=0.1, fit_sigma: bool=True):
         super(BNN, self).__init__()
-        hidden_dimension = 100
+        hidden_dimension = 400
         self.model = nn.Sequential(
                                    bnn.BayesLinear(prior_mu=0.0,
                                                    prior_sigma=0.1,
@@ -34,7 +35,7 @@ class BNN(nn.Module):
                                                    in_features=hidden_dimension,
                                                    out_features=output_dimension)
                                     )
-        self.log_sigma2 = nn.Parameter(torch.ones(1), requires_grad=True)
+        self.log_sigma2 = nn.Parameter(torch.ones(1)*np.log(sigma**2), requires_grad=fit_sigma)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """
@@ -49,11 +50,17 @@ class BNN(nn.Module):
         n_output = target.shape[1]
         error = w*(prediction - target)
         squared_error = error**2
-        #return 0.5*squared_error.mean()
         sigma2 = torch.exp(self.log_sigma2)[0]
         norm_error = 0.5*squared_error/sigma2
         norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output
-        return norm_error.sum(dim=1).mean(dim=0) + norm_term
+        L = norm_error.sum(dim=1).mean(dim=0) + norm_term
+        # hyperprior for sigma to avoid large or too small sigma
+        # with a standardized input, this hyperprior forces sigma to be
+        # on avg. 1 and it is broad enough to allow for different sigma
+        alpha = 2.0
+        beta = 0.15
+        hl = -alpha*np.log(beta) + (alpha + 1)*self.log_sigma2 + beta/sigma2 + gamma(alpha)
+        return L + hl
 
     def aleatoric_uncertainty(self) -> torch.Tensor:
         """
@@ -126,7 +133,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
         # train
         self.model.train()
-        epochs = 500
+        epochs = 200
         for epoch in range(epochs):
             losses = list()
             nlls = list()
-- 
GitLab