diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..eefa81288ee7d23ed21c5acf89efb2a52ec514a9
--- /dev/null
+++ b/pes_to_spec/bnn.py
@@ -0,0 +1,281 @@
+from sklearn.base import BaseEstimator, RegressorMixin
+from typing import Any, Dict, Optional, Union, Tuple
+
+import numpy as np
+from scipy.special import gamma
+
+import torch
+import torch.nn as nn
+import torchbnn as bnn
+from torch.utils.data import TensorDataset, DataLoader
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+class ProgressMeter(object):
+    def __init__(self, num_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch):
+        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+        entries += [str(meter) for meter in self.meters]
+        print('  '.join(entries))
+
+    def _get_batch_fmtstr(self, num_batches):
+        num_digits = len(str(num_batches // 1))
+        fmt = '{:' + str(num_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+class BNN(nn.Module):
+    """
+        A model Bayesian Neural network.
+        Each weight is represented by a Gaussian with a mean and a standard deviation.
+        Each evaluation of forward leads to a different choice of the weights, so running
+        forward several times we can check the effect of the weights variation on the same input.
+        The neg_log_likelihood function implements the negative log likelihood to be used as the first part of the loss
+        function (the second shall be the Kullback-Leibler divergence).
+        The negative log-likelihood is simply the negative log likelihood of a Gaussian
+        between the prediction and the true value. The standard deviation of the Gaussian is left as a
+        parameter to be fit: sigma.
+    """
+    def __init__(self, input_dimension: int=1, output_dimension: int=1):
+        super(BNN, self).__init__()
+        hidden_dimension = 100
+        # controls the aleatoric uncertainty
+        self.log_isigma2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+        # controls the weight hyperprior
+        self.log_ilambda2 = nn.Parameter(-torch.ones(1)*np.log(0.1**2), requires_grad=True)
+
+        # inverse Gamma hyper prior alpha and beta
+        #
+        # Hyperprior choice on the weights:
+        # We want to allow the hyperprior on the weights' variance to have large variance,
+        # so that the weights prior can be anything, if possible, but at the same time prevent it from going to infinity
+        # (which would allow the weights to be anything, but remove regularization and de-stabilize the fit).
+        # Therefore, the weights should be allowed to have high std. dev. on their priors, just not so much so that the fit is unstable.
+        # At the same time, the prior std. dev. should not be too small (that would regularize too much.
+        # The values below have been taken from BoTorch (alpha, beta) = (3.0, 6.0) and seem to work well if the inputs have been standardized.
+        # They lead to a high mean for the weights std. dev. (18) and a large variance (sqrt(var) = 10.4), so that the weights prior is large
+        # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization.
+        # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior.
+        # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit.
+        self.alpha_lambda = 3.0
+        self.beta_lambda = 6.0
+
+        # Hyperprior choice on the likelihood noise level:
+        # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different
+        # from the weights prior, it must be allowed to be small, since if we have a lot of data, it is conceivable that there is little noise in the data.
+        # We therefore want to have high variance in the hyperprior for sigma, but we do not need to prevent it from becoming small.
+        # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative
+        # This seems to lead to a larger training time, though.
+        # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range
+        self.alpha_sigma = 2.0
+        self.beta_sigma = 0.15
+
+        self.model = nn.Sequential(
+                                   bnn.BayesLinear(prior_mu=0.0,
+                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
+                                                   in_features=input_dimension,
+                                                   out_features=hidden_dimension),
+                                   nn.ReLU(),
+                                   bnn.BayesLinear(prior_mu=0.0,
+                                                   prior_sigma=torch.exp(-0.5*self.log_ilambda2),
+                                                   in_features=hidden_dimension,
+                                                   out_features=output_dimension)
+                                    )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Calculate the result f(x) applied on the input x.
+        """
+        return self.model(x)
+
+    def neg_log_gamma(self, log_x: torch.Tensor, x: torch.Tensor, alpha, beta) -> torch.Tensor:
+        """
+        Return the negative log of the gamma pdf.
+        """
+        return -alpha*np.log(beta) - (alpha - 1)*log_x + beta*x + gamma(alpha)
+
+    def neg_log_likelihood(self, prediction: torch.Tensor, target: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
+        """
+        Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
+        """
+        n_output = target.shape[1]
+        error = w*(prediction - target)
+        squared_error = error**2
+        sigma2 = torch.exp(-self.log_isigma2)[0]
+        norm_error = 0.5*squared_error/sigma2
+        norm_term = 0.5*(np.log(2*np.pi) - self.log_isigma2[0])*n_output
+        return norm_error.sum(dim=1).mean(dim=0) + norm_term
+
+    def neg_log_hyperprior(self) -> torch.Tensor:
+        """
+        Calculate the negative log of the hyperpriors.
+        """
+        # hyperprior for sigma to avoid large or too small sigma
+        # with a standardized input, this hyperprior forces sigma to be
+        # on avg. 1 and it is broad enough to allow for different sigma
+        isigma2 = torch.exp(self.log_ilambda2)[0]
+        neg_log_hyperprior_noise = self.neg_log_gamma(self.log_isigma2, isigma2, self.alpha_sigma, self.beta_sigma)
+        ilambda2 = torch.exp(self.log_ilambda2)[0]
+        neg_log_hyperprior_weights = self.neg_log_gamma(self.log_ilambda2, ilambda2, self.alpha_lambda, self.beta_lambda)
+        return neg_log_hyperprior_noise + neg_log_hyperprior_weights
+
+    def aleatoric_uncertainty(self) -> torch.Tensor:
+        """
+            Get the aleatoric component of the uncertainty.
+        """
+        #return 0
+        return torch.exp(-0.5*self.log_isigma2[0])
+
+    def w_precision(self) -> torch.Tensor:
+        """
+            Get the weights precision.
+        """
+        return torch.exp(self.log_ilambda2[0])
+
+class BNNModel(RegressorMixin, BaseEstimator):
+    """
+    Regression model with uncertainties.
+
+    Args:
+    """
+    def __init__(self, state_dict=None):
+        if state_dict is not None:
+            Nx = state_dict["model.0.weight_mu"].shape[1]
+            Ny = state_dict["model.2.weight_mu"].shape[0]
+            self.model = BNN(Nx, Ny)
+            self.model.load_state_dict(state_dict)
+        else:
+            self.model = BNN()
+        self.model.eval()
+
+    def state_dict(self) -> Dict[str, Any]:
+        return self.model.state_dict()
+
+    def fit(self, X: np.ndarray, y: np.ndarray, weights: Optional[np.ndarray]=None, **fit_params) -> RegressorMixin:
+        """
+        Perform the fit and evaluate uncertainties with the test set.
+
+        Args:
+          X: The input.
+          y: The target.
+          weights: The weights.
+          fit_params: If it contains X_test and y_test, they are used to validate the model.
+
+        Returns: The object itself.
+        """
+        if weights is None:
+            weights = np.ones(len(X), dtype=np.float32)
+        if len(weights.shape) == 1:
+            weights = weights[:, np.newaxis]
+
+        ds = TensorDataset(torch.from_numpy(X),
+                           torch.from_numpy(y),
+                           torch.from_numpy(weights))
+
+        # create model
+        self.model = BNN(X.shape[1], y.shape[1])
+
+        # prepare data loader
+        B = 5
+        loader = DataLoader(ds,
+                            batch_size=B,
+                            num_workers=5,
+                            shuffle=True,
+                            #pin_memory=True,
+                            drop_last=True,
+                            )
+        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
+        number_of_batches = len(ds)/float(B)
+        weight_prior = 1.0/float(number_of_batches)
+        # the NLL is divided by the number of batch samples
+        # so divide also the prior losses by the number of batch elements, so that the
+        # function optimized is F/# samples
+        # https://arxiv.org/pdf/1505.05424.pdf
+        weight_prior /= float(B)
+
+        # KL loss
+        kl_loss = bnn.BKLLoss(reduction='sum', last_layer_only=False)
+
+        # train
+        self.model.train()
+        epochs = 200
+        for epoch in range(epochs):
+            meter = {k: AverageMeter(k, ':6.3f')
+                    for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
+            progress = ProgressMeter(
+                            len(loader),
+                            meter.values(),
+                            prefix="Epoch: [{}]".format(epoch))
+            for i, batch in enumerate(loader):
+                x_b, y_b, w_b = batch
+                y_b_pred = self.model(x_b)
+
+                nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
+                nlprior = weight_prior * kl_loss(self.model)
+                nlhyper = weight_prior * self.model.neg_log_hyperprior()
+                loss = nll + nlprior + nlhyper
+
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+
+                meter['loss'].update(loss.detach().cpu().item(), B)
+                meter['-log(lkl)'].update(nll.detach().cpu().item(), B)
+                meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
+                meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
+                meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
+                meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
+
+            progress.display(len(loader))
+        self.model.eval()
+
+        return self
+
+    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
+        """
+        Predict y from X.
+
+        Args:
+          X: Input dataset.
+
+        Returns: Predicted Y and, if return_std is True, also its uncertainty.
+        """
+        K = 10
+        y_pred = list()
+        for _ in range(K):
+            y_k = self.model(torch.from_numpy(X)).detach().numpy()
+            y_pred.append(y_k)
+        y_pred = np.stack(y_pred, axis=1)
+        y_mu = np.mean(y_pred, axis=1)
+        y_epi = np.std(y_pred, axis=1)
+        y_ale = self.model.aleatoric_uncertainty().detach().numpy()
+        y_unc = (y_epi**2 + y_ale**2)**0.5
+        if not return_std:
+            return y_mu
+        return y_mu, y_unc
+
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 2e1d360af294006dac8bce4253efad73ce74460c..214ce31125dad632724c6ef8e71ae1d48247b78e 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -4,262 +4,23 @@ import joblib
 
 import numpy as np
 from scipy.signal import fftconvolve
-#from scipy.signal import find_peaks_cwt
-from scipy.optimize import fmin_l_bfgs_b
-from sklearn.decomposition import PCA
+from sklearn.decomposition import IncrementalPCA, PCA
 from sklearn.base import TransformerMixin, BaseEstimator
-from sklearn.base import RegressorMixin
 from sklearn.base import OutlierMixin
 from sklearn.pipeline import Pipeline
 from sklearn.kernel_approximation import Nystroem
-#from sklearn.linear_model import ARDRegression
 from sklearn.linear_model import BayesianRidge
-#from sklearn.covariance import EllipticEnvelope
+from sklearn.metrics import accuracy_score
 from scipy.stats import gaussian_kde
-from functools import reduce
 from itertools import product
-from time import time_ns
 
 from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
 from copy import deepcopy
 
-from typing import Any, Dict, List, Optional, Union, Tuple
-
-# previous method with a single matrix, but using gradient descent
-# ARDRegression to be used instead
-# advantages:
-# * parallelization
-# * using evidence maximization as an iterative method
-# * automatic relevance determination to reduce overtraining
-#
-from autograd import numpy as anp
-from autograd import grad
-class FitModel(RegressorMixin, BaseEstimator):
-    """
-    Linear regression model with uncertainties.
-
-    Args:
-    """
-    def __init__(self):
-        # model parameter sizes
-        self.Nx: int = 0
-        self.Ny: int = 0
-
-        # fit result
-        self.pars: Optional[Dict[str, np.ndarray]] = None
-        self.structure: Optional[Dict[str, Tuple[int, int]]] = None
+from pes_to_spec.bnn import BNNModel
 
-        # fit monitoring
-        self.loss_train: List[float] = list()
-        self.loss_test: List[float] = list()
-
-        self.nll_train: Optional[np.ndarray] = None
-        self.nll_train_expected: Optional[np.ndarray] = None
-
-    def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin:
-        """
-        Perform the fit and evaluate uncertainties with the test set.
-
-        Args:
-          X: The input.
-          y: The target.
-          fit_params: If it contains X_test and y_test, they are used to validate the model.
-
-        Returns: The object itself.
-        """
-        if 'X_test' in fit_params and 'y_test' in fit_params:
-            X_test = fit_params['X_test']
-            y_test = fit_params['y_test']
-        else:
-            X_test = X
-            y_test = y
-
-        # model parameter sizes
-        self.Nx: int = int(X.shape[1])
-        self.Ny: int = int(y.shape[1])
-
-        # initial parameter values
-        self.structure = dict(A_inf=(self.Nx, self.Ny),
-                              b_inf=(1, self.Ny),
-                              Ap_inf=(self.Nx, self.Ny),
-                              log_inv_alpha=(1, self.Ny),
-                              log_inv_alpha_prime=(self.Nx, 1),
-                              #log_inv_alpha_prime2=(self.Nx, 1),
-                              #log_inv_tau1=(1, 1),
-                              #log_inv_tau2=(1, 1),
-                             )
-        # initialize close to the solution
-        # pinv(X) @ y solves the problem Ax = y
-        # assume b is zero, since both X and y are mean subtracted after the PCA
-        # assume a small noise uncertainty in alpha and in tau
-        init_method = dict(A_inf=lambda shape: np.linalg.pinv(X) @ (y - np.mean(y, axis=0, keepdims=True)),
-                           b_inf=lambda shape: np.mean(y, axis=0, keepdims=True),
-                           Ap_inf=lambda shape: np.zeros(shape),
-                           log_inv_alpha=lambda shape: 1.0*np.ones(shape),
-                           log_inv_alpha_prime=lambda shape: 1.0*np.ones(shape),
-                           #log_inv_tau1=lambda shape: 1.0*np.ones(shape),
-                           #log_inv_tau2=lambda shape: 1.0*np.ones(shape),
-                          )
-        x0: np.ndarray = FitModel.get_pars_init(self.structure, init_method)
-
-        # reset loss monitoring
-        self.loss_train: List[float] = list()
-        self.loss_test: List[float] = list()
-
-        def loss(x: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float:
-            """
-            Calculate the loss function value for a given parameter set `x`.
-
-            Args:
-              x: The parameters as a flat array.
-              X: The independent-variable dataset.
-              Y: The dependent-variable dataset.
-
-            Returns: The loss.
-            """
-            p = FitModel.get_pars(x, self.structure)
-            return anp.mean(self.nll(X, Y, pars=p), axis=0)
-
-        def loss_history(x: np.ndarray) -> float:
-            """
-            Calculate the loss function and keep a history of it in training and testing.
-
-            Args:
-              x: Parameters flattened out.
-
-            Returns: The loss value.
-            """
-            l_train = loss(x, X, y)
-            l_test = loss(x, X_test, y_test)
-
-            self.loss_train += [l_train]
-            self.loss_test += [l_test]
-            return l_train
-
-        def loss_train(x: np.ndarray) -> float:
-            """
-            Calculate the loss function for the training dataset.
-
-            Args:
-              x: Parameters flattened out.
-
-            Returns: The loss value.
-            """
-            l_train = loss(x, X, y)
-            return l_train
-
-        grad_loss = grad(loss_train)
-        sc_op = fmin_l_bfgs_b(loss_history,
-                              x0,
-                              grad_loss,
-                              disp=True,
-                              factr=1e7,
-                              #factr=10,
-                              maxiter=10000,
-                              iprint=0)
-
-        # Inference
-        self.pars = FitModel.get_pars(sc_op[0], self.structure)
-        self.nll_train = sc_op[1]
-        self.nll_train_expected = np.mean(self.nll(X, pars=self.pars), axis=0, keepdims=True)
-        return self
-
-    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
-        """
-        Predict y from X.
-
-        Args:
-          X: Input dataset.
-
-        Returns: Predicted Y and, if return_std is True, also its uncertainty.
-        """
-        # result
-        A = self.pars["A_inf"]
-        b = self.pars["b_inf"]
-        X2 = anp.square(X)
-        Ap = self.pars["Ap_inf"]
-        y = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
-        if not return_std:
-            return y
-        # input-dependent uncertainty
-        log_inv_alpha = anp.matmul(X, self.pars["log_inv_alpha_prime"]) + self.pars["log_inv_alpha"]
-        sqrt_inv_alpha = anp.exp(0.5*log_inv_alpha)
-        return y, sqrt_inv_alpha
-
-    @staticmethod
-    def get_pars(x: np.ndarray, structure: Dict[str, Tuple[int, int]]) -> Dict[str, np.ndarray]:
-        pars = dict()
-        s = 0
-        for key, value in structure.items():
-            n = value[0]*value[1]
-            e = s + n
-            pars[key] = x[s:e].reshape(value)
-            s += n
-        return pars
-
-    @staticmethod
-    def get_pars_size(structure: Dict[str, Tuple[int, int]]) -> int:
-        size = [value[0]*value[1] for _, value in structure.items()]
-        return reduce(lambda x, y: x*y, size)
-
-    @staticmethod
-    def get_pars_init(structure: Dict[str, Tuple[int, int]], init_method: Optional[Dict[str, Any]]=dict()) -> int:
-        init = {key: np.zeros((value[0], value[1])).reshape(-1) if not key in init_method
-                else init_method[key](value).reshape(-1)
-                for key, value in structure.items()}
-        return np.concatenate(list(init.values()), axis=0)
-
-    def nll(self, X: np.ndarray, Y: Optional[np.ndarray]=None, pars: Optional[Dict[str, np.ndarray]]=None) -> np.ndarray:
-        """
-        To estimate p(M|X) = int_Y p(M|X,Y)p(Y) dY, we assume p(Y) is Normal(mean(X), std(X)).
-        p(M|X,Y=Normal(mu(X), var(X))) = 1/2b exp(-(mu(X)-mu(X))/b) = 1/2b.
-        -log p = log(2) + log(b)
-        We return -log [p(M|X,Y=mu(X))/p(M|X_train,Y_train)] = -log p(M|X,Y=mu(X)) + log p(M|X_train, Y_traun)
-        Negative log likelihood L allows one
-
-        Args:
-          X: The input data.
-          Y: The true result. If None, set it to the expectation.
-
-        Returns: negative log likelihood.
-        """
-        if pars is None:
-            pars = self.pars
-
-        A = pars["A_inf"]
-        b = pars["b_inf"]
-        X2 = anp.square(X)
-        Ap = pars["Ap_inf"]
-        Y_pred = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
-
-        log_inv_alpha = anp.matmul(X, pars["log_inv_alpha_prime"]) + pars["log_inv_alpha"]
-        sqrt_alpha = anp.exp(-0.5*log_inv_alpha)
-        #log_inv_tau1 = pars["log_inv_tau1"]
-        #tau1 = anp.exp(-log_inv_tau1)
-        #log_inv_tau2 = pars["log_inv_tau2"]
-        #tau2 = anp.exp(-log_inv_tau2)
-        if Y is None:
-            Y = self.predict(X)
-        # likelihood = p(D|x, y, A, b) = 1/sqrt(2 pi sigma^2) exp(-0.5*(Ax + b - y)/sigma^2)
-        # sigma is modelled as exp(A_e x + b_e) to make the aleatoric uncertainty data-dependent
-        L = anp.sum((anp.fabs(Y_pred - Y))*sqrt_alpha + 0.5*log_inv_alpha, axis=1)
-        # prior p(A_inf) p(b_inf) = p(A_inf) cte. (assume all b equally likely)
-        # A has normal prior with var = 1/tau
-        # 1/sqrt(2pi)1/sqrt(sigma**2) exp(-0.5 A**2/sigma**2)
-        # log p = -0.5 A **2/sigma**2 - 0.5 log sigma**2 - 0.5 log 2pi
-        # - log p = 0.5 A**2/sigma**2 + 0.5 log sigma**2 + 0.5 log 2pi
-        # 0.5 log sigma**2 = 0.5 log 1/tau
-        #L_prior = anp.sum(0.5*anp.square(pars["A_inf"].ravel())*tau.ravel() + 0.5*log_inv_tau.ravel())
-        #L_prior = (anp.sum(0.5*anp.square(A.ravel())*tau1 + 0.5*log_inv_tau1)
-        #           + anp.sum(0.5*anp.square(Ap.ravel())*tau2 + 0.5*log_inv_tau2)
-        #          )
-        #alpha = 2.0
-        #beta = 2.0
-        #L_hyper = anp.sum((alpha - 1)*log_inv_tau1 + beta*tau1
-        #           + (alpha - 1)*log_inv_tau2 + beta*tau2
-        #          )
-        return L
+from typing import Any, Dict, List, Optional, Union, Tuple
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
@@ -478,7 +239,10 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         self.mean = dict()
         self.std = dict()
 
-    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
+    def transform(self, X: Dict[str, np.ndarray],
+                  keep_dictionary_structure: bool=False,
+                  pulse_spacing: Optional[Dict[str, List[int]]]=None,
+                  pulse_energy: Optional[np.ndarray]=None) -> Union[np.ndarray, Dict[str, np.ndarray]]:
         """
         Get a dictionary with the channel names for the inut low resolution data and output
         only the relevant input data in an array.
@@ -488,6 +252,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
              where i is a number between 1 and 4 and k is a letter between A and D.
           keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary.
           pulse_spacing: Distances between pulses in multi-pulse data. If there is only one pulse, set it to a list containing only the element zero.
+          pulse_energy: Pulse energy.
 
         Returns: Concatenated and pre-processed low-resolution data of shape (train_id, pulse_id, features).
         """
@@ -503,8 +268,10 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
                  for channel, item in X.items()}
         if not keep_dictionary_structure:
             selected = list(y.values())
-            if self.poly:
-                selected += [np.sqrt(np.fabs(v)) for v in y.values()]
+            if pulse_energy is not None:
+                selected += [pulse_energy[:, np.newaxis, :]]
+                if self.poly:
+                    selected += [pulse_energy[:, np.newaxis, :]*v for v in y.values()]
             return np.concatenate(selected, axis=-1)
         return y
 
@@ -750,34 +517,24 @@ class Model(TransformerMixin, BaseEstimator):
                  Set to None to perform no selection.
       validation_size: Fraction (number between 0 and 1) of the data to take for
                        validation and systematic uncertainty estimate.
-      n_nonlinear_kernel: Number of nonlinear kernel components added at the preprocessing stage
-                       to obtain nonlinearities as an input and improve the prediction.
-      poly: Whether to use a polynomial expantion of the low-resolution data.
+      bnn: Use BNN?
 
     """
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=400,
+                 n_pca_lr: int=600,
                  n_pca_hr: int=20,
                  high_res_sigma: float=0.2,
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
                  validation_size: float=0.05,
-                 n_nonlinear_kernel: int=0,
-                 poly: bool=False,
+                 bnn: bool=True,
                 ):
         self.high_res_sigma = high_res_sigma
         # models
-        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=poly)
+        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=not bnn)
         x_model_steps = list()
