From d55c33a2bb8bb1d4cd71c319c252b05b95685c0d Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Thu, 15 Dec 2022 12:56:12 +0100
Subject: [PATCH] Saving PCA and plotting.

---
 pes_to_spec/model.py     | 14 ++++++++++++--
 requirements.txt         |  1 +
 scripts/test_analysis.py | 41 +++++++++++++++++++++++++++++++++-------
 3 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 60732b1..d5264de 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -1,6 +1,7 @@
 import numpy as np
 from autograd import numpy as anp
 from autograd import grad
+import joblib
 import h5py
 from scipy.signal import fftconvolve
 from scipy.optimize import fmin_l_bfgs_b
@@ -137,12 +138,14 @@ class Model(object):
         result = np.stack((high_res_predicted, high_res_unc, self.high_pca_unc), axis=0)
         return result
 
-    def save(self, filename: str):
+    def save(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
         """
         Save the fit model in a file.
 
         Args:
           filename: H5 file name where to save this.
+          lr_pca_filename: Name of the file where to save the low-resolution PCA decomposition.
+          hr_pca_filename: Name of the file where to save the high-resolution PCA decomposition.
         """
         with h5py.File(filename, 'w') as hf:
             d = self.fit_model.as_dict()
@@ -151,19 +154,26 @@ class Model(object):
                     hf.attrs[key] = value
                 else:
                     hf.create_dataset(key, data=value)
+        joblib.dump(self.lr_pca, lr_pca_filename)
+        joblib.dump(self.hr_pca, hr_pca_filename)
 
-    def load(self, filename: str):
+
+    def load(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
         """
         Load model from a file.
 
         Args:
           filename: Name of the file where to read the model from.
+          lr_pca_filename: Name of the file from where to load the low-resolution PCA decomposition.
+          hr_pca_filename: Name of the file from where to load the high-resolution PCA decomposition.
 
         """
         with h5py.File(filename, 'r') as hf:
             d = {k: hf[k][()] for k in hf.keys()}
             d.update({k: hf.attrs[k] for k in hf.attrs})
             self.fit_model.from_dict(d)
+        self.lr_pca = joblib.load(lr_pca_filename)
+        self.hr_pca = joblib.load(hr_pca_filename)
 
 class FitModel(object):
     """
diff --git a/requirements.txt b/requirements.txt
index eff7f6b..891565f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ scikit-learn
 extra_data
 autograd
 h5py
+joblib
 matplotlib
diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
index 51fc27b..adb6b8e 100644
--- a/scripts/test_analysis.py
+++ b/scripts/test_analysis.py
@@ -7,6 +7,28 @@ import matplotlib
 matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt
+from matplotlib.gridspec import GridSpec
+
+def plot_result(filename: str, spec_pred: np.ndarray, spec_raw_int: np.ndarray, spec_raw_pe: np.ndarray):
+    """
+    Plot result with uncertainty band.
+
+    Args:
+      filename: Output file name.
+      spec_pred: Predicted result with uncertainty bands in a shape of (3, features).
+      spec_raw_int: True expected result with shape (features,).
+      spec_raw_pe: x axis with the photon energy in eV.
+
+    """
+    fig = plt.figure(figsize=(10, 10))
+    gs = GridSpec(1, 1)
+    ax = fig.add_subplot(gs[0, 0])
+    ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=3, label="High resolution measurement")
+    ax.plot(spec_raw_pe, spec_pred[0,:], c='r', lw=3, label="High resolution prediction")
+    ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[1,:], spec_pred[0,:] + spec_pred[1,:], fillcolor='red', alpha=0.6, label="68% unc. (stat.)")
+    ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[2,:], spec_pred[0,:] + spec_pred[2,:], fillcolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
+    fig.savefig(filename)
+    plt.close(fig)
 
 def main():
     """
@@ -26,6 +48,8 @@ def main():
     # these are the train ID intersection
     # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
     tids = matching_ids(spec_tid, pes_tid, xgm_tid)
+    train_tids = tids[:-10]
+    test_tids = tids[-10:]
 
     # read the spec photon energy and intensity
     spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
@@ -36,17 +60,20 @@ def main():
     pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels}
 
     # read the XGM information
-    xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
-    xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
-
-    retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
-    retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
+    #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
+    #xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
+    #retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
+    #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
 
     model = Model()
-    model.fit(pes_raw, spec_raw_int)
+    model.fit({k: v[train_tids,:] for k, v in pes_raw}, spec_raw_int[train_tids,:], spec_raw_pe[train_tids, :])
 
     # test
-    model.predict(pes_raw)
+    spec_pred = model.predict(pes_raw)
+
+    # plot
+    for tid in test_tids:
+        plot_result(f"test_{tid}.png", spec_pred[:, tid, :], spec_raw_int[tid, :], spec_raw_pe[0, :])
 
 if __name__ == '__main__':
     main()
-- 
GitLab