diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 32461c28a8c9313fe8adea470240d397f12e5627..7c0b3411199f5337f3bf60f3d1ac40af3d8f2ac2 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -2,13 +2,259 @@ from sklearn.base import BaseEstimator, RegressorMixin
 from typing import Any, Dict, Optional, Union, Tuple
 
 import numpy as np
+import math
 from scipy.special import gamma
 
 import torch
 import torch.nn as nn
-import torchbnn as bnn
+import torch.nn.functional as F
 from torch.utils.data import TensorDataset, DataLoader
 
+class BayesLinearEmpiricalPrior(nn.Module):
+    """
+    Applies Bayesian Linear
+
+    Args:
+        prior_mu (Float): mean of prior normal distribution.
+        prior_sigma (Float): sigma of prior normal distribution.
+
+    """
+    __constants__ = ['prior_mu', 'prior_sigma', 'bias', 'in_features', 'out_features']
+
+    def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True):
+        super(BayesLinearEmpiricalPrior, self).__init__()
+        self.in_features = in_features
+        self.out_features = out_features
+
+        self.prior_mu = prior_mu
+        self.prior_sigma = prior_sigma
+
+        self.prior_log_sigma_w = nn.Parameter(torch.ones((out_features, in_features))*np.log(prior_sigma))
+        self.prior_log_sigma_b = nn.Parameter(torch.ones((out_features,))*np.log(prior_sigma))
+
+        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
+        self.weight_log_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
+        self.register_buffer('weight_eps', None)
+
+        if bias is None or bias is False :
+            self.bias = False
+        else :
+            self.bias = True
+
+        if self.bias:
+            self.bias_mu = nn.Parameter(torch.Tensor(out_features))
+            self.bias_log_sigma = nn.Parameter(torch.Tensor(out_features))
+            self.register_buffer('bias_eps', None)
+        else:
+            self.register_parameter('bias_mu', None)
+            self.register_parameter('bias_log_sigma', None)
+            self.register_buffer('bias_eps', None)
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        stdv = 1. / np.sqrt(self.weight_mu.size(1))
+        self.weight_mu.data.uniform_(-stdv, stdv)
+        self.weight_log_sigma.data.fill_(np.log(self.prior_sigma))
+        if self.bias :
+            self.bias_mu.data.uniform_(-stdv, stdv)
+            self.bias_log_sigma.data.fill_(np.log(self.prior_sigma))
+
+    def freeze(self) :
+        self.weight_eps = torch.randn_like(self.weight_log_sigma)
+        if self.bias :
+            self.bias_eps = torch.randn_like(self.bias_log_sigma)
+
+    def unfreeze(self) :
+        self.weight_eps = None
+        if self.bias :
+            self.bias_eps = None
+
+    def forward(self, input):
+        r"""
+        Overriden.
+        """
+        if self.weight_eps is None :
+            weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
+        else :
+            weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
+
+        if self.bias:
+            if self.bias_eps is None :
+                bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
+            else :
+                bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
+        else :
+            bias = None
+
+        return F.linear(input, weight, bias)
+
+    def extra_repr(self):
+        r"""
+        Overriden.
+        """
+        return 'prior_mu={}, prior_sigma={}, in_features={}, out_features={}, bias={}'.format(self.prior_mu, self.prior_sigma, self.in_features, self.out_features, self.bias is not None)
+
+class BayesLinear(nn.Module):
+    r"""
+    Applies Bayesian Linear
+
+    Arguments:
+        prior_mu (Float): mean of prior normal distribution.
+        prior_sigma (Float): sigma of prior normal distribution.
+
+    """
+    __constants__ = ['prior_mu', 'bias', 'in_features', 'out_features']
+
+    def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True):
+        super(BayesLinear, self).__init__()
+        self.in_features = in_features
+        self.out_features = out_features
+
+        self.prior_mu = prior_mu
+        self.prior_log_sigma = nn.Parameter(torch.ones(1)*np.log(prior_sigma), requires_grad=True)
+
+        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
+        self.weight_log_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
+        self.register_buffer('weight_eps', None)
+
+        if bias is None or bias is False:
+            self.bias = False
+        else :
+            self.bias = True
+
+        if self.bias:
+            self.bias_mu = nn.Parameter(torch.Tensor(out_features))
+            self.bias_log_sigma = nn.Parameter(torch.Tensor(out_features))
+            self.register_buffer('bias_eps', None)
+        else:
+            self.register_parameter('bias_mu', None)
+            self.register_parameter('bias_log_sigma', None)
+            self.register_buffer('bias_eps', None)
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        stdv = 1. / np.sqrt(self.weight_mu.size(1))
+        self.weight_mu.data.uniform_(-stdv, stdv)
+        self.weight_log_sigma.data.fill_(self.prior_log_sigma.detach()[0])
+        if self.bias :
+            self.bias_mu.data.uniform_(-stdv, stdv)
+            self.bias_log_sigma.data.fill_(self.prior_log_sigma.detach()[0])
+
+    def freeze(self) :
+        self.weight_eps = torch.randn_like(self.weight_log_sigma)
+        if self.bias :
+            self.bias_eps = torch.randn_like(self.bias_log_sigma)
+
+    def unfreeze(self) :
+        self.weight_eps = None
+        if self.bias:
+            self.bias_eps = None
+
+    def forward(self, input):
+        r"""
+        Overriden.
+        """
+        if self.weight_eps is None :
+            weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
+        else :
+            weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
+
+        if self.bias:
+            if self.bias_eps is None :
+                bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
+            else :
+                bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
+        else :
+            bias = None
+
+        return F.linear(input, weight, bias)
+
+    def extra_repr(self):
+        r"""
+        Overriden.
+        """
+        return 'prior_mu={}, prior_sigma={}, in_features={}, out_features={}, bias={}'.format(self.prior_mu, self.prior_sigma, self.in_features, self.out_features, self.bias is not None)
+
+def _kl_loss(mu_0, log_sigma_0, mu_1, log_sigma_1) :
+    """
+    An method for calculating KL divergence between two Normal distribtuion.
+
+    Arguments:
+        mu_0 (Float) : mean of normal distribution.
+        log_sigma_0 (Float): log(standard deviation of normal distribution).
+        mu_1 (Float): mean of normal distribution.
+        log_sigma_1 (Float): log(standard deviation of normal distribution).
+
+    """
+    if isinstance(log_sigma_1, float):
+        sigma_1 = np.exp(log_sigma_1)
+    else:
+        sigma_1 = torch.exp(log_sigma_1)
+    kl = log_sigma_1 - log_sigma_0 + \
+    (torch.exp(log_sigma_0)**2 + (mu_0-mu_1)**2)/(2*sigma_1**2) - 0.5
+    return kl.sum()
+
+class BKLLoss(nn.Module):
+    """
+    Loss for calculating KL divergence of baysian neural network model.
+
+    Args:
+        reduction (string, optional): Specifies the reduction to apply to the output:
+            ``'mean'``: the sum of the output will be divided by the number of
+            elements of the output.
+            ``'sum'``: the output will be summed.
+        last_layer_only (Bool): True for return only the last layer's KL divergence.
+    """
+    __constants__ = ['reduction']
+
+    def __init__(self, reduction='mean', last_layer_only=False):
+        super(BKLLoss, self).__init__()
+        self.last_layer_only = last_layer_only
+        self.reduction = reduction
+
+    def forward(self, model):
+        """
+        Args:
+            model (nn.Module): a model to be calculated for KL-divergence.
+        """
+        #return bayesian_kl_loss(model, reduction=self.reduction, last_layer_only=self.last_layer_only)
+        device = torch.device("cuda" if next(model.parameters()).is_cuda else "cpu")
+        kl = torch.Tensor([0]).to(device)
+        kl_sum = torch.Tensor([0]).to(device)
+        n = torch.Tensor([0]).to(device)
+
+        for m in model.modules() :
+            if isinstance(m, (BayesLinearEmpiricalPrior)):
+                kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma_w)
+                kl_sum += kl
+                n += len(m.weight_mu.view(-1))
+
+                if m.bias :
+                    kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma_b)
+                    kl_sum += kl
+                    n += len(m.bias_mu.view(-1))
+            if isinstance(m, (BayesLinear)):
+                kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma)
+                kl_sum += kl
+                n += len(m.weight_mu.view(-1))
+
+                if m.bias :
+                    kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma)
+                    kl_sum += kl
+                    n += len(m.bias_mu.view(-1))
+
+        if self.last_layer_only or n == 0 :
+            return kl
+
+        if self.reduction == 'mean':
+            return kl_sum/n
+        elif self.reduction == 'sum':
+            return kl_sum
+        else:
+            raise ValueError(f"{self.reduction} is not valid")
+
 class AverageMeter(object):
     """Computes and stores the average and current value"""
     def __init__(self, name, fmt=':f'):