-        self.n_nonlinear_kernel = n_nonlinear_kernel
-        if n_nonlinear_kernel > 0:
-            # Kernel PCA using Nystroem
-            x_model_steps += [('fex', Pipeline([('prepca', PCA(n_pca_lr, whiten=True)),
-                                                ('nystroem', Nystroem(n_components=n_nonlinear_kernel, kernel='rbf', gamma=None, n_jobs=8)),
-                                                ])),
-                             ]
         x_model_steps += [
                           ('pca', PCA(n_pca_lr, whiten=True)),
                           ('unc', UncertaintyHolder()),
@@ -788,28 +545,22 @@ class Model(TransformerMixin, BaseEstimator):
                                 ('pca', PCA(n_pca_hr, whiten=True)),
                                 ('unc', UncertaintyHolder()),
                                 ])
-        #self.ood = {ch: IsolationForest(n_jobs=-1)
-        #            for ch in channels+['full']}
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
-        #self.ood = {ch: EllipticEnvelope(contamination=0.003)
-        #            for ch in channels+['full']}
-        #self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        #self.fit_model = FitModel()
+        if bnn:
+            self.fit_model = BNNModel()
+        else:
+            self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
+        self.bnn = bnn
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
         self.sigma_xgm = np.nan
-        self.kde_intensity = None
-        self.mu_intensity = np.nan
-        self.sigma_intensity = np.nan
 
         # we are reducing per channel from 2*delta_tof to delta_tof to check correlations
         n_pca_lr_per_channel = delta_tof
-        self.channel_pca = {ch: PCA(n_pca_lr_per_channel, whiten=True)
+        self.channel_pca = {ch: IncrementalPCA(n_pca_lr_per_channel, whiten=True)
                             for ch in channels}
-        #self.channel_fit_model = {ch: FitModel() for ch in channels}
 
         # size of the test subset
         self.validation_size = validation_size
@@ -872,7 +623,8 @@ class Model(TransformerMixin, BaseEstimator):
 
     def fit(self, low_res_data: Dict[str, np.ndarray],
             high_res_data: np.ndarray, high_res_photon_energy: np.ndarray,
-            weights: Optional[np.ndarray]=None
+            weights: Optional[np.ndarray]=None,
+            pulse_energy: Optional[np.ndarray]=None,
             ) -> np.ndarray:
         """
         Train the model.
@@ -891,7 +643,8 @@ class Model(TransformerMixin, BaseEstimator):
         if weights is None:
             weights = np.ones(high_res_data.shape[0])
         print("Fitting PCA on low-resolution data.")
-        low_res_select = self.x_select.fit_transform(low_res_data)
+        self.x_select.fit(low_res_data)
+        low_res_select = self.x_select.transform(low_res_data, pulse_energy=pulse_energy)
         # keep the number of pulses
         B, P, _ = low_res_select.shape
         low_res_select = low_res_select.reshape((B*P, -1))
@@ -902,7 +655,6 @@ class Model(TransformerMixin, BaseEstimator):
         print("Fitting PCA on high-resolution data.")
         y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
 
-        #self.fit_model.set_params(fex__gamma=1.0/float(x_t.shape[0]))
         print("Fitting outlier detection")
         self.ood['full'].fit(x_t)
         inliers = self.ood['full'].predict(x_t) > 0.0
@@ -1015,12 +767,6 @@ class Model(TransformerMixin, BaseEstimator):
         self.resolution = np.sqrt(energy_var)
         #print("Resolution:", self.resolution)
 
-        # get intensity effect
-        intensity = np.sum(z, axis=1)
-        self.kde_intensity = gaussian_kde(intensity.reshape(-1), bw_method="scott")
-        self.mu_intensity = np.mean(intensity.reshape(-1), axis=0)
-        self.sigma_intensity = np.std(intensity.reshape(-1), axis=0)
-
         # for consistency check per channel
         selection_model = self.x_select
         low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
@@ -1057,7 +803,7 @@ class Model(TransformerMixin, BaseEstimator):
         result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels}
         return result
 
-    def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
+    def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> np.ndarray:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
         using a robust covariance matrix estimate of the data
@@ -1065,10 +811,11 @@ class Model(TransformerMixin, BaseEstimator):
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
           pulse_spacing: The pulse spacing in multi-pulse data.
+          beam_intensity: Beam intensity.
 
         Returns: An outlier score: if it is greater than 0, this is an outlier.
         """
-        low_res = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
+        low_res = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing, pulse_energy=pulse_energy)
         B, P, _ = low_res.shape
         pca_model = self.x_model
         low_pca = pca_model.transform(low_res.reshape((B*P, -1)))
