Skip to content
Snippets Groups Projects
Commit b623decd authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Better plots.

parent cde3eebf
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,19 @@ from typing import Dict, Optional ...@@ -21,7 +21,19 @@ from typing import Dict, Optional
from time import time_ns from time import time_ns
import pandas as pd import pandas as pd
def plot_pes(filename: str, pes_raw_int: np.ndarray): SMALL_SIZE = 12
MEDIUM_SIZE = 18
BIGGER_SIZE = 24
plt.rc('font', size=BIGGER_SIZE) # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int):
""" """
Plot low-resolution spectrum. Plot low-resolution spectrum.
...@@ -33,7 +45,7 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray): ...@@ -33,7 +45,7 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray):
fig = plt.figure(figsize=(16, 8)) fig = plt.figure(figsize=(16, 8))
gs = GridSpec(1, 1) gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
ax.plot(pes_raw_int, c='b', lw=3, label="Low-resolution measurement") ax.plot(np.arange(first, last), pes_raw_int, c='b', lw=3, label="Low-resolution measurement")
ax.legend() ax.legend()
ax.set(title=f"", ax.set(title=f"",
xlabel="ToF index", xlabel="ToF index",
...@@ -53,19 +65,21 @@ def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np ...@@ -53,19 +65,21 @@ def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np
spec_raw_int: Original true expected result with shape (features,). spec_raw_int: Original true expected result with shape (features,).
""" """
fig = plt.figure(figsize=(16, 8)) fig = plt.figure(figsize=(12, 8))
gs = GridSpec(1, 1) gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
unc_stat = np.mean(spec_pred["unc"]) unc_stat = np.mean(spec_pred["unc"])
unc_pca = np.mean(spec_pred["pca"]) unc_pca = np.mean(spec_pred["pca"])
ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)") #ax.plot(spec_raw_pe, spec_smooth, c='b', lw=3, label="High-resolution measurement (smoothened)")
ax.plot(spec_raw_pe, spec_pred["expected"], c='r', lw=3, label="High-resolution prediction") ax.plot(spec_raw_pe, spec_pred["expected"], c='r', lw=3, label="High-resolution prediction")
ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)") unc = np.sqrt(spec_pred["unc"]**2 + spec_pred["pca"]**2)
ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)") ax.fill_between(spec_raw_pe, spec_pred["expected"] - unc, spec_pred["expected"] + unc, facecolor='red', alpha=0.6, label="68% unc.")
#ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["unc"], spec_pred["expected"] + spec_pred["unc"], facecolor='red', alpha=0.6, label="68% unc. (stat.)")
#ax.fill_between(spec_raw_pe, spec_pred["expected"] - spec_pred["pca"], spec_pred["expected"] + spec_pred["pca"], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
if spec_raw_int is not None: if spec_raw_int is not None:
ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High-resolution measurement") ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High-resolution measurement")
ax.legend() ax.legend(frameon=False, borderaxespad=0)
ax.set(title=f"avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}", ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}",
xlabel="Photon energy [eV]", xlabel="Photon energy [eV]",
ylabel="Intensity") ylabel="Intensity")
fig.savefig(filename) fig.savefig(filename)
...@@ -140,7 +154,6 @@ def main(): ...@@ -140,7 +154,6 @@ def main():
spec_raw_pe[train_idx, :]) spec_raw_pe[train_idx, :])
t += [time_ns() - start] t += [time_ns() - start]
t_names += ["Fit"] t_names += ["Fit"]
spec_smooth = model.preprocess_high_res(spec_raw_int)
print("Saving the model") print("Saving the model")
start = time_ns() start = time_ns()
...@@ -174,6 +187,8 @@ def main(): ...@@ -174,6 +187,8 @@ def main():
print(df_time) print(df_time)
print("Plotting") print("Plotting")
spec_smooth = model.preprocess_high_res(spec_raw_int)
first, last = model.get_low_resolution_range()
# plot # plot
for tid in test_tids: for tid in test_tids:
idx = np.where(tid==tids)[0][0] idx = np.where(tid==tids)[0][0]
...@@ -185,7 +200,7 @@ def main(): ...@@ -185,7 +200,7 @@ def main():
spec_raw_pe[idx, :], spec_raw_pe[idx, :],
spec_raw_int[idx, :]) spec_raw_int[idx, :])
for ch in channels: for ch in channels:
plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, :]) plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment