From fe96d70ea1996736d446a91a7bfa98a12cadf189 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 21 Dec 2022 12:19:11 +0100
Subject: [PATCH] Fixed saving into file.

---
 pes_to_spec/model.py                 | 172 ++++++++++++++++++---------
 pes_to_spec/test/offline_analysis.py |  37 ++++--
 2 files changed, 140 insertions(+), 69 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 8e5d701..aac55a5 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -1,13 +1,13 @@
 import numpy as np
 from autograd import numpy as anp
 from autograd import grad
-#import joblib
 import h5py
 from scipy.signal import fftconvolve
 from scipy.signal import find_peaks_cwt
 from scipy.optimize import fmin_l_bfgs_b
 from sklearn.decomposition import PCA, IncrementalPCA
 from sklearn.model_selection import train_test_split
+from sklearn.base import TransformerMixin, BaseEstimator
 
 import matplotlib.pyplot as plt
 
@@ -18,13 +18,41 @@ def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     unique_ids = list(set(a).intersection(b).intersection(c))
     return np.array(unique_ids)
 
-def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group):
+def dump_in_group(content: Dict[str, Any], hf: h5py.Group):
     """
-    Save a PCA object to an H5 file.
+    Write dictionary as an HDF5 group.
+
+    Args:
+      content: The content to be written.
+      hf: The HDF5 group.
+    """
+    for key, value in content.items():
+        if isinstance(value, int) or isinstance(value, float):
+            hf.attrs[key] = value
+        elif isinstance(value, list):
+            hf.attrs[key] = value
+        else:
+            hf.create_dataset(key, data=value, compression="gzip")
+
+def read_from_group(hf: h5py.Group) -> Dict[str, Any]:
+    """
+    Read dictionary from an HDF5 group.
+
+    Args:
+      hf: The HDF5 group.
+
+    Returns: the content to be read.
+    """
+    d = {k: hf[k][()] for k in hf.keys() if not isinstance(hf[k], h5py.Group)}
+    d.update({k: hf.attrs[k] for k in hf.attrs})
+    return d
+
+def get_pca_props(pca_obj: Union[IncrementalPCA, PCA]):
+    """
+    Get PCA properties and attributes.
 
     Args:
       pca_obj: An instance where to load data to.
-      pca_group: An H5 file group.
 
     """
 
@@ -32,25 +60,23 @@ def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group):
              "explained_variance_",
              "explained_variance_ratio_",
              "singular_values_",
-             "mean_"]
-    attrs = ["n_components_",
-             #"n_features_",
-             #"n_samples_",
+             "mean_",
+             "n_components_",
+             "n_features_",
+             "n_samples_",
              "noise_variance_",
-             #"n_features_in_"
+             "n_features_in_"
              ]
-    for p in props:
-        pca_group.create_dataset(p, data=getattr(pca_obj, p))
-    for a in attrs:
-        pca_group.attrs[a] = getattr(pca_obj, a)
+    return {p: getattr(pca_obj, p)
+            for p in props}
 
-def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Union[IncrementalPCA, PCA]:
+def set_pca_props(pca_obj: Union[IncrementalPCA, PCA], pca_props: Dict[str, Any]) -> Union[IncrementalPCA, PCA]:
     """
-    Load a PCA object from an H5 file.
+    Load a PCA object from a dict.
 
     Args:
       pca_obj: An instance where to load data to.
-      pca_group: An H5 file group.
+      pca_props: A dictionary of properties read.
 
     Returns: Updated PCA instance.
     """
@@ -59,17 +85,15 @@ def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Unio
              "explained_variance_",
              "explained_variance_ratio_",
              "singular_values_",
-             "mean_"]
-    attrs = ["n_components_",
-             #"n_features_",
-             #"n_samples_",
+             "mean_",
+             "n_components_",
+             "n_features_",
+             "n_samples_",
              "noise_variance_",
-             #"n_features_in_"
+             "n_features_in_"
              ]
     for p in props:
