From b623decd83036e82c982cc835aaf95db2f321b3c Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 23 Jan 2023 18:10:52 +0100 Subject: [PATCH] Better plots. --- pes_to_spec/test/offline_analysis.py | 35 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 37bdd6c..5944147 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -21,7 +21,19 @@ from typing import Dict, Optional from time import time_ns import pandas as pd -def plot_pes(filename: str, pes_raw_int: np.ndarray): +SMALL_SIZE = 12 +MEDIUM_SIZE = 18 +BIGGER_SIZE = 24 + +plt.rc('font', size=BIGGER_SIZE) # controls default text sizes +plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title +plt.rc('axes', labelsize=BIGGER_SIZE) # fontsize of the x and y labels +plt.rc('xtick', labelsize=BIGGER_SIZE) # fontsize of the tick labels +plt.rc('ytick', labelsize=BIGGER_SIZE) # fontsize of the tick labels +plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize +plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title + +def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int): """ Plot low-resolution spectrum. @@ -33,7 +45,7 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray): fig = plt.figure(figsize=(16, 8)) gs = GridSpec(1, 1) ax = fig.add_subplot(gs[0, 0]) - ax.plot(pes_raw_int, c='b', lw=3, label="Low-resolution measurement") + ax.plot(np.arange(first, last), pes_raw_int, c='b', lw=3, label="Low-resolution measurement") ax.legend() ax.set(title=f"", xlabel="ToF index", @@ -53,19 +65,21 @@ def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np spec_raw_int: Original true expected result with shape (features,). """ - fig = plt.figure(figsize=(16, 8)) + fig = plt.figure(figsize=(12, 8)) gs = GridSpec(1, 1) ax = fig.add_subplot(gs[0, 0]) unc_stat = np.mean(spec_pred["unc"]) unc_pca = np.mean(spec_pred["pca"]) - ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)") + #ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)") ax.plot(spec_raw_pe, spec_pred["expected"], c='r', lw=3, label="High-resolution prediction") - ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)") - ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)") + unc = np.sqrt(spec_pred["unc"]**2 + spec_pred["pca"]**2) + ax.fill_between(spec_raw_pe, spec_pred["expected"] - unc, spec_pred["expected"] + unc, facecolor='red', alpha=0.6, label="68% unc.") + #ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)") + #ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)") if spec_raw_int is not None: ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High-resolution measurement") - ax.legend() - ax.set(title=f"avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}", + ax.legend(frameon=False, borderaxespad=0) + ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}", xlabel="Photon energy [eV]", ylabel="Intensity") fig.savefig(filename) @@ -140,7 +154,6 @@ def main(): spec_raw_pe[train_idx, :]) t += [time_ns() - start] t_names += ["Fit"] - spec_smooth = model.preprocess_high_res(spec_raw_int) print("Saving the model") start = time_ns() @@ -174,6 +187,8 @@ def main(): print(df_time) print("Plotting") + spec_smooth = model.preprocess_high_res(spec_raw_int) + first, last = model.get_low_resolution_range() # plot for tid in test_tids: idx = np.where(tid==tids)[0][0] @@ -185,7 +200,7 @@ def main(): spec_raw_pe[idx, :], spec_raw_int[idx, :]) for ch in channels: - plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, :]) + plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last) if __name__ == '__main__': main() -- GitLab