from keras.layers import Input, Dense, Flatten
from keras.models import Model
from Database import Database
import numpy as np, json
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from dotenv import dotenv_values
import pandas as pd
# from tensorflow.python.ops.confusion_matrix import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support

class Autoencoder:

    def __get_autoencoder(self, input_dim) -> Model:
        input_shape = (input_dim,)
        input_layer = Input(shape=input_shape)

        # Encoder layers
        encoder = Flatten()(input_layer)
        encoder = Dense(128, activation='relu')(encoder)
        encoder = Dense(64, activation='relu')(encoder)
        # encoder = Dense(32, activation='relu')(encoder)

        # Decoder layers
        # decoder = Dense(64, activation='relu')(encoder)
        decoder = Dense(128, activation='relu')(encoder)  #decoder
        decoder = Dense(input_dim, activation='sigmoid')(decoder)

        # Autoencoder model
        autoencoder = Model(inputs=input_layer, outputs=decoder)
        # autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
        autoencoder.compile(optimizer='adam', loss='mse')

        return autoencoder
    
    def __print_summary(self, model: Model):
        print(model.summary())
        return
    
    def __fit_autoencoder(self,epochs,batch_size,model: Model, train_var,valid_var=None):
        history =  model.fit(train_var,train_var,
                        #  validation_data=(valid_var,valid_var),
                         epochs=epochs,batch_size=batch_size)
        return history, model

    def __split_train_test_val(self, data):
        train_array, test_array = train_test_split(data,test_size=0.2,random_state=42)
        train_array, valid_array = train_test_split(train_array,test_size=0.1,random_state=42)
        return train_array, valid_array, test_array
    
    @staticmethod
    def __compute_metrics(conf_matrix):
        precision = conf_matrix[1][1] / (conf_matrix[1][1] + conf_matrix[0][1])

        if precision==1:
            print(conf_matrix)

        recall = conf_matrix[1][1] / (conf_matrix[1][1] + conf_matrix[1][0])
        f1 = (2 * precision * recall) / (precision + recall)
        # print("precision: " + str(precision) + ", recall: " + str(recall) + ", f1: " + str(f1))
        return precision, recall, f1

    def __find_optimal_modified(self,error_df: pd.DataFrame, steps=50):
        min_error, max_error = error_df["Reconstruction_error"].min(), error_df["Reconstruction_error"].max()
        optimal_threshold = (min_error+max_error)/2
        y_pred = [0 if e > optimal_threshold else 1 for e in error_df.Reconstruction_error.values]
        precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='macro')

        return optimal_threshold, precision, recall, f1

    def __find_optimal(self,error_df: pd.DataFrame, steps=50):
        min_error, max_error = error_df["Reconstruction_error"].min(), error_df["Reconstruction_error"].max()
        optimal_threshold = min_error
        max_f1 = 0
        max_pr = 0
        max_re = 0
        # step_value = (max_error-min_error)/(steps - 1)
        for threshold in np.arange(min_error, max_error, 0.005):
            # print("Threshold: " + str(threshold))
            # y_pred = [1 if e > threshold else 0 for e in error_df.Reconstruction_error.values]
            y_pred = [0 if e > threshold else 1 for e in error_df.Reconstruction_error.values]
            # conf_matrix = confusion_matrix(error_df.True_class, y_pred)
            # precision, recall, f1 = self.__compute_metrics(conf_matrix)
            # precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='macro')
            # precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='micro')
            # precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='weighted')
            precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='binary')

            if f1 > max_f1:
                max_f1 = f1
                optimal_threshold = threshold
                max_pr = precision
                max_re = recall
        print(f"Result optimal_threshold={optimal_threshold}, max_precision={max_pr}, max_recall={max_re}, max_f1={max_f1}")
        # return optimal_threshold, max_pr.numpy(), max_re.numpy(), max_f1.numpy()
        return optimal_threshold, max_pr, max_re, max_f1

    @staticmethod
    def __split_by_percent(data,percent):
        return train_test_split(data,test_size=0.3,random_state=42)



    def train_autoencoder(self):
        #GraphCodeBERT

        autoencoder = self.__get_autoencoder(768)
        self.__print_summary(autoencoder)

        #Create Dataset df
        df = pd.DataFrame(columns=['Embedding','True_class'])


        #DB
        db = Database(dotenv_values(".env")['COLLECTION_NAME'])
        # embeddings_list = [emb["embedding"] for emb in list(db.find_docs({"refactoring_type":"Extract Method"}))]
        pos_emb_list, neg_emb_list = [],[] 
        for doc in list(db.find_docs({"refactoring_type":"Extract Method"})):
            pos_emb_list.append(doc['embedding_pos'])
            neg_emb_list.append(doc['embedding_neg'])
        
        pos_emb_list_train, pos_emb_list_test = self.__split_by_percent(pos_emb_list,0.3)
        _, neg_emb_list_test = self.__split_by_percent(neg_emb_list,0.3)

        x_train = np.array(pos_emb_list_train)
        x_test = np.array(pos_emb_list_test+neg_emb_list_test)
        y_test = np.array([1 for i in range(0,len(pos_emb_list_test))]+[0 for i in range(0,len(neg_emb_list_test))])
        # print(np.array(pos_emb_list_train).shape)

        epoch = 25
        history, trained_model = self.__fit_autoencoder(epoch,32,autoencoder,x_train)
        trained_model.save('./results/autoencoder_'+str(epoch)+'.hdf5')

        #Test
        test_predict = trained_model.predict(x_test)

        mse = np.mean(np.power(x_test - test_predict, 2), axis=1)

        
        error_df = pd.DataFrame({'Reconstruction_error': mse,
                        'True_class': y_test})

        print("Max: ", error_df["Reconstruction_error"].max())
        print("Min: ", error_df["Reconstruction_error"].min())

        # optimal_threshold, precision, recall, f1 = self.__find_optimal(error_df,100)
        optimal_threshold, precision, recall, f1 = self.__find_optimal_modified(error_df,100)
        print(f"Result optimal_threshold={optimal_threshold}, max_precision={precision}, max_recall={recall}, max_f1={f1}")
        metrics = {
            "Threshold":optimal_threshold,
            "Precision": precision,
            "Recall":recall,
            "F1":f1
        }
        with open('./results/metrics.json','w') as fp:
            json.dump(metrics,fp)

        plt.plot(history.history['loss'])

        plt.savefig("./results/training_graph.png")

if __name__=="__main__":
    Autoencoder().train_autoencoder()