From 4fd3934ff053bfa15d444e1e94832a28b0dc38c8 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 17 Apr 2023 15:17:01 +0200
Subject: [PATCH] Added bnn model.

---
 pes_to_spec/bnn.py                   | 183 +++++++++++++++++
 pes_to_spec/model.py                 | 288 ++-------------------------
 pes_to_spec/test/offline_analysis.py |  88 +++++---
 pyproject.toml                       |   3 +-
 4 files changed, 254 insertions(+), 308 deletions(-)
 create mode 100644 pes_to_spec/bnn.py

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
new file mode 100644
index 0000000..9ebf2a3
--- /dev/null
+++ b/pes_to_spec/bnn.py
@@ -0,0 +1,183 @@
+from sklearn.base import BaseEstimator, RegressorMixin
+from typing import Any, Dict, Optional, Union, Tuple
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torchbnn as bnn
+from torch.utils.data import Dataset, DataLoader
+
+class BNN(nn.Module):
+    """
+        A model Bayesian Neural network.
+        Each weight is represented by a Gaussian with a mean and a standard deviation.
+        Each evaluation of forward leads to a different choice of the weights, so running
+        forward several times we can check the effect of the weights variation on the same input.
+        The nll function implements the negative log likelihood to be used as the first part of the loss
+        function (the second shall be the Kullback-Leibler divergence).
+        The negative log-likelihood is simply the negative log likelihood of a Gaussian
+        between the prediction and the true value. The standard deviation of the Gaussian is left as a
+        parameter to be fit: sigma.
+    """
+    def __init__(self, input_dimension: int=1, output_dimension: int=1):
+        super(BNN, self).__init__()
+        hidden_dimension = 100
+        self.model = nn.Sequential(
+                                   bnn.BayesLinear(prior_mu=0,
+                                                   prior_sigma=0.1,
+                                                   in_features=input_dimension,
+                                                   out_features=hidden_dimension),
+                                   nn.ReLU(),
+                                   bnn.BayesLinear(prior_mu=0,
+                                                   prior_sigma=0.1,
+                                                   in_features=hidden_dimension,
+                                                   out_features=output_dimension)
+                                    )
+        self.log_sigma2 = nn.Parameter(torch.ones(1), requires_grad=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Calculate the result f(x) applied on the input x.
+        """
+        return self.model(x)
+
+    def nll(self, prediction: torch.Tensor, target: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
+        """
+        Calculate the negative log-likelihood (divided by the batch size, since we take the mean).
+        """
+        error = w*(prediction - target)
+        squared_error = error**2
+        #return 0.5*squared_error.mean()
+        sigma2 = torch.exp(self.log_sigma2)[0]
+        norm_error = 0.5*squared_error/sigma2
+        norm_term = 0.5*(np.log(2*np.pi) + self.log_sigma2[0])
+        return norm_error.mean() + norm_term
+
+    def aleatoric_uncertainty(self) -> torch.Tensor:
+        """
+            Get the aleatoric component of the uncertainty.
+        """
+        #return 0
+        return torch.exp(0.5*self.log_sigma2[0])
+
+class BNNDataset(Dataset):
+    def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray):
+        self.x = x
+        self.y = y
+        self.w = w
+        assert len(x) == len(y) and len(x) == len(w)
+    def __len__(self) -> int:
+        """How many samples do I have?"""
+        return len(self.x)
+    def __getitem__(self, idx):
+        return {"x": self.x[idx, :], "y": self.y[idx, :], "w": self.w[idx, :]}
+
+class BNNModel(RegressorMixin, BaseEstimator):
+    """
+    Regression model with uncertainties.
+
+    Args:
+    """
+    def __init__(self, state_dict=None):
+        if state_dict is not None:
+            Nx = state_dict["model.0.weight_mu"].shape[1]
+            Ny = state_dict["model.2.weight_mu"].shape[0]
+            self.model = BNN(Nx, Ny)
+            self.model.load_state_dict(state_dict)
+        else:
+            self.model = BNN()
+        self.model.eval()
+
+    def state_dict(self) -> Dict[str, Any]:
+        return self.model.state_dict()
+
+    def fit(self, X: np.ndarray, y: np.ndarray, weights: Optional[np.ndarray]=None, **fit_params) -> RegressorMixin:
+        """
+        Perform the fit and evaluate uncertainties with the test set.
+
+        Args:
+          X: The input.
+          y: The target.
+          weights: The weights.
+          fit_params: If it contains X_test and y_test, they are used to validate the model.
+
+        Returns: The object itself.
+        """
+        if weights is None:
+            weights = np.ones(len(X), dtype=np.float32)
+        if len(weights.shape) == 1:
+            weights = weights[:, np.newaxis]
+        ds = BNNDataset(X, y, weights)
+
+        # create model
+        self.model = BNN(X.shape[1], y.shape[1])
+
+        # prepare data loader
+        B = 20
+        loader = DataLoader(ds, batch_size=B)
+        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
+        number_of_batches = len(ds)/float(B)
+        weight_kl = 1.0/float(number_of_batches)
+
+        # KL loss
+        kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)
+
+        # train
+        self.model.train()
+        epochs = 1000
+        for epoch in range(epochs):
+            losses = list()
+            nlls = list()
+            priors = list()
+            for batch in loader:
+                x_b = batch["x"]
+                y_b = batch["y"]
+                w_b = batch["w"]
+                y_b_pred = self.model(x_b)
+
+                nll = self.model.nll(y_b_pred, y_b, w_b)
+                prior = weight_kl * kl_loss(self.model)
+                loss = nll + prior
+
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+
+                losses.append(loss.detach().cpu().item())
+                nlls.append(nll.detach().cpu().item())
+                priors.append(prior.detach().cpu().item())
+
+            # monitor
+            ale = self.model.aleatoric_uncertainty().detach().numpy()
+            losses = np.mean(np.array(losses))
+            nlls = np.mean(np.array(nlls))
+            priors = np.mean(np.array(priors))
+            print(f"Epoch {epoch}/{epochs}  total: {losses:.5f}, -LL: {nlls:.5f}, prior: {priors:.5f}, aleatoric unc.: {ale:.5f}")
+        self.model.eval()
+
+        return self
+
+    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
+        """
+        Predict y from X.
+
+        Args:
+          X: Input dataset.
+
+        Returns: Predicted Y and, if return_std is True, also its uncertainty.
+        """
+        K = 10
+        y_pred = list()
+        for _ in range(K):
+            y_k = self.model(torch.from_numpy(X)).detach().numpy()
+            y_pred.append(y_k)
+        y_pred = np.stack(y_pred, axis=1)
+        y_mu = np.mean(y_pred, axis=1)
+        y_epi = np.std(y_pred, axis=1)
+        y_ale = self.model.aleatoric_uncertainty().detach().numpy()
+        y_unc = (y_epi**2 + y_ale**2)**0.5
+        if not return_std:
+            return y_mu
+        return y_mu, y_unc
+
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index e8fe9ef..37a7ae2 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -4,262 +4,23 @@ import joblib
 
 import numpy as np
 from scipy.signal import fftconvolve
-#from scipy.signal import find_peaks_cwt
-from scipy.optimize import fmin_l_bfgs_b
-from sklearn.decomposition import PCA
+from sklearn.decomposition import IncrementalPCA, PCA
 from sklearn.base import TransformerMixin, BaseEstimator
-from sklearn.base import RegressorMixin
 from sklearn.base import OutlierMixin
 from sklearn.pipeline import Pipeline
 from sklearn.kernel_approximation import Nystroem
-#from sklearn.linear_model import ARDRegression
 from sklearn.linear_model import BayesianRidge
-#from sklearn.covariance import EllipticEnvelope
+from sklearn.metrics import accuracy_score
 from scipy.stats import gaussian_kde
-from functools import reduce
 from itertools import product
-from time import time_ns
 
 from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
 from copy import deepcopy
 
-from typing import Any, Dict, List, Optional, Union, Tuple
-
-# previous method with a single matrix, but using gradient descent
-# ARDRegression to be used instead
-# advantages:
-# * parallelization
-# * using evidence maximization as an iterative method
-# * automatic relevance determination to reduce overtraining
-#
-from autograd import numpy as anp
-from autograd import grad
-class FitModel(RegressorMixin, BaseEstimator):
-    """
-    Linear regression model with uncertainties.
-
-    Args:
-    """
-    def __init__(self):
-        # model parameter sizes
-        self.Nx: int = 0
-        self.Ny: int = 0
-
-        # fit result
-        self.pars: Optional[Dict[str, np.ndarray]] = None
-        self.structure: Optional[Dict[str, Tuple[int, int]]] = None
-
-        # fit monitoring
-        self.loss_train: List[float] = list()
-        self.loss_test: List[float] = list()
-
-        self.nll_train: Optional[np.ndarray] = None
-        self.nll_train_expected: Optional[np.ndarray] = None
-
-    def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin:
-        """
-        Perform the fit and evaluate uncertainties with the test set.
-
-        Args:
-          X: The input.
-          y: The target.
-          fit_params: If it contains X_test and y_test, they are used to validate the model.
-
-        Returns: The object itself.
-        """
-        if 'X_test' in fit_params and 'y_test' in fit_params:
-            X_test = fit_params['X_test']
-            y_test = fit_params['y_test']
-        else:
-            X_test = X
-            y_test = y
-
-        # model parameter sizes
-        self.Nx: int = int(X.shape[1])
-        self.Ny: int = int(y.shape[1])
-
-        # initial parameter values
-        self.structure = dict(A_inf=(self.Nx, self.Ny),
-                              b_inf=(1, self.Ny),
-                              Ap_inf=(self.Nx, self.Ny),
-                              log_inv_alpha=(1, self.Ny),
-                              log_inv_alpha_prime=(self.Nx, 1),
-                              #log_inv_alpha_prime2=(self.Nx, 1),
-                              #log_inv_tau1=(1, 1),
-                              #log_inv_tau2=(1, 1),
-                             )
-        # initialize close to the solution
-        # pinv(X) @ y solves the problem Ax = y
-        # assume b is zero, since both X and y are mean subtracted after the PCA
-        # assume a small noise uncertainty in alpha and in tau
-        init_method = dict(A_inf=lambda shape: np.linalg.pinv(X) @ (y - np.mean(y, axis=0, keepdims=True)),
-                           b_inf=lambda shape: np.mean(y, axis=0, keepdims=True),
-                           Ap_inf=lambda shape: np.zeros(shape),
-                           log_inv_alpha=lambda shape: 1.0*np.ones(shape),
-                           log_inv_alpha_prime=lambda shape: 1.0*np.ones(shape),
-                           #log_inv_tau1=lambda shape: 1.0*np.ones(shape),
-                           #log_inv_tau2=lambda shape: 1.0*np.ones(shape),
-                          )
-        x0: np.ndarray = FitModel.get_pars_init(self.structure, init_method)
-
-        # reset loss monitoring
-        self.loss_train: List[float] = list()
-        self.loss_test: List[float] = list()
-
-        def loss(x: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float:
-            """
-            Calculate the loss function value for a given parameter set `x`.
-
-            Args:
-              x: The parameters as a flat array.
-              X: The independent-variable dataset.
-              Y: The dependent-variable dataset.
-
-            Returns: The loss.
-            """
-            p = FitModel.get_pars(x, self.structure)
-            return anp.mean(self.nll(X, Y, pars=p), axis=0)
-
-        def loss_history(x: np.ndarray) -> float:
-            """
-            Calculate the loss function and keep a history of it in training and testing.
-
-            Args:
-              x: Parameters flattened out.
-
-            Returns: The loss value.
-            """
-            l_train = loss(x, X, y)
-            l_test = loss(x, X_test, y_test)
-
-            self.loss_train += [l_train]
-            self.loss_test += [l_test]
-            return l_train
-
-        def loss_train(x: np.ndarray) -> float:
-            """
-            Calculate the loss function for the training dataset.
-
-            Args:
-              x: Parameters flattened out.
-
-            Returns: The loss value.
-            """
-            l_train = loss(x, X, y)
-            return l_train
-
-        grad_loss = grad(loss_train)
-        sc_op = fmin_l_bfgs_b(loss_history,
-                              x0,
-                              grad_loss,
-                              disp=True,
-                              factr=1e7,
-                              #factr=10,
-                              maxiter=10000,
-                              iprint=0)
-
-        # Inference
-        self.pars = FitModel.get_pars(sc_op[0], self.structure)
-        self.nll_train = sc_op[1]
-        self.nll_train_expected = np.mean(self.nll(X, pars=self.pars), axis=0, keepdims=True)
-        return self
-
-    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
-        """
-        Predict y from X.
+from pes_to_spec.bnn import BNNModel
 
-        Args:
-          X: Input dataset.
-
-        Returns: Predicted Y and, if return_std is True, also its uncertainty.
-        """
-        # result
-        A = self.pars["A_inf"]
-        b = self.pars["b_inf"]
-        X2 = anp.square(X)
-        Ap = self.pars["Ap_inf"]
-        y = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
-        if not return_std:
-            return y
-        # input-dependent uncertainty
-        log_inv_alpha = anp.matmul(X, self.pars["log_inv_alpha_prime"]) + self.pars["log_inv_alpha"]
-        sqrt_inv_alpha = anp.exp(0.5*log_inv_alpha)
-        return y, sqrt_inv_alpha
-
-    @staticmethod
-    def get_pars(x: np.ndarray, structure: Dict[str, Tuple[int, int]]) -> Dict[str, np.ndarray]:
-        pars = dict()
-        s = 0
-        for key, value in structure.items():
-            n = value[0]*value[1]
-            e = s + n
-            pars[key] = x[s:e].reshape(value)
-            s += n
-        return pars
-
-    @staticmethod
-    def get_pars_size(structure: Dict[str, Tuple[int, int]]) -> int:
-        size = [value[0]*value[1] for _, value in structure.items()]
-        return reduce(lambda x, y: x*y, size)
-
-    @staticmethod
-    def get_pars_init(structure: Dict[str, Tuple[int, int]], init_method: Optional[Dict[str, Any]]=dict()) -> int:
-        init = {key: np.zeros((value[0], value[1])).reshape(-1) if not key in init_method
-                else init_method[key](value).reshape(-1)
-                for key, value in structure.items()}
-        return np.concatenate(list(init.values()), axis=0)
-
-    def nll(self, X: np.ndarray, Y: Optional[np.ndarray]=None, pars: Optional[Dict[str, np.ndarray]]=None) -> np.ndarray:
-        """
-        To estimate p(M|X) = int_Y p(M|X,Y)p(Y) dY, we assume p(Y) is Normal(mean(X), std(X)).
-        p(M|X,Y=Normal(mu(X), var(X))) = 1/2b exp(-(mu(X)-mu(X))/b) = 1/2b.
-        -log p = log(2) + log(b)
-        We return -log [p(M|X,Y=mu(X))/p(M|X_train,Y_train)] = -log p(M|X,Y=mu(X)) + log p(M|X_train, Y_traun)
-        Negative log likelihood L allows one
-
-        Args:
-          X: The input data.
-          Y: The true result. If None, set it to the expectation.
-
-        Returns: negative log likelihood.
-        """
-        if pars is None:
-            pars = self.pars
-
-        A = pars["A_inf"]
-        b = pars["b_inf"]
-        X2 = anp.square(X)
-        Ap = pars["Ap_inf"]
-        Y_pred = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
-
-        log_inv_alpha = anp.matmul(X, pars["log_inv_alpha_prime"]) + pars["log_inv_alpha"]
-        sqrt_alpha = anp.exp(-0.5*log_inv_alpha)
-        #log_inv_tau1 = pars["log_inv_tau1"]
-        #tau1 = anp.exp(-log_inv_tau1)
-        #log_inv_tau2 = pars["log_inv_tau2"]
-        #tau2 = anp.exp(-log_inv_tau2)
-        if Y is None:
-            Y = self.predict(X)
-        # likelihood = p(D|x, y, A, b) = 1/sqrt(2 pi sigma^2) exp(-0.5*(Ax + b - y)/sigma^2)
-        # sigma is modelled as exp(A_e x + b_e) to make the aleatoric uncertainty data-dependent
-        L = anp.sum((anp.fabs(Y_pred - Y))*sqrt_alpha + 0.5*log_inv_alpha, axis=1)
-        # prior p(A_inf) p(b_inf) = p(A_inf) cte. (assume all b equally likely)
-        # A has normal prior with var = 1/tau
-        # 1/sqrt(2pi)1/sqrt(sigma**2) exp(-0.5 A**2/sigma**2)
-        # log p = -0.5 A **2/sigma**2 - 0.5 log sigma**2 - 0.5 log 2pi
-        # - log p = 0.5 A**2/sigma**2 + 0.5 log sigma**2 + 0.5 log 2pi
-        # 0.5 log sigma**2 = 0.5 log 1/tau
-        #L_prior = anp.sum(0.5*anp.square(pars["A_inf"].ravel())*tau.ravel() + 0.5*log_inv_tau.ravel())
-        #L_prior = (anp.sum(0.5*anp.square(A.ravel())*tau1 + 0.5*log_inv_tau1)
-        #           + anp.sum(0.5*anp.square(Ap.ravel())*tau2 + 0.5*log_inv_tau2)
-        #          )
-        #alpha = 2.0
-        #beta = 2.0
-        #L_hyper = anp.sum((alpha - 1)*log_inv_tau1 + beta*tau1
-        #           + (alpha - 1)*log_inv_tau2 + beta*tau2
-        #          )
-        return L
+from typing import Any, Dict, List, Optional, Union, Tuple
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
@@ -478,7 +239,10 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         self.mean = dict()
         self.std = dict()
 
-    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> np.ndarray:
+    def transform(self, X: Dict[str, np.ndarray],
+                  keep_dictionary_structure: bool=False,
+                  pulse_spacing: Optional[Dict[str, List[int]]]=None,
+                  pulse_energy: Optional[np.ndarray]=None) -> Union[np.ndarray, Dict[str, np.ndarray]]:
         """
         Get a dictionary with the channel names for the inut low resolution data and output
         only the relevant input data in an array.
@@ -761,7 +525,7 @@ class Model(TransformerMixin, BaseEstimator):
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=400,
+                 n_pca_lr: int=600,
                  n_pca_hr: int=20,
                  high_res_sigma: float=0.2,
                  tof_start: Optional[int]=None,
@@ -791,28 +555,19 @@ class Model(TransformerMixin, BaseEstimator):
                                 ('pca', PCA(n_pca_hr, whiten=True)),
                                 ('unc', UncertaintyHolder()),
                                 ])
