From f145f5cb5d36e986f736bea2ea60912bf440fead Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Thu, 17 Aug 2023 17:00:48 +0200 Subject: [PATCH] Corrected bugs in the BNN and moved relevant code from torchbnn to remove its dependency. Split log_sigma per layer, so that one can do proper empirical Bayes. --- pes_to_spec/bnn.py | 350 +++++++++++++++++++++++++-- pes_to_spec/model.py | 16 +- pes_to_spec/test/offline_analysis.py | 2 +- pes_to_spec/test/prepare_plots.py | 4 +- pyproject.toml | 1 - 5 files changed, 337 insertions(+), 36 deletions(-) diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 32461c2..7c0b341 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -2,13 +2,259 @@ from sklearn.base import BaseEstimator, RegressorMixin from typing import Any, Dict, Optional, Union, Tuple import numpy as np +import math from scipy.special import gamma import torch import torch.nn as nn -import torchbnn as bnn +import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader +class BayesLinearEmpiricalPrior(nn.Module): + """ + Applies Bayesian Linear + + Args: + prior_mu (Float): mean of prior normal distribution. + prior_sigma (Float): sigma of prior normal distribution. + + """ + __constants__ = ['prior_mu', 'prior_sigma', 'bias', 'in_features', 'out_features'] + + def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True): + super(BayesLinearEmpiricalPrior, self).__init__() + self.in_features = in_features + self.out_features = out_features + + self.prior_mu = prior_mu + self.prior_sigma = prior_sigma + + self.prior_log_sigma_w = nn.Parameter(torch.ones((out_features, in_features))*np.log(prior_sigma)) + self.prior_log_sigma_b = nn.Parameter(torch.ones((out_features,))*np.log(prior_sigma)) + + self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features)) + self.weight_log_sigma = nn.Parameter(torch.Tensor(out_features, in_features)) + self.register_buffer('weight_eps', None) + + if bias is None or bias is False : + self.bias = False + else : + self.bias = True + + if self.bias: + self.bias_mu = nn.Parameter(torch.Tensor(out_features)) + self.bias_log_sigma = nn.Parameter(torch.Tensor(out_features)) + self.register_buffer('bias_eps', None) + else: + self.register_parameter('bias_mu', None) + self.register_parameter('bias_log_sigma', None) + self.register_buffer('bias_eps', None) + + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / np.sqrt(self.weight_mu.size(1)) + self.weight_mu.data.uniform_(-stdv, stdv) + self.weight_log_sigma.data.fill_(np.log(self.prior_sigma)) + if self.bias : + self.bias_mu.data.uniform_(-stdv, stdv) + self.bias_log_sigma.data.fill_(np.log(self.prior_sigma)) + + def freeze(self) : + self.weight_eps = torch.randn_like(self.weight_log_sigma) + if self.bias : + self.bias_eps = torch.randn_like(self.bias_log_sigma) + + def unfreeze(self) : + self.weight_eps = None + if self.bias : + self.bias_eps = None + + def forward(self, input): + r""" + Overriden. + """ + if self.weight_eps is None : + weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma) + else : + weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps + + if self.bias: + if self.bias_eps is None : + bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma) + else : + bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps + else : + bias = None + + return F.linear(input, weight, bias) + + def extra_repr(self): + r""" + Overriden. + """ + return 'prior_mu={}, prior_sigma={}, in_features={}, out_features={}, bias={}'.format(self.prior_mu, self.prior_sigma, self.in_features, self.out_features, self.bias is not None) + +class BayesLinear(nn.Module): + r""" + Applies Bayesian Linear + + Arguments: + prior_mu (Float): mean of prior normal distribution. + prior_sigma (Float): sigma of prior normal distribution. + + """ + __constants__ = ['prior_mu', 'bias', 'in_features', 'out_features'] + + def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True): + super(BayesLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + self.prior_mu = prior_mu + self.prior_log_sigma = nn.Parameter(torch.ones(1)*np.log(prior_sigma), requires_grad=True) + + self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features)) + self.weight_log_sigma = nn.Parameter(torch.Tensor(out_features, in_features)) + self.register_buffer('weight_eps', None) + + if bias is None or bias is False: + self.bias = False + else : + self.bias = True + + if self.bias: + self.bias_mu = nn.Parameter(torch.Tensor(out_features)) + self.bias_log_sigma = nn.Parameter(torch.Tensor(out_features)) + self.register_buffer('bias_eps', None) + else: + self.register_parameter('bias_mu', None) + self.register_parameter('bias_log_sigma', None) + self.register_buffer('bias_eps', None) + + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / np.sqrt(self.weight_mu.size(1)) + self.weight_mu.data.uniform_(-stdv, stdv) + self.weight_log_sigma.data.fill_(self.prior_log_sigma.detach()[0]) + if self.bias : + self.bias_mu.data.uniform_(-stdv, stdv) + self.bias_log_sigma.data.fill_(self.prior_log_sigma.detach()[0]) + + def freeze(self) : + self.weight_eps = torch.randn_like(self.weight_log_sigma) + if self.bias : + self.bias_eps = torch.randn_like(self.bias_log_sigma) + + def unfreeze(self) : + self.weight_eps = None + if self.bias: + self.bias_eps = None + + def forward(self, input): + r""" + Overriden. + """ + if self.weight_eps is None : + weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma) + else : + weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps + + if self.bias: + if self.bias_eps is None : + bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma) + else : + bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps + else : + bias = None + + return F.linear(input, weight, bias) + + def extra_repr(self): + r""" + Overriden. + """ + return 'prior_mu={}, prior_sigma={}, in_features={}, out_features={}, bias={}'.format(self.prior_mu, self.prior_sigma, self.in_features, self.out_features, self.bias is not None) + +def _kl_loss(mu_0, log_sigma_0, mu_1, log_sigma_1) : + """ + An method for calculating KL divergence between two Normal distribtuion. + + Arguments: + mu_0 (Float) : mean of normal distribution. + log_sigma_0 (Float): log(standard deviation of normal distribution). + mu_1 (Float): mean of normal distribution. + log_sigma_1 (Float): log(standard deviation of normal distribution). + + """ + if isinstance(log_sigma_1, float): + sigma_1 = np.exp(log_sigma_1) + else: + sigma_1 = torch.exp(log_sigma_1) + kl = log_sigma_1 - log_sigma_0 + \ + (torch.exp(log_sigma_0)**2 + (mu_0-mu_1)**2)/(2*sigma_1**2) - 0.5 + return kl.sum() + +class BKLLoss(nn.Module): + """ + Loss for calculating KL divergence of baysian neural network model. + + Args: + reduction (string, optional): Specifies the reduction to apply to the output: + ``'mean'``: the sum of the output will be divided by the number of + elements of the output. + ``'sum'``: the output will be summed. + last_layer_only (Bool): True for return only the last layer's KL divergence. + """ + __constants__ = ['reduction'] + + def __init__(self, reduction='mean', last_layer_only=False): + super(BKLLoss, self).__init__() + self.last_layer_only = last_layer_only + self.reduction = reduction + + def forward(self, model): + """ + Args: + model (nn.Module): a model to be calculated for KL-divergence. + """ + #return bayesian_kl_loss(model, reduction=self.reduction, last_layer_only=self.last_layer_only) + device = torch.device("cuda" if next(model.parameters()).is_cuda else "cpu") + kl = torch.Tensor([0]).to(device) + kl_sum = torch.Tensor([0]).to(device) + n = torch.Tensor([0]).to(device) + + for m in model.modules() : + if isinstance(m, (BayesLinearEmpiricalPrior)): + kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma_w) + kl_sum += kl + n += len(m.weight_mu.view(-1)) + + if m.bias : + kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma_b) + kl_sum += kl + n += len(m.bias_mu.view(-1)) + if isinstance(m, (BayesLinear)): + kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma) + kl_sum += kl + n += len(m.weight_mu.view(-1)) + + if m.bias : + kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma) + kl_sum += kl + n += len(m.bias_mu.view(-1)) + + if self.last_layer_only or n == 0 : + return kl + + if self.reduction == 'mean': + return kl_sum/n + elif self.reduction == 'sum': + return kl_sum + else: + raise ValueError(f"{self.reduction} is not valid") + class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): @@ -49,6 +295,7 @@ class ProgressMeter(object): return '[' + fmt + '/' + fmt.format(num_batches) + ']' + class BNN(nn.Module): """ A model Bayesian Neural network. @@ -61,13 +308,13 @@ class BNN(nn.Module): between the prediction and the true value. The standard deviation of the Gaussian is left as a parameter to be fit: sigma. """ - def __init__(self, input_dimension: int=1, output_dimension: int=1): + def __init__(self, input_dimension: int=1, output_dimension: int=1, rvm: bool=False): super(BNN, self).__init__() hidden_dimension = 50 # controls the aleatoric uncertainty self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True) # controls the weight hyperprior - self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True) + self.log_ilambda2 = -np.log(0.1**2) # inverse Gamma hyper prior alpha and beta # @@ -82,8 +329,8 @@ class BNN(nn.Module): # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization. # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior. # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit. - self.alpha_lambda = 0.001 - self.beta_lambda = 0.001 + self.alpha_lambda = 0.0001 + self.beta_lambda = 0.0001 # Hyperprior choice on the likelihood noise level: # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different @@ -92,20 +339,45 @@ class BNN(nn.Module): # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative # This seems to lead to a larger training time, though. # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range - self.alpha_sigma = 0.001 - self.beta_sigma = 0.001 - - self.model = nn.Sequential( - bnn.BayesLinear(prior_mu=0.0, - prior_sigma=torch.exp(-0.5*self.log_ilambda2), + self.alpha_sigma = 0.0001 + self.beta_sigma = 0.0001 + + if rvm: + self.model = nn.Sequential( + BayesLinearEmpiricalPrior(prior_mu=0.0, + prior_sigma=np.exp(-0.5*self.log_ilambda2), + in_features=input_dimension, + out_features=hidden_dimension), + nn.ReLU(), + BayesLinearEmpiricalPrior(prior_mu=0.0, + prior_sigma=np.exp(-0.5*self.log_ilambda2), + in_features=hidden_dimension, + out_features=output_dimension) + ) + else: + self.model = nn.Sequential( + BayesLinear(prior_mu=0.0, + prior_sigma=np.exp(-0.5*self.log_ilambda2), in_features=input_dimension, out_features=hidden_dimension), - nn.ReLU(), - bnn.BayesLinear(prior_mu=0.0, - prior_sigma=torch.exp(-0.5*self.log_ilambda2), + nn.ReLU(), + BayesLinear(prior_mu=0.0, + prior_sigma=np.exp(-0.5*self.log_ilambda2), in_features=hidden_dimension, out_features=output_dimension) - ) + ) + self.rvm = rvm + + def prune(self): + """Prune weights.""" + with torch.no_grad(): + for layer in self.model.modules(): + if isinstance(layer, BayesLinearEmpiricalPrior): + log_isigma2 = -2.0*layer.prior_log_sigma_w + isigma2 = torch.exp(log_isigma2) + keep = isigma2 < 1e4 + layer.weight_mu[~keep] *= 0.0 + layer.weight_log_sigma[~keep] = -12.0 def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -138,10 +410,21 @@ class BNN(nn.Module): # with a standardized input, this hyperprior forces sigma to be # on avg. 1 and it is broad enough to allow for different sigma isigma2 = torch.exp(self.log_isigma2) - neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma) - ilambda2 = torch.exp(self.log_ilambda2) - neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda) - return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum() + neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma).sum() + if self.rvm: + log_ilambda2 = [-2.0*self.model[0].prior_log_sigma_w, + -2.0*self.model[2].prior_log_sigma_w, + -2.0*self.model[0].prior_log_sigma_b, + -2.0*self.model[2].prior_log_sigma_b + ] + else: + log_ilambda2 = [-2.0*self.model[0].prior_log_sigma, + -2.0*self.model[2].prior_log_sigma, + ] + ilambda2 = [torch.exp(k) for k in log_ilambda2] + neg_log_hyperprior_weights = sum(self.neg_log_gamma(log_k, k, self.alpha_lambda, self.beta_lambda).sum() + for log_k, k in zip(log_ilambda2, ilambda2)) + return neg_log_hyperprior_noise + neg_log_hyperprior_weights def aleatoric_uncertainty(self) -> torch.Tensor: """ @@ -154,7 +437,18 @@ class BNN(nn.Module): """ Get the weights precision. """ - return torch.exp(self.log_ilambda2[0]) + if self.rvm: + log_ilambda2 = [-2.0*self.model[0].prior_log_sigma_w, + -2.0*self.model[2].prior_log_sigma_w, + -2.0*self.model[0].prior_log_sigma_b, + -2.0*self.model[2].prior_log_sigma_b + ] + else: + log_ilambda2 = [-2.0*self.model[0].prior_log_sigma, + -2.0*self.model[2].prior_log_sigma, + ] + ilambda2 = [torch.exp(k) for k in log_ilambda2] + return sum(k.mean() for k in ilambda2)/len(ilambda2) class BNNModel(RegressorMixin, BaseEstimator): """ @@ -162,14 +456,15 @@ class BNNModel(RegressorMixin, BaseEstimator): Args: """ - def __init__(self, state_dict=None): + def __init__(self, state_dict=None, rvm: bool=False): if state_dict is not None: Nx = state_dict["model.0.weight_mu"].shape[1] Ny = state_dict["model.2.weight_mu"].shape[0] - self.model = BNN(Nx, Ny) + self.model = BNN(Nx, Ny, rvm=rvm) self.model.load_state_dict(state_dict) else: - self.model = BNN() + self.model = BNN(rvm=rvm) + self.rvm = rvm self.model.eval() def state_dict(self) -> Dict[str, Any]: @@ -197,7 +492,7 @@ class BNNModel(RegressorMixin, BaseEstimator): torch.from_numpy(weights)) # create model - self.model = BNN(X.shape[1], y.shape[1]) + self.model = BNN(X.shape[1], y.shape[1], rvm=self.rvm) # prepare data loader B = 50 @@ -218,11 +513,11 @@ class BNNModel(RegressorMixin, BaseEstimator): weight_prior /= float(B) # KL loss - kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False) + kl_loss = BKLLoss(reduction='sum', last_layer_only=False) # train self.model.train() - epochs = 1000 + epochs = 250 for epoch in range(epochs): meter = {k: AverageMeter(k, ':6.3f') for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} @@ -251,6 +546,9 @@ class BNNModel(RegressorMixin, BaseEstimator): meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B) progress.display(len(loader)) + if self.rvm: + self.model.prune() + self.model.eval() return self diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 5751510..44c1428 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -574,7 +574,7 @@ class Model(TransformerMixin, BaseEstimator): Set to None to perform no selection. validation_size: Fraction (number between 0 and 1) of the data to take for validation and systematic uncertainty estimate. - model_type: Which model to use. "bnn" for a BNN, "ridge" for Ridge and "ard" for ARD. + model_type: Which model to use. "bnn" for a BNN, "bnn_rvm" for a BNN with RVM, "ridge" for Ridge and "ard" for ARD. """ def __init__(self, @@ -586,11 +586,11 @@ class Model(TransformerMixin, BaseEstimator): tof_start: Optional[int]=None, delta_tof: Optional[int]=300, validation_size: float=0.05, - model_type: Literal["bnn", "ridge", "ard"]="ard", + model_type: Literal["bnn", "bnn_rvm", "ridge", "ard"]="ard", ): self.high_res_sigma = high_res_sigma # models - self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=(model_type not in ["bnn"])) + self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=(model_type not in ["bnn", "bnn_rvm"])) x_model_steps = list() x_model_steps += [ ('pca', PCA(n_pca_lr, whiten=True)), @@ -606,6 +606,8 @@ class Model(TransformerMixin, BaseEstimator): for ch in channels+['full']} if model_type == "bnn": self.fit_model = BNNModel() + elif model_type == "bnn_rvm": + self.fit_model = BNNModel(rvm=True) elif model_type == "ridge": self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) elif model_type == "ard": @@ -627,7 +629,7 @@ class Model(TransformerMixin, BaseEstimator): def n_pars(self) -> float: """Get number of parameters.""" - if self.model_type == "bnn": + if self.model_type in ("bnn", "bnn_rvm"): return sum(p.numel() for p in self.fit_model.model.parameters()) return sum(len(estimator.coef_) + 1 for estimator in self.fit_model.estimators_) @@ -1015,7 +1017,7 @@ class Model(TransformerMixin, BaseEstimator): joblib.dump([self.x_select, self.x_model, self.y_model, - self.fit_model.state_dict() if self.model_type == "bnn" else self.fit_model, + self.fit_model.state_dict() if self.model_type in ("bnn", "bnn_rvm") else self.fit_model, self.channel_pca, #self.channel_fit_model DataHolder(dict( @@ -1073,7 +1075,9 @@ class Model(TransformerMixin, BaseEstimator): obj.x_model = x_model obj.y_model = y_model if obj.model_type == "bnn": - obj.fit_model = BNNModel(state_dict=fit_model) + obj.fit_model = BNNModel(state_dict=fit_model, rvm=False) + elif obj.model_type == "bnn_rvm": + obj.fit_model = BNNModel(state_dict=fit_model, rvm=True) else: obj.fit_model = fit_model obj.channel_pca = channel_pca diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 6c53009..a199e8b 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -131,7 +131,7 @@ def main(): parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name') parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=0, help='XGM intensity threshold in uJ.') - parser.add_argument('-T', '--model-type', type=str, metavar='TYPE', default="ard", choices=["bnn", "ridge", "ard"], help='Which model type to use.') + parser.add_argument('-T', '--model-type', type=str, metavar='TYPE', default="ard", choices=["bnn", "bnn_rvm", "ridge", "ard"], help='Which model type to use.') parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') args = parser.parse_args() diff --git a/pes_to_spec/test/prepare_plots.py b/pes_to_spec/test/prepare_plots.py index 7dc0888..12ca07c 100755 --- a/pes_to_spec/test/prepare_plots.py +++ b/pes_to_spec/test/prepare_plots.py @@ -12,8 +12,8 @@ from matplotlib.gridspec import GridSpec import seaborn as sns SMALL_SIZE = 12 -MEDIUM_SIZE = 18 -BIGGER_SIZE = 24 +MEDIUM_SIZE = 22 +BIGGER_SIZE = 26 plt.rc('font', size=BIGGER_SIZE) # controls default text sizes plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title diff --git a/pyproject.toml b/pyproject.toml index 86ea8d5..9ea807b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "scipy>=1.6", "scikit-learn>=1.2.0", "torch", - "torchbnn", ] [project.optional-dependencies] -- GitLab