diff --git a/pes_to_spec/__init__.py b/pes_to_spec/__init__.py
index 9656fa30cb220100277700f0a41479a478cb4c82..d155d94123fbd7800c5e78b3d2a846a599a1b5af 100644
--- a/pes_to_spec/__init__.py
+++ b/pes_to_spec/__init__.py
@@ -2,4 +2,4 @@
 Estimate high-resolution photon spectrometer data from low-resolution non-invasive measurements.
 """
 
-VERSION = "0.0.3"
+VERSION = "0.0.5"
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 891637ef43121b40213b97a3ba48a449cbf76afc..8017ed551329183bc56b2f4aeb0d172ef9da91de 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -3,25 +3,28 @@ from __future__ import annotations
 import joblib
 
 import numpy as np
+import scipy.stats
 from scipy.signal import fftconvolve
-from scipy.signal import find_peaks_cwt
+#from scipy.signal import find_peaks_cwt
 from scipy.optimize import fmin_l_bfgs_b
-from sklearn.decomposition import PCA, IncrementalPCA
+from sklearn.decomposition import PCA
 from sklearn.pipeline import Pipeline, FeatureUnion
+from sklearn.preprocessing import FunctionTransformer
 from sklearn.base import TransformerMixin, BaseEstimator
 from sklearn.base import RegressorMixin
 from sklearn.kernel_approximation import Nystroem
+from sklearn.preprocessing import PolynomialFeatures
 from sklearn.linear_model import ARDRegression
-#from sklearn.svm import LinearSVR
-#from sklearn.gaussian_process import GaussianProcessRegressor
-#from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
+from sklearn.linear_model import BayesianRidge
+from sklearn.neighbors import KernelDensity
+from sklearn.ensemble import IsolationForest
+#from sklearn.neighbors import LocalOutlierFactor
+#from sklearn.covariance import EllipticEnvelope
+from functools import reduce
 from itertools import product
-from sklearn.model_selection import train_test_split
 
 from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
-from functools import partial
-import multiprocessing as mp
 from copy import deepcopy
 
 from typing import Any, Dict, List, Optional, Union, Tuple
@@ -33,160 +36,232 @@ from typing import Any, Dict, List, Optional, Union, Tuple
 # * using evidence maximization as an iterative method
 # * automatic relevance determination to reduce overtraining
 #
-#from autograd import numpy as anp
-#from autograd import grad
-#class FitModel(RegressorMixin, BaseEstimator):
-#    """
-#    Linear regression model with uncertainties.
-#
-#    Args:
-#      l: Regularization coefficient.
-#    """
-#    def __init__(self, l: float=1e-6):
-#        self.l = l
-#
-#        # model parameter sizes
-#        self.Nx: int = 0
-#        self.Ny: int = 0
-#
-#        # fit result
-#        self.A_inf: Optional[np.ndarray] = None
-#        self.b_inf: Optional[np.ndarray] = None
-#        self.u_inf: Optional[np.ndarray] = None
-#
-#        # fit monitoring
-#        self.loss_train: List[float] = list()
-#        self.loss_test: List[float] = list()
-#
-#        self.input_data = None
-#
-#    def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin:
-#        """
-#        Perform the fit and evaluate uncertainties with the test set.
-#
-#        Args:
-#          X: The input.
-#          y: The target.
-#          fit_params: If it contains X_test and y_test, they are used to validate the model.
-#
-#        Returns: The object itself.
-#        """
-#        if 'X_test' in fit_params and 'y_test' in fit_params:
-#            X_test = fit_params['X_test']
-#            y_test = fit_params['y_test']
-#        else:
-#            X_test = X
-#            y_test = y
-#
-#        # model parameter sizes
-#        self.Nx: int = int(X.shape[1])
-#        self.Ny: int = int(y.shape[1])
-#
-#        # initial parameter values
-#        A0: np.ndarray = np.eye(self.Nx, self.Ny).reshape(self.Nx*self.Ny)
-#        b0: np.ndarray = np.zeros(self.Ny)
-#        Aeps: np.ndarray = np.zeros(self.Nx)
-#        u0: np.ndarray = np.zeros(self.Ny)
-#        x0: np.ndarray = np.concatenate((A0, b0, u0, Aeps), axis=0)
-#
-#        # reset loss monitoring
-#        self.loss_train: List[float] = list()
-#        self.loss_test: List[float] = list()
-#
-#        def loss(x: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float:
-#            """
-#            Calculate the loss function value for a given parameter set `x`.
-#
-#            Args:
-#              x: The parameters as a flat array.
-#              X: The independent-variable dataset.
-#              Y: The dependent-variable dataset.
-#
-#            Returns: The loss.
-#            """
-#            # diag( (in @ x - out) @ (in @ x - out)^T )
-#            A = x[:self.Nx*self.Ny].reshape((self.Nx, self.Ny))
-#            b = x[self.Nx*self.Ny:(self.Nx*self.Ny+self.Ny)].reshape((1, self.Ny))
-#
-#            b_eps = x[(self.Nx*self.Ny+self.Ny):(self.Nx*self.Ny+self.Ny+self.Ny)].reshape((1, self.Ny))
-#            A_eps = x[(self.Nx*self.Ny+self.Ny+self.Ny):].reshape((self.Nx, 1))
-#            log_unc = anp.matmul(X, A_eps) + b_eps
-#
-#            #log_unc = anp.log(anp.exp(log_unc) + anp.exp(log_eps))
-#            iunc2 = anp.exp(-2*log_unc)
-#
-#            L = anp.mean( (0.5*((anp.matmul(X, A) + b - Y)**2)*iunc2 + log_unc).sum(axis=1), axis=0 )
-#            weights2 = (anp.sum(anp.square(A.ravel()))
-#                        #+ anp.sum(anp.square(b.ravel()))
-#                        #+ anp.sum(anp.square(A_eps.ravel()))
-#                        #+ anp.sum(anp.square(b_eps.ravel()))
-#                        )
-#            return L + self.l/2 * weights2
-#
-#        def loss_history(x: np.ndarray) -> float:
-#            """
-#            Calculate the loss function and keep a history of it in training and testing.
-#
-#            Args:
-#              x: Parameters flattened out.
-#
-#            Returns: The loss value.
-#            """
-#            l_train = loss(x, X, y)
-#            l_test = loss(x, X_test, y_test)
-#
-#            self.loss_train += [l_train]
-#            self.loss_test += [l_test]
-#            return l_train
-#
-#        def loss_train(x: np.ndarray) -> float:
-#            """
-#            Calculate the loss function for the training dataset.
-#
-#            Args:
-#              x: Parameters flattened out.
-#
-#            Returns: The loss value.
-#            """
-#            l_train = loss(x, X, y)
-#            return l_train
-#
-#        grad_loss = grad(loss_train)
-#        sc_op = fmin_l_bfgs_b(loss_history,
-#                              x0,
-#                              grad_loss,
-#                              disp=True,
-#                              factr=1e7,
-#                              maxiter=100,
-#                              iprint=0)
-#
-#        # Inference
-#        self.A_inf = sc_op[0][:self.Nx*self.Ny].reshape(self.Nx, self.Ny)
-#        self.b_inf = sc_op[0][self.Nx*self.Ny:(self.Nx*self.Ny+self.Ny)].reshape(1, self.Ny)
-#        self.u_inf = sc_op[0][(self.Nx*self.Ny+self.Ny):(self.Nx*self.Ny+self.Ny+self.Ny)].reshape(1, self.Ny) # removed np.exp
-#        self.A_eps = sc_op[0][(self.Nx*self.Ny+self.Ny+self.Ny):].reshape(self.Nx, 1)
-#
-#    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
-#        """
-#        Predict y from X.
-#
-#        Args:
-#          X: Input dataset.
-#
-#        Returns: Predicted Y and, if return_std is True, also its uncertainty.
-#        """
-#
-#        # result
-#        y = (X @ self.A_inf + self.b_inf)
-#        if not return_std:
-#            return y
-#        # flat uncertainty
-#        y_unc = self.u_inf[0,:]
-#        # input-dependent uncertainty
-#        y_eps = np.exp(X @ self.A_eps + y_unc)
-#        return y, y_eps
+from autograd import numpy as anp
+from autograd import grad
+class FitModel(RegressorMixin, BaseEstimator):
+    """
+    Linear regression model with uncertainties.
 
+    Args:
+    """
+    def __init__(self):
+        # model parameter sizes
+        self.Nx: int = 0
+        self.Ny: int = 0
 