@@ -49,6 +295,7 @@ class ProgressMeter(object):
         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
 
 
+
 class BNN(nn.Module):
     """
         A model Bayesian Neural network.
@@ -61,13 +308,13 @@ class BNN(nn.Module):
         between the prediction and the true value. The standard deviation of the Gaussian is left as a
         parameter to be fit: sigma.
     """
-    def __init__(self, input_dimension: int=1, output_dimension: int=1):
+    def __init__(self, input_dimension: int=1, output_dimension: int=1, rvm: bool=False):
         super(BNN, self).__init__()
         hidden_dimension = 50
         # controls the aleatoric uncertainty
         self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True)
         # controls the weight hyperprior
-        self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+        self.log_ilambda2 = -np.log(0.1**2)
 
         # inverse Gamma hyper prior alpha and beta
         #
@@ -82,8 +329,8 @@ class BNN(nn.Module):
         # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization.
         # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior.
         # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit.
-        self.alpha_lambda = 0.001
-        self.beta_lambda = 0.001
+        self.alpha_lambda = 0.0001
+        self.beta_lambda = 0.0001
 
         # Hyperprior choice on the likelihood noise level:
         # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different
@@ -92,20 +339,45 @@ class BNN(nn.Module):
         # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative
         # This seems to lead to a larger training time, though.
         # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range
