From 8323a1575baf0450c3a4d51e9d0898dcbe5fd07a Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 9 Jan 2023 15:36:55 +0100
Subject: [PATCH] Using two pipelines to split more evenly the preprocessing
 steps.

---
 README.md                            |   5 +-
 pes_to_spec/model.py                 | 610 +++++++++++++++------------
 pes_to_spec/test/offline_analysis.py |   7 +-
 3 files changed, 351 insertions(+), 271 deletions(-)

diff --git a/README.md b/README.md
index 123f9b3..bc85448 100644
--- a/README.md
+++ b/README.md
@@ -32,12 +32,11 @@ model.fit(low_resolution_raw_data,
           high_resolution_photon_energy)
 
 # save it for later usage:
-model.save("model.h5")
+model.save("model.joblib")
 
 # when performing inference:
 # load a model:
-model = Model()
-model.load("model.h5")
+model = Model.load("model.joblib")
 
 # and use it to map a low-resolution spectrum to a high-resolution one
 # as before, the low_resolution_raw_data refers to a dictionary mapping the channel name
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 429a250..d33806a 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -6,11 +6,16 @@ from scipy.signal import fftconvolve
 from scipy.signal import find_peaks_cwt
 from scipy.optimize import fmin_l_bfgs_b
 from sklearn.decomposition import PCA, IncrementalPCA
-from sklearn.model_selection import train_test_split
+from sklearn.pipeline import Pipeline
 from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.base import RegressorMixin
+from sklearn.compose import TransformedTargetRegressor
 from itertools import product
+from sklearn.model_selection import train_test_split
 from time import time_ns
 
+import joblib
+
 import matplotlib.pyplot as plt
 
 from typing import Union, Any, Dict, List, Optional
@@ -107,89 +112,146 @@ class PromptNotFoundError(Exception):
     def __str__(self) -> str:
         return "No prompt peak has been detected."
 
-class Model(TransformerMixin, BaseEstimator):
+class HighResolutionSmoother(TransformerMixin, BaseEstimator):
     """
-    Object representing a previous fit of the model to be used to predict high-resolution
-    spectrum from a low-resolution one.
+    Smoothens out the high resolution data.
+
+    Args:
+      high_res_sigma: Energy resolution in eV.
+    """
+    def __init__(self,
+                 high_res_sigma: float=0.2
+                 ):
+        self.high_res_sigma = high_res_sigma
+        self.energy = None
+
+    def fit(self, X, y=None, **fit_params) -> TransformerMixin:
+        """
+        Fit records the energy axis.
+
+        Args:
+          X: Irrelevant.
+          y: Irrelevant.
+          fit_params: Contains the energy axis in the key "energy" with shape (any, features).
+
+        Returns: The object itself.
+        """
+        self.energy = fit_params["energy"]
+        if len(self.energy.shape) == 2:
+            self.energy = self.energy[0,:]
+        return self
+
+    def transform(self, X: np.ndarray) -> np.ndarray:
+        """
+        Apply smoothing in X using the energy axis.
+
+        Args:
+          X: Input to smoothen with shape (train id, features).
+
+        Returns: Smoothened out spectrum.
+        """
+        # use a default energy axis is none is given
+        # assume only the energy step
+        energy = np.broadcast_to(self.energy, X.shape)
+
+        # Apply smoothing
+        n_features = X.shape[1]
+        # get the centre value of the energy axis
+        mu = energy[:, n_features//2, np.newaxis]
+        # generate a gaussian
+        gaussian = np.exp(-0.5*(energy - mu)**2/self.high_res_sigma**2)
+        gaussian /= np.sum(gaussian, axis=1, keepdims=True)
+        # apply it to the data
+        high_res_gc = fftconvolve(X, gaussian, mode="same", axes=1)
+        return high_res_gc
+
+
+class UncertaintyHolder(TransformerMixin, BaseEstimator):
+    """
+    Keep track of uncertainties.
+
+    """
+    def __init__(self):
+        self.unc: np.ndarray = np.zeros((1, 0), dtype=float)
+
+    def set_uncertainty(self, unc: np.ndarray):
+        """
+        Set the uncertainty.
+
+        Args:
+          unc: The uncertainty.
+        """
+        self.unc = np.copy(unc)
+
+    def fit(self, X, y=None) -> TransformerMixin:
+        """
+        Does nothing.
+
+        Args:
+          X: Irrelevant.
+          y: Irrelevant.
+
+        Returns: Itself.
+        """
+        return self
+
+    def transform(self, X: np.ndarray) -> np.ndarray:
+        """
+        Identity map.
+
+        Args:
+          X: The input.
+        """
+        return X
+
+    def inverse_transform(self, X: np.ndarray) -> np.ndarray:
+        """
+        Identity map.
+
+        Args:
+          X: The input.
+        """
+        return X
+
+    def uncertainty(self):
+        """The uncertainty recorded."""
+        return self.unc
+
+class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
+    """
+    Select only relevant entries in the low-resolution data.
 
     Args:
       channels: Selected channels to use as an input for the low resolution data.
-      n_pca_lr: Number of low-resolution data PCA components.
-      n_pca_hr: Number of high-resolution data PCA components.
-      high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts.
       tof_start: Start looking at this index from the low-resolution spectrometer data.
                  Set to None to perform no selection
       delta_tof: Number of components to take from the low-resolution spectrometer.
                  Set to None to perform no selection.
-      validation_size: Fraction (number between 0 and 1) of the data to take for
-                       validation and systematic uncertainty estimate.
-
     """
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=600,
-                 n_pca_hr: int=20,
-                 high_res_sigma: float=0.2,
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
-                 validation_size: float=0.05):
+                 ):
         self.channels = channels