@@ -1078,11 +825,7 @@ class Model(TransformerMixin, BaseEstimator):
         """Get KDE for the XGM intensity."""
         return self.kde_xgm
 
-    def intensity_profile(self) -> gaussian_kde:
-        """Get KDE for the predicted intensity."""
-        return self.kde_intensity
-
-    def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]:
+    def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> Dict[str, np.ndarray]:
         """
         Predict a high-resolution spectrum from a low resolution given one.
         The output includes the uncertainty in its second and third entries of the first dimension.
@@ -1100,7 +843,7 @@ class Model(TransformerMixin, BaseEstimator):
         #t += [time_ns()*1e-9]
         #n += ["Initial"]
 
-        low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
+        low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing, pulse_energy=pulse_energy)
         B, P, _ = low_res_pre.shape
         low_res_pre = low_res_pre.reshape((B*P, -1))
         #t += [time_ns()*1e-9]
@@ -1170,11 +913,10 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
-                     self.fit_model,
+                     self.fit_model.state_dict() if self.bnn else self.fit_model,
                      self.channel_pca,
                      #self.channel_fit_model
-                     DataHolder(dict(mu_intensity=self.mu_intensity,
-                                     sigma_intensity=self.sigma_intensity,
+                     DataHolder(dict(
                                      mu_xgm=self.mu_xgm,
                                      sigma_xgm=self.sigma_xgm,
                                      wiener_filter_ft=self.wiener_filter_ft,
@@ -1184,11 +926,11 @@ class Model(TransformerMixin, BaseEstimator):
                                      resolution=self.resolution,
                                      transfer_function=self.transfer_function,
                                      impulse_response=self.impulse_response,
+                                     bnn=self.bnn,
                                     )
                                ),
                      self.ood,
                      self.kde_xgm,
-                     self.kde_intensity,
                      ], filename, compress='zlib')
 
     @staticmethod
@@ -1204,26 +946,14 @@ class Model(TransformerMixin, BaseEstimator):
         (x_select,
          x_model, y_model, fit_model,
          channel_pca,
-         #channel_fit_model
          extra,
          ood,
          kde_xgm,
-         kde_intensity,
         ) = joblib.load(filename)
+
         obj = Model()
-        obj.x_select = x_select
-        obj.x_model = x_model
-        obj.y_model = y_model
-        obj.fit_model = fit_model
-        obj.channel_pca = channel_pca
-        #obj.channel_fit_model = channel_fit_model
-        obj.ood = ood
-        obj.kde_xgm = kde_xgm
-        obj.kde_intensity = kde_intensity
 
         extra = extra.get_data()
-        obj.mu_intensity = extra["mu_intensity"]
-        obj.sigma_intensity = extra["sigma_intensity"]
         obj.mu_xgm = extra["mu_xgm"]
         obj.sigma_xgm = extra["sigma_xgm"]
         obj.wiener_filter_ft = extra["wiener_filter_ft"]
@@ -1233,5 +963,19 @@ class Model(TransformerMixin, BaseEstimator):
         obj.resolution = extra["resolution"]
         obj.transfer_function = extra["transfer_function"]
         obj.impulse_response = extra["impulse_response"]
+        obj.bnn = extra["bnn"]
+
+        obj.x_select = x_select
+        obj.x_model = x_model
+        obj.y_model = y_model
+        if obj.bnn:
+            obj.fit_model = BNNModel(state_dict=fit_model)
+        else:
+            obj.fit_model = fit_model
+        obj.channel_pca = channel_pca
+        #obj.channel_fit_model = channel_fit_model
+        obj.ood = ood
+        obj.kde_xgm = kde_xgm
+
         return obj
 
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 34ba12e303534f51bc1e161e24c3472b2a9e398d..5ede9532259921dd4033e2f67dfc5492073a7ac7 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -134,7 +134,8 @@ def main():
     """
     parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.")
     parser.add_argument('-p', '--proposal', type=int, metavar='INT', help='Proposal number', default=2828)