-        #self.ood = {ch: IsolationForest(n_jobs=-1)
-        #            for ch in channels+['full']}
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
-        #self.ood = {ch: EllipticEnvelope(contamination=0.003)
-        #            for ch in channels+['full']}
-        #self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        #self.fit_model = FitModel()
+        #self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
+        self.fit_model = BNNModel()
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
         self.sigma_xgm = np.nan
-        self.kde_intensity = None
-        self.mu_intensity = np.nan
-        self.sigma_intensity = np.nan
 
         # we are reducing per channel from 2*delta_tof to delta_tof to check correlations
         n_pca_lr_per_channel = delta_tof
-        self.channel_pca = {ch: PCA(n_pca_lr_per_channel, whiten=True)
+        self.channel_pca = {ch: IncrementalPCA(n_pca_lr_per_channel, whiten=True)
                             for ch in channels}
-        #self.channel_fit_model = {ch: FitModel() for ch in channels}
 
         # size of the test subset
         self.validation_size = validation_size
@@ -907,7 +662,6 @@ class Model(TransformerMixin, BaseEstimator):
         print("Fitting PCA on high-resolution data.")
         y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
 
-        #self.fit_model.set_params(fex__gamma=1.0/float(x_t.shape[0]))
         print("Fitting outlier detection")
         self.ood['full'].fit(x_t)
         inliers = self.ood['full'].predict(x_t) > 0.0
@@ -1020,12 +774,6 @@ class Model(TransformerMixin, BaseEstimator):
         self.resolution = np.sqrt(energy_var)
         #print("Resolution:", self.resolution)
 
-        # get intensity effect
-        intensity = np.sum(z, axis=1)
-        self.kde_intensity = gaussian_kde(intensity.reshape(-1), bw_method="scott")
-        self.mu_intensity = np.mean(intensity.reshape(-1), axis=0)
-        self.sigma_intensity = np.std(intensity.reshape(-1), axis=0)
-
         # for consistency check per channel
         selection_model = self.x_select
         low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
@@ -1084,10 +832,6 @@ class Model(TransformerMixin, BaseEstimator):
         """Get KDE for the XGM intensity."""
         return self.kde_xgm
 
-    def intensity_profile(self) -> gaussian_kde:
-        """Get KDE for the predicted intensity."""
-        return self.kde_intensity
-
     def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> Dict[str, np.ndarray]:
         """
         Predict a high-resolution spectrum from a low resolution given one.
@@ -1176,7 +920,7 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
-                     self.fit_model,
+                     self.fit_model.state_dict(),
                      self.channel_pca,
                      #self.channel_fit_model
                      DataHolder(dict(mu_intensity=self.mu_intensity,
@@ -1194,7 +938,6 @@ class Model(TransformerMixin, BaseEstimator):
                                ),
                      self.ood,
                      self.kde_xgm,
-                     self.kde_intensity,
                      ], filename, compress='zlib')
 
     @staticmethod
@@ -1210,22 +953,19 @@ class Model(TransformerMixin, BaseEstimator):
         (x_select,
          x_model, y_model, fit_model,
          channel_pca,
-         #channel_fit_model
          extra,
          ood,
          kde_xgm,
-         kde_intensity,
         ) = joblib.load(filename)
         obj = Model()
         obj.x_select = x_select
         obj.x_model = x_model
         obj.y_model = y_model
-        obj.fit_model = fit_model
+        obj.fit_model = BNNModel(state_dict=fit_model)
         obj.channel_pca = channel_pca
         #obj.channel_fit_model = channel_fit_model
         obj.ood = ood
         obj.kde_xgm = kde_xgm
-        obj.kde_intensity = kde_intensity
 
         extra = extra.get_data()
         obj.mu_intensity = extra["mu_intensity"]
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 78c4771..edd3efd 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -134,7 +134,8 @@ def main():
     """
     parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.")
     parser.add_argument('-p', '--proposal', type=int, metavar='INT', help='Proposal number', default=2828)
-    parser.add_argument('-r', '--run', type=int, metavar='INT', help='Run number', default=206)
+    parser.add_argument('-r', '--run', type=str, metavar='INT,INT,...', help='Run numbers, comma-separated.', default=206)
+    parser.add_argument('-t', '--test-run', type=int, metavar='INT', help='Run to test', default=None)
     parser.add_argument('-m', '--model', type=str, metavar='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.')
     parser.add_argument('-d', '--directory', type=str, metavar='DIRECTORY', default=".", help='Where to save the results.')
     parser.add_argument('-S', '--spec', type=str, metavar='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name')
@@ -147,8 +148,20 @@ def main():
     args = parser.parse_args()
 
     print("Opening run ...")
+    runs = args.run.split(',')
+    runs = [int(r) for r in runs]
     # get run
-    run = open_run(proposal=args.proposal, run=args.run)
+    run_list = [open_run(proposal=args.proposal, run=r) for r in runs]
+    run = run_list[0]
+    if len(run_list) > 1:
+        run = run.union(*run_list[1:])
+
+    run_test = run
+    other_run_test = False
+    if "test_run" in args and args.test_run is not None:
+        other_run_test = True
+        run_test = open_run(proposal=args.proposal, run=args.test_run)
+
     #run = RunDirectory("/gpfs/exfel/data/scratch/tmichela/data/r0206")
 
     # ----------------Used in the first tests-------------------------
@@ -186,15 +199,25 @@ def main():
 
     # reserve part of it for the test stage
     train_tids = tids[:-10]
-    test_tids = tids[-10:]
+    if other_run_test:
+        pes_tidt = run_test[pes_name, "digitizers.trainId"].ndarray()
+        xgm_tidt = run_test[xgm_name, "data.trainId"].ndarray()
+        spec_tidt = run_test[spec_name, "data.trainId"].ndarray()
+        test_tids = matching_ids(spec_tidt, pes_tidt, xgm_tidt)
+    else:
+        test_tids = tids
 
     # read the PES data for each channel
     channels = [f"channel_{i}_{l}"
                 for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
     pes_raw = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
                for ch in channels}
-    pes_raw_t = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
-               for ch in channels}
+    pes_raw_t = {ch: run_test[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
+                   for ch in channels}
+
+    # select test SPEC data
+    spec_raw_pe_t = run_test[spec_name, "data.photonEnergy"].select_trains(by_id[test_tids - spec_offset]).ndarray()
+    spec_raw_int_t = run_test[spec_name, "data.intensityDistribution"].select_trains(by_id[test_tids - spec_offset]).ndarray()
 
     print("Data in memory.")
 
@@ -203,9 +226,9 @@ def main():
     #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
     #retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
     #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
-    
+
     xgm_flux =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()[:, 0][:, np.newaxis]
-    xgm_flux_t =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[test_tids]).ndarray()[:, 0][:, np.newaxis]
+    xgm_flux_t =  run_test['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[test_tids]).ndarray()[:, 0][:, np.newaxis]
 
     t = list()
     t_names = list()
@@ -213,6 +236,9 @@ def main():
     model = Model(poly=args.poly)
 
     train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
+    # we just need this for training and we need to avoid copying it, which blows up the memoray usage
+    for k in pes_raw.keys():
+        pes_raw[k] = pes_raw[k][train_idx]
 
     model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png"))
     if len(args.model) == 0:
@@ -220,10 +246,11 @@ def main():
         start = time_ns()
         w = model.uniformize(xgm_flux[train_idx])
         print("w", np.amin(w), np.amax(w), np.median(w))
-        model.fit({k: v[train_idx, :]
-                   for k, v in pes_raw.items()},
-                   spec_raw_int[train_idx, :],
-                   spec_raw_pe[train_idx, :],
+        model.fit(pes_raw,
+                  #{k: v[train_idx]
+                  #for k, v in pes_raw.items()},
+                   spec_raw_int[train_idx],
+                   spec_raw_pe[train_idx],
                    w,
                    pulse_energy=xgm_flux[train_idx],
                    )
@@ -282,7 +309,7 @@ def main():
     # test
     print("Predict")
     start = time_ns()
-    spec_pred = model.predict(pes_raw, pulse_energy=xgm_flux)
+    spec_pred = model.predict(pes_raw_t, pulse_energy=xgm_flux_t)
     spec_pred["deconvolved"] = model.deconvolve(spec_pred["expected"])
     t += [time_ns() - start]
     t_names += ["Predict"]
@@ -296,17 +323,16 @@ def main():
     showSpec = False
     if len(args.model) == 0:
         showSpec = True
-        spec_smooth = model.preprocess_high_res(spec_raw_int)
+        spec_smooth = model.preprocess_high_res(spec_raw_int_t)
 
         # chi2 w.r.t XGM intensity
-        erange = spec_raw_pe[0,-1] - spec_raw_pe[0,0]
-        de = (spec_raw_pe[0,1] - spec_raw_pe[0,0])
+        de = (spec_raw_pe_t[0,1] - spec_raw_pe_t[0,0])
         chi2 = np.sum((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=(-1, -2))
         ndof = float(spec_smooth.shape[1]) - 1.0
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        ax.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=20)
+        ax.scatter(chi2/ndof, xgm_flux_t[:,0], c='r', s=20)
         ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
                xlabel=r"$\chi^2/$ndof",
                ylabel="XGM intensity [uJ]",
@@ -316,7 +342,7 @@ def main():
         # Manually set the position and relative size of the inset axes within ax1
         ip = InsetPosition(ax, [0.65,0.6,0.35,0.4])
         ax2.set_axes_locator(ip)
-        ax2.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=30)
+        ax2.scatter(chi2/ndof, xgm_flux_t[:,0], c='r', s=30)
         #ax2.scatter(chi2/ndof, np.sum(spec_pred["expected"], axis=1)*de, c='b', s=30)
         #ax2.scatter(chi2/ndof, np.sum(spec_raw_int, axis=1)*de, c='g', s=30)
         ax2.set(title="",
@@ -353,16 +379,16 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.kdeplot(x=xgm_flux[:,0], ax=ax)
+        sns.kdeplot(x=xgm_flux_t[:,0], ax=ax)
         ax.set(title=f"",
                xlabel="XGM intensity [uJ]",
                ylabel="Density [a.u.]",
                )
-        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux[:,0]):.2f}",
+        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux_t[:,0]):.2f}",
                 verticalalignment='top', horizontalalignment='right',
                 transform=ax.transAxes,
                 color='black', fontsize=15)
