From 201a15685c8568b50d5d553f4f0e971b0bef8b3a Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Fri, 16 Dec 2022 11:09:20 +0100
Subject: [PATCH] Corrected smoothing normalization to always match the
 original normalization.

---
 pes_to_spec/model.py     | 10 +++++-----
 scripts/test_analysis.py | 15 ++++++++++-----
 2 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 9a45f68..8f19fbf 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -4,6 +4,7 @@ from autograd import grad
 import joblib
 import h5py
 from scipy.signal import fftconvolve
+from scipy.signal.windows import gaussian as gaussian_window
 from scipy.optimize import fmin_l_bfgs_b
 from sklearn.decomposition import PCA
 from sklearn.model_selection import train_test_split
@@ -96,11 +97,10 @@ class Model(object):
         """
         # Apply smoothing
         n_features = high_res_data.shape[1]
-        mu = high_res_photon_energy[0, n_features//2]
-        gaussian = np.exp(-((high_res_photon_energy - mu)/self.high_res_sigma)**2/2)/np.sqrt(2*np.pi*self.high_res_sigma**2)
-        print(np.sum(gaussian))
-        # 80 to match normalization (empirically taken)
-        high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)/80.0
+        mu = high_res_photon_energy[:, n_features//2, np.newaxis]
+        gaussian = np.exp(-0.5*(high_res_photon_energy - mu)**2/self.high_res_sigma**2)
+        gaussian /= np.sum(gaussian, axis=1, keepdims=True)
+        high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)
         return high_res_gc
 
     def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
index 0a6bc59..438dedc 100755
--- a/scripts/test_analysis.py
+++ b/scripts/test_analysis.py
@@ -15,25 +15,30 @@ matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 from matplotlib.gridspec import GridSpec
 
-def plot_result(filename: str, spec_pred: np.ndarray, spec_raw_int: np.ndarray, spec_raw_pe: np.ndarray):
+from typing import Optional
+
+def plot_result(filename: str, spec_pred: np.ndarray, spec_smooth: np.ndarray, spec_raw_pe: np.ndarray, spec_raw_int: Optional[np.ndarray]=None):
     """
     Plot result with uncertainty band.
 
     Args:
       filename: Output file name.
       spec_pred: Predicted result with uncertainty bands in a shape of (3, features).
-      spec_raw_int: True expected result with shape (features,).
+      spec_smooth: Smoothened expected result with shape (features,).
       spec_raw_pe: x axis with the photon energy in eV.
+      spec_raw_int: Original true expected result with shape (features,).
 
     """
-    fig = plt.figure(figsize=(10, 10))
+    fig = plt.figure(figsize=(16, 8))
     gs = GridSpec(1, 1)
     ax = fig.add_subplot(gs[0, 0])
     eps = np.mean(spec_pred[:, 1])
-    ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=3, label="High resolution measurement (smoothened)")
+    ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High resolution measurement (smoothened)")
     ax.plot(spec_raw_pe, spec_pred[:, 0], c='r', lw=3, label="High resolution prediction")
     ax.fill_between(spec_raw_pe, spec_pred[:, 0] - spec_pred[:, 1], spec_pred[:, 0] + spec_pred[:, 1], facecolor='red', alpha=0.6, label="68% unc. (stat.)")
     ax.fill_between(spec_raw_pe, spec_pred[:, 0] - spec_pred[:, 2], spec_pred[:, 0] + spec_pred[:, 2], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
+    if spec_raw_int is not None:
+        ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High resolution measurement")
     ax.legend()
     ax.set(title=f"avg(unc) = {eps}",
            xlabel="Photon energy [eV]",
@@ -87,7 +92,7 @@ def main():
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
-        plot_result(f"test_{tid}.png", spec_pred[idx, :, :], spec_smooth[idx, :], spec_raw_pe[idx, :])
+        plot_result(f"test_{tid}.png", spec_pred[idx, :, :], spec_smooth[idx, :], spec_raw_pe[idx, :], spec_raw_int[idx, :])
 
 if __name__ == '__main__':
     main()
-- 
GitLab