-        self.n_pca_lr = n_pca_lr
-        self.n_pca_hr = n_pca_hr
-
-        # PCA models
-        self.lr_pca = PCA(n_pca_lr, whiten=True)
-        self.hr_pca = PCA(n_pca_hr, whiten=True)
-
-        # PCA unc. in high resolution
-        self.high_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float)
-        self.low_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float)
-
-        # fit model
-        self.fit_model = FitModel()
-
-        # size of the test subset
-        self.validation_size = validation_size
-
-        # where to cut on the ToF PES data
         self.tof_start = tof_start
         self.delta_tof = delta_tof
 
-        # high-resolution photon energy axis
-        self.high_res_photon_energy: Optional[np.ndarray] = None
-
-        # smoothing of the SPEC data in eV
-        self.high_res_sigma = high_res_sigma
-
-    def parameters(self) -> Dict[str, Any]:
-        """
-        Dump parameters as a dictionary.
-        """
-        return dict(channels=self.channels,
-                    n_pca_lr=self.n_pca_lr,
-                    n_pca_hr=self.n_pca_hr,
-                    high_res_sigma=self.high_res_sigma,
-                    tof_start=self.tof_start,
-                    delta_tof=self.delta_tof,
-                    validation_size=self.validation_size,
-                    high_pca_unc=self.high_pca_unc,
-                    low_pca_unc=self.low_pca_unc,
-                    high_res_photon_energy=self.high_res_photon_energy,
-                    )
-
-    def preprocess_low_res(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
+    def transform(self, X: Dict[str, np.ndarray]) -> np.ndarray:
         """
         Get a dictionary with the channel names for the inut low resolution data and output
         only the relevant input data in an array.
 
         Args:
-          low_res_data: Dictionary with keys named channel_{i}_{k},
-                        where i is a number between 1 and 4 and k is a letter between A and D.
+          X: Dictionary with keys named channel_{i}_{k},
+             where i is a number between 1 and 4 and k is a letter between A and D.
 
         Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
         """
-        items = [low_res_data[k] for k in self.channels]
+        if self.tof_start is None:
+            raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
+        items = [X[k] for k in self.channels]
         if self.delta_tof is not None:
             items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items]
         else:
@@ -197,37 +259,17 @@ class Model(TransformerMixin, BaseEstimator):
         cat = np.concatenate(items, axis=1)
         return cat
 
