import os
import json
import re
import gradio as gr
import pandas as pd
import requests
import random
import urllib.parse
import spacy
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import List, Dict
from tempfile import NamedTemporaryFile
from bs4 import BeautifulSoup
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
from llama_parse import LlamaParse

huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")

# Load SentenceTransformer model
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

def load_spacy_model():
    try:
        # Try to load the model
        return spacy.load("en_core_web_sm")
    except OSError:
        # If loading fails, download the model
        os.system("python -m spacy download en_core_web_sm")
        # Try loading again
        return spacy.load("en_core_web_sm")

# Load spaCy model
nlp = load_spacy_model()

class EnhancedContextDrivenChatbot:
    def __init__(self, history_size=10):
        self.history = []
        self.history_size = history_size
        self.entity_tracker = {}

    def add_to_history(self, text):
        self.history.append(text)
        if len(self.history) > self.history_size:
            self.history.pop(0)
        
        # Update entity tracker
        doc = nlp(text)
        for ent in doc.ents:
            if ent.label_ not in self.entity_tracker:
                self.entity_tracker[ent.label_] = set()
            self.entity_tracker[ent.label_].add(ent.text)

    def get_context(self):
        return " ".join(self.history)

    def is_follow_up_question(self, question):
        doc = nlp(question.lower())
        follow_up_indicators = set(['it', 'this', 'that', 'these', 'those', 'he', 'she', 'they', 'them'])
        return any(token.text in follow_up_indicators for token in doc)

    def extract_topics(self, text):
        doc = nlp(text)
        return [chunk.text for chunk in doc.noun_chunks]

    def get_most_relevant_context(self, question):
        if not self.history:
            return question

        # Create a combined context from history
        combined_context = self.get_context()
        
        # Get embeddings
        context_embedding = sentence_model.encode([combined_context])[0]
        question_embedding = sentence_model.encode([question])[0]
        
        # Calculate similarity
        similarity = cosine_similarity([context_embedding], [question_embedding])[0][0]
        
        # If similarity is low, it might be a new topic
        if similarity < 0.3:  # This threshold can be adjusted
            return question
        
        # Otherwise, prepend the context
        return f"{combined_context} {question}"

    def process_question(self, question):
        contextualized_question = self.get_most_relevant_context(question)
        
        # Extract topics from the question
        topics = self.extract_topics(question)
        
        # Check if it's a follow-up question
        if self.is_follow_up_question(question):
            # If it's a follow-up, make sure to include previous context
            contextualized_question = f"{self.get_context()} {question}"
        
        # Add the new question to history
        self.add_to_history(question)
        
        return contextualized_question, topics, self.entity_tracker
        
# Initialize LlamaParse
llama_parser = LlamaParse(
    api_key=llama_cloud_api_key,
    result_type="markdown",
    num_workers=4,
    verbose=True,
    language="en",
)

def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Document]:
    """Loads and splits the document into pages."""
    if parser == "pypdf":
        loader = PyPDFLoader(file.name)
        return loader.load_and_split()
    elif parser == "llamaparse":
        try:
            documents = llama_parser.load_data(file.name)
            return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
        except Exception as e:
            print(f"Error using Llama Parse: {str(e)}")
            print("Falling back to PyPDF parser")
            loader = PyPDFLoader(file.name)
            return loader.load_and_split()
    else:
        raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")

def update_vectors(files, parser):
    if not files:
        return "Please upload at least one PDF file."
    
    embed = get_embeddings()
    total_chunks = 0
    
    all_data = []
    for file in files:
        data = load_document(file, parser)
        all_data.extend(data)
        total_chunks += len(data)
    
    if os.path.exists("faiss_database"):
        database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
        database.add_documents(all_data)
    else:
        database = FAISS.from_documents(all_data, embed)
    
    database.save_local("faiss_database")
    
    return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."

def get_embeddings():
    return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