-        self.alpha_sigma = 0.001
-        self.beta_sigma = 0.001
-
-        self.model = nn.Sequential(
-                                   bnn.BayesLinear(prior_mu=0.0,
-                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
+        self.alpha_sigma = 0.0001
+        self.beta_sigma = 0.0001
+
+        if rvm:
+            self.model = nn.Sequential(
+                                       BayesLinearEmpiricalPrior(prior_mu=0.0,
+                                                                     prior_sigma=np.exp(-0.5*self.log_ilambda2),
+                                                                     in_features=input_dimension,
+                                                                     out_features=hidden_dimension),
+                                       nn.ReLU(),
+                                       BayesLinearEmpiricalPrior(prior_mu=0.0,
+                                                                     prior_sigma=np.exp(-0.5*self.log_ilambda2),
+                                                                     in_features=hidden_dimension,
+                                                                     out_features=output_dimension)
+                                        )
+        else:
+            self.model = nn.Sequential(
+                                       BayesLinear(prior_mu=0.0,
+                                                   prior_sigma=np.exp(-0.5*self.log_ilambda2),
                                                    in_features=input_dimension,
                                                    out_features=hidden_dimension),
-                                   nn.ReLU(),
-                                   bnn.BayesLinear(prior_mu=0.0,
-                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
+                                       nn.ReLU(),
+                                       BayesLinear(prior_mu=0.0,
+                                                   prior_sigma=np.exp(-0.5*self.log_ilambda2),
                                                    in_features=hidden_dimension,
                                                    out_features=output_dimension)
-                                    )
+                                        )
+        self.rvm = rvm
+
+    def prune(self):
+        """Prune weights."""
+        with torch.no_grad():
+            for layer in self.model.modules():
+                if isinstance(layer, BayesLinearEmpiricalPrior):
+                    log_isigma2 = -2.0*layer.prior_log_sigma_w
+                    isigma2 = torch.exp(log_isigma2)
+                    keep = isigma2 < 1e4
+                    layer.weight_mu[~keep] *= 0.0
+                    layer.weight_log_sigma[~keep] = -12.0
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """
@@ -138,10 +410,21 @@ class BNN(nn.Module):
         # with a standardized input, this hyperprior forces sigma to be
         # on avg. 1 and it is broad enough to allow for different sigma
         isigma2 = torch.exp(self.log_isigma2)
-        neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
-        ilambda2 = torch.exp(self.log_ilambda2)
-        neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
-        return neg_log_hyperprior_noise.sum() + neg_log_hyperprior_weights.sum()
+        neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma).sum()
+        if self.rvm:
+            log_ilambda2 = [-2.0*self.model[0].prior_log_sigma_w,
+                            -2.0*self.model[2].prior_log_sigma_w,
+                            -2.0*self.model[0].prior_log_sigma_b,
+                            -2.0*self.model[2].prior_log_sigma_b
+                            ]
+        else:
+            log_ilambda2 = [-2.0*self.model[0].prior_log_sigma,
+                            -2.0*self.model[2].prior_log_sigma,
+                            ]
+        ilambda2 = [torch.exp(k) for k in log_ilambda2]
+        neg_log_hyperprior_weights = sum(self.neg_log_gamma(log_k, k, self.alpha_lambda, self.beta_lambda).sum()
+                                         for log_k, k in zip(log_ilambda2, ilambda2))
+        return neg_log_hyperprior_noise + neg_log_hyperprior_weights
 
     def aleatoric_uncertainty(self) -> torch.Tensor:
         """
@@ -154,7 +437,18 @@ class BNN(nn.Module):
         """
             Get the weights precision.
         """
-        return torch.exp(self.log_ilambda2[0])
+        if self.rvm:
+            log_ilambda2 = [-2.0*self.model[0].prior_log_sigma_w,
+                            -2.0*self.model[2].prior_log_sigma_w,
+                            -2.0*self.model[0].prior_log_sigma_b,
+                            -2.0*self.model[2].prior_log_sigma_b
+                            ]
+        else:
+            log_ilambda2 = [-2.0*self.model[0].prior_log_sigma,
+                            -2.0*self.model[2].prior_log_sigma,
+                            ]
+        ilambda2 = [torch.exp(k) for k in log_ilambda2]
+        return sum(k.mean() for k in ilambda2)/len(ilambda2)
 
 class BNNModel(RegressorMixin, BaseEstimator):
     """
@@ -162,14 +456,15 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
     Args:
     """
-    def __init__(self, state_dict=None):
+    def __init__(self, state_dict=None, rvm: bool=False):
         if state_dict is not None:
             Nx = state_dict["model.0.weight_mu"].shape[1]
             Ny = state_dict["model.2.weight_mu"].shape[0]
-            self.model = BNN(Nx, Ny)
+            self.model = BNN(Nx, Ny, rvm=rvm)
             self.model.load_state_dict(state_dict)
         else:
-            self.model = BNN()
+            self.model = BNN(rvm=rvm)
+        self.rvm = rvm
         self.model.eval()
 
     def state_dict(self) -> Dict[str, Any]:
@@ -197,7 +492,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
                            torch.from_numpy(weights))
 
         # create model
-        self.model = BNN(X.shape[1], y.shape[1])
+        self.model = BNN(X.shape[1], y.shape[1], rvm=self.rvm)
 
         # prepare data loader
         B = 50
@@ -218,11 +513,11 @@ class BNNModel(RegressorMixin, BaseEstimator):
         weight_prior /= float(B)
 
         # KL loss
-        kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False)
+        kl_loss = BKLLoss(reduction='sum', last_layer_only=False)
 
         # train
         self.model.train()
