bsiddhharth commited on
Commit
df852e4
·
1 Parent(s): 9c5f440

Added rag.py - includes -> taking pdf's as input and chat with it along with the history

Browse files
Files changed (3) hide show
  1. .gitignore +11 -0
  2. rag.py +283 -0
  3. requirements.txt +18 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore virtual environment
2
+ venv2/
3
+
4
+ # Ignore environment files
5
+ .env
6
+
7
+ # Ignore Python compiled files
8
+ *.pyc
9
+ __pycache__/
10
+
11
+ temp.pdf
rag.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
2
+ from langchain.chains.combine_documents import create_stuff_documents_chain # combining the entire doc and send it to the context
3
+ from langchain_chroma import Chroma
4
+ from langchain_community.chat_message_histories import ChatMessageHistory
5
+ from langchain_core.chat_history import BaseChatMessageHistory
6
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from langchain_groq import ChatGroq
8
+ from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ from langchain_community.vectorstores import FAISS
13
+ import os
14
+ import streamlit as st
15
+
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
+
19
+
20
+ os.environ['HF_TOKEN']=os.getenv("HF_TOKEN")
21
+ embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
22
+ os.environ['GROQ_API_KEY']=os.getenv("GROQ_API_KEY")
23
+ groq_api_key=os.getenv("GROQ_API_KEY")
24
+
25
+
26
+ def initialize_session_state():
27
+ """Initialize session state variables if they don't exist."""
28
+ session_state_defaults = {
29
+ 'vectorstore': None,
30
+ 'retriever': None,
31
+ 'conversation_chain': None,
32
+ 'chat_history': [],
33
+ 'uploaded_file_names': set()
34
+ }
35
+
36
+ for key, default_value in session_state_defaults.items():
37
+ if key not in st.session_state:
38
+ st.session_state[key] = default_value
39
+
40
+ def setup_rag_pipeline(documents):
41
+ """Set up the RAG pipeline with embeddings and retrieval."""
42
+ # Use HuggingFace embeddings
43
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
44
+
45
+ # Split documents
46
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
47
+ splits = text_splitter.split_documents(documents)
48
+
49
+ # Create vector store and retriever
50
+ vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings)
51
+ retriever = vectorstore.as_retriever()
52
+
53
+ # Configure LLM
54
+ groq_api_key = os.getenv("GROQ_API_KEY")
55
+ llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
56
+
57
+ # Contextualization prompt
58
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
59
+ ("system", "Given a chat history and the latest user question, "
60
+ "formulate a standalone question which can be understood "
61
+ "without the chat history. Do NOT answer the question, "
62
+ "just reformulate it if needed and otherwise return it as is."),
63
+ MessagesPlaceholder("chat_history"),
64
+ ("human", "{input}")
65
+ ])
66
+
67
+ # QA prompt
68
+ qa_prompt = ChatPromptTemplate.from_messages([
69
+ ("system", "You are an assistant for question-answering tasks. "
70
+ "Use the following pieces of retrieved context to answer "
71
+ "the question. If you don't know the answer, say that you "
72
+ "don't know. Use three sentences minimum and keep the "
73
+ "answer concise. Can include any number of words\n\n{context}"),
74
+ MessagesPlaceholder("chat_history"),
75
+ ("human", "{input}")
76
+ ])
77
+
78
+ # Create chains
79
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
80
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
81
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
82
+
83
+ # Conversational RAG chain with message history
84
+ conversational_rag_chain = RunnableWithMessageHistory(
85
+ rag_chain,
86
+ lambda session_id: ChatMessageHistory(),
87
+ input_messages_key="input",
88
+ history_messages_key="chat_history",
89
+ output_messages_key="answer"
90
+ )
91
+
92
+ return conversational_rag_chain, vectorstore, retriever
93
+
94
+ def main():
95
+ # Initialize Streamlit app
96
+ st.title("RAG with PDF Uploads")
97
+ st.write("Upload PDFs and chat with their content")
98
+
99
+ # Initialize session state
100
+ initialize_session_state()
101
+
102
+ # Reset session state when the reset button is clicked
103
+ # if st.button("Reset"):
104
+ # # Clear session state variables
105
+ # st.session_state.clear()
106
+ # # Reinitialize the session state
107
+ # initialize_session_state()
108
+ # st.success("Session reset successfully!")
109
+
110
+ if st.button("Reset"):
111
+ # Clear all session state variables
112
+ for key in list(st.session_state.keys()):
113
+ del st.session_state[key]
114
+ # Reinitialize the session state
115
+ initialize_session_state()
116
+ st.success("Session reset successfully!")
117
+ # Force a rerun of the app to clear the UI
118
+ st.rerun()
119
+
120
+ # API Key check
121
+ if not os.getenv("GROQ_API_KEY"):
122
+ st.error("Please set the GROQ_API_KEY environment variable.")
123
+ return
124
+
125
+ # File upload
126
+ uploaded_files = st.file_uploader("Upload PDF files", type='pdf', accept_multiple_files=True)
127
+
128
+ # Process uploaded files
129
+ if uploaded_files:
130
+ # Get current file names
131
+ current_file_names = {file.name for file in uploaded_files}
132
+
133
+ # Check if new files have been uploaded
134
+ if current_file_names != st.session_state.uploaded_file_names:
135
+ # Update the set of uploaded file names
136
+ st.session_state.uploaded_file_names = current_file_names
137
+
138
+ # Process PDF documents
139
+ documents = []
140
+ for uploaded_file in uploaded_files:
141
+ # Save the uploaded file temporarily
142
+ with open("./temp.pdf", "wb") as file:
143
+ file.write(uploaded_file.getvalue())
144
+
145
+ # Load the PDF
146
+ loader = PyPDFLoader("./temp.pdf")
147
+ docs = loader.load()
148
+ documents.extend(docs)
149
+
150
+ # Setup RAG pipeline
151
+ st.session_state.conversation_chain, st.session_state.vectorstore, st.session_state.retriever = setup_rag_pipeline(documents)
152
+
153
+ # Chat interface
154
+ user_input = st.text_input("Ask a question about your documents:")
155
+
156
+ if user_input and st.session_state.conversation_chain:
157
+ try:
158
+ # Invoke the conversational chain
159
+ response = st.session_state.conversation_chain.invoke(
160
+ {"input": user_input},
161
+ config={"configurable": {"session_id": "default_session"}}
162
+ )
163
+
164
+ # Display the answer
165
+ st.write("Assistant:", response['answer'])
166
+
167
+ # Update chat history
168
+ st.session_state.chat_history.append({"user": user_input, "assistant": response['answer']})
169
+
170
+ except Exception as e:
171
+ st.error(f"An error occurred: {e}")
172
+
173
+ # Display chat history
174
+ if st.session_state.chat_history:
175
+ st.subheader("Chat History")
176
+ for chat in st.session_state.chat_history:
177
+ st.markdown(f"**You:** {chat['user']}")
178
+ st.markdown(f"**Assistant:** {chat['assistant']}")
179
+
180
+ if __name__ == "__main__":
181
+ main()
182
+
183
+
184
+
185
+
186
+ # st.title("RAG With PDF uplaods and chat history")
187
+ # st.write("Upload Pdf's and chat with their content")
188
+
189
+ # llm = ChatGroq(groq_api_key = groq_api_key, model_name = "Gemma2-9b-It")
190
+
191
+ # session_id=st.text_input("Session ID",value="default_session")
192
+ # # statefully manages the session history
193
+
194
+ # if 'store' not in st.session_state:
195
+ # st.session_state.store={}
196
+
197
+ # uploaded_files = st.file_uploader("Upload the pdf file", type='pdf', accept_multiple_files=True)
198
+
199
+ # # process uploaded files
200
+ # if uploaded_files:
201
+ # documents = []
202
+ # for uploaded_file in uploaded_files:
203
+ # tempfile = f"./temp.pdf"
204
+ # with open(tempfile,"wb") as file:
205
+ # file.write(uploaded_file.getvalue())
206
+ # # file.name = uploaded_file.name
207
+
208
+ # loader= PyPDFLoader(tempfile) # i think this works only on saved files hence tempfile was created
209
+ # # recheck
210
+ # docs = loader.load()
211
+ # documents.extend(docs)
212
+
213
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
214
+ # splits = text_splitter.split_documents(documents)
215
+ # vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
216
+ # retriever = vectorstore.as_retriever()
217
+
218
+ # contextualize_q_system_prompt=(
219
+ # "Given a chat history and the latest user question"
220
+ # "which might reference context in the chat history, "
221
+ # "formulate a standalone question which can be understood "
222
+ # "without the chat history. Do NOT answer the question, "
223
+ # "just reformulate it if needed and otherwise return it as is."
224
+ # )
225
+
226
+ # contextualize_q_prompt = ChatPromptTemplate.from_messages(
227
+ # [
228
+ # ("system", contextualize_q_system_prompt),
229
+ # MessagesPlaceholder("chat_history"),
230
+ # ("human","{input}")
231
+
232
+ # ]
233
+ # )
234
+
235
+ # history_aware_retriever=create_history_aware_retriever(llm,retriever,contextualize_q_prompt)
236
+
237
+ # ## Answer question
238
+
239
+ # # Answer question
240
+ # system_prompt = (
241
+ # "You are an assistant for question-answering tasks. "
242
+ # "Use the following pieces of retrieved context to answer "
243
+ # "the question. If you don't know the answer, say that you "
244
+ # "don't know. Use three sentences maximum and keep the "
245
+ # "answer concise."
246
+ # "\n\n"
247
+ # "{context}"
248
+ # )
249
+ # qa_prompt = ChatPromptTemplate.from_messages(
250
+ # [
251
+ # ("system", system_prompt),
252
+ # MessagesPlaceholder("chat_history"),
253
+ # ("human", "{input}"),
254
+ # ]
255
+ # )
256
+ # question_answer_chain=create_stuff_documents_chain(llm,qa_prompt)
257
+ # rag_chain=create_retrieval_chain(history_aware_retriever,question_answer_chain)
258
+
259
+ # def get_session_history(session:str)->BaseChatMessageHistory:
260
+ # if session_id not in st.session_state.store:
261
+ # st.session_state.store[session_id]=ChatMessageHistory()
262
+ # return st.session_state.store[session_id]
263
+
264
+ # conversational_rag_chain=RunnableWithMessageHistory(
265
+ # rag_chain,get_session_history,
266
+ # input_messages_key="input",
267
+ # history_messages_key="chat_history",
268
+ # output_messages_key="answer"
269
+ # )
270
+
271
+ # user_input = st.text_input("Question: ")
272
+
273
+ # if user_input:
274
+ # session_history=get_session_history(session_id)
275
+ # response = conversational_rag_chain.invoke(
276
+ # {"input": user_input},
277
+ # config={
278
+ # "configurable": {"session_id":session_id}
279
+ # }, # constructs a key "abc123" in `store`.
280
+ # )
281
+ # st.write(st.session_state.store)
282
+ # st.write("Assistant:", response['answer'])
283
+ # st.write("Chat History:", session_history.messages)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ python-dotenv
3
+ ipykernel
4
+ langchain_community
5
+ pypdf
6
+ bs4
7
+ arxiv
8
+ pymupdf
9
+ wikipedia
10
+ langchain-text-splitters
11
+ sentence_transformers
12
+ langchain_huggingface
13
+ faiss-cpu
14
+ streamlit
15
+ langchain-groq
16
+ chromadb
17
+ langserve
18
+ langchain_chroma