From 703dc404652ba5e0e969d2fa7316521ec015a913 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 18 Apr 2023 11:28:18 +0200
Subject: [PATCH] Added hyperpriors and better logging of training.

---
 pes_to_spec/bnn.py                   | 169 +++++++++++++++++++++------
 pes_to_spec/model.py                 |  22 ++--
 pes_to_spec/test/offline_analysis.py |   4 +-
 3 files changed, 140 insertions(+), 55 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 57f4dc5..234fed3 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -9,33 +9,103 @@ import torch.nn as nn
 import torchbnn as bnn
 from torch.utils.data import Dataset, DataLoader
 
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+class ProgressMeter(object):
+    def __init__(self, num_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch):
+        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+        entries += [str(meter) for meter in self.meters]
+        print('\t'.join(entries))
+
+    def _get_batch_fmtstr(self, num_batches):
+        num_digits = len(str(num_batches // 1))
+        fmt = '{:' + str(num_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
 class BNN(nn.Module):
     """
         A model Bayesian Neural network.
         Each weight is represented by a Gaussian with a mean and a standard deviation.
         Each evaluation of forward leads to a different choice of the weights, so running
         forward several times we can check the effect of the weights variation on the same input.
-        The nll function implements the negative log likelihood to be used as the first part of the loss
+        The neg_log_likelihood function implements the negative log likelihood to be used as the first part of the loss
         function (the second shall be the Kullback-Leibler divergence).
         The negative log-likelihood is simply the negative log likelihood of a Gaussian
         between the prediction and the true value. The standard deviation of the Gaussian is left as a
         parameter to be fit: sigma.
     """
-    def __init__(self, input_dimension: int=1, output_dimension: int=1, sigma: float=0.1, fit_sigma: bool=True):
+    def __init__(self, input_dimension: int=1, output_dimension: int=1):
         super(BNN, self).__init__()
-        hidden_dimension = 400
+        hidden_dimension = 100
+        # controls the aleatoric uncertainty
+        self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+        # controls the weight hyperprior
+        self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+
+        # inverse Gamma hyper prior alpha and beta
+        #
+        # Hyperprior choice on the weights:
+        # We want to allow the hyperprior on the weights' variance to have large variance,
+        # so that the weights prior can be anything, if possible, but at the same time prevent it from going to infinity
+        # (which would allow the weights to be anything, but remove regularization and de-stabilize the fit).
+        # Therefore, the weights should be allowed to have high std. dev. on their priors, just not so much so that the fit is unstable.
+        # At the same time, the prior std. dev. should not be too small (that would regularize too much.
+        # The values below have been taken from BoTorch (alpha, beta) = (3.0, 6.0) and seem to work well if the inputs have been standardized.
+        # They lead to a high mean for the weights std. dev. (18) and a large variance (sqrt(var) = 10.4), so that the weights prior is large
+        # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization.
+        # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior.
+        # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit.
+        self.alpha_lambda = 3.0
+        self.beta_lambda = 6.0
+
+        # Hyperprior choice on the likelihood noise level:
+        # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different
+        # from the weights prior, it must be allowed to be small, since if we have a lot of data, it is conceivable that there is little noise in the data.
+        # We therefore want to have high variance in the hyperprior for sigma, but we do not need to prevent it from becoming small.
+        # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative
+        # This seems to lead to a larger training time, though.
+        # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range
+        self.alpha_sigma = 2.0
+        self.beta_sigma = 0.15
+
         self.model = nn.Sequential(
                                    bnn.BayesLinear(prior_mu=0.0,
-                                                   prior_sigma=0.1,
+                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
                                                    in_features=input_dimension,
                                                    out_features=hidden_dimension),
                                    nn.ReLU(),
                                    bnn.BayesLinear(prior_mu=0.0,
-                                                   prior_sigma=0.1,
+                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
                                                    in_features=hidden_dimension,
                                                    out_features=output_dimension)
                                     )
-        self.log_sigma2 = nn.Parameter(torch.ones(1)*np.log(sigma**2), requires_grad=fit_sigma)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """
@@ -43,31 +113,49 @@ class BNN(nn.Module):
         """
         return self.model(x)
 
-    def nll(self, prediction: torch.Tensor, target: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
+    def neg_log_gamma(self, log_x: torch.Tensor, x: torch.Tensor, alpha, beta) -> torch.Tensor:
+        """
+        Return the negative log of the gamma pdf.
+        """
+        return -alpha*np.log(beta) - (alpha - 1)*log_x + beta*x + gamma(alpha)
+
+    def neg_log_likelihood(self, prediction: torch.Tensor, target: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
         """
         Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
         """
         n_output = target.shape[1]
         error = w*(prediction - target)
         squared_error = error**2
-        sigma2 = torch.exp(self.log_sigma2)[0]
+        sigma2 = torch.exp(-self.log_isigma2)[0]
         norm_error = 0.5*squared_error/sigma2
-        norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output
-        L = norm_error.sum(dim=1).mean(dim=0) + norm_term
+        norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output
+        return norm_error.sum(dim=1).mean(dim=0) + norm_term
+
+    def neg_log_hyperprior(self) -> torch.Tensor:
+        """
+        Calculate the negative log of the hyperpriors.
+        """
         # hyperprior for sigma to avoid large or too small sigma
         # with a standardized input, this hyperprior forces sigma to be
         # on avg. 1 and it is broad enough to allow for different sigma
-        alpha = 2.0
-        beta = 0.15
-        hl = -alpha*np.log(beta) + (alpha + 1)*self.log_sigma2 + beta/sigma2 + gamma(alpha)
-        return L + hl
+        isigma2 = torch.exp(self.log_ilambda2)[0]
+        neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
+        ilambda2 = torch.exp(self.log_ilambda2)[0]
+        neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
+        return neg_log_hyperprior_noise + neg_log_hyperprior_weights
 
     def aleatoric_uncertainty(self) -> torch.Tensor:
         """
             Get the aleatoric component of the uncertainty.
         """
         #return 0
-        return torch.exp(0.5*self.log_sigma2[0])
+        return torch.exp(-0.5*self.log_isigma2[0])
+
+    def l(self) -> torch.Tensor:
+        """
+            Get the weights std. dev.
+        """
+        return torch.exp(-0.5*self.log_ilambda2[0])
 
 class BNNDataset(Dataset):
     def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray):
@@ -122,50 +210,55 @@ class BNNModel(RegressorMixin, BaseEstimator):
         self.model = BNN(X.shape[1], y.shape[1])
 
         # prepare data loader
-        B = 10
+        B = 5
         loader = DataLoader(ds, batch_size=B)
         optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
         number_of_batches = len(ds)/float(B)
-        weight_kl = 1.0/float(number_of_batches)
+        weight_prior = 1.0/float(number_of_batches)
+        # the NLL is divided by the number of batch samples
+        # so divide also the prior losses by the number of batch elements, so that the
+        # function optimized is F/# samples
+        # https://arxiv.org/pdf/1505.05424.pdf
+        weight_prior /= float(B)
 
         # KL loss
         kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False)
 
         # train
         self.model.train()
-        epochs = 200
+        epochs = 100
         for epoch in range(epochs):
-            losses = list()
-            nlls = list()
-            priors = list()
-            for batch in loader:
+            meter = {k: AverageMeter(k, ':6.3f')
+                    for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'lambda')}
+            progress = ProgressMeter(
+                            len(loader),
+                            meter.values(),
+                            prefix="Epoch: [{}]".format(epoch))
+            for i, batch in enumerate(loader):
                 x_b = batch["x"]
                 y_b = batch["y"]
                 w_b = batch["w"]
                 y_b_pred = self.model(x_b)
 
-                # the NLL is divided by the number of batch samples
-                # so divide also the KL loss by the number of batch elements, so that the
-                # function optimized is F/# samples
-                # https://arxiv.org/pdf/1505.05424.pdf
-                nll = self.model.nll(y_b_pred, y_b, w_b)
-                prior = weight_kl * kl_loss(self.model)/float(B)
-                loss = nll + prior
+                nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
+                nlprior = weight_prior * kl_loss(self.model)
+                nlhyper = weight_prior * self.model.neg_log_hyperprior()
+                loss = nll + nlprior + nlhyper
 
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
 
-                losses.append(loss.detach().cpu().item())
-                nlls.append(nll.detach().cpu().item())
-                priors.append(prior.detach().cpu().item())
+                meter['loss'].update(loss.detach().cpu().item(), B)
+                meter['-log(lkl)'].update(nll.detach().cpu().item(), B)
+                meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
+                meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
+                meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
+                meter['lambda'].update(self.model.l().detach().cpu().item(), B)
 
-            # monitor
-            ale = self.model.aleatoric_uncertainty().detach().numpy()
-            losses = np.mean(np.array(losses))
-            nlls = np.mean(np.array(nlls))
-            priors = np.mean(np.array(priors))
-            print(f"Epoch {epoch}/{epochs}  total: {losses:.5f}, -LL: {nlls:.5f}, prior: {priors:.5f}, aleatoric unc.: {ale:.5f}")
+                if i % 100 == 0:
+                    progress.display(i)
+            progress.display(len(loader))
         self.model.eval()
 
         return self
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index a961db3..39b6c4b 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -517,9 +517,7 @@ class Model(TransformerMixin, BaseEstimator):
                  Set to None to perform no selection.
       validation_size: Fraction (number between 0 and 1) of the data to take for
                        validation and systematic uncertainty estimate.
-      n_nonlinear_kernel: Number of nonlinear kernel components added at the preprocessing stage
-                       to obtain nonlinearities as an input and improve the prediction.
-      poly: Whether to use a polynomial expantion of the low-resolution data.
+      bnn: Use BNN?
 
     """
     def __init__(self,
@@ -531,20 +529,12 @@ class Model(TransformerMixin, BaseEstimator):
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
                  validation_size: float=0.05,
-                 n_nonlinear_kernel: int=0,
-                 poly: bool=False,
+                 bnn: bool=True,
                 ):
         self.high_res_sigma = high_res_sigma
         # models
-        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=poly)
+        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=not bnn)
         x_model_steps = list()
-        self.n_nonlinear_kernel = n_nonlinear_kernel
-        if n_nonlinear_kernel > 0:
-            # Kernel PCA using Nystroem
-            x_model_steps += [('fex', Pipeline([('prepca', PCA(n_pca_lr, whiten=True)),
-                                                ('nystroem', Nystroem(n_components=n_nonlinear_kernel, kernel='rbf', gamma=None, n_jobs=8)),
-                                                ])),
-                             ]
         x_model_steps += [
                           ('pca', PCA(n_pca_lr, whiten=True)),
                           ('unc', UncertaintyHolder()),
@@ -557,8 +547,10 @@ class Model(TransformerMixin, BaseEstimator):
                                 ])
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
-        #self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        self.fit_model = BNNModel()
+        if bnn:
+            self.fit_model = BNNModel()
+        else:
+            self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 0322a00..5ede953 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -143,7 +143,7 @@ def main():
     parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
     parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
     parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
-    parser.add_argument('-e', '--poly', action="store_true", default=False, help='Wheteher to expand PES data in higher order polynomials.')
+    parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?')
     parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
 
     args = parser.parse_args()
@@ -234,7 +234,7 @@ def main():
     t = list()
     t_names = list()
 
-    model = Model(poly=args.poly)
+    model = Model(bnn=args.bnn)
 
     train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
     # we just need this for training and we need to avoid copying it, which blows up the memoray usage
-- 
GitLab