-    def preprocess_high_res(self, high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
-        """
-        Get the high resolution data and preprocess it.
-
-        Args:
-          high_res_data: High resolution data with shape (train_id, features).
-          high_res_photon_energy: High resolution photon energy values
-                                  (the "x"-axis of the high resolution data) with
-                                  shape (train_id, features).
-
-        Returns: Pre-processed high-resolution data of shape (train_id, features) before.
-        """
-        # Apply smoothing
-        n_features = high_res_data.shape[1]
-        mu = high_res_photon_energy[:, n_features//2, np.newaxis]
-        gaussian = np.exp(-0.5*(high_res_photon_energy - mu)**2/self.high_res_sigma**2)
-        gaussian /= np.sum(gaussian, axis=1, keepdims=True)
-        high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)
-        return high_res_gc
-
-    def estimate_prompt_peak(self, low_res_data: Dict[str, np.ndarray]) -> int:
+    def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
         """
         Estimate the prompt peak index.
 
         Args:
-          low_res_data: Low resolution data with a dictionary containing the channel names.
+          X: Low resolution data with a dictionary containing the channel names.
 
-        Returns: The prompt peak index.
+        Returns: The index.
         """
         # reduce on channel and on train ID
-        sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
+        sum_low_res = - np.mean(sum(list(X.values())), axis=0)
         widths = np.arange(10, 50, step=1)
         peak_idx = find_peaks_cwt(sum_low_res, widths)
         if len(peak_idx) < 1:
@@ -242,17 +284,30 @@ class Model(TransformerMixin, BaseEstimator):
         improved_guess = min_search + int(np.argmax(restricted_arr))
         return improved_guess
 
-    def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
+    def fit(self, X: Dict[str, np.ndarray], y: Optional[Any]=None) -> TransformerMixin:
+        """
+        Estimate the prompt peak index.
+
+        Args:
+          X: Low resolution data with a dictionary containing the channel names.
+          y: Ignored.
+
+        Returns: The object itself.
+        """
+        self.tof_start = self.estimate_prompt_peak(X)
+        return self
+
+    def debug_peak_finding(self, X: Dict[str, np.ndarray], filename: str):
         """
         Produce image to understand if the peak finding step worked well.
 
         Args:
-          low_res_data: Low resolution data with a dictionary containing the channel names.
+          X: Low resolution data with a dictionary containing the channel names.
           filename: The file name where to save the plot.
 
         """
-        sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
-        peak_idx = self.estimate_prompt_peak(low_res_data)
+        sum_low_res = - np.mean(sum(list(X.values())), axis=0)
+        peak_idx = self.estimate_prompt_peak(X)
         fig = plt.figure(figsize=(8, 16))
         ax = plt.gca()
         ax.plot(np.arange(peak_idx-100, peak_idx+300),
@@ -271,178 +326,7 @@ class Model(TransformerMixin, BaseEstimator):
         plt.savefig(filename)
         plt.close(fig)
 
-    def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
-        """
-        Train the model.
-
-        Args:
-          low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`,
-                        where i is a number between 1 and 4 and k is a letter between A and D.
-                        For each dictionary entry, a numpy array is expected with shape
-                        (train_id, ToF channel).
-          high_res_data: Reference high resolution data with a one-to-one match to the
-                         low resolution data in the train_id dimension. Shape (train_id, ToF channel).
-          high_res_photon_energy: Photon energy axis for the high-resolution data.
-
-        Returns: Smoothened high resolution spectrum.
-        """
-
-        self.high_res_photon_energy = high_res_photon_energy[0, np.newaxis, :]
-
-        # if the prompt peak has not been given, guess it
-        if self.tof_start is None:
-            self.tof_start = self.estimate_prompt_peak(low_res_data)
-
-        low_res = self.preprocess_low_res(low_res_data)
-        high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy)
-        # fit PCA
-        low_pca = self.lr_pca.fit_transform(low_res)
-        high_pca = self.hr_pca.fit_transform(high_res)
-        # split in train and test for PCA uncertainty evaluation
-        (low_pca_train, low_pca_test,
-         high_pca_train, high_pca_test) = train_test_split(low_pca, high_pca,
-                                                           test_size=self.validation_size,
-                                                           random_state=42)
-        # fit the linear model
-        self.fit_model.fit(low_pca_train,
-                           high_pca_train,
-                           low_pca_test,
-                           high_pca_test)
-
-        high_pca_rec = self.hr_pca.inverse_transform(high_pca)
-        self.high_pca_unc =  np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True))
-
-        low_pca_rec = self.lr_pca.inverse_transform(low_pca)
-        self.low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
-
-        return high_res
-
-    def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float:
-        """
-        Check if a new low-resolution data source is compatible with the one used in training, by
-        comparing the effect of the trained PCA model on it.
-
-        Args:
-          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
-
-        Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model.
-        """
-        low_res = self.preprocess_low_res(low_res_data)
-        low_pca = self.lr_pca.transform(low_res)
-        low_pca_rec = self.lr_pca.inverse_transform(low_pca)
-
-        #fig = plt.figure(figsize=(8, 16))
-        #ax = plt.gca()
-        #ax.plot(low_res[0,...],
-        #        c="b",
-        #        label="LR")
-        #ax.plot(low_pca_rec[0,...],
-        #        c="r",
-        #        label="LR rec.")
-        #ax.set(title="",
-        #       xlabel="Photon Spectrometer channel",
-        #       ylabel="Low resolution spectrometer intensity")
-        #ax.legend()
-        #plt.savefig("check.png")
-        #plt.close(fig)
-
-        low_pca_unc =  np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True))
-        return low_pca_unc/self.low_pca_unc
-
-
-    def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
-        """
-        Predict a high-resolution spectrum from a low resolution given one.
-        The output includes the uncertainty in its second and third entries of the first dimension.
-
-        Args:
-          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
-
-        Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
-                 the expected prediction in key "expected", the stat. uncertainty in key "unc" and
-                 a (1, energy channel) array for the PCA syst. uncertainty in key "pca".
-        """
-        low_res = self.preprocess_low_res(low_res_data)
-        low_pca = self.lr_pca.transform(low_res)
-        # Get high res.
-        high_pca = self.fit_model.predict(low_pca)
-        n_trains = low_pca.shape[0]
-        pca_y = np.concatenate((high_pca["Y"],
-                                high_pca["Y"] + high_pca["Y_eps"]),
-                               axis=0)
-        high_res_predicted = self.hr_pca.inverse_transform(pca_y)
-        expected = high_res_predicted[:n_trains, :]
-        unc = high_res_predicted[n_trains:, :] - expected
-        return dict(expected=expected,
-                    unc=unc,
-                    pca=self.high_pca_unc)
-
-    def save(self, filename: str):
-        """
-        Save the fit model in a file.
-
-        Args:
-          filename: H5 file name where to save this.
-        """
-        #joblib.dump(self, filename)
-        with h5py.File(filename, 'w') as hf:
-            # transform parameters into a dict
-            d = self.fit_model.as_dict()
-            d.update(self.parameters())
-            # dump them in the file
-            dump_in_group(d, hf)
-            # this is not ideal, because it depends on the knowledge of the PCA
-            # object structure, but saving to a joblib file would mean creating several
-            # files
-            # create a group
-            lr_pca = hf.create_group("lr_pca")
-            # get PCA properties
-            lr_pca_props = get_pca_props(self.lr_pca)
-            # create the HR group
-            hr_pca = hf.create_group("hr_pca")
-            # get PCA properties
-            hr_pca_props = get_pca_props(self.hr_pca)
-            # dump them
-            dump_in_group(lr_pca_props, lr_pca)
-            dump_in_group(hr_pca_props, hr_pca)
-
-
-    def load(self, filename: str):
-        """
-        Load model from a file.
-
-        Args:
-          filename: Name of the file where to read the model from.
-
-        """
-        with h5py.File(filename, 'r') as hf:
-            # read from file
-            d = read_from_group(hf)
-            # load fit_model parameters
-            self.fit_model.from_dict(d)
-            # load parameters of this class
-            for key in self.parameters().keys():
-                value = d[key]
-                if key == 'channels':
-                    value = [item.decode() if isinstance(item, bytes)
-                             else item
-                             for item in value]
-                setattr(self, key, value)
-            # this is not ideal, because it depends on the knowledge of the PCA
-            # object structure, but saving to a joblib file would mean creating several
-            # files
-            lr_pca = hf["/lr_pca/"]
-            hr_pca = hf["/hr_pca/"]
-            self.lr_pca = PCA(self.n_pca_lr, whiten=True)
-            self.hr_pca = PCA(self.n_pca_hr, whiten=True)
-            # read properties in dictionaries
-            lr_pca_props = read_from_group(lr_pca)
-            hr_pca_props = read_from_group(hr_pca)
-            # set them
-            self.lr_pca = set_pca_props(self.lr_pca, lr_pca_props)
-            self.hr_pca = set_pca_props(self.hr_pca, hr_pca_props)
-
-class FitModel(object):
+class FitModel(RegressorMixin, BaseEstimator):
     """
     Linear regression model with uncertainties.
     """
