From b623decd83036e82c982cc835aaf95db2f321b3c Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 23 Jan 2023 18:10:52 +0100
Subject: [PATCH] Better plots.

---
 pes_to_spec/test/offline_analysis.py | 35 ++++++++++++++++++++--------
 1 file changed, 25 insertions(+), 10 deletions(-)

diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 37bdd6c..5944147 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -21,7 +21,19 @@ from typing import Dict, Optional
 from time import time_ns
 import pandas as pd
 
-def plot_pes(filename: str, pes_raw_int: np.ndarray):
+SMALL_SIZE = 12
+MEDIUM_SIZE = 18
+BIGGER_SIZE = 24
+
+plt.rc('font', size=BIGGER_SIZE)         # controls default text sizes
+plt.rc('axes', titlesize=BIGGER_SIZE)    # fontsize of the axes title
+plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
+plt.rc('xtick', labelsize=BIGGER_SIZE)   # fontsize of the tick labels
+plt.rc('ytick', labelsize=BIGGER_SIZE)   # fontsize of the tick labels
+plt.rc('legend', fontsize=MEDIUM_SIZE)   # legend fontsize
+plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
+
+def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int):
     """
     Plot low-resolution spectrum.
 
@@ -33,7 +45,7 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray):
     fig = plt.figure(figsize=(16, 8))
     gs = GridSpec(1, 1)
     ax = fig.add_subplot(gs[0, 0])
-    ax.plot(pes_raw_int, c='b', lw=3, label="Low-resolution measurement")
+    ax.plot(np.arange(first, last), pes_raw_int, c='b', lw=3, label="Low-resolution measurement")
     ax.legend()
     ax.set(title=f"",
            xlabel="ToF index",
@@ -53,19 +65,21 @@ def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np
       spec_raw_int: Original true expected result with shape (features,).
 
     """
-    fig = plt.figure(figsize=(16, 8))
+    fig = plt.figure(figsize=(12, 8))
     gs = GridSpec(1, 1)
     ax = fig.add_subplot(gs[0, 0])
     unc_stat = np.mean(spec_pred["unc"])
     unc_pca = np.mean(spec_pred["pca"])
-    ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)")
+    #ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)")
     ax.plot(spec_raw_pe, spec_pred["expected"], c='r', lw=3, label="High-resolution prediction")
-    ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)")
-    ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
+    unc = np.sqrt(spec_pred["unc"]**2 + spec_pred["pca"]**2)
+    ax.fill_between(spec_raw_pe, spec_pred["expected"] - unc, spec_pred["expected"] + unc, facecolor='red', alpha=0.6, label="68% unc.")
+    #ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)")
+    #ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
     if spec_raw_int is not None:
         ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High-resolution measurement")
-    ax.legend()
-    ax.set(title=f"avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
+    ax.legend(frameon=False, borderaxespad=0)
+    ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
            xlabel="Photon energy [eV]",
            ylabel="Intensity")
     fig.savefig(filename)
@@ -140,7 +154,6 @@ def main():
               spec_raw_pe[train_idx, :])
     t += [time_ns() - start]
     t_names += ["Fit"]
-    spec_smooth = model.preprocess_high_res(spec_raw_int)
 
     print("Saving the model")
     start = time_ns()
@@ -174,6 +187,8 @@ def main():
     print(df_time)
 
     print("Plotting")
+    spec_smooth = model.preprocess_high_res(spec_raw_int)
+    first, last = model.get_low_resolution_range()
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
@@ -185,7 +200,7 @@ def main():
                     spec_raw_pe[idx, :],
                     spec_raw_int[idx, :])
         for ch in channels:
-            plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, :])
+            plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last)
 
 if __name__ == '__main__':
     main()
-- 
GitLab