bsiddhharth
updated requirements.txt
4a3eff6
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain # combining the entire doc and send it to the context
# from langchain_chroma import Chroma
from langchain_community.chat_message_histories import ChatMessageHistory
# from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_groq import ChatGroq
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
import os
import streamlit as st
from dotenv import load_dotenv
load_dotenv()
os.environ['HF_TOKEN']=os.getenv("HF_TOKEN")
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
os.environ['GROQ_API_KEY']=os.getenv("GROQ_API_KEY")
groq_api_key=os.getenv("GROQ_API_KEY")
def initialize_session_state():
"""Initialize session state variables if they don't exist."""
session_state_defaults = {
'vectorstore': None,
'retriever': None,
'conversation_chain': None,
'chat_history': [],
'uploaded_file_names': set()
}
for key, default_value in session_state_defaults.items():
if key not in st.session_state:
st.session_state[key] = default_value
def setup_rag_pipeline(documents):
"""Set up the RAG pipeline with embeddings and retrieval."""
# Use HuggingFace embeddings
# embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
splits = text_splitter.split_documents(documents)
# Create vector store and retriever
vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever()
# Configure LLM
groq_api_key = os.getenv("GROQ_API_KEY")
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
# Contextualization prompt
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system", "Given a chat history and the latest user question, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
# QA prompt
qa_prompt = ChatPromptTemplate.from_messages([
("system", "You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences minimum and keep the "
"answer concise. Can include any number of words\n\n{context}"),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
# Create chains
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Conversational RAG chain with message history
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
lambda session_id: ChatMessageHistory(),
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer"
)
return conversational_rag_chain, vectorstore, retriever
def chat():
# Initialize Streamlit app
st.title("RAG with PDF Uploads")
st.write("Upload PDFs and chat with their content")
# Initialize session state
initialize_session_state()
# Reset session state when the reset button is clicked
# if st.button("Reset"):
# # Clear session state variables
# st.session_state.clear()
# # Reinitialize the session state
# initialize_session_state()
# st.success("Session reset successfully!")
if st.button("Reset"):
# Clear all session state variables
for key in list(st.session_state.keys()):
del st.session_state[key]
# Reinitialize the session state
initialize_session_state()
st.success("Session reset successfully!")
# Force a rerun of the app to clear the UI
st.rerun()
# API Key check
if not os.getenv("GROQ_API_KEY"):
st.error("Please set the GROQ_API_KEY environment variable.")
return
# File upload
uploaded_files = st.file_uploader("Upload PDF files", type='pdf', accept_multiple_files=True)
# Process uploaded files
if uploaded_files:
# Get current file names
current_file_names = {file.name for file in uploaded_files}
# Check if new files have been uploaded
if current_file_names != st.session_state.uploaded_file_names:
# Update the set of uploaded file names
st.session_state.uploaded_file_names = current_file_names
# Process PDF documents
documents = []
for uploaded_file in uploaded_files:
# Save the uploaded file temporarily
with open("./temp.pdf", "wb") as file:
file.write(uploaded_file.getvalue())
# Load the PDF
loader = PyPDFLoader("./temp.pdf")
docs = loader.load()
documents.extend(docs)
# Setup RAG pipeline
st.session_state.conversation_chain, st.session_state.vectorstore, st.session_state.retriever = setup_rag_pipeline(documents)
# Chat interface
user_input = st.text_input("Ask a question about your documents:")
if user_input and st.session_state.conversation_chain:
try:
# Invoke the conversational chain
response = st.session_state.conversation_chain.invoke(
{"input": user_input},
config={"configurable": {"session_id": "default_session"}}
)
# Display the answer
# st.write("Assistant:", response['answer'])
st.markdown(f"<span class='assistant-label'>Assistant:</span> {response['answer']}", unsafe_allow_html=True)
# Update chat history
st.session_state.chat_history.append({"user": user_input, "assistant": response['answer']})
except Exception as e:
st.error(f"An error occurred: {e}")
# Display chat history
if st.session_state.chat_history:
st.markdown("<h4 style='color:#53ff1a;'>Chat History</h4>", unsafe_allow_html=True)
with st.expander(""):
# st.subheader("Chat History")
for chat in st.session_state.chat_history:
# st.markdown(f"**You:** {chat['user']}")
st.markdown(f"<span class='user-label'>You:</span> {chat['user']}", unsafe_allow_html=True)
# st.markdown(f"**Assistant:** {chat['assistant']}")
st.markdown(f"<span class='assistant-label'>Assistant:</span> {chat['assistant']}", unsafe_allow_html=True)