@@ -462,14 +346,27 @@ class FitModel(object):
 
         self.input_data = None
 
-    def fit(self, X_train: np.ndarray, Y_train: np.ndarray, X_test: np.ndarray, Y_test: np.ndarray):
+    def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin:
         """
         Perform the fit and evaluate uncertainties with the test set.
+
+        Args:
+          X: The input.
+          y: The target.
+          fit_params: If it contains X_test and y_test, they are used to validate the model.
+
+        Returns: The object itself.
         """
+        if 'X_test' in fit_params and 'y_test' in fit_params:
+            X_test = fit_params['X_test']
+            y_test = fit_params['y_test']
+        else:
+            X_test = X
+            y_test = y
 
         # model parameter sizes
-        self.Nx: int = int(X_train.shape[1])
-        self.Ny: int = int(Y_train.shape[1])
+        self.Nx: int = int(X.shape[1])
+        self.Ny: int = int(y.shape[1])
 
         # initial parameter values
         A0: np.ndarray = np.eye(self.Nx, self.Ny).reshape(self.Nx*self.Ny)
@@ -515,8 +412,8 @@ class FitModel(object):
 
             Returns: The loss value.
             """
-            l_train = loss(x, X_train, Y_train)
-            l_test = loss(x, X_test, Y_test)
+            l_train = loss(x, X, y)
+            l_test = loss(x, X_test, y_test)
 
             self.loss_train += [l_train]
             self.loss_test += [l_test]
@@ -531,7 +428,7 @@ class FitModel(object):
 
             Returns: The loss value.
             """