-        setattr(pca_obj, p, pca_group[p][()])
-    for a in attrs:
-        setattr(pca_obj, a, pca_group.attrs[a])
+        setattr(pca_obj, p, pca_props[p])
     return pca_obj
 
 class PromptNotFoundError(Exception):
@@ -81,7 +105,7 @@ class PromptNotFoundError(Exception):
     def __str__(self) -> str:
         return "No prompt peak has been detected."
 
-class Model(object):
+class Model(TransformerMixin, BaseEstimator):
     """
     Object representing a previous fit of the model to be used to predict high-resolution
     spectrum from a low-resolution one.
@@ -91,9 +115,12 @@ class Model(object):
       n_pca_lr: Number of low-resolution data PCA components.
       n_pca_hr: Number of high-resolution data PCA components.
       high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts.
-      tof_start: Start looking at this index from the low-resolution spectrometer data. Set to None to perform no selection
-      delta_tof: Number of components to take from the low-resolution spectrometer. Set to None to perform no selection.
-      validation_size: Fraction (number between 0 and 1) of the data to take for validation and systematic uncertainty estimate.
+      tof_start: Start looking at this index from the low-resolution spectrometer data.
+                 Set to None to perform no selection
+      delta_tof: Number of components to take from the low-resolution spectrometer.
+                 Set to None to perform no selection.
+      validation_size: Fraction (number between 0 and 1) of the data to take for
+                       validation and systematic uncertainty estimate.
 
     """
     def __init__(self,
@@ -114,7 +141,7 @@ class Model(object):
         self.n_pca_hr = n_pca_hr
 
         # PCA models
-        self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True)
+        self.lr_pca = PCA(n_pca_lr, whiten=True)
         self.hr_pca = PCA(n_pca_hr, whiten=True)
 
         # PCA unc. in high resolution
@@ -146,7 +173,10 @@ class Model(object):
                     high_res_sigma=self.high_res_sigma,
                     tof_start=self.tof_start,
                     delta_tof=self.delta_tof,