-    parser.add_argument('-r', '--run', type=int, metavar='INT', help='Run number', default=206)
+    parser.add_argument('-r', '--run', type=str, metavar='INT,INT,...', help='Run numbers, comma-separated.', default=206)
+    parser.add_argument('-t', '--test-run', type=int, metavar='INT', help='Run to test', default=None)
     parser.add_argument('-m', '--model', type=str, metavar='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.')
     parser.add_argument('-d', '--directory', type=str, metavar='DIRECTORY', default=".", help='Where to save the results.')
     parser.add_argument('-S', '--spec', type=str, metavar='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name')
@@ -142,13 +143,26 @@ def main():
     parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
     parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
     parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
-    parser.add_argument('-e', '--poly', action="store_true", default=False, help='Wheteher to expand PES data in higher order polynomials.')
+    parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?')
+    parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
 
     args = parser.parse_args()
 
     print("Opening run ...")
+    runs = args.run.split(',')
+    runs = [int(r) for r in runs]
     # get run
-    run = open_run(proposal=args.proposal, run=args.run)
+    run_list = [open_run(proposal=args.proposal, run=r) for r in runs]
+    run = run_list[0]
+    if len(run_list) > 1:
+        run = run.union(*run_list[1:])
+
+    run_test = run
+    other_run_test = False
+    if "test_run" in args and args.test_run is not None:
+        other_run_test = True
+        run_test = open_run(proposal=args.proposal, run=args.test_run)
+
     #run = RunDirectory("/gpfs/exfel/data/scratch/tmichela/data/r0206")
 
     # ----------------Used in the first tests-------------------------