-            l_train = loss(x, X_train, Y_train)
+            l_train = loss(x, X, y)
             return l_train
 
         grad_loss = grad(loss_train)
@@ -603,3 +500,188 @@ class FitModel(object):
         result["Y_eps"] = np.exp(X @ self.A_eps + result["Y_unc"])
         return result
 
+
+class Model(TransformerMixin, BaseEstimator):
+    """
+    Object representing a previous fit of the model to be used to predict high-resolution
+    spectrum from a low-resolution one.
+
+    Args:
+      channels: Selected channels to use as an input for the low resolution data.
+      n_pca_lr: Number of low-resolution data PCA components.
+      n_pca_hr: Number of high-resolution data PCA components.
+      high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts.
+      tof_start: Start looking at this index from the low-resolution spectrometer data.
+                 Set to None to perform no selection
+      delta_tof: Number of components to take from the low-resolution spectrometer.
+                 Set to None to perform no selection.
+      validation_size: Fraction (number between 0 and 1) of the data to take for
+                       validation and systematic uncertainty estimate.
+
+    """
+    def __init__(self,
+                 channels:List[str]=[f"channel_{j}_{k}"
+                                     for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
+                 n_pca_lr: int=600,
+                 n_pca_hr: int=20,
+                 high_res_sigma: float=0.2,
+                 tof_start: Optional[int]=None,
+                 delta_tof: Optional[int]=300,
+                 validation_size: float=0.05):
+        # models
+        self.x_model = Pipeline([
+                                ('select', SelectRelevantLowResolution(channels, tof_start, delta_tof)),
+                                ('pca', PCA(n_pca_lr, whiten=True)),
+                                ('unc', UncertaintyHolder()),
+                               ])
+        self.y_model = Pipeline([('smoothen', HighResolutionSmoother(high_res_sigma)),
+                                ('pca', PCA(n_pca_hr, whiten=True)),
+                                ('unc', UncertaintyHolder()),
+                                ])
+        self.fit_model = FitModel()
+
+        # size of the test subset
+        self.validation_size = validation_size
+
+    def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
+        """
+        Produce image to understand if the peak finding step worked well.
+
+        Args:
+          low_res_data: Low resolution data with a dictionary containing the channel names.
+          filename: The file name where to save the plot.
+
+        """
+        self.x_model['select'].debug_peak_finding(low_res_data, filename)
+
+    def preprocess_high_res(self, high_res_data: np.ndarray) -> np.ndarray:
+        """
+        Preprocess high-resolution data to remove high requency components.
+
+        Args:
+          high_res_data: The high-resolution data.
+
+        Returns: Smoothened spectrum.
+        """
+        return self.y_model['smoothen'].transform(high_res_data)
+
+    def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
+        """
+        Train the model.
+
+        Args:
+          low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`,
+                        where i is a number between 1 and 4 and k is a letter between A and D.
+                        For each dictionary entry, a numpy array is expected with shape
+                        (train_id, ToF channel).
+          high_res_data: Reference high resolution data with a one-to-one match to the
+                         low resolution data in the train_id dimension. Shape (train_id, ToF channel).
+          high_res_photon_energy: Photon energy axis for the high-resolution data.
+
+        Returns: Smoothened high resolution spectrum.
+        """
+        x_t = self.x_model.fit_transform(low_res_data)
+        y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
+        self.fit_model.fit(x_t, y_t)
+
+        # calculate the effect of the PCA
+        high_res = self.y_model['smoothen'].transform(high_res_data)
+        high_pca = self.y_model.transform(high_res_data)
+        high_pca_rec = self.y_model['pca'].inverse_transform(high_pca)
+        high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True))
+        self.y_model['unc'].set_uncertainty(high_pca_unc)
+
+        low_res = self.x_model['select'].transform(low_res_data)
+        low_pca = self.x_model['pca'].transform(low_res)
+        low_pca_rec = self.x_model['pca'].inverse_transform(low_pca)
+        low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
+        self.x_model['unc'].set_uncertainty(low_pca_unc)
+
+        return high_res
+
+    def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float:
+        """
+        Check if a new low-resolution data source is compatible with the one used in training, by
+        comparing the effect of the trained PCA model on it.
+
+        Args:
+          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+
+        Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model.
+        """
+        low_res = self.x_model['select'].transform(low_res_data)
+        low_pca = self.x_model['pca'].transform(low_res_data)
+        low_pca_rec = self.x_model['pca'].inverse_transform(low_pca)
+        low_pca_unc = self.x_model['unc'].uncertainty()
+
+        #fig = plt.figure(figsize=(8, 16))
+        #ax = plt.gca()
+        #ax.plot(low_res[0,...],
+        #        c="b",
+        #        label="LR")
+        #ax.plot(low_pca_rec[0,...],
+        #        c="r",
+        #        label="LR rec.")
+        #ax.set(title="",
+        #       xlabel="Photon Spectrometer channel",
+        #       ylabel="Low resolution spectrometer intensity")
+        #ax.legend()
+        #plt.savefig("check.png")
+        #plt.close(fig)
+
+        low_pca_unc =  np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True))
+        return low_pca_unc/low_pca_unc
+
+
+    def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+        """
+        Predict a high-resolution spectrum from a low resolution given one.
+        The output includes the uncertainty in its second and third entries of the first dimension.
+
+        Args:
+          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+
+        Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
+                 the expected prediction in key "expected", the stat. uncertainty in key "unc" and
+                 a (1, energy channel) array for the PCA syst. uncertainty in key "pca".
+        """
+        low_pca = self.x_model.transform(low_res_data)
+        high_pca = self.fit_model.predict(low_pca)
+        n_trains = high_pca["Y"].shape[0]
+        pca_y = np.concatenate((high_pca["Y"],
+                                high_pca["Y"] + high_pca["Y_eps"]),
+                               axis=0)
+        high_res_predicted = self.y_model.inverse_transform(pca_y)
+        expected = high_res_predicted[:n_trains, :]
+        unc = high_res_predicted[n_trains:, :] - expected
+        return dict(expected=expected,
+                    unc=unc,
+                    pca=self.y_model['unc'].uncertainty())
+
+    def save(self, filename: str):
+        """
+        Save the fit model in a file.
+
+        Args:
+          filename: File name where to save this.
+        """
+        joblib.dump([self.x_model,
+                     self.y_model,
+                     self.fit_model],
+                    filename)
+
+    @staticmethod
+    def load(filename: str) -> Model:
+        """
+        Load model from a file.
+
+        Args:
+          filename: Name of the file where to read the model from.
+
+        """
+        x_model, y_model, fit_model = joblib.load(filename)
+        obj = Model()
+        obj.x_model = x_model
+        obj.y_model = y_model
+        obj.fit_model = fit_model
+
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index b1a155c..66cc778 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -139,18 +139,17 @@ def main():
               spec_raw_pe[train_idx, :])
     t += [time_ns() - start]
     t_names += ["Fit"]
-    spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
+    spec_smooth = model.preprocess_high_res(spec_raw_int)
 
     print("Saving the model")
     start = time_ns()
-    model.save("model.h5")
+    model.save("model.joblib")
     t += [time_ns() - start]
     t_names += ["Save"]
 
     print("Loading the model")
     start = time_ns()
-    model = Model()
-    model.load("model.h5")
+    model = Model.load("model.joblib")
     t += [time_ns() - start]
     t_names += ["Load"]
 
-- 
GitLab