+        # fit result
+        self.pars: Optional[Dict[str, np.ndarray]] = None
+        self.structure: Optional[Dict[str, Tuple[int, int]]] = None
+        
+        # fit monitoring
+        self.loss_train: List[float] = list()
+        self.loss_test: List[float] = list()
+        
+        self.nll_train: Optional[np.ndarray] = None
+        self.nll_train_expected: Optional[np.ndarray] = None
+
+    def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin:
+        """
+        Perform the fit and evaluate uncertainties with the test set.
+
+        Args:
+          X: The input.
+          y: The target.
+          fit_params: If it contains X_test and y_test, they are used to validate the model.
+
+        Returns: The object itself.
+        """
+        if 'X_test' in fit_params and 'y_test' in fit_params:
+            X_test = fit_params['X_test']
+            y_test = fit_params['y_test']
+        else:
+            X_test = X
+            y_test = y
+
+        # model parameter sizes
+        self.Nx: int = int(X.shape[1])
+        self.Ny: int = int(y.shape[1])
+
+        # initial parameter values
+        self.structure = dict(A_inf=(self.Nx, self.Ny),
+                              b_inf=(1, self.Ny),
+                              Ap_inf=(self.Nx, self.Ny),
+                              log_inv_alpha=(1, self.Ny),
+                              log_inv_alpha_prime=(self.Nx, 1),
+                              #log_inv_alpha_prime2=(self.Nx, 1),
+                              #log_inv_tau1=(1, 1),
+                              #log_inv_tau2=(1, 1),
+                             )
+        # initialize close to the solution
+        # pinv(X) @ y solves the problem Ax = y
+        # assume b is zero, since both X and y are mean subtracted after the PCA
+        # assume a small noise uncertainty in alpha and in tau
+        init_method = dict(A_inf=lambda shape: np.linalg.pinv(X) @ (y - np.mean(y, axis=0, keepdims=True)),
+                           b_inf=lambda shape: np.mean(y, axis=0, keepdims=True),
+                           Ap_inf=lambda shape: np.zeros(shape),
+                           log_inv_alpha=lambda shape: 1.0*np.ones(shape),
+                           log_inv_alpha_prime=lambda shape: 1.0*np.ones(shape),
+                           #log_inv_tau1=lambda shape: 1.0*np.ones(shape),
+                           #log_inv_tau2=lambda shape: 1.0*np.ones(shape),
+                          )
+        x0: np.ndarray = FitModel.get_pars_init(self.structure, init_method)
+
+        # reset loss monitoring
+        self.loss_train: List[float] = list()
+        self.loss_test: List[float] = list()
+
+        def loss(x: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float:
+            """
+            Calculate the loss function value for a given parameter set `x`.
+
+            Args:
+              x: The parameters as a flat array.
+              X: The independent-variable dataset.
+              Y: The dependent-variable dataset.
+
+            Returns: The loss.
+            """
+            p = FitModel.get_pars(x, self.structure)
+            return anp.mean(self.nll(X, Y, pars=p), axis=0)
+
+        def loss_history(x: np.ndarray) -> float:
+            """
+            Calculate the loss function and keep a history of it in training and testing.
+
+            Args:
+              x: Parameters flattened out.
+
+            Returns: The loss value.
+            """
+            l_train = loss(x, X, y)
+            l_test = loss(x, X_test, y_test)
+
+            self.loss_train += [l_train]
+            self.loss_test += [l_test]
+            return l_train
+
+        def loss_train(x: np.ndarray) -> float:
+            """
+            Calculate the loss function for the training dataset.
+
+            Args:
+              x: Parameters flattened out.
+
+            Returns: The loss value.
+            """
+            l_train = loss(x, X, y)
+            return l_train
+
+        grad_loss = grad(loss_train)
+        sc_op = fmin_l_bfgs_b(loss_history,
+                              x0,
+                              grad_loss,
+                              disp=True,
+                              factr=1e7,
+                              #factr=10,
+                              maxiter=10000,
+                              iprint=0)
+
+        # Inference
+        self.pars = FitModel.get_pars(sc_op[0], self.structure)        
+        self.nll_train = sc_op[1]
+        self.nll_train_expected = np.mean(self.nll(X, pars=self.pars), axis=0, keepdims=True)
+
+    def predict(self, X: np.ndarray, return_std: bool=False) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
+        """
+        Predict y from X.
+
+        Args:
+          X: Input dataset.
+
+        Returns: Predicted Y and, if return_std is True, also its uncertainty.
+        """
+        # result
+        A = self.pars["A_inf"]
+        b = self.pars["b_inf"]
+        X2 = anp.square(X)
+        Ap = self.pars["Ap_inf"]
+        y = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
+        if not return_std:
+            return y
+        # input-dependent uncertainty
+        log_inv_alpha = anp.matmul(X, self.pars["log_inv_alpha_prime"]) + self.pars["log_inv_alpha"]
+        sqrt_inv_alpha = anp.exp(0.5*log_inv_alpha)
+        return y, sqrt_inv_alpha
+
+    @staticmethod
+    def get_pars(x: np.ndarray, structure: Dict[str, Tuple[int, int]]) -> Dict[str, np.ndarray]:
+        pars = dict()
+        s = 0
+        for key, value in structure.items():
+            n = value[0]*value[1]
+            e = s + n
+            pars[key] = x[s:e].reshape(value)
+            s += n
+        return pars
+
+    @staticmethod
+    def get_pars_size(structure: Dict[str, Tuple[int, int]]) -> int:
+        size = [value[0]*value[1] for _, value in structure.items()]
+        return reduce(lambda x, y: x*y, size)
+
+    @staticmethod
+    def get_pars_init(structure: Dict[str, Tuple[int, int]], init_method: Optional[Dict[str, Any]]=dict()) -> int:
+        init = {key: np.zeros((value[0], value[1])).reshape(-1) if not key in init_method
+                else init_method[key](value).reshape(-1)
+                for key, value in structure.items()}
+        return np.concatenate(list(init.values()), axis=0)
+
+    def nll(self, X: np.ndarray, Y: Optional[np.ndarray]=None, pars: Optional[Dict[str, np.ndarray]]=None) -> np.ndarray:
+        """
+        To estimate p(M|X) = int_Y p(M|X,Y)p(Y) dY, we assume p(Y) is Normal(mean(X), std(X)).
+        p(M|X,Y=Normal(mu(X), var(X))) = 1/2b exp(-(mu(X)-mu(X))/b) = 1/2b.
+        -log p = log(2) + log(b)
+        We return -log [p(M|X,Y=mu(X))/p(M|X_train,Y_train)] = -log p(M|X,Y=mu(X)) + log p(M|X_train, Y_traun)
+        Negative log likelihood L allows one
+        
+        Args:
+          X: The input data.
+          Y: The true result. If None, set it to the expectation.
+        
+        Returns: negative log likelihood.
+        """
+        if pars is None:
+            pars = self.pars
+            
+        A = pars["A_inf"]
+        b = pars["b_inf"]
+        X2 = anp.square(X)
+        Ap = pars["Ap_inf"]
+        Y_pred = anp.matmul(X2, Ap) + anp.matmul(X, A) + b
+        
+        log_inv_alpha = anp.matmul(X, pars["log_inv_alpha_prime"]) + pars["log_inv_alpha"]
+        sqrt_alpha = anp.exp(-0.5*log_inv_alpha)
+        #log_inv_tau1 = pars["log_inv_tau1"]
+        #tau1 = anp.exp(-log_inv_tau1)
+        #log_inv_tau2 = pars["log_inv_tau2"]
+        #tau2 = anp.exp(-log_inv_tau2)
+        if Y is None:
+            Y = self.predict(X)
+        # likelihood = p(D|x, y, A, b) = 1/sqrt(2 pi sigma^2) exp(-0.5*(Ax + b - y)/sigma^2)
+        # sigma is modelled as exp(A_e x + b_e) to make the aleatoric uncertainty data-dependent
+        L = anp.sum((anp.fabs(Y_pred - Y))*sqrt_alpha + 0.5*log_inv_alpha, axis=1)
+        # prior p(A_inf) p(b_inf) = p(A_inf) cte. (assume all b equally likely)
+        # A has normal prior with var = 1/tau
+        # 1/sqrt(2pi)1/sqrt(sigma**2) exp(-0.5 A**2/sigma**2)
+        # log p = -0.5 A **2/sigma**2 - 0.5 log sigma**2 - 0.5 log 2pi
+        # - log p = 0.5 A**2/sigma**2 + 0.5 log sigma**2 + 0.5 log 2pi
+        # 0.5 log sigma**2 = 0.5 log 1/tau
+        #L_prior = anp.sum(0.5*anp.square(pars["A_inf"].ravel())*tau.ravel() + 0.5*log_inv_tau.ravel())
+        #L_prior = (anp.sum(0.5*anp.square(A.ravel())*tau1 + 0.5*log_inv_tau1)
+        #           + anp.sum(0.5*anp.square(Ap.ravel())*tau2 + 0.5*log_inv_tau2)
+        #          )
+        #alpha = 2.0
+        #beta = 2.0
+        #L_hyper = anp.sum((alpha - 1)*log_inv_tau1 + beta*tau1
+        #           + (alpha - 1)*log_inv_tau2 + beta*tau2
+        #          )
+        return L
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
@@ -379,8 +454,6 @@ class DataHolder(TransformerMixin, BaseEstimator):
         """
         return X
 
-
-
 class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
     """
     Select only relevant entries in the low-resolution data.
@@ -391,16 +464,21 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
                  Set to None to perform no selection
       delta_tof: Number of components to take from the low-resolution spectrometer.
                  Set to None to perform no selection.
+      poly: Whether to output a polynomial expantion of the low-resolution data.
     """
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
+                 poly: bool=False
                  ):
         self.channels = channels
         self.tof_start = tof_start
         self.delta_tof = delta_tof
+        self.poly = poly
+        self.mean = dict()
+        self.std = dict()
 
     def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False) -> np.ndarray:
         """
@@ -422,7 +500,10 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
             last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
             y = {channel: item[:, first:last] for channel, item in X.items()}
         if not keep_dictionary_structure:
-            return np.concatenate(list(y.values()), axis=1)
+            selected = list(y.values())
+            if self.poly:
+                selected += [np.sqrt(np.fabs(v)) for v in y.values()]
+            return np.concatenate(selected, axis=1)
         return y
 
     def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
@@ -436,8 +517,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         """
         # reduce on channel and on train ID
         sum_low_res = - np.mean(sum(list(X.values())), axis=0)
-        widths = np.arange(10, 50, step=1)
-        peak_idx = find_peaks_cwt(sum_low_res, widths)
+        axis = np.arange(0.0, sum_low_res.shape[0], 1.0)
+        #widths = np.arange(10, 50, step=5)
+        #peak_idx = find_peaks_cwt(sum_low_res, widths)
+        gaussian = np.exp(-0.5*(axis - sum_low_res.shape[0]//2)**2/20**2)
+        gaussian /= np.sum(gaussian, axis=0, keepdims=True)
+        # apply it to the data
+        smoothened = fftconvolve(sum_low_res, gaussian, mode="same", axes=0)
+        peak_idx = [np.argmax(smoothened)]
         if len(peak_idx) < 1:
             raise PromptNotFoundError()
         peak_idx = sorted(peak_idx, key=lambda k: np.fabs(sum_low_res[k]), reverse=True)
@@ -460,7 +547,14 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
 
         Returns: The object itself.
         """
+        print("Finding peaks")
         self.tof_start = self.estimate_prompt_peak(X)
+        X_tr = self.transform(X, keep_dictionary_structure=True)
+        self.mean = {ch: np.mean(X_tr[ch], axis=0, keepdims=True)
+                     for ch in X.keys()}
+        self.std = {ch: np.std(X_tr[ch], axis=0, keepdims=True)
+                    for ch in X.keys()}
+        print("Found peaks")
         return self
 
     def debug_peak_finding(self, X: Dict[str, np.ndarray], filename: str):
@@ -493,9 +587,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         plt.savefig(filename)
         plt.close(fig)
 
-def _fit_estimator(estimator, X: np.ndarray, y: np.ndarray):
+def _fit_estimator(estimator, X: np.ndarray, y: np.ndarray, w: np.ndarray):
     estimator = clone(estimator)
-    estimator.fit(X, y)
+    estimator.fit(X, y, w)
     return estimator
 
 class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
@@ -504,7 +598,7 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
         self.estimator = estimator
         self.n_jobs = n_jobs
 
-    def fit(self, X: np.ndarray, y: np.ndarray):
+    def fit(self, X: np.ndarray, y: np.ndarray, weights: Optional[np.ndarray]=None):
         """Fit the model to data, separately for each output variable.
 
         Args:
@@ -521,10 +615,12 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
                 "y must have at least two dimensions for "
                 "multi-output regression but has only one."
             )
+        if weights is None:
+            weights = np.ones(y.shape[0])
 
         self.estimators_ = Parallel(n_jobs=self.n_jobs)(
             delayed(_fit_estimator)(
-                self.estimator, X, y[:, i]
+                self.estimator, X, y[:, i], weights
             )
             for i in range(y.shape[1])
         )
@@ -551,7 +647,7 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
             return np.asarray(y).T, np.asarray(unc).T
 
         return np.asarray(y).T
-
+    
 class Model(TransformerMixin, BaseEstimator):
     """
     Object representing a previous fit of the model to be used to predict high-resolution
@@ -570,25 +666,31 @@ class Model(TransformerMixin, BaseEstimator):
                        validation and systematic uncertainty estimate.
       n_nonlinear_kernel: Number of nonlinear kernel components added at the preprocessing stage
                        to obtain nonlinearities as an input and improve the prediction.
+      poly: Whether to use a polynomial expantion of the low-resolution data.
 
     """
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=600,
+                 n_pca_lr: int=400,
                  n_pca_hr: int=20,
                  high_res_sigma: float=0.2,
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
                  validation_size: float=0.05,
-                 n_nonlinear_kernel: int=0):
+                 n_nonlinear_kernel: int=0,
+                 poly: bool=False,
+                ):
         # models
+        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=poly)
         x_model_steps = list()
-        x_model_steps += [('select', SelectRelevantLowResolution(channels, tof_start, delta_tof))]
+        self.n_nonlinear_kernel = n_nonlinear_kernel
         if n_nonlinear_kernel > 0:
+            # Kernel PCA using Nystroem
             x_model_steps += [('fex', Pipeline([('prepca', PCA(n_pca_lr, whiten=True)),
                                                 ('nystroem', Nystroem(n_components=n_nonlinear_kernel, kernel='rbf', gamma=None, n_jobs=8)),
-                                                ]))]
+                                                ])),
+                             ]
         x_model_steps += [
                           ('pca', PCA(n_pca_lr, whiten=True)),
                           ('unc', UncertaintyHolder()),
@@ -599,18 +701,27 @@ class Model(TransformerMixin, BaseEstimator):
                                 ('pca', PCA(n_pca_hr, whiten=True)),
                                 ('unc', UncertaintyHolder()),
                                 ])
+        self.ood = {ch: IsolationForest()
+                    for ch in channels+['full']}
+        #self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
+        self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
         #self.fit_model = FitModel()
-        self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=30, tol=1e-4, verbose=True))
 
-        self.channel_mean = {ch: np.nan for ch in channels}
-        self.channel_relevance = {ch: np.nan for ch in channels}
+        self.mu_intensity = np.nan
+        self.std_intensity = np.nan
+
+        # we are reducing per channel from 2*delta_tof to delta_tof to check correlations
+        n_pca_lr_per_channel = delta_tof
+        self.channel_pca = {ch: PCA(n_pca_lr_per_channel, whiten=True)
+                            for ch in channels}
+        #self.channel_fit_model = {ch: FitModel() for ch in channels}
 
         # size of the test subset
         self.validation_size = validation_size
 
     def get_channels(self) -> List[str]:
         """Get channels used in training."""
-        return self.x_model.named_steps["select"].channels
+        return self.x_select.channels
 
     def get_energy_values(self) -> np.ndarray:
         """Get x-axis of high-resolution data."""
@@ -618,7 +729,9 @@ class Model(TransformerMixin, BaseEstimator):
 
     def get_low_resolution_range(self) -> Tuple[int, int]:
         """Get bin number with first and last relevant bins in the low-resolution spectrum."""
-        return self.x_model.named_steps['select'].tof_start, (self.x_model.named_steps['select'].tof_start + self.x_model.named_steps['select'].delta_tof)
+        first = (self.x_select.tof_start - self.x_select.delta_tof)
+        last = (self.x_select.tof_start + self.x_select.delta_tof)
+        return first, last
 
     def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
         """
@@ -629,7 +742,7 @@ class Model(TransformerMixin, BaseEstimator):
           filename: The file name where to save the plot.
 
         """
-        self.x_model['select'].debug_peak_finding(low_res_data, filename)
+        self.x_select.debug_peak_finding(low_res_data, filename)
 
     def preprocess_high_res(self, high_res_data: np.ndarray) -> np.ndarray:
         """
@@ -641,8 +754,30 @@ class Model(TransformerMixin, BaseEstimator):
         Returns: Smoothened spectrum.
         """
         return self.y_model['smoothen'].transform(high_res_data)
-
-    def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
+    
+    def uniformize(self, intensity: np.ndarray) -> np.ndarray:
+        """
+        Calculate weights to uniformize data in variable intensity.
+        
+        Args:
+          intensity: The variable to uniformize the weights on.
+        
+        Return: weights.
+        """
+        kde = KernelDensity()
+        kde.fit(intensity)
+        q = np.quantile(intensity, [0.10, 0.90])
+        l, h = q[0], q[1]
+        x = intensity*((intensity > l) & (intensity < h)) + l*(intensity <= l) + h*(intensity >= h)
+        log_prob = kde.score_samples(x)
+        w = np.exp(-log_prob)
+        w = w/np.median(w)
+        return w
+
+    def fit(self, low_res_data: Dict[str, np.ndarray],
+            high_res_data: np.ndarray, high_res_photon_energy: np.ndarray,
+            weights: Optional[np.ndarray]=None
+            ) -> np.ndarray:
         """
         Train the model.
 
@@ -657,13 +792,24 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: Smoothened high resolution spectrum.
         """
+        if weights is None:
+            weights = np.ones(high_Res_data.shape[0])
         print("Fitting PCA on low-resolution data.")
-        x_t = self.x_model.fit_transform(low_res_data)
+        low_res_select = self.x_select.fit_transform(low_res_data)
+        n_components = min(self.x_model["pca"].n_components, low_res_select.shape[0])
+        self.x_model.set_params(pca__n_components=n_components)
+        x_t = self.x_model.fit_transform(low_res_select)
         print("Fitting PCA on high-resolution data.")
         y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
+        intensity = np.sum(y_t, axis=1)
+        self.mu_intensity = np.mean(intensity, axis=0)
+        self.sigma_intensity = np.mean(intensity, axis=0)
         #self.fit_model.set_params(fex__gamma=1.0/float(x_t.shape[0]))
+        print("Fitting outlier detection")
+        self.ood['full'].fit(x_t)
+        inliers = self.ood['full'].predict(x_t) > 0.0
         print("Fitting model.")
-        self.fit_model.fit(x_t, y_t)
+        self.fit_model.fit(x_t[inliers], y_t[inliers], weights[inliers])
 
         # calculate the effect of the PCA
         print("Calculate PCA unc. on high-resolution data.")
@@ -673,60 +819,20 @@ class Model(TransformerMixin, BaseEstimator):
         high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True))
         self.y_model['unc'].set_uncertainty(high_pca_unc)
 
-        print("Calculate PCA unc. on low-resolution data.")
-        low_res = self.x_model['select'].transform(low_res_data)
-        pca_model = self.x_model['pca']
-        if 'fex' in self.x_model.named_steps:
-            pca_model = self.x_model['fex'].named_steps['prepca']
-        low_pca = pca_model.transform(low_res)
-        low_pca_rec = pca_model.inverse_transform(low_pca)
-        low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
-        self.x_model['unc'].set_uncertainty(low_pca_unc)
-
         # for consistency check per channel
-        selection_model = self.x_model['select']
+        selection_model = self.x_select
+        low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         for channel in self.get_channels():
-            self.channel_mean[channel] = np.mean(low_res_data[channel], axis=0, keepdims=True)
-            print(f"Calculate PCA relevance on {channel}")
-            # freeze input data in one channel only
-            low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
-                                       else np.repeat(self.channel_mean[channel], low_res_data[ch].shape[0], axis=0)
-                                   for ch in self.get_channels()}
-            low_res = selection_model.transform(low_res_data_frozen)
-            low_pca = pca_model.fit_transform(low_res)
-            low_pca_rec = pca_model.inverse_transform(low_pca)
-            low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
-            self.channel_relevance[channel] = low_pca_unc
+            print(f"Calculate PCA on {channel}")
+            low_pca = self.channel_pca[channel].fit_transform(low_res_selected[channel])
+            self.ood[channel].fit(low_pca)
+            #low_pca_rec = self.channel_pca[channel].inverse_transform(low_pca)
+            #self.channel_fit_model[channel].fit(low_pca, y_t)
+            
         print("End of fit.")
 
         return high_res
 
-    def get_channel_quality(self, channel: str, low_res_data: Dict[str, np.ndarray],
-                            pca_model: PCA,
-                            channel_relevance: Dict[str, float],
-                            selection_model: SelectRelevantLowResolution,
-                            channel_mean: Dict[str, np.ndarray]) -> float:
-        """
-        Get the compatibility for a single channel.
-
-        Args:
-          channel: The channel.
-          low_res: The data in a dictionary.
-          pca_model: The PCA model.
-
-        Returns: the compatibility factor.
-        """
-        # freeze input data in one channel only
-        low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
-                                   else np.repeat(channel_mean[channel], low_res_data[ch].shape[0], axis=0)
-                               for ch in low_res_data.keys()}
-        low_res_selected = selection_model.transform(low_res_data_frozen)
-        low_pca = pca_model.transform(low_res_selected)
-        low_pca_rec = pca_model.inverse_transform(low_pca)
-        low_pca_unc = channel_relevance[channel]
-        low_pca_dev =  np.sqrt(np.mean((low_res_selected - low_pca_rec)**2, axis=1, keepdims=True))
-        return low_pca_dev/low_pca_unc
-
     def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
@@ -735,61 +841,54 @@ class Model(TransformerMixin, BaseEstimator):
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
 
-        Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model per channel.
+        Returns: Outlier score. If it is bigger than 0, this is an outlier.
         """
-        selection_model = self.x_model['select']
+        selection_model = self.x_select
         quality = {channel: 0.0 for channel in low_res_data.keys()}
         channels = list(low_res_data.keys())
-        pca_model = self.x_model['pca']
-        if 'fex' in self.x_model.named_steps:
-            pca_model = self.x_model['fex'].named_steps['prepca']
-        #with mp.Pool(len(low_res_data.keys())) as p:
-        values = map(partial(self.get_channel_quality,
-                             low_res_data=low_res_data,
-                             pca_model=pca_model,
-                             channel_relevance=self.channel_relevance,
-                             selection_model=selection_model,
-                             channel_mean=self.channel_mean
-                            ), channels)
-        quality = dict(zip(channels, values))
-        return quality
+        # check if each channel is close to the mean
+        low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
+        low_pca = {ch: self.channel_pca[ch].transform(low_res_selected[ch])
+                   for ch in channels}
+        return {ch: self.ood[ch].predict(low_pca[ch]) for ch in channels}
+    
+        ## for chi2
+        #deviation = {ch: low_pca[ch]
+        #             for ch in channels}
+        #ndof = {ch: float(deviation[ch].shape[1] - 1)
+        #        for ch in channels}
+        #chi2 = {ch: np.sum(deviation[ch]**2, axis=1, keepdims=True)
+        #        for ch in channels}
+        #chi2_mu = {ch: scipy.stats.chi2.mean(ndof[ch])
+        #           for ch in channels}
+        #chi2_sigma = {ch: scipy.stats.chi2.std(ndof[ch])
+        #              for ch in channels}
+        #return {ch: (chi2[ch] - chi2_mu[ch])/chi2_sigma[ch]
+        #        for ch in channels}
 
     def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
-        comparing the effect of the trained PCA model on it.
+        using a robust covariance matrix estimate of the data
 
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
 
-        Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model.
+        Returns: An outlier score: if it is greater than 0, this is an outlier.
         """
-        low_res = self.x_model['select'].transform(low_res_data)
-        pca_model = self.x_model['pca']
-        if 'fex' in self.x_model.named_steps:
-            pca_model = self.x_model['fex'].named_steps['prepca']
+        low_res = self.x_select.transform(low_res_data)
+        pca_model = self.x_model
+        #pca_model = self.x_model['pca']
+        #if 'fex' in self.x_model.named_steps:
+        #    pca_model = self.x_model['fex'].named_steps['prepca']
         low_pca = pca_model.transform(low_res)
-        low_pca_rec = pca_model.inverse_transform(low_pca)
-        low_pca_unc = self.x_model['unc'].uncertainty()
-
-        #fig = plt.figure(figsize=(8, 16))
-        #ax = plt.gca()
-        #ax.plot(low_res[0,...],
-        #        c="b",
-        #        label="LR")
-        #ax.plot(low_pca_rec[0,...],
-        #        c="r",
-        #        label="LR rec.")
-        #ax.set(title="",
-        #       xlabel="Photon Spectrometer channel",
-        #       ylabel="Low resolution spectrometer intensity")
-        #ax.legend()
-        #plt.savefig("check.png")
-        #plt.close(fig)
-
-        low_pca_dev =  np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True))
-        return low_pca_dev/low_pca_unc
-
+        return self.ood['full'].predict(low_pca)
+        #deviation = low_pca
+        #ndof = float(deviation.shape[1] - 1)
+        #chi2 = np.sum(deviation**2, axis=1, keepdims=True)
+        #chi2_mu = scipy.stats.chi2.mean(ndof)
+        #chi2_sigma = scipy.stats.chi2.std(ndof)
+        #return (chi2 - chi2_mu)/chi2_sigma
 
     def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
         """
@@ -803,20 +902,33 @@ class Model(TransformerMixin, BaseEstimator):
                  the expected prediction in key "expected", the stat. uncertainty in key "unc" and
                  a (1, energy channel) array for the PCA syst. uncertainty in key "pca".
         """
-        low_pca = self.x_model.transform(low_res_data)
+        low_res_pre = self.x_select.transform(low_res_data)
+        low_pca = self.x_model.transform(low_res_pre)
         high_pca, high_pca_unc = self.fit_model.predict(low_pca, return_std=True)
+        intensity = np.sum(high_pca, axis=1, keepdims=True)
+        Z_intensity = (intensity - self.mu_intensity)/self.sigma_intensity
         #high_pca = self.fit_model.predict(low_pca)
         #high_pca_unc = 0
         n_trains = high_pca.shape[0]
+        # Note: The whiten=True setting in the PCA model leads to an affine transformation
         pca_y = np.concatenate((high_pca,
-                                high_pca + high_pca_unc),
+                                high_pca + high_pca_unc,
+                               ),
                                axis=0)
         high_res_predicted = self.y_model["pca"].inverse_transform(pca_y)
         expected = high_res_predicted[:n_trains, :]
-        unc = high_res_predicted[n_trains:, :] - expected
+        e_plus = high_res_predicted[n_trains:(2*n_trains), :]
+        unc = np.fabs(e_plus - expected)
+        pca_unc = self.y_model['unc'].uncertainty()
+        total_unc = np.sqrt(pca_unc**2 + unc**2)
+
         return dict(expected=expected,
                     unc=unc,
-                    pca=self.y_model['unc'].uncertainty())
+                    pca=pca_unc,
+                    total_unc=total_unc,
+                    inlier=self.ood['full'].predict(low_pca),
+                    Z_intensity=Z_intensity
+                    )
 
     def save(self, filename: str):
         """
@@ -825,11 +937,15 @@ class Model(TransformerMixin, BaseEstimator):
         Args:
           filename: File name where to save this.
         """
-        joblib.dump([self.x_model,
+        joblib.dump([self.x_select,
+                     self.x_model,
                      self.y_model,
                      self.fit_model,
-                     DataHolder(self.channel_mean),
-                     DataHolder(self.channel_relevance)
+                     self.channel_pca,
+                     #self.channel_fit_model
+                     DataHolder(self.mu_intensity),
+                     DataHolder(self.sigma_intensity),
+                     self.ood
                      ], filename, compress='zlib')
 
     @staticmethod
@@ -842,12 +958,23 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: A new model object.
         """
-        x_model, y_model, fit_model, channel_mean, channel_relevance = joblib.load(filename)
+        (x_select,
+         x_model, y_model, fit_model,
+         channel_pca,
+         #channel_fit_model
+         mu_intensity,
+         sigma_intensity,
+         ood
+        ) = joblib.load(filename)
         obj = Model()
+        obj.x_select = x_select
         obj.x_model = x_model
         obj.y_model = y_model
         obj.fit_model = fit_model
-        obj.channel_mean = channel_mean.get_data()
-        obj.channel_relevance = channel_relevance.get_data()
+        obj.channel_pca = channel_pca
+        #obj.channel_fit_model = channel_fit_model
+        obj.ood = ood
+        obj.mu_intensity = mu_intensity.get_data()
+        obj.sigma_intensity = sigma_intensity.get_data()
         return obj
 
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 6db5415c6def669ebc50cc2adfe708929e0d2372..790a5b368054825f9672b165fe753b4ddbfbe993 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -4,8 +4,11 @@ import sys
 sys.path.append('.')
 sys.path.append('..')
 
+import os
+import argparse
+
 import numpy as np
-from extra_data import RunDirectory, by_id
+from extra_data import open_run, by_id, RunDirectory
 from pes_to_spec.model import Model, matching_ids
 
 from itertools import product
@@ -15,8 +18,8 @@ matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt
 from matplotlib.gridspec import GridSpec
-from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition,
-                                                  mark_inset)
+from mpl_toolkits.axes_grid.inset_locator import InsetPosition
+import seaborn as sns
 
 from typing import Dict, Optional
 
