From d394f8bb1516d263863a195596182e22be6c5a41 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Sun, 7 May 2023 11:45:50 +0200
Subject: [PATCH] Added number of observations and more plots.

---
 pes_to_spec/bnn.py                   |   2 +-
 pes_to_spec/model.py                 |  12 +++
 pes_to_spec/test/offline_analysis.py | 142 ++++++++++++++++++++-------
 3 files changed, 120 insertions(+), 36 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 887512e..73aca75 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -63,7 +63,7 @@ class BNN(nn.Module):
     """
     def __init__(self, input_dimension: int=1, output_dimension: int=1):
         super(BNN, self).__init__()
-        hidden_dimension = 50
+        hidden_dimension = 100
         # controls the aleatoric uncertainty
         self.log_isigma2 = nn.Parameter(-torch.ones(1, output_dimension)*np.log(0.1**2), requires_grad=True)
         # controls the weight hyperprior
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 6b86d5e..1e24b05 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -611,6 +611,7 @@ class Model(TransformerMixin, BaseEstimator):
         elif model_type == "ard":
             self.fit_model = MultiOutputGenericWithStd(ARDRegression(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
         self.model_type = model_type
+        self.n_obs = 0
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
@@ -624,6 +625,12 @@ class Model(TransformerMixin, BaseEstimator):
         # size of the test subset
         self.validation_size = validation_size
 
+    def n_pars(self) -> float:
+        """Get number of parameters."""
+        if self.model_type == "bnn":
+            return sum(p.numel() for p in self.fit_model.model.parameters())
+        return sum(len(estimator.coef_) + 1 for estimator in self.fit_model.estimators_)
+
     def get_channels(self) -> List[str]:
         """Get channels used in training."""
         return self.x_select.channels
@@ -719,6 +726,7 @@ class Model(TransformerMixin, BaseEstimator):
         if weights is not None:
             weights = weights[inliers]
         self.fit_model.fit(x_t[inliers], y_t[inliers], weights)
+        self.n_obs = len(x_t[inliers])
 
         # calculate the effect of the PCA
         print("Calculate PCA unc. on high-resolution data.")
@@ -944,6 +952,8 @@ class Model(TransformerMixin, BaseEstimator):
                     unc=unc.reshape((B, P, -1)),
                     pca=pca_unc,
                     total_unc=total_unc.reshape((B, P, -1)),
+                    expected_pca=high_pca.reshape((B, P, -1)),
+                    expected_pca_unc=high_pca_unc.reshape((B, P, -1)),
                     )
 
     def deconvolve(self, expected: np.ndarray) -> np.ndarray:
@@ -986,6 +996,7 @@ class Model(TransformerMixin, BaseEstimator):
                                      transfer_function=self.transfer_function,
                                      impulse_response=self.impulse_response,
                                      model_type=self.model_type,
+                                     n_obs=self.n_obs,
                                     )
                                ),
                      self.ood,
@@ -1023,6 +1034,7 @@ class Model(TransformerMixin, BaseEstimator):
         obj.transfer_function = extra["transfer_function"]
         obj.impulse_response = extra["impulse_response"]
         obj.model_type = extra["model_type"]
+        obj.n_obs = extra["n_obs"]
 
         obj.x_select = x_select
         obj.x_model = x_model
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 93c66b7..ebe3a52 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -105,15 +105,17 @@ def plot_result(filename: str,
     Y = np.amax(spec_smooth)
     ax.legend(frameon=False, borderaxespad=0, loc='upper left')
     ax.set_title(f"Beam intensity: {intensity*1e-3:.1f} mJ", loc="left")
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
     ax.set(
            xlabel="Photon energy [eV]",
            ylabel="Intensity",
-           ylim=(0, 1.5*Y))
+           ylim=(0, 1.3*Y))
     if pes is not None:
         ax2 = plt.axes([0,0,1,1])
         # Manually set the position and relative size of the inset axes within ax1
         #ip = InsetPosition(ax, [0.65,0.6,0.35,0.4])
-        ip = InsetPosition(ax, [0.7,0.7,0.35,0.4])
+        ip = InsetPosition(ax, [0.72,0.7,0.35,0.4])
         ax2.set_axes_locator(ip)
         if pes_to_show == "sum":
             pes_plot = sum([pes[k][pes_bin] for k in pes.keys()])
@@ -283,6 +285,20 @@ def main():
     t += [time_ns() - start]
     t_names += ["Load"]
 
+    # transfer function
+    fig = plt.figure(figsize=(12, 8))
+    gs = GridSpec(1, 1)
+    ax = fig.add_subplot(gs[0, 0])
+    plt.plot(model.wiener_energy, np.absolute(model.impulse_response))
+    ax.set(title=f"",
+           xlabel=r"Energy [eV]",
+           ylabel="Response [a.u.]",
+           yscale='log',
+           )
+    fig.savefig(os.path.join(args.directory, "impulse.png"))
+    plt.close(fig)
+    print(f"Resolution: {model.resolution:.2f} eV")
+
     # plot Wiener filter
     fig = plt.figure(figsize=(12, 8))
     gs = GridSpec(1, 1)
@@ -303,6 +319,7 @@ def main():
     ax.set(title=f"",
            xlabel=r"Energy [eV]",
            ylabel="Filter value [a.u.]",
+           yscale='log',
            )
     fig.savefig(os.path.join(args.directory, "wiener.png"))
     plt.close(fig)
@@ -335,10 +352,15 @@ def main():
         showSpec = True
         spec_smooth = model.preprocess_high_res(spec_raw_int_t)
 
+        mse = np.mean((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2, axis=(0, 1))
+        var = np.std(spec_smooth, axis=0)**2
+        r2 = 1 - mse/var
+        print(f"MSE = {np.mean(mse):.2f}, var = {np.mean(var):.2f}, R^2 = {np.mean(r2):.2f}")
+
         # chi2 w.r.t XGM intensity
-        de = (spec_raw_pe_t[0,1] - spec_raw_pe_t[0,0])
         chi2 = np.sum((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=(-1, -2))
-        ndof = float(spec_smooth.shape[1]) - 1.0
+        ndof = spec_smooth.shape[1]
+        print(f"Chi2 after PCA: {np.mean(chi2):.2f}, ndof: {ndof}, chi2/ndof: {np.mean(chi2/ndof):.2f}")
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
@@ -369,39 +391,90 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.kdeplot(x=chi2/ndof, ax=ax)
+        sns.histplot(x=chi2/ndof, kde=True, linewidth=3, ax=ax)
         ax.set(title=f"",
                xlabel=r"$\chi^2/$ndof",
-               ylabel="Density [a.u.]",
+               ylabel="Counts [a.u.]",
                xlim=(0, 5),
                )
-        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(chi2/ndof):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
-        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(chi2/ndof):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
+        #ax.text(0.90, 0.95, fr"$\mu = ${np.mean(chi2/ndof):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        #ax.text(0.90, 0.90, fr"$\sigma = ${np.std(chi2/ndof):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
         fig.savefig(os.path.join(args.directory, "chi2.png"))
         plt.close(fig)
 
+        spec_smooth_pca = model.y_model['pca'].transform(spec_smooth)
+        chi2_prepca = np.sum((spec_smooth_pca[:, np.newaxis, :] - spec_pred["expected_pca"])**2/(spec_pred["expected_pca_unc"]**2), axis=(-1, -2))
+        ndof_prepca = float(spec_smooth_pca.shape[-1])
+        print(f"Chi2 before PCA: {np.mean(chi2_prepca):.2f}, ndof: {ndof_prepca}, chi2/ndof: {np.mean(chi2_prepca/ndof_prepca):.2f} +/- {np.std(chi2_prepca/ndof_prepca):.2f}")
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.histplot(x=chi2_prepca/ndof_prepca, kde=True, linewidth=3, ax=ax)
+        ax.set(title=f"",
+               xlabel=r"$\chi^2/$ndof before undoing PCA",
+               ylabel="Counts [a.u.]",
+               xlim=(0, 5),
+               )
+        #ax.text(0.90, 0.95, fr"$\mu = ${np.mean(chi2/ndof):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        #ax.text(0.90, 0.90, fr"$\sigma = ${np.std(chi2/ndof):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        fig.savefig(os.path.join(args.directory, "chi2_prepca.png"))
+        plt.close(fig)
+
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        ax.scatter(chi2_prepca/ndof_prepca, xgm_flux_t[:,0], c='r', s=20)
+        ax.set(title=f"",
+               xlabel=r"$\chi^2/$ndof before undoing PCA",
+               ylabel="XGM intensity [uJ]",
+               xlim=(0, 5),
+               ylim=(0, np.mean(xgm_flux_t) + 3*np.std(xgm_flux_t))
+               )
+        fig.savefig(os.path.join(args.directory, "intensity_vs_chi2_prepca.png"))
+        plt.close(fig)
+
+        res_prepca = np.sum((spec_smooth_pca[:, np.newaxis, :] - spec_pred["expected_pca"])/spec_pred["expected_pca_unc"], axis=1)
+        fig = plt.figure(figsize=(12, 8))
+        gs = GridSpec(1, 1)
+        ax = fig.add_subplot(gs[0, 0])
+        sns.kdeplot(data={f"Dim. {k}": res_prepca[:, k] for k in range(res_prepca.shape[1])}, linewidth=3, ax=ax)
+        ax.set(title=f"",
+               xlabel=r"residue/uncertainty [a.u.]",
+               ylabel="Counts [a.u.]",
+               xlim=(-3, 3),
+               )
+        fig.savefig(os.path.join(args.directory, "res_prepca.png"))
+        plt.close(fig)
+
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.kdeplot(x=xgm_flux_t[:,0], ax=ax)
+        sns.histplot(x=xgm_flux_t[:,0], kde=True, linewidth=3, ax=ax)
         ax.set(title=f"",
                xlabel="XGM intensity [uJ]",
-               ylabel="Density [a.u.]",
+               ylabel="Counts [a.u.]",
                )
-        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux_t[:,0]):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
-        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux_t[:,0]):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
+        #ax.text(0.90, 0.95, fr"$\mu = ${np.mean(xgm_flux_t[:,0]):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        #ax.text(0.90, 0.90, fr"$\sigma = ${np.std(xgm_flux_t[:,0]):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        plt.tight_layout()
         fig.savefig(os.path.join(args.directory, "intensity.png"))
         plt.close(fig)
 
@@ -422,20 +495,19 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.kdeplot(x=rmse, ax=ax)
+        sns.histplot(x=rmse, kde=True, linewidth=3, ax=ax)
         ax.set(title=f"",
                xlabel="Root-mean-squared error",
-               ylabel="Density [a.u.]",
-               xlim=(0, 20),
+               ylabel="Counts [a.u.]",
                )
-        ax.text(0.90, 0.95, fr"$\mu = ${np.mean(rmse):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
-        ax.text(0.90, 0.90, fr"$\sigma = ${np.std(rmse):.2f}",
-                verticalalignment='top', horizontalalignment='right',
-                transform=ax.transAxes,
-                color='black', fontsize=15)
+        #ax.text(0.90, 0.95, fr"$\mu = ${np.mean(rmse):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
+        #ax.text(0.90, 0.90, fr"$\sigma = ${np.std(rmse):.2f}",
+        #        verticalalignment='top', horizontalalignment='right',
+        #        transform=ax.transAxes,
+        #        color='black', fontsize=15)
         fig.savefig(os.path.join(args.directory, "rmse.png"))
         plt.close(fig)
 
-- 
GitLab