import sys
sys.path.append("./")
sys.path.append("..")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from src.data.data import ReadPesSpec, PesChannelSelector
from src.data.data_preproc import SpecPreprocessing

from src.models.find_components.Find_Component import FindPCAcomps
from src.models.fit_methods.Fit_Methods import FitBFGS, FitBFGSPCA
from src.models.fit_methods.model import Model

from src.utils.utils import load_train_test_h5, load_rec_data_h5, load_rec_data_h5_bfgs_pca, load_checkpoint
from src.utils.utils import create_experiment_dirs

from sklearn.model_selection import train_test_split
import numpy as np


# Load trained model in the path exp_dir
exp_dir = "test3_pulseen_short_test_eps_r0015"
Y_train_model, Y_test_model, spec_train_model, spec_test_model, spec_raw_pe, X_train_model, X_test_model, pes_train_model, pes_test_model, att_dict, xgm_pulseen_train, xgm_pulseen_test = load_train_test_h5(exp_dir)

att_dict["pes_pca_preprocessing"] = False  # set pca non trainable
att_dict["spec_pca_preprocessing"] = False # set pca non trainable


# Change the exp_dir to new dir where inference data is located
#att_dict["exp_dir"] = "test3_inference"
att_dict["run_number"] = "r0014" 
att_dict["spec_ofset"] = -2


RPS = ReadPesSpec(att_dict)    # read data
data_dict = RPS.get_data()


model_instance = Model(model_type="bfgs_pca_eps", data_info=att_dict)
pes_train, pes_test, spec_train, spec_test = model_instance.preprocess(data_dict, att_dict)              # Model Preprocess


inference_model, pes_pca_model, spec_pca_model = model_instance.load_model()                                 # Move to incference sctipt  TODO!


# Update the Train and Test data for inference model
model_instance.X_train, model_instance.X_test = pes_pca_model.transform(pes_train), pes_pca_model.transform(pes_test)
model_instance.Y_train, model_instance.Y_test = spec_pca_model.transform(spec_train), spec_pca_model.transform(spec_test)


inf_dir = exp_dir + "_inference_TestDATA" + att_dict["run_number"]
model_instance.data_info["exp_dir"] = inf_dir


create_experiment_dirs(inf_dir)
model_instance.save_preproc()


result = model_instance.predict(input_value=model_instance.X_test,  
                                target_value=model_instance.Y_test,
                                spec_target=spec_test_model)


model_instance.save_prediction()


print("Finished Inference")