From 0171eeca1e159d4b5019a4f80ef1cf2033ebe768 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 17 Apr 2023 17:20:47 +0200
Subject: [PATCH] Corrected BNN weighting due to averaging. Removed intensity
 calculation, as we rely on the XGM now.

---
 pes_to_spec/bnn.py                   | 19 ++++++++++++-------
 pes_to_spec/model.py                 |  5 +----
 pes_to_spec/test/offline_analysis.py |  3 +++
 3 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 9ebf2a3..1b6072e 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -22,7 +22,7 @@ class BNN(nn.Module):
     """
     def __init__(self, input_dimension: int=1, output_dimension: int=1):
         super(BNN, self).__init__()
-        hidden_dimension = 100
+        hidden_dimension = 500
         self.model = nn.Sequential(
                                    bnn.BayesLinear(prior_mu=0,
                                                    prior_sigma=0.1,
@@ -46,13 +46,14 @@ class BNN(nn.Module):
         """
         Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
         """
+        n_output = target.shape[1]
         error = w*(prediction - target)
         squared_error = error**2
         #return 0.5*squared_error.mean()
         sigma2 = torch.exp(self.log_sigma2)[0]
         norm_error = 0.5*squared_error/sigma2
-        norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])
-        return norm_error.mean() + norm_term
+        norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output
+        return norm_error.sum(dim=1).mean(dim=0) + norm_term
 
     def aleatoric_uncertainty(self) -> torch.Tensor:
         """
@@ -114,18 +115,18 @@ class BNNModel(RegressorMixin, BaseEstimator):
         self.model = BNN(X.shape[1], y.shape[1])
 
         # prepare data loader
-        B = 20
+        B = 10
         loader = DataLoader(ds, batch_size=B)
         optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
         number_of_batches = len(ds)/float(B)
         weight_kl = 1.0/float(number_of_batches)
 
         # KL loss
-        kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)
+        kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False)
 
         # train
         self.model.train()
-        epochs = 1000
+        epochs = 200
         for epoch in range(epochs):
             losses = list()
             nlls = list()
@@ -136,8 +137,12 @@ class BNNModel(RegressorMixin, BaseEstimator):
                 w_b = batch["w"]
                 y_b_pred = self.model(x_b)
 
+                # the NLL is divided by the number of batch samples
+                # so divide also the KL loss by the number of batch elements, so that the
+                # function optimized is F/# samples
+                # https://arxiv.org/pdf/1505.05424.pdf
                 nll = self.model.nll(y_b_pred, y_b, w_b)
-                prior = weight_kl * kl_loss(self.model)
+                prior = weight_kl * kl_loss(self.model)/float(B)
                 loss = nll + prior
 
                 optimizer.zero_grad()
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 37a7ae2..a961db3 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -923,8 +923,7 @@ class Model(TransformerMixin, BaseEstimator):
                      self.fit_model.state_dict(),
                      self.channel_pca,
                      #self.channel_fit_model
-                     DataHolder(dict(mu_intensity=self.mu_intensity,
-                                     sigma_intensity=self.sigma_intensity,
+                     DataHolder(dict(
                                      mu_xgm=self.mu_xgm,
                                      sigma_xgm=self.sigma_xgm,
                                      wiener_filter_ft=self.wiener_filter_ft,
@@ -968,8 +967,6 @@ class Model(TransformerMixin, BaseEstimator):
         obj.kde_xgm = kde_xgm
 
         extra = extra.get_data()
-        obj.mu_intensity = extra["mu_intensity"]
-        obj.sigma_intensity = extra["sigma_intensity"]
         obj.mu_xgm = extra["mu_xgm"]
         obj.sigma_xgm = extra["sigma_xgm"]
         obj.wiener_filter_ft = extra["wiener_filter_ft"]
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index edd3efd..0322a00 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -144,6 +144,7 @@ def main():
     parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
     parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
     parser.add_argument('-e', '--poly', action="store_true", default=False, help='Wheteher to expand PES data in higher order polynomials.')
+    parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
 
     args = parser.parse_args()
 
@@ -245,6 +246,8 @@ def main():
         print("Fitting")
         start = time_ns()
         w = model.uniformize(xgm_flux[train_idx])
+        if not args.weight:
+            w[...] = 1.0
         print("w", np.amin(w), np.amax(w), np.median(w))
         model.fit(pes_raw,
                   #{k: v[train_idx]
-- 
GitLab