Spaces:
Sleeping
Sleeping
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain # combining the entire doc and send it to the context | |
# from langchain_chroma import Chroma | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
# from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_groq import ChatGroq | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.vectorstores import FAISS | |
import os | |
import streamlit as st | |
from dotenv import load_dotenv | |
load_dotenv() | |
os.environ['HF_TOKEN']=os.getenv("HF_TOKEN") | |
embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
os.environ['GROQ_API_KEY']=os.getenv("GROQ_API_KEY") | |
groq_api_key=os.getenv("GROQ_API_KEY") | |
def initialize_session_state(): | |
"""Initialize session state variables if they don't exist.""" | |
session_state_defaults = { | |
'vectorstore': None, | |
'retriever': None, | |
'conversation_chain': None, | |
'chat_history': [], | |
'uploaded_file_names': set() | |
} | |
for key, default_value in session_state_defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = default_value | |
def setup_rag_pipeline(documents): | |
"""Set up the RAG pipeline with embeddings and retrieval.""" | |
# Use HuggingFace embeddings | |
# embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500) | |
splits = text_splitter.split_documents(documents) | |
# Create vector store and retriever | |
vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings) | |
retriever = vectorstore.as_retriever() | |
# Configure LLM | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile") | |
# Contextualization prompt | |
contextualize_q_prompt = ChatPromptTemplate.from_messages([ | |
("system", "Given a chat history and the latest user question, " | |
"formulate a standalone question which can be understood " | |
"without the chat history. Do NOT answer the question, " | |
"just reformulate it if needed and otherwise return it as is."), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}") | |
]) | |
# QA prompt | |
qa_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an assistant for question-answering tasks. " | |
"Use the following pieces of retrieved context to answer " | |
"the question. If you don't know the answer, say that you " | |
"don't know. Use three sentences minimum and keep the " | |
"answer concise. Can include any number of words\n\n{context}"), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}") | |
]) | |
# Create chains | |
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
# Conversational RAG chain with message history | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
lambda session_id: ChatMessageHistory(), | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer" | |
) | |
return conversational_rag_chain, vectorstore, retriever | |
def chat(): | |
# Initialize Streamlit app | |
st.title("RAG with PDF Uploads") | |
st.write("Upload PDFs and chat with their content") | |
# Initialize session state | |
initialize_session_state() | |
# Reset session state when the reset button is clicked | |
# if st.button("Reset"): | |
# # Clear session state variables | |
# st.session_state.clear() | |
# # Reinitialize the session state | |
# initialize_session_state() | |
# st.success("Session reset successfully!") | |
if st.button("Reset"): | |
# Clear all session state variables | |
for key in list(st.session_state.keys()): | |
del st.session_state[key] | |
# Reinitialize the session state | |
initialize_session_state() | |
st.success("Session reset successfully!") | |
# Force a rerun of the app to clear the UI | |
st.rerun() | |
# API Key check | |
if not os.getenv("GROQ_API_KEY"): | |
st.error("Please set the GROQ_API_KEY environment variable.") | |
return | |
# File upload | |
uploaded_files = st.file_uploader("Upload PDF files", type='pdf', accept_multiple_files=True) | |
# Process uploaded files | |
if uploaded_files: | |
# Get current file names | |
current_file_names = {file.name for file in uploaded_files} | |
# Check if new files have been uploaded | |
if current_file_names != st.session_state.uploaded_file_names: | |
# Update the set of uploaded file names | |
st.session_state.uploaded_file_names = current_file_names | |
# Process PDF documents | |
documents = [] | |
for uploaded_file in uploaded_files: | |
# Save the uploaded file temporarily | |
with open("./temp.pdf", "wb") as file: | |
file.write(uploaded_file.getvalue()) | |
# Load the PDF | |
loader = PyPDFLoader("./temp.pdf") | |
docs = loader.load() | |
documents.extend(docs) | |
# Setup RAG pipeline | |
st.session_state.conversation_chain, st.session_state.vectorstore, st.session_state.retriever = setup_rag_pipeline(documents) | |
# Chat interface | |
user_input = st.text_input("Ask a question about your documents:") | |
if user_input and st.session_state.conversation_chain: | |
try: | |
# Invoke the conversational chain | |
response = st.session_state.conversation_chain.invoke( | |
{"input": user_input}, | |
config={"configurable": {"session_id": "default_session"}} | |
) | |
# Display the answer | |
# st.write("Assistant:", response['answer']) | |
st.markdown(f"<span class='assistant-label'>Assistant:</span> {response['answer']}", unsafe_allow_html=True) | |
# Update chat history | |
st.session_state.chat_history.append({"user": user_input, "assistant": response['answer']}) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
# Display chat history | |
if st.session_state.chat_history: | |
st.markdown("<h4 style='color:#53ff1a;'>Chat History</h4>", unsafe_allow_html=True) | |
with st.expander(""): | |
# st.subheader("Chat History") | |
for chat in st.session_state.chat_history: | |
# st.markdown(f"**You:** {chat['user']}") | |
st.markdown(f"<span class='user-label'>You:</span> {chat['user']}", unsafe_allow_html=True) | |
# st.markdown(f"**Assistant:** {chat['assistant']}") | |
st.markdown(f"<span class='assistant-label'>Assistant:</span> {chat['assistant']}", unsafe_allow_html=True) | |