import sys
sys.path.append("./")
sys.path.append("..")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from src.utils.utils import load_train_test_h5, load_rec_data_h5, load_rec_data_h5_bfgs_pca, load_checkpoint, load_rec_data_h5_bfgs_pca_eps
import numpy as np
import h5py

exp_dir = "test3_pulseen_short_test_eps_r0015" # Evaluate the trained model

# Load Data
Y_train, Y_test, spec_train, spec_test, spec_raw_pe, X_train, X_test, pes_train, pes_test, att_dict, xgm_pulseen_trains, xgm_pulseen_tests = load_train_test_h5(exp_dir)

xgm_pulseen_test = []
for xgm_i in range(len(xgm_pulseen_tests)):
    xgm_pulseen_test.append(xgm_pulseen_tests[xgm_i][0])

n_pca_comps = att_dict["n_pca_comps_pes"]

if att_dict["model_type"]=="bfgs":


    # Load Rec
    Y_rec, Y_rec_unc = load_rec_data_h5(exp_dir)
    print(Y_test.shape, Y_rec.shape)

    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()

    for i_bfgs in range(15):
        
        ax = axes_2[i_bfgs]
        ax.plot(np.arange(0,Y_rec.shape[1],1), Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(np.arange(0,Y_rec.shape[1],1), Y_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                np.arange(0,Y_rec.shape[1],1), Y_rec[i_bfgs,:] - Y_rec_unc,  Y_rec[i_bfgs,:] + Y_rec_unc, color="pink", alpha=0.5, label="BFGS std"
            )
        ax.set_xlabel('XXX ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction.png", bbox_inches='tight')
    plt.close(fig)


if att_dict["model_type"]=="bfgs_pca":

    # Load Rec
    Y_rec, Y_rec_unc_bfgs, Y_rec_unc_specpca, Y_rec_unc = load_rec_data_h5_bfgs_pca(exp_dir)
    print(Y_test.shape, Y_rec.shape)

    Y_eps = 0*Y_rec_unc_bfgs # TODO! fix this remove eps version from here
    eps_mean = np.mean(Y_eps, axis=1)
    print("len eps mean", len(eps_mean))

    # compute RMSE
    def int_sum_1(A):
        return np.sum(A, axis=1)
    pes_sum = int_sum_1(pes_test)
    spec_sum = int_sum_1(spec_test)

    cor_mat = np.corrcoef(pes_test)
    cor_mat_int = np.sum(cor_mat, axis=1)
    #where_1 = np.where(cor_mat == 1)
    #cor_mat[where_1] = 0



    # quick fft eval
    FC_START = 0
    FC_END = 80
    N_ROI = 40
    # Get the Fourier Components
    n_Y_comps = spec_test.shape[1]
    Y_test_fft_raw = np.abs(np.fft.fft(spec_test))[:, :round(n_Y_comps/2)]
    Y_rec_fft_raw = np.abs(np.fft.fft(Y_rec))[:, :round(n_Y_comps/2)]

    # Cut the Fourier components and take only those who are relevant and not close to 0
    Y_test_fft = Y_test_fft_raw[:, FC_START:FC_END]
    Y_rec_fft = Y_rec_fft_raw[:, FC_START:FC_END]
    
    # Split relevant fourier components into N regions of interests (roi)
    split_size = round((FC_END - FC_START)/N_ROI)

    def split_to_rois(a, splitedSize = split_size):
        a_splited = [a[:,x:x+splitedSize] for x in range(0, a.shape[1], splitedSize)]
        return a_splited

    Y_test_fft_rois = split_to_rois(Y_test_fft)
    Y_rec_fft_rois = split_to_rois(Y_rec_fft)

    # Define the X axis, pixels
    pixel_rois = split_to_rois(np.arange(FC_START, FC_END, 1).reshape(1,-1))

    # Compute the absolute difference of gt and rec fourier components per region of interest per train ID
    delta_Y_fft_rois = [np.abs(a-b) for a, b in zip(Y_test_fft_rois, Y_rec_fft_rois)]
    # Sum along all Fourier components in roi per train id 
    delta_Y_fft_rois_sum = [np.sum(a, axis=1) for a in delta_Y_fft_rois]
    print("len(delta_Y_fft_rois_sum)", len(delta_Y_fft_rois))
    # Sum along all train IDs per roi
    delta_Y_fft_rois_sum_mean = [x.mean()  for x in delta_Y_fft_rois_sum]


    # find corr coef
    def corr_coef(a, b):
        cor_coef = np.corrcoef(a, b)[0][1]
        #print("The corr coef is :",  np.round(cor_coef, 2))
        return np.round(cor_coef, 2)
    cc = []
    for i_bfgs in range(Y_rec.shape[0]):
        #corr_coef(sc_res[i_bfgs,:], Y_test[i_bfgs,:]) 
        cc.append(corr_coef(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]) )



    # Function to calculate Chi-distance
    def chi2_distance(A, B):
        # compute the chi-squared distance using above formula
        chi = 0.5 * np.sum([((a - b) ** 2) / (a + b)
                        for (a, b) in zip(A, B)])

        #print("The Chi-square distance is :", chi)
        return chi
    chi_2 = []
    for i_bfgs in range(Y_rec.shape[0]):
        chi_2.append(chi2_distance(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]))


    # compute RMSE
    def rmse_value(A, B):
        return np.sqrt(np.mean((A-B)**2))
    rmse = []
    for i_bfgs in range(Y_rec.shape[0]):
        rmse.append(rmse_value(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]))
    rmse_quickfft = []
    for i_bfgs in range(Y_rec_fft.shape[0]):
        rmse_quickfft.append(rmse_value(Y_rec_fft[i_bfgs,:], Y_test_fft[i_bfgs,:]))






    # Save Images

    # Save cor mat test set
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        im = ax.imshow(cor_mat)
        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('train IDs.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor mat: PCA(PES)", y=1.0, pad=-20, fontsize=22)

    cax = fig.add_axes([0.27, 0.95, 0.5, 0.05])
    fig.colorbar(im, cax=cax, orientation='horizontal')
    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_mat_pca_pes.png", bbox_inches='tight')
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    ax.plot(cor_mat[42,:], label=f"42")
    ax.plot(cor_mat[43,:], label=f"43")
    ax.plot(cor_mat[171,:], label=f"171")
    #ax.plot(cor_mat_int, label=f"int")
    ax.legend(fontsize=20)
    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_mat_pca_pes_line.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(xgm_pulseen_test)), xgm_pulseen_test, label=f"xgm pulse energy: PES")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"xgm test pulse energy: PES", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/xgm_pulse_energy.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs PES sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, xgm_pulseen_test, label=f"eps_mean xgm pulse energy")

        ax.set_xlabel('eps_mean', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes_sum vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, rmse_quickfft, label=f"eps_mean QuickFFT")

        ax.set_xlabel('eps_mean', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"eps_mean QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/eps_mean_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(xgm_pulseen_test, rmse_quickfft, label=f"xgm_pulseen_test QuickFFT")

        ax.set_xlabel('xgm_pulseen_test', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pulse energy QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/xgm_pulseen_test_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save CC vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(cc, rmse_quickfft, label=f"cc vs QuickFFT")

        ax.set_xlabel('cc', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc vs QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs Spec sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(spec_sum, xgm_pulseen_test, label=f"xgm pulse energy vs spec int")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_spec_int.png", bbox_inches='tight')
    plt.close(fig)
    

    # Save xgm_pulse energy vs corr matrix sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(cor_mat_int, xgm_pulseen_test, label=f"xgm pulse energy vs cor_mat_int")

        ax.set_xlabel('cor_mat_int', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_cor_mat_int.png", bbox_inches='tight')
    plt.close(fig)
    

    # pes_sum vs Spec sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_sum, pes_sum, label=f" spec int vs  pes_sum")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pes_sum", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_vs_pes_int.png", bbox_inches='tight')
    plt.close(fig)
    
    
    # Save pes sum test data
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(pes_sum)), pes_sum, label=f"integrated int: PES")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"integrated int: PES", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_int_sum.png", bbox_inches='tight')
    plt.close(fig)


    # Save pes sum test data
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(spec_sum)), spec_sum, label=f"integrated int: SPEC")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"integrated int: SPEC", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_sum.png", bbox_inches='tight')
    plt.close(fig)


    # Save pes_sum vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        #ax.plot(np.arange(len(cc)), pes_sum/max(np.array(pes_sum)), label=f"pes_sum")

        ax.plot(pes_sum/max(np.array(pes_sum)), cc/max(cc), label=f"correlation coef.")

        ax.set_xlabel('pes_sum', fontsize=22)
        ax.set_ylabel('cor. coef.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes_sum vs cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_int_sum_vs_cor_coef.png", bbox_inches='tight')
    plt.close(fig)

    # Save spec_sum vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        #ax.plot(np.arange(len(cc)), spec_sum/max(np.array(spec_sum)), label=f"spec_sum")

        ax.plot(spec_sum/max(np.array(spec_sum)), cc/max(cc), label=f"spec_sum correlation coef")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('cor. coef.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum vs cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_sum_vs_corr_coef.png", bbox_inches='tight')
    plt.close(fig)


    # Save quick fft eval
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(N_ROI), delta_Y_fft_rois_sum_mean, label=f"FFT eval {i_bfgs}")

        ax.set_xlabel('FFT bin comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f" qucik fft eval", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/quick_fft_eval.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(cc)), cc, label=f"correlation coef")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/corr_coef.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        ax.scatter( cc/max(cc), eps_mean/spec_sum, label=f"cc vs eps_mean")

        ax.set_xlabel('cor coef', fontsize=22)
        ax.set_ylabel('eps mean', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc vs eps_mean", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_coef_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        ax.scatter( cc/max(cc), xgm_pulseen_test/max(np.array(xgm_pulseen_test)), label=f"xgm_pulseen_test")

        ax.set_xlabel('cor coef', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor. coef. vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_coef_vs_pulse_energy.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmax(cc)
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc best rec", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_best_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmin(cc)
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc worse rec", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_worse_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec pes
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmax(cc)
        
        ax.plot(pes_test[i_bfgs,:], "r", label=f"PES test {i_bfgs}")
        
        ax.set_xlabel(' x comps ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes raw (cc best rec)", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_raw_best_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation worst rec pes
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmin(cc)
        
        ax.plot(pes_test[i_bfgs,:], "r", label=f"PES test {i_bfgs}")
        
        ax.set_xlabel(' x comps ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes raw (cc worst rec)", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_raw_worst_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save chi_2
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(chi_2)), chi_2, label=f"chi_2")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"chi_2", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/chi_2.png", bbox_inches='tight')
    plt.close(fig)


    # Save rmse
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(rmse)), rmse, label=f"rmse")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"rmse", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/rmse.png", bbox_inches='tight')
    plt.close(fig)


    # Save rmse vs eps_mean
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, rmse, label=f"rmse_vs_eps_mean")

        ax.set_xlabel('eps mean', fontsize=22)
        ax.set_ylabel('rmse', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"rmse_vs_eps_mean", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/rmse_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # reconstruction with total unc
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc[0,:], color="pink", alpha=0.5, label="u_total"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir} eps={eps_mean[i_bfgs]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction.png", bbox_inches='tight')
    plt.close(fig)


    # reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_eps[0,:],  Y_rec[i_bfgs,:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir} eps={eps_mean[i_bfgs]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps.png", bbox_inches='tight')
    plt.close(fig)


    print(Y_rec_unc_bfgs[0,:])
    # Reconstruction with 2 uncs
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_2unc.png", bbox_inches='tight')
    plt.close(fig)


    # Single ID reconstruction with 2 unc
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_2unc_id1.png", bbox_inches='tight')
    plt.close(fig)


    # Single ID reconstruction with TOTAL unc
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc[0,:], color="r", alpha=0.5, label="u_total"
            )

        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_unc_id1.png", bbox_inches='tight')
    plt.close(fig)


    # Example single spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_spec.png", bbox_inches='tight')
    plt.close(fig)


    # Example single Y
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(Y_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('PCA comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_Y.png", bbox_inches='tight')
    plt.close(fig)


    # Example single PES
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(pes_train[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('Channel comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_pes.png", bbox_inches='tight')
    plt.close(fig)


    # # Example single X
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(X_train[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('PCA comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_X.png", bbox_inches='tight')
    plt.close(fig)


    args_cor_mat_int = np.argsort(eps_mean)
    args_cor_mat_int = args_cor_mat_int[230:]
    # ORDERED reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        i_bfgs = i_bfgs
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:], "r", label=f"SPEC rec {i_bfgs, args_cor_mat_int[i_bfgs]}")
        ax.plot(spec_raw_pe, spec_test[args_cor_mat_int[i_bfgs],:], label=f"SPEC gt {i_bfgs, args_cor_mat_int[i_bfgs]}")
        ax.fill_between(
            spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:] - Y_rec_unc_specpca,  Y_rec[args_cor_mat_int[i_bfgs],:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:] - Y_eps[0,:],  Y_rec[args_cor_mat_int[i_bfgs],:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"eps={eps_mean[args_cor_mat_int[i_bfgs]]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps_SORTED.png", bbox_inches='tight')
    plt.close(fig)


    args_eps = np.argsort(eps_mean)
    # ORDERED reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[args_eps[i_bfgs],:], "r", label=f"SPEC rec {i_bfgs, args_eps[i_bfgs]}")
        ax.plot(spec_raw_pe, spec_test[args_eps[i_bfgs],:], label=f"SPEC gt {i_bfgs, args_eps[i_bfgs]}")
        ax.fill_between(
            spec_raw_pe, Y_rec[args_eps[i_bfgs],:] - Y_rec_unc_specpca,  Y_rec[args_eps[i_bfgs],:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[args_eps[i_bfgs],:] - Y_eps[0,:],  Y_rec[args_eps[i_bfgs],:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"eps={eps_mean[args_eps[i_bfgs]]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps_SORTED_descending.png", bbox_inches='tight')
    plt.close(fig)





    print("finished eval")


if att_dict["model_type"]=="bfgs_pca_eps":

    # Load Rec
    Y_rec, Y_rec_unc_bfgs, Y_rec_unc_specpca, Y_rec_unc, Y_eps = load_rec_data_h5_bfgs_pca_eps(exp_dir)
    print(Y_test.shape, Y_rec.shape)

    eps_mean = np.mean(Y_eps, axis=1)
    eps_mean = np.round(eps_mean, 2)    # round eps mean
    print("len eps mean", len(eps_mean))

    # compute RMSE
    def int_sum_1(A):
        return np.sum(A, axis=1)
    pes_sum = int_sum_1(pes_test)
    spec_sum = int_sum_1(spec_test)

    cor_mat = np.corrcoef(pes_test)
    cor_mat_int = np.sum(cor_mat, axis=1)
    #where_1 = np.where(cor_mat == 1)
    #cor_mat[where_1] = 0



    # quick fft eval
    FC_START = 0
    FC_END = 80
    N_ROI = 40
    # Get the Fourier Components
    n_Y_comps = spec_test.shape[1]
    Y_test_fft_raw = np.abs(np.fft.fft(spec_test))[:, :round(n_Y_comps/2)]
    Y_rec_fft_raw = np.abs(np.fft.fft(Y_rec))[:, :round(n_Y_comps/2)]

    # Cut the Fourier components and take only those who are relevant and not close to 0
    Y_test_fft = Y_test_fft_raw[:, FC_START:FC_END]
    Y_rec_fft = Y_rec_fft_raw[:, FC_START:FC_END]
    
    # Split relevant fourier components into N regions of interests (roi)
    split_size = round((FC_END - FC_START)/N_ROI)

    def split_to_rois(a, splitedSize = split_size):
        a_splited = [a[:,x:x+splitedSize] for x in range(0, a.shape[1], splitedSize)]
        return a_splited

    Y_test_fft_rois = split_to_rois(Y_test_fft)
    Y_rec_fft_rois = split_to_rois(Y_rec_fft)

    # Define the X axis, pixels
    pixel_rois = split_to_rois(np.arange(FC_START, FC_END, 1).reshape(1,-1))

    # Compute the absolute difference of gt and rec fourier components per region of interest per train ID
    delta_Y_fft_rois = [np.abs(a-b) for a, b in zip(Y_test_fft_rois, Y_rec_fft_rois)]
    # Sum along all Fourier components in roi per train id 
    delta_Y_fft_rois_sum = [np.sum(a, axis=1) for a in delta_Y_fft_rois]
    print("len(delta_Y_fft_rois_sum)", len(delta_Y_fft_rois))
    # Sum along all train IDs per roi
    delta_Y_fft_rois_sum_mean = [x.mean()  for x in delta_Y_fft_rois_sum]


    # find corr coef
    def corr_coef(a, b):
        cor_coef = np.corrcoef(a, b)[0][1]
        #print("The corr coef is :",  np.round(cor_coef, 2))
        return np.round(cor_coef, 2)
    cc = []
    for i_bfgs in range(Y_rec.shape[0]):
        #corr_coef(sc_res[i_bfgs,:], Y_test[i_bfgs,:]) 
        cc.append(corr_coef(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]) )



    # Function to calculate Chi-distance
    def chi2_distance(A, B):
        # compute the chi-squared distance using above formula
        chi = 0.5 * np.sum([((a - b) ** 2) / (a + b)
                        for (a, b) in zip(A, B)])

        #print("The Chi-square distance is :", chi)
        return chi
    chi_2 = []
    for i_bfgs in range(Y_rec.shape[0]):
        chi_2.append(chi2_distance(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]))


    # compute RMSE
    def rmse_value(A, B):
        return np.sqrt(np.mean((A-B)**2))
    rmse = []
    for i_bfgs in range(Y_rec.shape[0]):
        rmse.append(rmse_value(Y_rec[i_bfgs,:], spec_test[i_bfgs,:]))
    rmse_quickfft = []
    for i_bfgs in range(Y_rec_fft.shape[0]):
        rmse_quickfft.append(rmse_value(Y_rec_fft[i_bfgs,:], Y_test_fft[i_bfgs,:]))






    # Save Images

    # Save cor mat test set
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        im = ax.imshow(cor_mat)
        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('train IDs.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor mat: PCA(PES)", y=1.0, pad=-20, fontsize=22)

    cax = fig.add_axes([0.27, 0.95, 0.5, 0.05])
    fig.colorbar(im, cax=cax, orientation='horizontal')
    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_mat_pca_pes.png", bbox_inches='tight')
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    ax.plot(cor_mat[42,:], label=f"42")
    ax.plot(cor_mat[43,:], label=f"43")
    ax.plot(cor_mat[171,:], label=f"171")
    #ax.plot(cor_mat_int, label=f"int")
    ax.legend(fontsize=20)
    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_mat_pca_pes_line.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(xgm_pulseen_test)), xgm_pulseen_test, label=f"xgm pulse energy: PES")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"xgm test pulse energy: PES", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/xgm_pulse_energy.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs PES sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, xgm_pulseen_test, label=f"eps_mean xgm pulse energy")

        ax.set_xlabel('eps_mean', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes_sum vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, rmse_quickfft, label=f"eps_mean QuickFFT")

        ax.set_xlabel('eps_mean', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"eps_mean QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/eps_mean_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(xgm_pulseen_test, rmse_quickfft, label=f"xgm_pulseen_test QuickFFT")

        ax.set_xlabel('xgm_pulseen_test', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pulse energy QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/xgm_pulseen_test_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save CC vs QuickFFT
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(cc, rmse_quickfft, label=f"cc vs QuickFFT")

        ax.set_xlabel('cc', fontsize=22)
        ax.set_ylabel('QuickFFT', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc vs QuickFFT", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_vs_quick_fft.png", bbox_inches='tight')
    plt.close(fig)


    # Save xgm_pulse energy vs Spec sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(spec_sum, xgm_pulseen_test, label=f"xgm pulse energy vs spec int")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_spec_int.png", bbox_inches='tight')
    plt.close(fig)
    

    # Save xgm_pulse energy vs corr matrix sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(cor_mat_int, xgm_pulseen_test, label=f"xgm pulse energy vs cor_mat_int")

        ax.set_xlabel('cor_mat_int', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pulse_energy_vs_cor_mat_int.png", bbox_inches='tight')
    plt.close(fig)
    

    # pes_sum vs Spec sum
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_sum, pes_sum, label=f" spec int vs  pes_sum")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum  vs pes_sum", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_vs_pes_int.png", bbox_inches='tight')
    plt.close(fig)
    
    
    # Save pes sum test data
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(pes_sum)), pes_sum, label=f"integrated int: PES")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"integrated int: PES", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_int_sum.png", bbox_inches='tight')
    plt.close(fig)


    # Save pes sum test data
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(spec_sum)), spec_sum, label=f"integrated int: SPEC")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"integrated int: SPEC", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_sum.png", bbox_inches='tight')
    plt.close(fig)


    # Save pes_sum vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        #ax.plot(np.arange(len(cc)), pes_sum/max(np.array(pes_sum)), label=f"pes_sum")

        ax.plot(pes_sum/max(np.array(pes_sum)), cc/max(cc), label=f"correlation coef.")

        ax.set_xlabel('pes_sum', fontsize=22)
        ax.set_ylabel('cor. coef.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes_sum vs cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_int_sum_vs_cor_coef.png", bbox_inches='tight')
    plt.close(fig)

    # Save spec_sum vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        #ax.plot(np.arange(len(cc)), spec_sum/max(np.array(spec_sum)), label=f"spec_sum")

        ax.plot(spec_sum/max(np.array(spec_sum)), cc/max(cc), label=f"spec_sum correlation coef")

        ax.set_xlabel('spec_sum', fontsize=22)
        ax.set_ylabel('cor. coef.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"spec_sum vs cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/spec_int_sum_vs_corr_coef.png", bbox_inches='tight')
    plt.close(fig)


    # Save quick fft eval
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(N_ROI), delta_Y_fft_rois_sum_mean, label=f"FFT eval {i_bfgs}")

        ax.set_xlabel('FFT bin comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f" qucik fft eval", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/quick_fft_eval.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(cc)), cc, label=f"correlation coef")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor. coef.", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/corr_coef.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        ax.scatter( cc/max(cc), eps_mean/spec_sum, label=f"cc vs eps_mean")

        ax.set_xlabel('cor coef', fontsize=22)
        ax.set_ylabel('eps mean', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc vs eps_mean", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_coef_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation vs pulse energy
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        #ax.plot(np.arange(len(cc)), cc/max(cc), label=f"cross correlation")
        ax.scatter( cc/max(cc), xgm_pulseen_test/max(np.array(xgm_pulseen_test)), label=f"xgm_pulseen_test")

        ax.set_xlabel('cor coef', fontsize=22)
        ax.set_ylabel('pulse energy', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cor. coef. vs pulse energy", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cor_coef_vs_pulse_energy.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmax(cc)
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc best rec", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_best_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmin(cc)
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"cc worse rec", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/cc_worse_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation best rec pes
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmax(cc)
        
        ax.plot(pes_test[i_bfgs,:], "r", label=f"PES test {i_bfgs}")
        
        ax.set_xlabel(' x comps ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes raw (cc best rec)", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_raw_best_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save cross correlation worst rec pes
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):

        i_bfgs = np.argmin(cc)
        
        ax.plot(pes_test[i_bfgs,:], "r", label=f"PES test {i_bfgs}")
        
        ax.set_xlabel(' x comps ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"pes raw (cc worst rec)", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/pes_raw_worst_rec.png", bbox_inches='tight')
    plt.close(fig)


    # Save chi_2
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(chi_2)), chi_2, label=f"chi_2")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"chi_2", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/chi_2.png", bbox_inches='tight')
    plt.close(fig)


    # Save rmse
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(np.arange(len(rmse)), rmse, label=f"rmse")

        ax.set_xlabel('train IDs.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"rmse", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/rmse.png", bbox_inches='tight')
    plt.close(fig)


    # Save rmse vs eps_mean
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.scatter(eps_mean, rmse, label=f"rmse_vs_eps_mean")

        ax.set_xlabel('eps mean', fontsize=22)
        ax.set_ylabel('rmse', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        ax.set_title(f"rmse_vs_eps_mean", y=1.0, pad=-20, fontsize=22)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/rmse_vs_eps_mean.png", bbox_inches='tight')
    plt.close(fig)


    # reconstruction with total unc
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc[0,:], color="pink", alpha=0.5, label="u_total"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir} eps={eps_mean[i_bfgs]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction.png", bbox_inches='tight')
    plt.close(fig)


    # reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_eps[0,:],  Y_rec[i_bfgs,:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir} eps={eps_mean[i_bfgs]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps.png", bbox_inches='tight')
    plt.close(fig)


    print(Y_rec_unc_bfgs[0,:])
    # Reconstruction with 2 uncs
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"{exp_dir}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_2unc.png", bbox_inches='tight')
    plt.close(fig)


    # Single ID reconstruction with 2 unc
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_bfgs[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc_bfgs[0,:], color="r", alpha=0.5, label="u_bfgs"
            )
        ax.fill_between(
            spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc_specpca,  Y_rec[i_bfgs,:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_2unc_id1.png", bbox_inches='tight')
    plt.close(fig)


    # Single ID reconstruction with TOTAL unc
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, Y_rec[i_bfgs,:], "r", label=f"SPEC rec {i_bfgs}")
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")
        ax.fill_between(
                spec_raw_pe, Y_rec[i_bfgs,:] - Y_rec_unc[0,:],  Y_rec[i_bfgs,:] + Y_rec_unc[0,:], color="r", alpha=0.5, label="u_total"
            )

        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_unc_id1.png", bbox_inches='tight')
    plt.close(fig)


    # Example single spec
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(spec_raw_pe, spec_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_spec.png", bbox_inches='tight')
    plt.close(fig)


    # Example single Y
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(Y_test[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('PCA comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_Y.png", bbox_inches='tight')
    plt.close(fig)


    # Example single PES
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(pes_train[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('Channel comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_pes.png", bbox_inches='tight')
    plt.close(fig)


    # # Example single X
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    for i_bfgs in range(1):
        
        ax.plot(X_train[i_bfgs,:], label=f"SPEC gt {i_bfgs}")

        ax.set_xlabel('PCA comps.', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.tick_params(axis="y", labelsize=20)
        ax.legend(fontsize=20)
        #ax.set_title(f"{exp_dir}", y=1.0, pad=-20)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/single_X.png", bbox_inches='tight')
    plt.close(fig)


    args_cor_mat_int = np.argsort(eps_mean)
    args_cor_mat_int = args_cor_mat_int[230:]
    # ORDERED reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        i_bfgs = i_bfgs
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:], "r", label=f"SPEC rec {i_bfgs, args_cor_mat_int[i_bfgs]}")
        ax.plot(spec_raw_pe, spec_test[args_cor_mat_int[i_bfgs],:], label=f"SPEC gt {i_bfgs, args_cor_mat_int[i_bfgs]}")
        ax.fill_between(
            spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:] - Y_rec_unc_specpca,  Y_rec[args_cor_mat_int[i_bfgs],:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[args_cor_mat_int[i_bfgs],:] - Y_eps[0,:],  Y_rec[args_cor_mat_int[i_bfgs],:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"eps={eps_mean[args_cor_mat_int[i_bfgs]]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps_SORTED.png", bbox_inches='tight')
    plt.close(fig)


    args_eps = np.argsort(eps_mean)
    # ORDERED reconstruction with total unc + eps
    fig, axes_2 = plt.subplots(10, 2, figsize=(15, 35))
    axes_2 = axes_2.flatten()
    for i_bfgs in range(20):
        
        ax = axes_2[i_bfgs]
        ax.plot(spec_raw_pe, Y_rec[args_eps[i_bfgs],:], "r", label=f"SPEC rec {i_bfgs, args_eps[i_bfgs]}")
        ax.plot(spec_raw_pe, spec_test[args_eps[i_bfgs],:], label=f"SPEC gt {i_bfgs, args_eps[i_bfgs]}")
        ax.fill_between(
            spec_raw_pe, Y_rec[args_eps[i_bfgs],:] - Y_rec_unc_specpca,  Y_rec[args_eps[i_bfgs],:] + Y_rec_unc_specpca, color="g", alpha=0.5, label="u_specpca"
        )
        ax.fill_between(
                spec_raw_pe, Y_rec[args_eps[i_bfgs],:] - Y_eps[0,:],  Y_rec[args_eps[i_bfgs],:] + Y_eps[0,:], color="blue", alpha=0.5, label="eps"
            )
        ax.set_xlabel('Energy (eV) ', fontsize=22)
        ax.set_ylabel('int a.u.', fontsize=22)
        ax.tick_params(axis="x", labelsize=20)
        ax.legend(fontsize=10)
        ax.set_title(f"eps={eps_mean[args_eps[i_bfgs]]}", y=1.0, pad=-10)

    fig.savefig(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/summaries/reconstruction_eps_SORTED_descending.png", bbox_inches='tight')
    plt.close(fig)





    print("finished eval")





#import joblib
#spec_pca_model = joblib.load(f"/home/adavtyan/my_repos/invasive/experiments/{exp_dir}/checkpoints/spec_pca_model.joblib")