Skip to content
Snippets Groups Projects
Commit 88c6699c authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Fixed plotting.

parent d2894a77
No related branches found
No related tags found
1 merge request!10Handle multi-pulse data
This commit is part of merge request !10. Comments created here will be created in the context of that merge request.
......@@ -300,7 +300,7 @@ def main():
# chi2 w.r.t XGM intensity
erange = spec_raw_pe[0,-1] - spec_raw_pe[0,0]
de = (spec_raw_pe[0,1] - spec_raw_pe[0,0])
chi2 = np.sum((spec_smooth[:, np.newaxis] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=1)
chi2 = np.sum((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=(-1, -2))
ndof = float(spec_smooth.shape[1]) - 1.0
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(1, 1)
......@@ -369,7 +369,7 @@ def main():
plt.close(fig)
# rmse
rmse = np.sqrt(np.mean((spec_smooth[:, np.newaxis] - spec_pred["expected"])**2, axis=1))
rmse = np.sqrt(np.mean((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2, axis=(-1, -2)))
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0])
......@@ -418,7 +418,7 @@ def main():
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0])
sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=np.sum(spec_pred["expected"], axis=1)*de, color='r', robust=True, ax=ax)
sns.regplot(x=np.sum(spec_raw_int, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
ax.set(title=f"",
xlabel="SPEC (raw) integral",
ylabel="Predicted integral",
......@@ -429,7 +429,7 @@ def main():
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0])
sns.regplot(x=np.sum(spec_pred["expected"], axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
ax.set(title=f"",
xlabel="Predicted integral",
ylabel="XGM intensity [uJ]",
......@@ -445,7 +445,7 @@ def main():
for tid in test_tids:
idx = np.where(tid==tids)[0][0]
plot_result(os.path.join(args.directory, f"test_{tid}.png"),
{k: item[idx, ...] if k != "pca"
{k: item[idx, 0, ...] if k != "pca"
else item[0, ...]
for k, item in spec_pred.items()},
spec_smooth[idx, :] if showSpec else None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment