diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index aaf850988631569f696274aea52fd1fca3625f28..887512e093e555bb65bef8442585eb0c84801146 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -63,9 +63,9 @@ class BNN(nn.Module): """ def __init__(self, input_dimension: int=1, output_dimension: int=1): super(BNN, self).__init__() - hidden_dimension = 100 + hidden_dimension = 50 # controls the aleatoric uncertainty - self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True) + self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True) # controls the weight hyperprior self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True) @@ -82,8 +82,8 @@ class BNN(nn.Module): # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization. # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior. # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit. - self.alpha_lambda = 0.1 - self.beta_lambda = 0.1 + self.alpha_lambda = 0.001 + self.beta_lambda = 0.001 # Hyperprior choice on the likelihood noise level: # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different @@ -92,8 +92,8 @@ class BNN(nn.Module): # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative # This seems to lead to a larger training time, though. # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range - self.alpha_sigma = 0.1 - self.beta_sigma = 0.1 + self.alpha_sigma = 0.001 + self.beta_sigma = 0.001 self.model = nn.Sequential( bnn.BayesLinear(prior_mu=0.0, @@ -123,13 +123,12 @@ class BNN(nn.Module): """ Calculate the negative log-likelihood (divided by the batch size, since we take the mean). """ - n_output = target.shape[1] error = w*(prediction - target) squared_error = error**2 - sigma2 = torch.exp(-self.log_isigma2)[0] + sigma2 = torch.exp(-self.log_isigma2) norm_error = 0.5*squared_error/sigma2 - norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output - return norm_error.sum(dim=1).mean(dim=0) + norm_term + norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2) + return (norm_error + norm_term).sum(dim=1).mean(dim=0) def neg_log_hyperprior(self) -> torch.Tensor: """ @@ -138,18 +137,18 @@ class BNN(nn.Module): # hyperprior for sigma to avoid large or too small sigma # with a standardized input, this hyperprior forces sigma to be # on avg. 1 and it is broad enough to allow for different sigma - isigma2 = torch.exp(self.log_ilambda2)[0] + isigma2 = torch.exp(self.log_isigma2) neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma) - ilambda2 = torch.exp(self.log_ilambda2)[0] + ilambda2 = torch.exp(self.log_ilambda2) neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda) - return neg_log_hyperprior_noise + neg_log_hyperprior_weights + return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum() def aleatoric_uncertainty(self) -> torch.Tensor: """ Get the aleatoric component of the uncertainty. """ #return 0 - return torch.exp(-0.5*self.log_isigma2[0]) + return torch.exp(-0.5*self.log_isigma2) def w_precision(self) -> torch.Tensor: """ @@ -201,10 +200,10 @@ class BNNModel(RegressorMixin, BaseEstimator): self.model = BNN(X.shape[1], y.shape[1]) # prepare data loader - B = 100 + B = 50 loader = DataLoader(ds, batch_size=B, - num_workers=5, + num_workers=20, shuffle=True, #pin_memory=True, drop_last=True, @@ -223,7 +222,7 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() - epochs = 1000 + epochs = 500 for epoch in range(epochs): meter = {k: AverageMeter(k, ':6.3f') for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} @@ -248,7 +247,7 @@ class BNNModel(RegressorMixin, BaseEstimator): meter['-log(lkl)'].update(nll.detach().cpu().item(), B) meter['-log(prior)'].update(nlprior.detach().cpu().item(), B) meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B) - meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B) + meter['sigma'].update(self.model.aleatoric_uncertainty().mean().detach().cpu().numpy(), B) meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B) progress.display(len(loader)) @@ -268,12 +267,12 @@ class BNNModel(RegressorMixin, BaseEstimator): K = 10 y_pred = list() for _ in range(K): - y_k = self.model(torch.from_numpy(X)).detach().numpy() + y_k = self.model(torch.from_numpy(X)).detach().cpu().numpy() y_pred.append(y_k) y_pred = np.stack(y_pred, axis=1) y_mu = np.mean(y_pred, axis=1) y_epi = np.std(y_pred, axis=1) - y_ale = self.model.aleatoric_uncertainty().detach().numpy() + y_ale = self.model.aleatoric_uncertainty().detach().cpu().numpy() y_unc = (y_epi**2 + y_ale**2)**0.5 if not return_std: return y_mu diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 2da28ebe6d2c4cca4206b69e5ad7d67266a2d4de..5ede9532259921dd4033e2f67dfc5492073a7ac7 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -144,7 +144,7 @@ def main(): parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.') parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?') - parser.add_argument('-w', '--weight', action="store_true", default=True, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') + parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') args = parser.parse_args()