@@ -186,15 +200,25 @@ def main():
 
     # reserve part of it for the test stage
     train_tids = tids[:-10]
-    test_tids = tids[-10:]
+    if other_run_test:
+        pes_tidt = run_test[pes_name, "digitizers.trainId"].ndarray()
+        xgm_tidt = run_test[xgm_name, "data.trainId"].ndarray()
+        spec_tidt = run_test[spec_name, "data.trainId"].ndarray()
+        test_tids = matching_ids(spec_tidt, pes_tidt, xgm_tidt)
+    else:
+        test_tids = tids
 
     # read the PES data for each channel
     channels = [f"channel_{i}_{l}"
                 for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
     pes_raw = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
                for ch in channels}
-    pes_raw_t = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
-               for ch in channels}
+    pes_raw_t = {ch: run_test[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
+                   for ch in channels}
+
+    # select test SPEC data
+    spec_raw_pe_t = run_test[spec_name, "data.photonEnergy"].select_trains(by_id[test_tids - spec_offset]).ndarray()
+    spec_raw_int_t = run_test[spec_name, "data.intensityDistribution"].select_trains(by_id[test_tids - spec_offset]).ndarray()
 
     print("Data in memory.")
 
@@ -203,28 +227,35 @@ def main():
     #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
     #retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
     #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
-    
+
     xgm_flux =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()[:, 0][:, np.newaxis]
-    xgm_flux_t =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[test_tids]).ndarray()[:, 0][:, np.newaxis]
+    xgm_flux_t =  run_test['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[test_tids]).ndarray()[:, 0][:, np.newaxis]
 
     t = list()
     t_names = list()
 
-    model = Model(poly=args.poly)
+    model = Model(bnn=args.bnn)
 
     train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
+    # we just need this for training and we need to avoid copying it, which blows up the memoray usage
+    for k in pes_raw.keys():
+        pes_raw[k] = pes_raw[k][train_idx]
 
     model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png"))
     if len(args.model) == 0:
         print("Fitting")
         start = time_ns()
         w = model.uniformize(xgm_flux[train_idx])
+        if not args.weight:
+            w[...] = 1.0
         print("w", np.amin(w), np.amax(w), np.median(w))
-        model.fit({k: v[train_idx, :]
-                   for k, v in pes_raw.items()},
-                   spec_raw_int[train_idx, :],
-                   spec_raw_pe[train_idx, :],
-                   w
+        model.fit(pes_raw,
+                  #{k: v[train_idx]
+                  #for k, v in pes_raw.items()},
+                   spec_raw_int[train_idx],
+                   spec_raw_pe[train_idx],
+                   w,
+                   pulse_energy=xgm_flux[train_idx],
                    )
         t += [time_ns() - start]
         t_names += ["Fit"]
@@ -271,7 +302,7 @@ def main():
 
     print("Check consistency")
     start = time_ns()
-    Z = model.check_compatibility(pes_raw_t)
+    Z = model.check_compatibility(pes_raw_t, pulse_energy=xgm_flux_t)
     print("Consistency check:", Z)
     Z = model.check_compatibility_per_channel(pes_raw_t)
     print("Consistency per channel:", Z)
@@ -281,7 +312,7 @@ def main():
     # test
     print("Predict")
     start = time_ns()
-    spec_pred = model.predict(pes_raw)
+    spec_pred = model.predict(pes_raw_t, pulse_energy=xgm_flux_t)
     spec_pred["deconvolved"] = model.deconvolve(spec_pred["expected"])
     t += [time_ns() - start]
     t_names += ["Predict"]
@@ -295,17 +326,16 @@ def main():
     showSpec = False
     if len(args.model) == 0:
         showSpec = True
-        spec_smooth = model.preprocess_high_res(spec_raw_int)
+        spec_smooth = model.preprocess_high_res(spec_raw_int_t)
 
         # chi2 w.r.t XGM intensity
-        erange = spec_raw_pe[0,-1] - spec_raw_pe[0,0]
-        de = (spec_raw_pe[0,1] - spec_raw_pe[0,0])
+        de = (spec_raw_pe_t[0,1] - spec_raw_pe_t[0,0])
         chi2 = np.sum((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=(-1, -2))
         ndof = float(spec_smooth.shape[1]) - 1.0
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        ax.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=20)
+        ax.scatter(chi2/ndof, xgm_flux_t[:,0], c='r', s=20)
         ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
                xlabel=r"$\chi^2/$ndof",
                ylabel="XGM intensity [uJ]",
@@ -315,7 +345,7 @@ def main():
         # Manually set the position and relative size of the inset axes within ax1
         ip = InsetPosition(ax, [0.65,0.6,0.35,0.4])
         ax2.set_axes_locator(ip)
-        ax2.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=30)
+        ax2.scatter(chi2/ndof, xgm_flux_t[:,0], c='r', s=30)
         #ax2.scatter(chi2/ndof, np.sum(spec_pred["expected"], axis=1)*de, c='b', s=30)
         #ax2.scatter(chi2/ndof, np.sum(spec_raw_int, axis=1)*de, c='g', s=30)
         ax2.set(title="",
@@ -352,16 +382,16 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.kdeplot(x=xgm_flux[:,0], ax=ax)
+        sns.kdeplot(x=xgm_flux_t[:,0], ax=ax)
         ax.set(title=f"",
                xlabel="XGM intensity [uJ]",
                ylabel="Density [a.u.]",
                )
-        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux[:,0]):.2f}",
+        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux_t[:,0]):.2f}",
                 verticalalignment='top', horizontalalignment='right',
                 transform=ax.transAxes,
                 color='black', fontsize=15)