-        epochs = 1000
+        epochs = 250
         for epoch in range(epochs):
             meter = {k: AverageMeter(k, ':6.3f')
                     for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
@@ -251,6 +546,9 @@ class BNNModel(RegressorMixin, BaseEstimator):
                 meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
 
             progress.display(len(loader))
+        if self.rvm:
+            self.model.prune()
+
         self.model.eval()
 
         return self
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 5751510b4a5e98187ffd1081849c4dba0223d2fa..44c14289b3ac237cbdf9ac29275caa7cfe89bb6e 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -574,7 +574,7 @@ class Model(TransformerMixin, BaseEstimator):
                  Set to None to perform no selection.
       validation_size: Fraction (number between 0 and 1) of the data to take for
                        validation and systematic uncertainty estimate.
-      model_type: Which model to use. "bnn" for a BNN, "ridge" for Ridge and "ard" for ARD.
+      model_type: Which model to use. "bnn" for a BNN, "bnn_rvm" for a BNN with RVM, "ridge" for Ridge and "ard" for ARD.
 
     """
     def __init__(self,
@@ -586,11 +586,11 @@ class Model(TransformerMixin, BaseEstimator):
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
                  validation_size: float=0.05,
-                 model_type: Literal["bnn", "ridge", "ard"]="ard",
+                 model_type: Literal["bnn", "bnn_rvm", "ridge", "ard"]="ard",
                 ):
         self.high_res_sigma = high_res_sigma
         # models
-        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=(model_type not in ["bnn"]))
+        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=(model_type not in ["bnn", "bnn_rvm"]))
         x_model_steps = list()
         x_model_steps += [
                           ('pca', PCA(n_pca_lr, whiten=True)),
@@ -606,6 +606,8 @@ class Model(TransformerMixin, BaseEstimator):
                     for ch in channels+['full']}
         if model_type == "bnn":
             self.fit_model = BNNModel()
+        elif model_type == "bnn_rvm":
+            self.fit_model = BNNModel(rvm=True)
         elif model_type == "ridge":
             self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
         elif model_type == "ard":
@@ -627,7 +629,7 @@ class Model(TransformerMixin, BaseEstimator):
 
     def n_pars(self) -> float:
         """Get number of parameters."""
-        if self.model_type == "bnn":
+        if self.model_type in ("bnn", "bnn_rvm"):
             return sum(p.numel() for p in self.fit_model.model.parameters())
         return sum(len(estimator.coef_) + 1 for estimator in self.fit_model.estimators_)
 
@@ -1015,7 +1017,7 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
-                     self.fit_model.state_dict() if self.model_type == "bnn" else self.fit_model,
+                     self.fit_model.state_dict() if self.model_type in ("bnn", "bnn_rvm") else self.fit_model,
                      self.channel_pca,
                      #self.channel_fit_model
                      DataHolder(dict(
@@ -1073,7 +1075,9 @@ class Model(TransformerMixin, BaseEstimator):
         obj.x_model = x_model
         obj.y_model = y_model
         if obj.model_type == "bnn":
-            obj.fit_model = BNNModel(state_dict=fit_model)
+            obj.fit_model = BNNModel(state_dict=fit_model, rvm=False)
+        elif obj.model_type == "bnn_rvm":
+            obj.fit_model = BNNModel(state_dict=fit_model, rvm=True)
         else:
             obj.fit_model = fit_model
         obj.channel_pca = channel_pca
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 6c530092c3464c3dd3272d1b789dca55e64b24f7..a199e8b663e5a0d9c813c9c967afbdd5729da4c6 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -131,7 +131,7 @@ def main():
     parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
     parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
     parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=0, help='XGM intensity threshold in uJ.')
-    parser.add_argument('-T', '--model-type', type=str, metavar='TYPE', default="ard", choices=["bnn", "ridge", "ard"], help='Which model type to use.')
+    parser.add_argument('-T', '--model-type', type=str, metavar='TYPE', default="ard", choices=["bnn", "bnn_rvm", "ridge", "ard"], help='Which model type to use.')
     parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
 
     args = parser.parse_args()
diff --git a/pes_to_spec/test/prepare_plots.py b/pes_to_spec/test/prepare_plots.py
index 7dc0888e72f74e18164d0c4a0ac37de81b858693..12ca07cd4d98a448665ee6a1d876d0adc2d294ed 100755
--- a/pes_to_spec/test/prepare_plots.py
+++ b/pes_to_spec/test/prepare_plots.py
@@ -12,8 +12,8 @@ from matplotlib.gridspec import GridSpec
 import seaborn as sns
 
 SMALL_SIZE = 12
-MEDIUM_SIZE = 18
-BIGGER_SIZE = 24
+MEDIUM_SIZE = 22
+BIGGER_SIZE = 26
 
 plt.rc('font', size=BIGGER_SIZE)         # controls default text sizes
 plt.rc('axes', titlesize=BIGGER_SIZE)    # fontsize of the axes title
diff --git a/pyproject.toml b/pyproject.toml
index 86ea8d577381b8e30672742abe3d63e0d5ff09b6..9ea807b4f8cce6209cbb409d43aec2cc6e905656 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,6 @@ dependencies = [
           "scipy>=1.6",
           "scikit-learn>=1.2.0",
           "torch",
-          "torchbnn",
           ]
 
 [project.optional-dependencies]