import sys
sys.path.append("./")
sys.path.append("..")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from src.data.data import ReadPesSpec, PesChannelSelector
from src.data.data_preproc import SpecPreprocessing

from src.models.find_components.Find_Component import FindPCAcomps
from src.models.fit_methods.Fit_Methods import FitBFGS, FitBFGSPCA
from src.models.fit_methods.model import Model

from src.utils.utils import load_train_test_h5, load_rec_data_h5, load_rec_data_h5_bfgs_pca, load_checkpoint
from src.utils.utils import create_experiment_dirs

from sklearn.model_selection import train_test_split
import numpy as np


# Load trained model in the path exp_dir
exp_dir = "test3_pulseen_short_test_eps_r0015"
Y_train_model, Y_test_model, spec_train_model, spec_test_model, spec_raw_pe, X_train_model, X_test_model, pes_train_model, pes_test_model, att_dict, xgm_pulseen_train, xgm_pulseen_test = load_train_test_h5(exp_dir)

att_dict["pes_pca_preprocessing"] = False  # set pca non trainable
att_dict["spec_pca_preprocessing"] = False # set pca non trainable

model_instance = Model(model_type="bfgs_pca_eps", data_info=att_dict)

inference_model, pes_pca_model, spec_pca_model = model_instance.load_model()                                 # Move to incference sctipt  TODO!

print("Finish")