From 02eded28b5db4cc617eb1507754c2f5557a9cb8f Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 19 Dec 2022 18:59:30 +0100
Subject: [PATCH] Save PCA in the same file.

---
 pes_to_spec/model.py     | 94 +++++++++++++++++++++++++++++++++-------
 scripts/test_analysis.py |  8 ++++
 2 files changed, 86 insertions(+), 16 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index fa88929..2f73ec2 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -1,7 +1,7 @@
 import numpy as np
 from autograd import numpy as anp
 from autograd import grad
-import joblib
+#import joblib
 import h5py
 from scipy.signal import fftconvolve
 from scipy.signal import find_peaks_cwt
@@ -9,17 +9,67 @@ from scipy.optimize import fmin_l_bfgs_b
 from sklearn.decomposition import PCA, IncrementalPCA
 from sklearn.model_selection import train_test_split
 
-from time import time_ns
-
 import matplotlib.pyplot as plt
 
-from typing import Any, Dict, List, Optional
+from typing import Union, Any, Dict, List, Optional
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
     unique_ids = list(set(a).intersection(b).intersection(c))
     return np.array(unique_ids)
 
+def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group):
+    """
+    Save a PCA object to an H5 file.
+
+    Args:
+      pca_obj: An instance where to load data to.
+      pca_group: An H5 file group.
+
+    """
+
+    props = ["components_",
+             "explained_variance_",
+             "explained_variance_ratio_",
+             "singular_values_",
+             "mean_"]
+    attrs = ["n_components_",
+             "n_features_",
+             "n_samples_",
+             "noise_variance_",
+             "n_features_in_"]
+    for p in props:
+        pca_group.create_dataset(p, getattr(pca_obj, p))
+    for a in attrs:
+        pca_group.attrs[p] = getattr(pca_obj, a)
+
+def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Union[IncrementalPCA, PCA]:
+    """
+    Load a PCA object from an H5 file.
+
+    Args:
+      pca_obj: An instance where to load data to.
+      pca_group: An H5 file group.
+
+    Returns: Updated PCA instance.
+    """
+
+    props = ["components_",
+             "explained_variance_",
+             "explained_variance_ratio_",
+             "singular_values_",
+             "mean_"]
+    attrs = ["n_components_",
+             "n_features_",
+             "n_samples_",
+             "noise_variance_",
+             "n_features_in_"]
+    for p in props:
+        setattr(pca_obj, p, pca_group[p])
+    for a in attrs:
+        setattr(pca_obj, a, pca_group[a])
+    return pca_obj
+
 class PromptNotFoundError(Exception):
     """
     Exception representing the error condition generated by not finding the prompt peak.
@@ -238,14 +288,12 @@ class Model(object):
                            axis=2)
         return result
 
-    def save(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
+    def save(self, filename: str):
         """
         Save the fit model in a file.
 
         Args:
           filename: H5 file name where to save this.
-          lr_pca_filename: Name of the file where to save the low-resolution PCA decomposition.
-          hr_pca_filename: Name of the file where to save the high-resolution PCA decomposition.
         """
         with h5py.File(filename, 'w') as hf:
             d = self.fit_model.as_dict()
@@ -255,18 +303,23 @@ class Model(object):
                     hf.attrs[key] = value
                 else:
                     hf.create_dataset(key, data=value)
-        joblib.dump(self.lr_pca, lr_pca_filename)
-        joblib.dump(self.hr_pca, hr_pca_filename)
-
-
-    def load(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
+            # this is not ideal, because it depends on the knowledge of the PCA
+            # object structure, but saving to a joblib file would mean creating several
+            # files
+            lr_pca = hf.create_group("lr_pca")
+            hr_pca = hf.create_group("hr_pca")
+            save_pca(self.lr_pca, lr_pca)
+            save_pca(self.hr_pca, hr_pca)
+        #joblib.dump(self.lr_pca, lr_pca_filename)
+        #joblib.dump(self.hr_pca, hr_pca_filename)
+
+
+    def load(self, filename: str):
         """
         Load model from a file.
 
         Args:
           filename: Name of the file where to read the model from.
-          lr_pca_filename: Name of the file from where to load the low-resolution PCA decomposition.
-          hr_pca_filename: Name of the file from where to load the high-resolution PCA decomposition.
 
         """
         with h5py.File(filename, 'r') as hf:
@@ -275,8 +328,17 @@ class Model(object):
             self.fit_model.from_dict(d)
             for key in self.parameters().keys():
                 setattr(self, key, d[key])
-        self.lr_pca = joblib.load(lr_pca_filename)
-        self.hr_pca = joblib.load(hr_pca_filename)
+            # this is not ideal, because it depends on the knowledge of the PCA
+            # object structure, but saving to a joblib file would mean creating several
+            # files
+            lr_pca = hf["/lr_pca/"]
+            hr_pca = hf["/hr_pca/"]
+            self.lr_pca = IncrementalPCA(self.n_pca_lr)
+            self.hr_pca = PCA(self.n_pca_hr)
+            self.lr_pca = load_pca(self.lr_pca, lr_pca)
+            self.hr_pca = load_pca(self.hr_pca, hr_pca)
+        #self.lr_pca = joblib.load(lr_pca_filename)
+        #self.hr_pca = joblib.load(hr_pca_filename)
 
 class FitModel(object):
     """
diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
index 3d848dc..e6cd510 100755
--- a/scripts/test_analysis.py
+++ b/scripts/test_analysis.py
@@ -121,10 +121,18 @@ def main():
     model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :])
     spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
 
+    print("Saving the model")
+    model.save("model.h5")
+
+    print("Loading the model")
+    model = Model()
+    model.load("model.h5")
+
     # test
     print("Predict")
     spec_pred = model.predict(pes_raw)
 
+    print("Plotting")
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
-- 
GitLab