File size: 7,580 Bytes
97eac40
 
4a3eff6
97eac40
4a3eff6
97eac40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b8c253
97eac40
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain # combining the entire doc and send it to the context
# from langchain_chroma import Chroma
from langchain_community.chat_message_histories import ChatMessageHistory
# from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_groq import ChatGroq
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
import os
import streamlit as st

from dotenv import load_dotenv
load_dotenv()


os.environ['HF_TOKEN']=os.getenv("HF_TOKEN")
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
os.environ['GROQ_API_KEY']=os.getenv("GROQ_API_KEY")
groq_api_key=os.getenv("GROQ_API_KEY")


def initialize_session_state():
    """Initialize session state variables if they don't exist."""
    session_state_defaults = {
        'vectorstore': None,
        'retriever': None,
        'conversation_chain': None,
        'chat_history': [],
        'uploaded_file_names': set()
    }
    
    for key, default_value in session_state_defaults.items():
        if key not in st.session_state:
            st.session_state[key] = default_value

def setup_rag_pipeline(documents):
    """Set up the RAG pipeline with embeddings and retrieval."""
    # Use HuggingFace embeddings
    # embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    
    # Split documents
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
    splits = text_splitter.split_documents(documents)
    
    # Create vector store and retriever
    vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings)
    retriever = vectorstore.as_retriever()
    
    # Configure LLM
    groq_api_key = os.getenv("GROQ_API_KEY")
    llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
    
    # Contextualization prompt
    contextualize_q_prompt = ChatPromptTemplate.from_messages([
        ("system", "Given a chat history and the latest user question, "
         "formulate a standalone question which can be understood "
         "without the chat history. Do NOT answer the question, "
         "just reformulate it if needed and otherwise return it as is."),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ])
    
    # QA prompt
    qa_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an assistant for question-answering tasks. "
         "Use the following pieces of retrieved context to answer "
         "the question. If you don't know the answer, say that you "
         "don't know. Use three sentences minimum and keep the "
         "answer concise. Can include any number of words\n\n{context}"),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ])
    
    # Create chains
    history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
    question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
    
    # Conversational RAG chain with message history
    conversational_rag_chain = RunnableWithMessageHistory(
        rag_chain,
        lambda session_id: ChatMessageHistory(),
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer"
    )
    
    return conversational_rag_chain, vectorstore, retriever

def chat():
    # Initialize Streamlit app
    st.title("RAG with PDF Uploads")
    st.write("Upload PDFs and chat with their content")
    
    # Initialize session state
    initialize_session_state()

    # Reset session state when the reset button is clicked
    # if st.button("Reset"):
    #     # Clear session state variables
    #     st.session_state.clear()
    #     # Reinitialize the session state
    #     initialize_session_state()
    #     st.success("Session reset successfully!")

    if st.button("Reset"):
        # Clear all session state variables
        for key in list(st.session_state.keys()):
            del st.session_state[key]
        # Reinitialize the session state
        initialize_session_state()
        st.success("Session reset successfully!")
        # Force a rerun of the app to clear the UI
        st.rerun()
    
    # API Key check
    if not os.getenv("GROQ_API_KEY"):
        st.error("Please set the GROQ_API_KEY environment variable.")
        return
    
    # File upload
    uploaded_files = st.file_uploader("Upload PDF files", type='pdf', accept_multiple_files=True)
    
    # Process uploaded files
    if uploaded_files:
        # Get current file names
        current_file_names = {file.name for file in uploaded_files}
        
        # Check if new files have been uploaded
        if current_file_names != st.session_state.uploaded_file_names:
            # Update the set of uploaded file names
            st.session_state.uploaded_file_names = current_file_names
            
            # Process PDF documents
            documents = []
            for uploaded_file in uploaded_files:
                # Save the uploaded file temporarily
                with open("./temp.pdf", "wb") as file:
                    file.write(uploaded_file.getvalue())
                
                # Load the PDF
                loader = PyPDFLoader("./temp.pdf")
                docs = loader.load()
                documents.extend(docs)
            
            # Setup RAG pipeline
            st.session_state.conversation_chain, st.session_state.vectorstore, st.session_state.retriever = setup_rag_pipeline(documents)
    
    # Chat interface
    user_input = st.text_input("Ask a question about your documents:")
    
    if user_input and st.session_state.conversation_chain:
        try:
            # Invoke the conversational chain
            response = st.session_state.conversation_chain.invoke(
                {"input": user_input},
                config={"configurable": {"session_id": "default_session"}}
            )
            
            # Display the answer
            # st.write("Assistant:", response['answer'])
            st.markdown(f"<span class='assistant-label'>Assistant:</span> {response['answer']}", unsafe_allow_html=True)
            
            # Update chat history
            st.session_state.chat_history.append({"user": user_input, "assistant": response['answer']})
        
        except Exception as e:
            st.error(f"An error occurred: {e}")
    
    # Display chat history
    if st.session_state.chat_history:
        st.markdown("<h4 style='color:#53ff1a;'>Chat History</h4>", unsafe_allow_html=True)
        with st.expander(""):
            # st.subheader("Chat History")
            for chat in st.session_state.chat_history:
                # st.markdown(f"**You:** {chat['user']}")
                st.markdown(f"<span class='user-label'>You:</span> {chat['user']}", unsafe_allow_html=True)
                # st.markdown(f"**Assistant:** {chat['assistant']}")
                st.markdown(f"<span class='assistant-label'>Assistant:</span> {chat['assistant']}", unsafe_allow_html=True)