#############################################################################################################################
# Filename   : app.py
# Description: A Streamlit application to showcase how RAG works.
# Author     : Georgios Ioannou
#
# Copyright © 2024 by Georgios Ioannou

#RAG Code written by Farhikhta Farzan
#MONGODB database created by Farhikhta Farzan
#Documents and research gathered by Keira James, Farhikhta Farzan, and Tesneem Essa
#############################################################################################################################
# Import libraries.
import os
import streamlit as st

from dotenv import load_dotenv, find_dotenv
from huggingface_hub import InferenceClient
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from pymongo.collection import Collection
from typing import Dict, Any


#############################################################################################################################


class RAGQuestionAnswering:
    def __init__(self):
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Initializes the RAG Question Answering system by setting up configuration
        and loading environment variables.

        Assumptions
        -----------
        - Expects .env file with MONGO_URI and HF_TOKEN
        - Requires proper MongoDB setup with vector search index
        - Needs connection to Hugging Face API

        Notes
        -----
        This is the main class that handles all RAG operations
        """
        self.load_environment()
        self.setup_mongodb()
        self.setup_embedding_model()
        self.setup_vector_search()
        self.setup_rag_chain()

    def load_environment(self) -> None:
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Loads environment variables from .env file and sets up configuration constants.

        Assumptions
        -----------
        Expects a .env file with MONGO_URI and HF_TOKEN defined

        Notes
        -----
        Will stop the application if required environment variables are missing
        """

        load_dotenv(find_dotenv())
        self.MONGO_URI = os.getenv("MONGO_URI")
        self.HF_TOKEN = os.getenv("HF_TOKEN")

        if not self.MONGO_URI or not self.HF_TOKEN:
            st.error("Please ensure MONGO_URI and HF_TOKEN are set in your .env file")
            st.stop()

        # MongoDB configuration.
        self.DB_NAME = "files"
        self.COLLECTION_NAME = "files_collection"
        self.VECTOR_SEARCH_INDEX = "vector_index"

    def setup_mongodb(self) -> None:
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Initializes the MongoDB connection and sets up the collection.

        Assumptions
        -----------
        - Valid MongoDB URI is available
        - Database and collection exist in MongoDB Atlas

        Notes
        -----
        Uses st.cache_resource for efficient connection management
        """

        @st.cache_resource
        def init_mongodb() -> Collection:
            cluster = MongoClient(self.MONGO_URI)
            return cluster[self.DB_NAME][self.COLLECTION_NAME]

        self.mongodb_collection = init_mongodb()

    def setup_embedding_model(self) -> None:
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Initializes the embedding model for vector search.

        Assumptions
        -----------
        - Valid Hugging Face API token
        - Internet connection to access the model

        Notes
        -----
        Uses the all-mpnet-base-v2 model from sentence-transformers
        """

        @st.cache_resource
        def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings:
            return HuggingFaceInferenceAPIEmbeddings(
                api_key=self.HF_TOKEN,
                model_name="sentence-transformers/all-mpnet-base-v2",
            )

        self.embedding_model = init_embedding_model()

    def setup_vector_search(self) -> None:
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Sets up the vector search functionality using MongoDB Atlas.

        Assumptions
        -----------
        - MongoDB Atlas vector search index is properly configured
        - Valid embedding model is initialized

        Notes
        -----
        Creates a retriever with similarity search and score threshold
        """

        @st.cache_resource
        def init_vector_search() -> MongoDBAtlasVectorSearch:
            return MongoDBAtlasVectorSearch.from_connection_string(
                connection_string=self.MONGO_URI,
                namespace=f"{self.DB_NAME}.{self.COLLECTION_NAME}",
                embedding=self.embedding_model,
                index_name=self.VECTOR_SEARCH_INDEX,
            )

        self.vector_search = init_vector_search()
        self.retriever = self.vector_search.as_retriever(
            search_type="similarity", search_kwargs={"k": 10, "score_threshold": 0.85}
        )

    def format_docs(self, docs: list[Document]) -> str:
        """
        Parameters
        ----------
        **docs:** list[Document] - List of documents to be formatted

        Output
        ------
        str: Formatted string containing concatenated document content

        Purpose
        -------
        Formats the retrieved documents into a single string for processing

        Assumptions
        -----------
        Documents have page_content attribute

        Notes
        -----
        Joins documents with double newlines for better readability
        """

        return "\n\n".join(doc.page_content for doc in docs)

    def generate_response(self, input_dict: Dict[str, Any]) -> str:
        """
        Parameters
        ----------
        **input_dict:** Dict[str, Any] - Dictionary containing context and question

        Output
        ------
        str: Generated response from the model

        Purpose
        -------
        Generates a response using the Hugging Face model based on context and question

        Assumptions
        -----------
        - Valid Hugging Face API token
        - Input dictionary contains 'context' and 'question' keys

        Notes
        -----
        Uses Qwen2.5-1.5B-Instruct model with controlled temperature
        """
        hf_client = InferenceClient(api_key=self.HF_TOKEN)
        formatted_prompt = self.prompt.format(**input_dict)

        response = hf_client.chat.completions.create(
            model="Qwen/Qwen2.5-1.5B-Instruct",
            messages=[
                {"role": "system", "content": formatted_prompt},
                {"role": "user", "content": input_dict["question"]},
            ],
            max_tokens=1000,
            temperature=0.2,
        )

        return response.choices[0].message.content

    def setup_rag_chain(self) -> None:
        """
        Parameters
        ----------
        None

        Output
        ------
        None

        Purpose
        -------
        Sets up the RAG chain for processing questions and generating answers

        Assumptions
        -----------
        Retriever and response generator are properly initialized

        Notes
        -----
        Creates a chain that combines retrieval and response generation
        """

        self.prompt = PromptTemplate.from_template(
            """Use the following pieces of context to answer the question at the end.

            START OF CONTEXT:
            {context}
            END OF CONTEXT:
            
            START OF QUESTION:
            {question}
            END OF QUESTION:

            If you do not know the answer, just say that you do not know.
            NEVER assume things.
            """
        )

        self.rag_chain = {
            "context": self.retriever | RunnableLambda(self.format_docs),
            "question": RunnablePassthrough(),
        } | RunnableLambda(self.generate_response)

    def process_question(self, question: str) -> str:
        """
        Parameters
        ----------
        **question:** str - The user's question to be answered

        Output
        ------
        str: The generated answer to the question

        Purpose
        -------
        Processes a user question through the RAG chain and returns an answer

        Assumptions
        -----------
        - Question is a non-empty string
        - RAG chain is properly initialized

        Notes
        -----
        Main interface for question-answering functionality
        """

        return self.rag_chain.invoke(question)


#############################################################################################################################
def setup_streamlit_ui() -> None:
    """
    Parameters
    ----------
    None

    Output
    ------
    None

    Purpose
    -------
    Sets up the Streamlit user interface with proper styling and layout

    Assumptions
    -----------
    - CSS file exists at ./static/styles/style.css
    - Image file exists at ./static/images/ctp.png

    Notes
    -----
    Handles all UI-related setup and styling
    """

    st.set_page_config(page_title="RAG Question Answering", page_icon="🤖")

    # Load CSS.
    with open("./static/styles/style.css") as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

    # Title and subtitles.
    st.markdown(
        '<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">RAG Question Answering</h1>',
        unsafe_allow_html=True,
    )
    st.markdown(
        '<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">Using Documents and Research</h3>',
        unsafe_allow_html=True,
    )
    st.markdown(
        '<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">Digital Detectives: AI VS Real Images</h2>',
        unsafe_allow_html=True,
    )

    # Display logo.
    left_co, cent_co, last_co = st.columns(3)
    with cent_co:
        st.image("./static/images/poster.jpg")


#############################################################################################################################


def main():
    """
    Parameters
    ----------
    None

    Output
    ------
    None

    Purpose
    -------
    Main function that runs the Streamlit application

    Assumptions
    -----------
    All required environment variables and files are present

    Notes
    -----
    Entry point for the application
    """

    # Setup UI.
    setup_streamlit_ui()

    # Initialize RAG system.
    rag_system = RAGQuestionAnswering()

    # Create input elements.
    query = st.text_input("Question:", key="question_input")

    # Handle submission.
    if st.button("Submit", type="primary"):
        if query:
            with st.spinner("Generating response..."):
                response = rag_system.process_question(query)
                st.text_area("Answer:", value=response, height=200, disabled=True)
        else:
            st.warning("Please enter a question.")

    # Add GitHub link.
    st.markdown(
        """
        <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;">
        <b>Check out our <a href="https://github.com/KeiraJames/CTP-Project-2024/tree/main" style="color: #FAF9F6;">GitHub repository</a></b>
        </p>
        """,
        unsafe_allow_html=True,
    )


#############################################################################################################################
if __name__ == "__main__":
    main()