From 0171eeca1e159d4b5019a4f80ef1cf2033ebe768 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 17 Apr 2023 17:20:47 +0200 Subject: [PATCH] Corrected BNN weighting due to averaging. Removed intensity calculation, as we rely on the XGM now. --- pes_to_spec/bnn.py | 19 ++++++++++++------- pes_to_spec/model.py | 5 +---- pes_to_spec/test/offline_analysis.py | 3 +++ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 9ebf2a3..1b6072e 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -22,7 +22,7 @@ class BNN(nn.Module): """ def __init__(self, input_dimension: int=1, output_dimension: int=1): super(BNN, self).__init__() - hidden_dimension = 100 + hidden_dimension = 500 self.model = nn.Sequential( bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, @@ -46,13 +46,14 @@ class BNN(nn.Module): """ Calculate the negative log-likelihood (divided by the batch size, since we take the mean). """ + n_output = target.shape[1] error = w*(prediction - target) squared_error = error**2 #return 0.5*squared_error.mean() sigma2 = torch.exp(self.log_sigma2)[0] norm_error = 0.5*squared_error/sigma2 - norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0]) - return norm_error.mean() + norm_term + norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])*n_output + return norm_error.sum(dim=1).mean(dim=0) + norm_term def aleatoric_uncertainty(self) -> torch.Tensor: """ @@ -114,18 +115,18 @@ class BNNModel(RegressorMixin, BaseEstimator): self.model = BNN(X.shape[1], y.shape[1]) # prepare data loader - B = 20 + B = 10 loader = DataLoader(ds, batch_size=B) optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) number_of_batches = len(ds)/float(B) weight_kl = 1.0/float(number_of_batches) # KL loss - kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False) + kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False) # train self.model.train() - epochs = 1000 + epochs = 200 for epoch in range(epochs): losses = list() nlls = list() @@ -136,8 +137,12 @@ class BNNModel(RegressorMixin, BaseEstimator): w_b = batch["w"] y_b_pred = self.model(x_b) + # the NLL is divided by the number of batch samples + # so divide also the KL loss by the number of batch elements, so that the + # function optimized is F/# samples + # https://arxiv.org/pdf/1505.05424.pdf nll = self.model.nll(y_b_pred, y_b, w_b) - prior = weight_kl * kl_loss(self.model) + prior = weight_kl * kl_loss(self.model)/float(B) loss = nll + prior optimizer.zero_grad() diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 37a7ae2..a961db3 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -923,8 +923,7 @@ class Model(TransformerMixin, BaseEstimator): self.fit_model.state_dict(), self.channel_pca, #self.channel_fit_model - DataHolder(dict(mu_intensity=self.mu_intensity, - sigma_intensity=self.sigma_intensity, + DataHolder(dict( mu_xgm=self.mu_xgm, sigma_xgm=self.sigma_xgm, wiener_filter_ft=self.wiener_filter_ft, @@ -968,8 +967,6 @@ class Model(TransformerMixin, BaseEstimator): obj.kde_xgm = kde_xgm extra = extra.get_data() - obj.mu_intensity = extra["mu_intensity"] - obj.sigma_intensity = extra["sigma_intensity"] obj.mu_xgm = extra["mu_xgm"] obj.sigma_xgm = extra["sigma_xgm"] obj.wiener_filter_ft = extra["wiener_filter_ft"] diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index edd3efd..0322a00 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -144,6 +144,7 @@ def main(): parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.') parser.add_argument('-e', '--poly', action="store_true", default=False, help='Wheteher to expand PES data in higher order polynomials.') + parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') args = parser.parse_args() @@ -245,6 +246,8 @@ def main(): print("Fitting") start = time_ns() w = model.uniformize(xgm_flux[train_idx]) + if not args.weight: + w[...] = 1.0 print("w", np.amin(w), np.amax(w), np.median(w)) model.fit(pes_raw, #{k: v[train_idx] -- GitLab