def clear_cache():
    if os.path.exists("faiss_database"):
        os.remove("faiss_database")
        return "Cache cleared successfully."
    else:
        return "No cache to clear."

def get_model(temperature, top_p, repetition_penalty):
    return HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.3",
        model_kwargs={
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "max_length": 1000
        },
        huggingfacehub_api_token=huggingface_token
    )

def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
    full_response = ""
    for i in range(max_chunks):
        try:
            chunk = model(prompt + full_response, max_new_tokens=max_tokens)
            chunk = chunk.strip()
            if chunk.endswith((".", "!", "?")):
                full_response += chunk
                break
            full_response += chunk
        except Exception as e:
            print(f"Error in generate_chunked_response: {e}")
            break
    return full_response.strip()

def extract_text_from_webpage(html):
    soup = BeautifulSoup(html, 'html.parser')
    for script in soup(["script", "style"]):
        script.extract()
    text = soup.get_text()
    lines = (line.strip() for line in text.splitlines())
    chunks = (phrase.strip() for line in lines for phrase in line.split("  "))
    text = '\n'.join(chunk for chunk in chunks if chunk)
    return text

_useragent_list = [
    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
    "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
    "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
    "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
]

def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
    escaped_term = urllib.parse.quote_plus(term)
    start = 0
    all_results = []
    max_chars_per_page = 8000

    print(f"Starting Google search for term: '{term}'")

    with requests.Session() as session:
        while start < num_results:
            try:
                user_agent = random.choice(_useragent_list)
                headers = {
                    'User-Agent': user_agent
                }
                resp = session.get(
                    url="https://www.google.com/search",
                    headers=headers,
                    params={
                        "q": term,
                        "num": num_results - start,
                        "hl": lang,
                        "start": start,
                        "safe": safe,
                    },
                    timeout=timeout,
                    verify=ssl_verify,
                )
                resp.raise_for_status()
                print(f"Successfully retrieved search results page (start={start})")
            except requests.exceptions.RequestException as e:
                print(f"Error retrieving search results: {e}")
                break

            soup = BeautifulSoup(resp.text, "html.parser")
            result_block = soup.find_all("div", attrs={"class": "g"})
            if not result_block:
                print("No results found on this page")
                break
            
            print(f"Found {len(result_block)} results on this page")
            for result in result_block:
                link = result.find("a", href=True)
                if link:
                    link = link["href"]
                    print(f"Processing link: {link}")
                    try:
                        webpage = session.get(link, headers=headers, timeout=timeout)
                        webpage.raise_for_status()
                        visible_text = extract_text_from_webpage(webpage.text)
                        if len(visible_text) > max_chars_per_page:
                            visible_text = visible_text[:max_chars_per_page] + "..."
                        all_results.append({"link": link, "text": visible_text})
                        print(f"Successfully extracted text from {link}")
                    except requests.exceptions.RequestException as e:
                        print(f"Error retrieving webpage content: {e}")
                        all_results.append({"link": link, "text": None})
                else:
                    print("No link found for this result")
                    all_results.append({"link": None, "text": None})
            start += len(result_block)

    print(f"Search completed. Total results: {len(all_results)}")
    
    if not all_results:
        print("No search results found. Returning a default message.")
        return [{"link": None, "text": "No information found in the web search results."}]

    return all_results

