from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import numpy as np

import pandas as pd

from src.models.fit_methods.Fit_Methods import FitBFGSPCA_EPS
from src.data.data_preproc import SpecPreprocessing
from src.models.find_components.Find_Component import FindPCAcomps

from src.utils.utils import save_chp_bfgs


class Model:
    def __init__(self, model_type = None, data_info = None):


        self.model_type = model_type
        self.data_info = data_info

        self.data_dict = None

        self.pes_train = None 
        self.pes_test = None
        self.spec_train = None
        self.spec_test = None

        self.xgm_pulseen = None
        self.xgm_pulseen_train = None
        self.xgm_pulseen_test = None

        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None

        self.pes_pca_model = None
        self.spec_pca_model = None


        if self.model_type == 'bfgs':
                self.user_defined_model = FitBFGS()
                #self.user_defined_preproc = SpecPreprocessing()
        elif self.model_type == 'bfgs_pca':
                self.user_defined_model = FitBFGSPCA()
        elif self.model_type == 'bfgs_pca_eps':
                self.user_defined_model = FitBFGSPCA_EPS()
                
            
            
    def split(self, test_size):
        X = np.array(self.df[['Humidity', 'Pressure (millibars)']])
        y = np.array(self.df['Temperature (C)'])
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size = test_size, random_state = 42)
    
    def preprocess(self, data_dict, data_info):
        self.data_dict = data_dict
        self.data_info = data_info

        self.xgm_pulseen = data_dict["xgm_pulseen"]

        self.spec_data = self.data_dict["spec_raw_int"].copy()   # Spec Data Intensity
        self.spec_data_pe = self.data_dict["spec_raw_pe"].copy()        # Spec Data Photon energy       
        self.pes_data = self.data_dict["pes_data"].copy()               # Pes Data Intensity


        # Data Preprocessing
        SP = SpecPreprocessing(self.spec_data, self.spec_data_pe)
        self.spec_data_gc, _ = SP.gaussian_convolve()
        self.spec_data = self.spec_data_gc #SP.make_fft()

        # Split Data into TEST TRAIN
        if self.data_info["use_data_subset"]:
            self.pes_train, self.pes_test, self.spec_train, self.spec_test, self.xgm_pulseen_train, self.xgm_pulseen_test = train_test_split(self.pes_data[self.data_info["data_subset_start"]:self.data_info["data_subset_end"]], 
                                                                            self.spec_data[self.data_info["data_subset_start"]:self.data_info["data_subset_end"]],
                                                                            self.xgm_pulseen[self.data_info["data_subset_start"]:self.data_info["data_subset_end"]],
                                                                            test_size=self.data_info["test_size"],
                                                                            random_state=42
                                                                            )
        else:
            self.pes_train, self.pes_test, self.spec_train, self.spec_test, self.xgm_pulseen_train, self.xgm_pulseen_test = train_test_split(self.pes_data, 
                                                                            self.spec_data,
                                                                            self.xgm_pulseen,
                                                                            test_size=self.data_info["test_size"],
                                                                            random_state=42
                                                                            )


        if self.data_info["pes_pca_preprocessing"]:    
            pes_PCA = FindPCAcomps(data_train=self.pes_train, data_test=self.pes_test, n_pca_comps=self.data_info["n_pca_comps_pes"])
            self.pes_pca_model, self.X_train, self.X_test = pes_PCA.get_pca()
        else:
            self.X_train, self.X_test = self.pes_train, self.pes_test

        if self.data_info["spec_pca_preprocessing"]:
            spec_PCA = FindPCAcomps(data_train=self.spec_train, data_test=self.spec_test, n_pca_comps=self.data_info["n_pca_comps_spec"])
            self.spec_pca_model, self.Y_train, self.Y_test = spec_PCA.get_pca()
            self.user_defined_model.pca_model_spec = self.spec_pca_model
            self.user_defined_model.spec_test = self.spec_test
            
        else:
            self.Y_train, self.Y_test = self.spec_train, self.spec_test

        print("print Y shape train and test in preprocessing last step :::: ",self.Y_train.shape, self.Y_test.shape)
        #save_train_test_h5(pes_train, pes_test, X_train, X_test, Y_train, Y_test, spec_data_pe, self.data_info, data_info["exp_dir"])
        return self.X_train, self.X_test, self.Y_train, self.Y_test

    def save_preproc(self):
        if self.model_type == 'bfgs':

            self.user_defined_model.save_preproc(self.pes_train, self.pes_test, 
                                                self.spec_train, self.spec_test,
                                                self.X_train, self.X_test, 
                                                self.Y_train, self.Y_test, 
                                                self.spec_data_pe, 
                                                self.data_info,
                                                )

            import joblib
            joblib.dump(self.pes_pca_model, f"experiments/{self.data_info['exp_dir']}/checkpoints/pes_pca_model.joblib")
            joblib.dump(self.spec_pca_model,  f"experiments/{self.data_info['exp_dir']}/checkpoints/spec_pca_model.joblib")
        elif self.model_type == 'bfgs_pca' or self.model_type == 'bfgs_pca_eps':
                        self.user_defined_model.save_preproc(self.pes_train, self.pes_test, 
                                                self.spec_train, self.spec_test,
                                                self.X_train, self.X_test, 
                                                self.Y_train, self.Y_test, 
                                                self.spec_data_pe, 
                                                self.data_info,
                                                self.xgm_pulseen_train, self.xgm_pulseen_test
                                                )

                        import joblib
                        joblib.dump(self.pes_pca_model, f"experiments/{self.data_info['exp_dir']}/checkpoints/pes_pca_model.joblib")
                        joblib.dump(self.spec_pca_model,  f"experiments/{self.data_info['exp_dir']}/checkpoints/spec_pca_model.joblib")


    def fit_eval(self, X_train, y_train, X_test, Y_test):
        #A_inf, b_inf, u_inf, Y_norm = fig_bgfs.train_bgfs()
        self.model = self.user_defined_model.fit_eval(X_train, y_train, X_test, Y_test)  # Change the name of self.model to something else ?


    def save_latest_ckp(self):
        self.user_defined_model.save_latest_chp(self.data_info["exp_dir"])

    def load_model(self):
        self.model, self.pes_pca_model, self.spec_pca_model = self.user_defined_model.load_model(self.data_info["exp_dir"])
        return self.model, self.pes_pca_model, self.spec_pca_model
    
    def predict(self, input_value, target_value=None, spec_target=None):
        result = self.user_defined_model.predict(input_value, target_value, spec_target)
        return result

    def save_prediction(self):
        self.user_defined_model.save_prediction(self.data_info["exp_dir"])

if __name__ == '__main__':
    print("Creating Model Instance")
    #model_instance = Model(model_type=None)
    #model_instance.preprocess()
    #model_instance.split(0.2)
    #model_instance.fit()    
    #print(model_instance.predict([.9, 1000]))
    #print("Accuracy: ", model_instance.model.score(model_instance.X_test, model_instance.y_test))