-                    validation_size=self.validation_size)
+                    validation_size=self.validation_size,
+                    high_pca_unc=self.high_pca_unc,
+                    high_res_photon_energy=self.high_res_photon_energy,
+                    )
 
     def preprocess_low_res(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
         """
@@ -154,7 +184,8 @@ class Model(object):
         only the relevant input data in an array.
 
         Args:
-          low_res_data: Dictionary with keys named channel_{i}_{k}, where i is a number between 1 and 4 and k is a letter between A and D.
+          low_res_data: Dictionary with keys named channel_{i}_{k},
+                        where i is a number between 1 and 4 and k is a letter between A and D.
 
         Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
         """
@@ -172,7 +203,9 @@ class Model(object):
 
         Args:
           high_res_data: High resolution data with shape (train_id, features).
-          high_res_photon_energy: High resolution photon energy values (the "x"-axis of the high resolution data) with shape (train_id, features).
+          high_res_photon_energy: High resolution photon energy values
+                                  (the "x"-axis of the high resolution data) with
+                                  shape (train_id, features).
 
         Returns: Pre-processed high-resolution data of shape (train_id, features) before.
         """
@@ -200,7 +233,14 @@ class Model(object):
         if len(peak_idx) < 1:
             raise PromptNotFoundError()
         peak_idx = sorted(peak_idx, key=lambda k: np.fabs(sum_low_res[k]), reverse=True)
-        return peak_idx[0]
+        best_guess = int(peak_idx[0])
+        # look around this estimate for the maximum
+        # this is probably not necessary
+        min_search = max(best_guess - 10, 0)
+        max_search = min(best_guess + 10, len(sum_low_res))
+        restricted_arr = sum_low_res[min_search:max_search]
+        improved_guess = min_search + int(np.argmax(restricted_arr))
+        return improved_guess
 
     def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
         """
@@ -236,8 +276,12 @@ class Model(object):
         Train the model.
 
         Args:
-          low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, where i is a number between 1 and 4 and k is a letter between A and D. For each dictionary entry, a numpy array is expected with shape (train_id, ToF channel).
-          high_res_data: Reference high resolution data with a one-to-one match to the low resolution data in the train_id dimension. Shape (train_id, ToF channel).
+          low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`,
+                        where i is a number between 1 and 4 and k is a letter between A and D.
+                        For each dictionary entry, a numpy array is expected with shape
+                        (train_id, ToF channel).
+          high_res_data: Reference high resolution data with a one-to-one match to the
+                         low resolution data in the train_id dimension. Shape (train_id, ToF channel).
           high_res_photon_energy: Photon energy axis for the high-resolution data.
 
         Returns: Smoothened high resolution spectrum.
@@ -272,7 +316,9 @@ class Model(object):
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
 
-        Returns: High resolution data with shape (train_id, ToF channel, 3). The component 0 of the last dimension is the predicted spectrum. Components 1 and 2 correspond to two sources of uncertainty.
+        Returns: High resolution data with shape (train_id, ToF channel, 3).
+                 The component 0 of the last dimension is the predicted spectrum.
+                 Components 1 and 2 correspond to two sources of uncertainty.
         """
         low_res = self.preprocess_low_res(low_res_data)
         low_pca = self.lr_pca.transform(low_res)
@@ -297,23 +343,27 @@ class Model(object):
         Args:
           filename: H5 file name where to save this.
         """
+        #joblib.dump(self, filename)
         with h5py.File(filename, 'w') as hf:
+            # transform parameters into a dict
             d = self.fit_model.as_dict()
             d.update(self.parameters())
-            for key, value in d.items():
-                if isinstance(value, int):
-                    hf.attrs[key] = value
-                else:
-                    hf.create_dataset(key, data=value)
+            # dump them in the file
+            dump_in_group(d, hf)
             # this is not ideal, because it depends on the knowledge of the PCA
             # object structure, but saving to a joblib file would mean creating several
             # files
+            # create a group
             lr_pca = hf.create_group("lr_pca")
+            # get PCA properties
+            lr_pca_props = get_pca_props(self.lr_pca)
+            # create the HR group
             hr_pca = hf.create_group("hr_pca")
-            save_pca(self.lr_pca, lr_pca)
-            save_pca(self.hr_pca, hr_pca)
-        #joblib.dump(self.lr_pca, lr_pca_filename)
-        #joblib.dump(self.hr_pca, hr_pca_filename)
+            # get PCA properties
+            hr_pca_props = get_pca_props(self.hr_pca)
+            # dump them
+            dump_in_group(lr_pca_props, lr_pca)
+            dump_in_group(hr_pca_props, hr_pca)
 
 
     def load(self, filename: str):
@@ -325,22 +375,31 @@ class Model(object):
 
         """
         with h5py.File(filename, 'r') as hf:
-            d = {k: hf[k][()] for k in hf.keys() if not isinstance(hf[k], h5py.Group)}
-            d.update({k: hf.attrs[k] for k in hf.attrs})
+            # read from file
+            d = read_from_group(hf)
+            # load fit_model parameters
             self.fit_model.from_dict(d)
+            # load parameters of this class
             for key in self.parameters().keys():
-                setattr(self, key, d[key])
+                value = d[key]
+                if key == 'channels':
+                    value = [item.decode() if isinstance(item, bytes)
+                             else item
+                             for item in value]
+                setattr(self, key, value)
             # this is not ideal, because it depends on the knowledge of the PCA
             # object structure, but saving to a joblib file would mean creating several
             # files
             lr_pca = hf["/lr_pca/"]
             hr_pca = hf["/hr_pca/"]
-            self.lr_pca = IncrementalPCA(self.n_pca_lr, whiten=True)
+            self.lr_pca = PCA(self.n_pca_lr, whiten=True)
             self.hr_pca = PCA(self.n_pca_hr, whiten=True)
-            self.lr_pca = load_pca(self.lr_pca, lr_pca)
-            self.hr_pca = load_pca(self.hr_pca, hr_pca)
-        #self.lr_pca = joblib.load(lr_pca_filename)
-        #self.hr_pca = joblib.load(hr_pca_filename)
+            # read properties in dictionaries
+            lr_pca_props = read_from_group(lr_pca)
+            hr_pca_props = read_from_group(hr_pca)
+            # set them
+            self.lr_pca = set_pca_props(self.lr_pca, lr_pca_props)
+            self.hr_pca = set_pca_props(self.hr_pca, hr_pca_props)
 
 class FitModel(object):
     """
@@ -388,8 +447,8 @@ class FitModel(object):
         self.Y_test: np.ndarray = Y_test
 
         # model parameter sizes
-        self.Nx: int = self.X_train.shape[1]
-        self.Ny: int = self.Y_train.shape[1]
+        self.Nx: int = int(self.X_train.shape[1])
+        self.Ny: int = int(self.Y_train.shape[1])
 
         # initial parameter values
         A0: np.ndarray = np.eye(self.Nx, self.Ny).reshape(self.Nx*self.Ny)
@@ -424,9 +483,6 @@ class FitModel(object):
             #log_unc = anp.log(anp.exp(log_unc) + anp.exp(log_eps))
             iunc2 = anp.exp(-2*log_unc)
 
-            #print("iunc2", iunc2)
-            #print("log_unc", log_unc)
-
             return anp.mean( (0.5*((X@ A + b - Y)**2)*iunc2 + log_unc).sum(axis=1), axis=0 ) # Put RELU on (XX@x) and introduce new matrix W
 
         def loss_history(x: np.ndarray) -> float:
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index e6cd510..62fb6b3 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -78,9 +78,12 @@ def main():
     # get train IDs and match them, so we are sure to have information from all needed sources
     # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
     spec_offset = -2
-    spec_tid = spec_offset + run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.trainId"].ndarray()
-    pes_tid = run['SA3_XTD10_PES/ADC/1:network', f"digitizers.trainId"].ndarray()
-    xgm_tid = run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.trainId"].ndarray()
+    spec_tid = spec_offset + run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
+                                 "data.trainId"].ndarray()
+    pes_tid = run['SA3_XTD10_PES/ADC/1:network',
+                  "digitizers.trainId"].ndarray()
+    xgm_tid = run['SA3_XTD10_XGM/XGM/DOOCS:output',
+                  "data.trainId"].ndarray()
     # these are the train ID intersection
     # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
     tids = matching_ids(spec_tid, pes_tid, xgm_tid)
@@ -88,16 +91,21 @@ def main():
     test_tids = tids[-10:]
 
     # read the spec photon energy and intensity
-    spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
-    spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
+    spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
+                      "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
+    spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output',
+                       "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
 
     # read the PES data for each channel
-    channels = [f"channel_{i}_{l}" for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
-    pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels}
+    channels = [f"channel_{i}_{l}"
+                for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
+    pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network',
+                   f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
+               for ch in channels}
 
     # read the XGM information
-    #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
-    #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
+    #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', "pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
+    #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', "data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
     #retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
     #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
 
@@ -118,7 +126,10 @@ def main():
 
     model.debug_peak_finding(pes_raw, "test_peak_finding.png")
     print("Fitting")
-    model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :])
+    model.fit({k: v[train_idx, :]
+               for k, v in pes_raw.items()},
+              spec_raw_int[train_idx, :],
+              spec_raw_pe[train_idx, :])
     spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
 
     print("Saving the model")
@@ -136,7 +147,11 @@ def main():
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
-        plot_result(f"test_{tid}.png", spec_pred[idx, :, :], spec_smooth[idx, :], spec_raw_pe[idx, :], spec_raw_int[idx, :])
+        plot_result(f"test_{tid}.png",
+                    spec_pred[idx, :, :],
+                    spec_smooth[idx, :],
+                    spec_raw_pe[idx, :],
+                    spec_raw_int[idx, :])
         for ch in channels:
             plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, :])
 
-- 
GitLab