-        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux[:,0]):.2f}",
+        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux_t[:,0]):.2f}",
                 verticalalignment='top', horizontalalignment='right',
                 transform=ax.transAxes,
                 color='black', fontsize=15)
@@ -373,7 +403,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        ax.scatter(rmse, xgm_flux[:,0], c='r', s=30)
+        ax.scatter(rmse, xgm_flux_t[:,0], c='r', s=30)
         ax = plt.gca()
         ax.set(title=f"",
                xlabel=r"Root-mean-squared error",
@@ -406,7 +436,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_raw_int_t, axis=1)*de, y=xgm_flux_t[:,0], color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="SPEC (raw) integral",
                ylabel="XGM Intensity [uJ]",
@@ -418,7 +448,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_raw_int, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_raw_int_t, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="SPEC (raw) integral",
                ylabel="Predicted integral",
@@ -429,7 +459,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux_t[:,0], color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="Predicted integral",
                ylabel="XGM intensity [uJ]",
@@ -442,23 +472,19 @@ def main():
     last -= 10
     pes_to_show = 'channel_1_D'
     # plot
-    for tid in test_tids:
-        idx = np.where(tid==tids)[0][0]
+    for idx in range(len(spec_raw_int_t) - 10, len(spec_raw_int_t)):
+        tid = test_tids[idx]
         plot_result(os.path.join(args.directory, f"test_{tid}.png"),
                    {k: item[idx, 0, ...] if k != "pca"
                        else item[0, ...]
                        for k, item in spec_pred.items()},
                     spec_smooth[idx, :] if showSpec else None,
-                    spec_raw_pe[idx, :] if showSpec else None,
-                    spec_raw_int[idx, :] if showSpec else None,
-                    #pes=-pes_raw[pes_to_show][idx, first:last],
-                    #pes_to_show=pes_to_show.replace('_', ' '),
-                    #pes_bin=np.arange(first, last),
-                    #wiener=model.wiener_filter
+                    spec_raw_pe_t[idx, :] if showSpec else None,
+                    spec_raw_int_t[idx, :] if showSpec else None,
                     )
         for ch in channels:
             plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"),
-                     pes_raw[ch][idx, first:last], first, last)
+                     pes_raw_t[ch][idx, first:last], first, last)
 
 if __name__ == '__main__':
     main()
diff --git a/pyproject.toml b/pyproject.toml
index a41df3d5f76768af6c1d4179e804a4e939c0c64f..86ea8d577381b8e30672742abe3d63e0d5ff09b6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,7 +28,8 @@ dependencies = [
           "numpy>=1.21",
           "scipy>=1.6",
           "scikit-learn>=1.2.0",
-          "autograd",
+          "torch",
+          "torchbnn",
           ]
 
 [project.optional-dependencies]