-        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux[:,0]):.2f}",
+        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux_t[:,0]):.2f}",
                 verticalalignment='top', horizontalalignment='right',
                 transform=ax.transAxes,
                 color='black', fontsize=15)
@@ -374,7 +400,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        ax.scatter(rmse, xgm_flux[:,0], c='r', s=30)
+        ax.scatter(rmse, xgm_flux_t[:,0], c='r', s=30)
         ax = plt.gca()
         ax.set(title=f"",
                xlabel=r"Root-mean-squared error",
@@ -407,7 +433,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_raw_int_t, axis=1)*de, y=xgm_flux_t[:,0], color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="SPEC (raw) integral",
                ylabel="XGM Intensity [uJ]",
@@ -419,7 +445,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_raw_int, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_raw_int_t, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="SPEC (raw) integral",
                ylabel="Predicted integral",
@@ -430,7 +456,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux_t[:,0], color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="Predicted integral",
                ylabel="XGM intensity [uJ]",
@@ -443,23 +469,19 @@ def main():
     last -= 10
     pes_to_show = 'channel_1_D'
     # plot
-    for tid in test_tids:
-        idx = np.where(tid==tids)[0][0]
+    for idx in range(len(spec_raw_int_t) - 10, len(spec_raw_int_t)):
+        tid = test_tids[idx]
         plot_result(os.path.join(args.directory, f"test_{tid}.png"),
                    {k: item[idx, 0, ...] if k != "pca"
                        else item[0, ...]
                        for k, item in spec_pred.items()},
                     spec_smooth[idx, :] if showSpec else None,
-                    spec_raw_pe[idx, :] if showSpec else None,
-                    spec_raw_int[idx, :] if showSpec else None,
-                    #pes=-pes_raw[pes_to_show][idx, first:last],
-                    #pes_to_show=pes_to_show.replace('_', ' '),
-                    #pes_bin=np.arange(first, last),
-                    #wiener=model.wiener_filter
+                    spec_raw_pe_t[idx, :] if showSpec else None,
+                    spec_raw_int_t[idx, :] if showSpec else None,
                     )
         for ch in channels:
             plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"),
-                     pes_raw[ch][idx, first:last], first, last)
+                     pes_raw_t[ch][idx, first:last], first, last)
 
 if __name__ == '__main__':
     main()
diff --git a/pyproject.toml b/pyproject.toml
index a41df3d..86ea8d5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,7 +28,8 @@ dependencies = [
           "numpy>=1.21",
           "scipy>=1.6",
           "scikit-learn>=1.2.0",
-          "autograd",
+          "torch",
+          "torchbnn",
           ]
 
 [project.optional-dependencies]
-- 
GitLab