import streamlit as st
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
from nltk.tokenize import RegexpTokenizer
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import re
import string
from nltk.stem import WordNetLemmatizer
import time
import transformers
import json
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

from biLSTM1 import biLSTM
from lstm_preprocessing import (
                                data_preprocessing, 
                                get_words_by_freq, 
                                padding, 
                                preprocess_single_string
                                )



# 1-Lesha, 2-Lena, 3-Gal
# +++++++++++
# 1 -Lesha

# Load the saved model 
with open('logistic_regression_model.pkl', 'rb') as file:
    loaded_model_1 = pickle.load(file)

with open('tfidf_vectorizer.pkl', 'rb') as file:
    vectorizer_1 = pickle.load(file)

# Load the stop words
stop_words = stopwords.words('english')
# Create a tokenizer
tokenizer = RegexpTokenizer(r'\w+')

def data_preprocessing(text: str) -> str:
    """preprocessing string: lowercase, removing html-tags, punctuation and stopwords

    Args:
        text (str): input string for preprocessing

    Returns:
        str: preprocessed string
    """

    text = text.lower()
    text = re.sub('<.*?>', '', text) # html tags
    text = ''.join([c for c in text if c not in string.punctuation])# Remove punctuation
    lemmatizer = WordNetLemmatizer()
    tokens = tokenizer.tokenize(text)
    tokens = [lemmatizer.lemmatize(word) for word in tokens if not word.isdigit() and word not in stop_words]
    return ' '.join(tokens)

# ++++
# Lena


def load_model_l():
    model_finetuned = transformers.AutoModel.from_pretrained(
        "nghuyong/ernie-2.0-base-en",
        output_attentions = False,
        output_hidden_states = False
    )
    model_finetuned.load_state_dict(torch.load('ErnieModel_imdb.pt', map_location=torch.device('cpu')))
    tokenizer = transformers.AutoTokenizer.from_pretrained("nghuyong/ernie-2.0-base-en")
    return model_finetuned, tokenizer

def preprocess_text(text_input, max_len, tokenizer):
    input_tokens = tokenizer(
        text_input, 
        return_tensors='pt', 
        padding=True, 
        max_length=max_len,
        truncation = True
        )
    return input_tokens

def predict_sentiment(model, input_tokens):
    id2label = {0: "negative", 1: "positive"}
    output = model(**input_tokens).pooler_output.detach().numpy()
    with open('LogReg_imdb_Ernie.pkl', 'rb') as file:
        cls = pickle.load(file)
    result = id2label[int(cls.predict(output))]
    return result

# ++++
# Gala
with open('vocab_to_int.json', 'r') as fp:
    vocab_to_int = json.load(fp)


VOCAB_SIZE = len(vocab_to_int)+1
EMBEDDING_DIM = 32
HIDDEN_DIM = 64
N_LAYERS = 3
SEQ_LEN = 128

def load_model_g():
    model = biLSTM(
 vocab_size=VOCAB_SIZE,
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    n_layers=N_LAYERS
    )
    model.load_state_dict(torch.load('biLSTM_model_do_05_lr001_best.pt', map_location=torch.device('cpu')))
    return model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def predict_sentence(text: str, model: nn.Module) -> str:
    id2label = {0: "negative", 1: "positive"}
    output = model.to(device)(preprocess_single_string(text, SEQ_LEN, vocab_to_int).unsqueeze(0).to(device))
    pred = int(output.round().item())
    result = id2label[pred]
    return result
 


# ++++++
# Lesha


# Create the Streamlit app
def main():
    st.title('Sentiment Analysis App')
    st.header('Classic ML, ErnieModel, bidirectional LSTM')
    user_input = st.text_area('Please enter your review:')
    st.write(user_input)
    submit = st.button("Predict!")
    col1, col2,col3 = st.columns(3)
    if user_input is not None and submit:
        with col1:
            # Preprocess the user input
            preprocessed_input_1 = data_preprocessing(user_input)
            # Vectorize the preprocessed input
            input_vector = vectorizer_1.transform([preprocessed_input_1])
            start_time = time.time()
            proba_1 = loaded_model_1.predict_proba(input_vector)[:, 1]
            # Predict the sentiment using the loaded model
            #prediction = loaded_model.predict(input_vector)[0]
            prediction_1 = round(proba_1[0])
            end_time = time.time()
            st.header('Classic ML (LogReg on TF-IDF)')
            # Display the predicted sentiment
            if prediction_1 == 0:
                st.write('The sentiment of your review is negative.')
                st.write('Predicted probability:', (1 - round(proba_1[0], 2))*100, '%')
            else:
                st.write('The sentiment of your review is positive.')
            st.write('Processing time:', round(end_time - start_time, 4), 'seconds')
#         Lena
    if user_input is not None and submit:
        with col2:
            model2, tokenizer = load_model_l()
            start_time = time.time()
            input_tokens = preprocess_text(user_input, 500, tokenizer)
            output = predict_sentiment(model2, input_tokens)
            end_time = time.time()
            st.header('ErnieModel')
            st.write('The sentiment of your review is', output)
            st.write('Processing time:', round(end_time - start_time, 4), 'seconds')         
# Gala 
    if user_input is not None and submit:
        with col3:
            model3 = load_model_g()
            start_time = time.time()
            output = predict_sentence(user_input,model3)
            end_time = time.time()
            st.header('bidirectional LSTM')
            st.write('The sentiment of your review is', output)
            st.write('Processing time:', round(end_time - start_time, 4), 'seconds')
        
        


if __name__ == '__main__':
    main()