@@ -55,7 +58,14 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int):
     fig.savefig(filename)
     plt.close(fig)
 
-def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np.ndarray, spec_raw_pe: np.ndarray, spec_raw_int: Optional[np.ndarray]=None, pes: Optional[np.ndarray]=None, pes_to_show: Optional[str]="", pes_bin: Optional[np.ndarray]=None):
+def plot_result(filename: str,
+                spec_pred: Dict[str, np.ndarray],
+                spec_smooth: np.ndarray,
+                spec_raw_pe: np.ndarray,
+                spec_raw_int: Optional[np.ndarray]=None,
+                pes: Optional[np.ndarray]=None,
+                pes_to_show: Optional[str]="",
+                pes_bin: Optional[np.ndarray]=None):
     """
     Plot result with uncertainty band.
 
@@ -73,8 +83,8 @@ def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np
     fig = plt.figure(figsize=(12, 8))
     gs = GridSpec(1, 1)
     ax = fig.add_subplot(gs[0, 0])
-    unc_stat = np.mean(spec_pred["unc"])
-    unc_pca = np.mean(spec_pred["pca"])
+    unc_stat = spec_pred["unc"]
+    unc_pca = spec_pred["pca"]
     unc = np.sqrt(unc_stat**2 + unc_pca**2)
     ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-res. measurement (smoothened)")
     ax.plot(spec_raw_pe, spec_pred["expected"], c='r', ls='--', lw=3, label="High-res. prediction")
@@ -116,90 +126,125 @@ def main():
     """
     Main entry point. Reads some data, trains and predicts.
     """
-    run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw"
-    run_id = "r0015"
+    parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.")
+    parser.add_argument('-p', '--proposal', type=int, metavar='INT', help='Proposal number', default=2828)
+    parser.add_argument('-r', '--run', type=int, metavar='INT', help='Run number', default=206)
+    parser.add_argument('-m', '--model', type=str, metavar='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.')
+    parser.add_argument('-d', '--directory', type=str, metavar='DIRECTORY', default=".", help='Where to save the results.')
+    parser.add_argument('-S', '--spec', type=str, metavar='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name')
+    parser.add_argument('-P', '--pes', type=str, metavar='NAME', default="SA3_XTD10_PES/ADC/1:network", help='PES name')
+    parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
+    parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
+    parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=0, help='XGM intensity threshold in uJ.')
+    parser.add_argument('-e', '--poly', action="store_true", default=False, help='Wheteher to expand PES data in higher order polynomials.')
+
+    args = parser.parse_args()
+
+    print("Opening run ...")
     # get run
-    run = RunDirectory(f"{run_dir}/{run_id}")
+    run = open_run(proposal=args.proposal, run=args.run)
+    #run = RunDirectory("/gpfs/exfel/data/scratch/tmichela/data/r0206")
 
+    # ----------------Used in the first tests-------------------------
     # get train IDs and match them, so we are sure to have information from all needed sources
     # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
     spec_offset = -2
-    spec_tid = spec_offset + run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
-                                 "data.trainId"].ndarray()
-    pes_tid = run['SA3_XTD10_PES/ADC/1:network',
-                  "digitizers.trainId"].ndarray()
-    xgm_tid = run['SA3_XTD10_XGM/XGM/DOOCS:output',
-                  "data.trainId"].ndarray()
-    # these are the train ID intersection
-    # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
-    tids = matching_ids(spec_tid, pes_tid, xgm_tid)
+    spec_name = 'SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output'
+    pes_name = 'SA3_XTD10_PES/ADC/1:network'
+
+    spec_offset = 0
+    spec_name = 'SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output'
+    pes_name = 'SA3_XTD10_PES/ADC/1:network'
+    # -------------------End of test setup ----------------------------
+
+    spec_offset = args.offset
+    spec_name = args.spec
+    pes_name = args.pes
+    xgm_name = args.xgm
+
+    pes_tid = run[pes_name, "digitizers.trainId"].ndarray()
+    xgm_tid = run[xgm_name, "data.trainId"].ndarray()
+
+    if len(args.model) == 0:
+        spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray()
+        # these are the train ID intersection
+        # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
+        tids = matching_ids(spec_tid, pes_tid, xgm_tid)
+
+        # read the spec photon energy and intensity
+        spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
+        spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
+
+    else: # when doing inference, no need to load SPEC data
+        tids = pes_tid
+
+    # reserve part of it for the test stage
     train_tids = tids[:-10]
     test_tids = tids[-10:]
 
-    # read the spec photon energy and intensity
-    spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
-                      "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
-    spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
-                       "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
-
     # read the PES data for each channel
     channels = [f"channel_{i}_{l}"
                 for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
-    pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network',
-                   f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
+    pes_raw = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
                for ch in channels}
-    pes_raw_t = {ch: run['SA3_XTD10_PES/ADC/1:network',
-                   f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
+    pes_raw_t = {ch: run[pes_name, f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
                for ch in channels}
 
+    print("Data in memory.")
+
     # read the XGM information
     #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', "pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
     #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
     #retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
     #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
+    
+    xgm_flux =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()[:, 0][:, np.newaxis]
+    xgm_flux_t =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[test_tids]).ndarray()[:, 0][:, np.newaxis]
 
     t = list()
     t_names = list()
 
-    # these have been manually selected:
-    #useful_channels = ["channel_1_D",
-    #                  "channel_2_B",
-    #                  "channel_3_A",
-    #                  "channel_3_B",
-    #                  "channel_4_C",
-    #                  "channel_4_D"]
-    model = Model()
+    model = Model(poly=args.poly)
 
-    train_idx = np.isin(tids, train_tids)
+    train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
 
-    model.debug_peak_finding(pes_raw, "test_peak_finding.png")
-    print("Fitting")
-    start = time_ns()
-    model.fit({k: v[train_idx, :]
-               for k, v in pes_raw.items()},
-              spec_raw_int[train_idx, :],
-              spec_raw_pe[train_idx, :])
-    t += [time_ns() - start]
-    t_names += ["Fit"]
+    model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png"))
+    if len(args.model) == 0:
+        print("Fitting")
+        start = time_ns()
+        w = model.uniformize(xgm_flux[train_idx])
+        print("w", np.amin(w), np.amax(w), np.median(w))
+        model.fit({k: v[train_idx, :]
+                   for k, v in pes_raw.items()},
+                   spec_raw_int[train_idx, :],
+                   spec_raw_pe[train_idx, :],
+                   w
+                   )
+        t += [time_ns() - start]
+        t_names += ["Fit"]
 
-    print("Saving the model")
-    start = time_ns()
-    model.save("model.joblib")
-    t += [time_ns() - start]
-    t_names += ["Save"]
+        print("Saving the model")
+        start = time_ns()
+        modelFilename = os.path.join(args.directory, "model.joblib")
+        model.save(modelFilename)
+        t += [time_ns() - start]
+        t_names += ["Save"]
+    else:
+        print("Model has been given, so I will just load it.")
+        modelFilename = args.model
 
     print("Loading the model")
     start = time_ns()
-    model = Model.load("model.joblib")
+    model = Model.load(modelFilename)
     t += [time_ns() - start]
     t_names += ["Load"]
 
     print("Check consistency")
     start = time_ns()
-    rmse = model.check_compatibility(pes_raw_t)
-    print("Consistency check RMSE ratios:", rmse)
-    rmse = model.check_compatibility_per_channel(pes_raw_t)
-    print("Consistency per channel check RMSE ratios:", rmse)
+    Z = model.check_compatibility(pes_raw_t)
+    print("Consistency check:", Z)
+    Z = model.check_compatibility_per_channel(pes_raw_t)
+    print("Consistency per channel:", Z)
     t += [time_ns() - start]
     t_names += ["Consistency"]
 
@@ -216,27 +261,173 @@ def main():
     print(df_time)
 
     print("Plotting")
-    spec_smooth = model.preprocess_high_res(spec_raw_int)
+    showSpec = False
+    if len(args.model) == 0:
+        showSpec = True
+        spec_smooth = model.preprocess_high_res(spec_raw_int)
+
+        # chi2 w.r.t XGM intensity
+        erange = spec_raw_pe[0,-1] - spec_raw_pe[0,0]
+        de = (spec_raw_pe[0,1] - spec_raw_pe[0,0])
+        chi2 = np.sum((spec_smooth - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=1)
+        ndof = float(spec_smooth.shape[1]) - 1.0
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        ax.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=20)
+        ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
+               xlabel=r"$\chi^2/$ndof",
+               ylabel="XGM intensity [uJ]",
+               xlim=(0, 5),
+               )
+        ax2 = plt.axes([0,0,1,1])
+        # Manually set the position and relative size of the inset axes within ax1
+        ip = InsetPosition(ax, [0.65,0.6,0.35,0.4])
+        ax2.set_axes_locator(ip)
+        ax2.scatter(chi2/ndof, xgm_flux[:,0], c='r', s=30)
+        #ax2.scatter(chi2/ndof, np.sum(spec_pred["expected"], axis=1)*de, c='b', s=30)
+        #ax2.scatter(chi2/ndof, np.sum(spec_raw_int, axis=1)*de, c='g', s=30)
+        ax2.set(title="",
+                xlabel=r"$\chi^2/$ndof",
+                ylabel=f"XGM intensity [uJ]",
+                )
+        ax2.title.set_size(SMALL_SIZE)
+        ax2.xaxis.label.set_size(SMALL_SIZE)
+        ax2.yaxis.label.set_size(SMALL_SIZE)
+        ax2.tick_params(axis='both', which='major', labelsize=SMALL_SIZE)
+        fig.savefig(os.path.join(args.directory, "intensity_vs_chi2.png"))
+        plt.close(fig)
+        
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.kdeplot(x=chi2/ndof, ax=ax)
+        ax.set(title=f"",
+               xlabel=r"$\chi^2/$ndof",
+               ylabel="Density [a.u.]",
+               xlim=(0, 5),
+               )
+        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(chi2/ndof):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(chi2/ndof):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        fig.savefig(os.path.join(args.directory, "chi2.png"))
+        plt.close(fig)
+        
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.kdeplot(x=xgm_flux[:,0], ax=ax)
+        ax.set(title=f"",
+               xlabel="XGM intensity [uJ]",
+               ylabel="Density [a.u.]",
+               )
+        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux[:,0]):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux[:,0]):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        fig.savefig(os.path.join(args.directory, "intensity.png"))
+        plt.close(fig)
+        
+        # rmse
+        rmse = np.sqrt(np.mean((spec_smooth - spec_pred["expected"])**2, axis=1))
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        ax.scatter(rmse, xgm_flux[:,0], c='r', s=30)
+        ax = plt.gca()
+        ax.set(title=f"",
+               xlabel=r"Root-mean-squared error",
+               ylabel="XGM intensity [uJ]",
+               )
+        fig.savefig(os.path.join(args.directory, "intensity_vs_rmse.png"))
+        plt.close(fig)
+        
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.kdeplot(x=rmse, ax=ax)
+        ax.set(title=f"",
+               xlabel="Root-mean-squared error",
+               ylabel="Density [a.u.]",
+               xlim=(0, 20),
+               )
+        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(rmse):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(rmse):.2f}",
+                verticalalignment='top', horizontalalignment='right',
+                transform=ax.transAxes,
+                color='black', fontsize=15)
+        fig.savefig(os.path.join(args.directory, "rmse.png"))
+        plt.close(fig)
+        
+        # SPEC integral w.r.t XGM intensity
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        ax.set(title=f"",
+               xlabel="SPEC (raw) integral",
+               ylabel="XGM Intensity [uJ]",
+               )
+        fig.savefig(os.path.join(args.directory, "xgm_vs_intensity.png"))
+        plt.close(fig)
+        
+        # SPEC integral w.r.t XGM intensity
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=np.sum(spec_pred["expected"], axis=1)*de, color='r', robust=True, ax=ax)
+        ax.set(title=f"",
+               xlabel="SPEC (raw) integral",
+               ylabel="Predicted integral",
+               )
+        fig.savefig(os.path.join(args.directory, "expected_vs_intensity.png"))
+        plt.close(fig)
+        
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.regplot(x=np.sum(spec_pred["expected"], axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        ax.set(title=f"",
+               xlabel="Predicted integral",
+               ylabel="XGM intensity [uJ]",
+               )
+        fig.savefig(os.path.join(args.directory, "xgm_vs_expected.png"))
+        plt.close(fig)
+
     first, last = model.get_low_resolution_range()
     first += 10
-    last -= 100
+    last -= 10
     pes_to_show = 'channel_1_D'
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
-        plot_result(f"test_{tid}.png",
+        plot_result(os.path.join(args.directory, f"test_{tid}.png"),
                    {k: item[idx, ...] if k != "pca"
                        else item[0, ...]
                        for k, item in spec_pred.items()},
-                    spec_smooth[idx, :],
-                    spec_raw_pe[idx, :],
-                    spec_raw_int[idx, :],
-                    pes=-pes_raw[pes_to_show][idx, first:last],
-                    pes_to_show=pes_to_show.replace('_', ' '),
-                    pes_bin=np.arange(first, last)
+                    spec_smooth[idx, :] if showSpec else None,
+                    spec_raw_pe[idx, :] if showSpec else None,
+                    spec_raw_int[idx, :] if showSpec else None,
+                    #pes=-pes_raw[pes_to_show][idx, first:last],
+                    #pes_to_show=pes_to_show.replace('_', ' '),
+                    #pes_bin=np.arange(first, last)
                     )
         for ch in channels:
-            plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last)
+            plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"),
+                     pes_raw[ch][idx, first:last], first, last)
 
 if __name__ == '__main__':
     main()
+
diff --git a/pyproject.toml b/pyproject.toml
index 729a6382b1856f652257812a449e7a7442ff3307..32d592425209f6bcfc75d3000b62312009fc6419 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,12 +28,11 @@ dependencies = [
           "numpy>=1.21",
           "scipy>=1.6",
           "scikit-learn>=1.0.2",
-          #"autograd",
-          #"h5py"
+          "autograd",
           ]
 
 [project.optional-dependencies]
-offline = ["matplotlib", "extra_data"]
+offline = ["seaborn", "statsmodels", "matplotlib", "extra_data"]
 
 [project.scripts]
 offline_analysis = "pes_to_spec.test.offline_analysis:main"