def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
    if not question:
        return "Please enter a question."

    model = get_model(temperature, top_p, repetition_penalty)
    embed = get_embeddings()

    if os.path.exists("faiss_database"):
        database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
    else:
        database = None

    max_attempts = 3
    context_reduction_factor = 0.7

    if web_search:
        contextualized_question, topics, entity_tracker = chatbot.process_question(question)
        serializable_entity_tracker = {k: list(v) for k, v in entity_tracker.items()}
        search_results = google_search(contextualized_question)
        all_answers = []

        for attempt in range(max_attempts):
            try:
                web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]

                if database is None:
                    database = FAISS.from_documents(web_docs, embed)
                else:
                    database.add_documents(web_docs)

                database.save_local("faiss_database")

                context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])

                prompt_template = """
                Answer the question based on the following web search results, conversation context, and entity information:
                Web Search Results:
                {context}
                Conversation Context: {conv_context}
                Current Question: {question}
                Topics: {topics}
                Entity Information: {entities}
                If the web search results don't contain relevant information, state that the information is not available in the search results.
                Provide a summarized and direct answer to the question without mentioning the web search or these instructions.
                Do not include any source information in your answer.                    
                """

                prompt_val = ChatPromptTemplate.from_template(prompt_template)
                formatted_prompt = prompt_val.format(
                    context=context_str, 
                    conv_context=chatbot.get_context(), 
                    question=question,
                    topics=", ".join(topics),
                    entities=json.dumps(serializable_entity_tracker)
                )

                full_response = generate_chunked_response(model, formatted_prompt)
                answer = extract_answer(full_response)
                all_answers.append(answer)
                break

            except Exception as e:
                print(f"Error in ask_question (attempt {attempt + 1}): {e}")
                if attempt == max_attempts - 1:
                    all_answers.append(f"I apologize, but I'm having trouble processing the query due to its length or complexity.")

        answer = "\n\n".join(all_answers)
        sources = set(doc.metadata['source'] for doc in web_docs)
        sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
        answer += sources_section

        return answer

    else:  # PDF document chat
        for attempt in range(max_attempts):
            try:
                if database is None:
                    return "No documents available. Please upload PDF documents to answer questions."

                retriever = database.as_retriever()
                relevant_docs = retriever.get_relevant_documents(question)
                context_str = "\n".join([doc.page_content for doc in relevant_docs])

                if attempt > 0:
                    words = context_str.split()
                    context_str = " ".join(words[:int(len(words) * context_reduction_factor)])

                prompt_template = """
                Answer the question based on the following context from the PDF document:
                Context:
                {context}
                Question: {question}
                If the context doesn't contain relevant information, state that the information is not available in the document.
                Provide a summarized and direct answer to the question.
                """

                prompt_val = ChatPromptTemplate.from_template(prompt_template)
                formatted_prompt = prompt_val.format(context=context_str, question=question)

                full_response = generate_chunked_response(model, formatted_prompt)
                answer = extract_answer(full_response)

                return answer

            except Exception as e:
                print(f"Error in ask_question (attempt {attempt + 1}): {e}")
                if attempt == max_attempts - 1:
                    return f"I apologize, but I'm having trouble processing your question. Could you please try rephrasing it more concisely?"

    return "An unexpected error occurred. Please try again later."

def extract_answer(full_response):
    answer_patterns = [
        r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
        r"Provide a concise and direct answer to the question:",
        r"Answer.",
        r"Provide a summarized and direct answer to the question.",
        r"Provide a summarized and direct answer to the original question without mentioning the web search or these instructions:",
        r"Do not include any source information in your answer."     
    ]

    for pattern in answer_patterns:
        match = re.split(pattern, full_response, flags=re.IGNORECASE)
        if len(match) > 1:
            return match[-1].strip()
    return full_response.strip()

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced PDF Document Chat and Web Search")
    
    with gr.Row():
        file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
        parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
        update_button = gr.Button("Upload PDF")
    
    update_output = gr.Textbox(label="Update Status")
    update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
    
    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="Conversation")
            question_input = gr.Textbox(label="Ask a question")
            submit_button = gr.Button("Submit")
        with gr.Column(scale=1):
            temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
            top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
            repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
            web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)

    enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()

    def chat(question, history, temperature, top_p, repetition_penalty, web_search):
        answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot)
        history.append((question, answer))
        return "", history
    
    submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox], outputs=[question_input, chatbot])
    
    clear_button = gr.Button("Clear Cache")
    clear_output = gr.Textbox(label="Cache Status")
    clear_button.click(clear_cache, inputs=[], outputs=clear_output)

if __name__ == "__main__":
    demo.launch()