import streamlit as st
from langchain.schema import Document
from langchain_core.messages import AIMessage, HumanMessage
from sentence_transformers import SentenceTransformer
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import PyPDFLoader
from aift.multimodal import textqa
from aift import setting
import chromadb

chromadb.api.client.SharedSystemClient.clear_system_cache()
# Set API key for Pathumma
setting.set_api_key('T69FqnYgOdreO5G0nZaM8gHcjo1sifyU')

# App Configuration
st.set_page_config(page_title="Nong Nok", page_icon="ðŸĪ–")

st.markdown(
    """
    <style>
        @import url('https://fonts.googleapis.com/css2?family=Kanit:wght@700&display=swap');
        
        body {
            margin: 0;
            padding: 0;
        }
        .header-container {
            position: absolute;
            top: 100%;
            left: 50%;
            transform: translate(-50%, -50%);
            text-align: center;
            margin-bottom: 25px;
        }
        .header-title {
            font-size: 4em;
            margin: 0;
            white-space: nowrap;
            font-family: 'Kanit', sans-serif;
            color: white; /* Fallback color */
            -webkit-text-stroke: 2px black; /* Stroke width and color */
            text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.5); /* Optional shadow for better visibility */
            animation: fadeIn 1s forwards;
        }
        .sub-title {
            position: absolute;
            bottom: -10px;
            right: -20px;
            font-size: 1.5em;
            transform: rotate(-10deg);
            color: #21A2DB;
            white-space: nowrap;    
        }
        @keyframes fadeIn {
            0% {
                color: transparent;
            }
            100% {
                color: white;
            }
        }
    </style>
    <div class="header-container">
        <h1 class="header-title">
            PDPA Chatbot
        </h1>
        <div class="sub-title">( Noknoy-0.5 )</div>
    </div>
    """,
    unsafe_allow_html=True
)

st.markdown(" ")
st.markdown(" ")
st.markdown(" ")
# Custom Embeddings
class CustomEmbeddings:
    def __init__(self, model_name="mrp/simcse-model-m-bert-thai-cased"):
        self.model = SentenceTransformer(model_name)

    def embed_query(self, text):
        return self.model.encode([text])[0].tolist()

    def embed_documents(self, texts):
        return [self.model.encode(text).tolist() for text in texts]

# Pathumma Model Wrapper
class PathummaModel:
    def __init__(self):
        pass

    def generate(self, instruction: str, return_json: bool = False):
        response = textqa.generate(instruction=instruction, return_json=return_json)
        if return_json:
            return response.get("content", "")
        return response

    def __call__(self, input: str):
        return self.generate(input, return_json=False)

# Initialize Pathumma model
model_local = PathummaModel()

# Load PDF file
file_path = "langchain.pdf"
loader = PyPDFLoader(file_path)
docs = loader.load()

# Split text into manageable chunks
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100)
doc_splits = text_splitter.split_documents(docs)

# Convert documents to Embeddings and store them in Chroma
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=CustomEmbeddings(model_name="mrp/simcse-model-m-bert-thai-cased"),
)
retriever = vectorstore.as_retriever()

# Generate a response using retriever
def get_response(user_query):
    retrieved_docs = retriever.get_relevant_documents(user_query)
    retrieved_context = " ".join([doc.page_content for doc in retrieved_docs])

    after_rag_template = """āļ•āļ­āļšāļ„āļģāļ–āļēāļĄāđ‚āļ”āļĒāļžāļīāļˆāļēāļĢāļ“āļēāļˆāļēāļāļšāļĢāļīāļšāļ—āļ•āđˆāļ­āđ„āļ›āļ™āļĩāđ‰āđ€āļ—āđˆāļēāļ™āļąāđ‰āļ™:
    {context}
    āļ„āļģāļ–āļēāļĄ: {question}
    """
    prompt = after_rag_template.format(context=retrieved_context, question=user_query)
    response = model_local(prompt)
    return response

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = [
        AIMessage(content='ðŸĶ āļĒāļīāļ™āļ”āļĩāļ•āđ‰āļ­āļ™āļĢāļąāļšāļŠāļđāđˆāļ™āđ‰āļ­āļ‡āļ™āļ āđāļŠāļ—āļšāļ­āļ—āļ—āļĩāđˆāļžāļĢāđ‰āļ­āļĄāļˆāļ°āđƒāļŦāđ‰āļ‚āđ‰āļ­āļĄāļđāļĨāļ„āļļāļ“āđ€āļāļĩāđˆāļĒāļ§āļāļąāļšāļžāļĢāļ°āļĢāļēāļŠāļšāļąāļāļāļąāļ•āļīāļ„āļļāđ‰āļĄāļ„āļĢāļ­āļ‡āļ‚āđ‰āļ­āļĄāļđāļĨāļŠāđˆāļ§āļ™āļšāļļāļ„āļ„āļĨ (PDPA) āļĄāļĩāļ­āļ°āđ„āļĢāđƒāļŦāđ‰āļŠāđˆāļ§āļĒāđ„āļŦāļĄāļ„āļĢāļąāļš?'),
    ]

# Render chat history
for message in st.session_state.chat_history:
    if isinstance(message, AIMessage):
        with st.chat_message("AI"):
            st.write(message.content)
    elif isinstance(message, HumanMessage):
        with st.chat_message("Human"):
            st.write(message.content)

# User input
user_query = st.chat_input("āļžāļīāļĄāļžāđŒāļ‚āđ‰āļ­āļ„āļ§āļēāļĄāļ—āļĩāđˆāļ™āļĩāđˆ...")
if user_query is not None and user_query.strip() != "":
    st.session_state.chat_history.append(HumanMessage(content=user_query))

    with st.chat_message("Human"):
        st.markdown(user_query)

    with st.chat_message("AI"):
        placeholder = st.empty()
        placeholder.markdown("āļāļģāļĨāļąāļ‡āļŠāļĢāđ‰āļēāļ‡āļ„āļģāļ•āļ­āļš...")  
        response = get_response(user_query)
        placeholder.markdown(response)  

    st.session_state.chat_history.append(AIMessage(content=response))