rajesh1729's picture
Update app.py
c316c4f verified
raw
history blame
5.61 kB
import os
import streamlit as st
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = []
if "chain" not in st.session_state:
st.session_state.chain = None
if "vectorstore" not in st.session_state: # Added vectorstore to session state
st.session_state.vectorstore = None
def create_sidebar():
with st.sidebar:
st.title("PDF Chat")
st.markdown("### Quick Demo of RAG")
api_key = st.text_input("OpenAI API Key:", type="password")
st.markdown("""
### Tools Used
- OpenAI
- LangChain
- ChromaDB
### Steps
1. Add API key
2. Upload PDF
3. Chat!
""")
return api_key
def process_pdfs(papers, api_key):
"""Process PDFs and return whether processing was successful"""
if not papers:
return False
with st.spinner("Processing PDFs..."):
try:
# Create embeddings instance
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
# Process all PDFs
all_texts = []
for paper in papers:
# Save and load PDF
file_path = os.path.join('./uploads', paper.name)
os.makedirs('./uploads', exist_ok=True)
with open(file_path, "wb") as f:
f.write(paper.getbuffer())
# Load and split the PDF
loader = PyPDFLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
texts = text_splitter.split_documents(documents)
all_texts.extend(texts)
# Cleanup
os.remove(file_path)
# Create new vectorstore
st.session_state.vectorstore = Chroma.from_documents(
documents=all_texts,
embedding=embeddings,
)
# Create chain
st.session_state.chain = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_key=api_key),
retriever=st.session_state.vectorstore.as_retriever(
search_kwargs={"k": 3} # Retrieve top 3 most relevant chunks
),
memory=ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
),
return_source_documents=True, # Include source documents in response
)
st.success(f"Processed {len(papers)} PDF(s) successfully!")
return True
except Exception as e:
st.error(f"Error processing PDFs: {str(e)}")
return False
def main():
st.set_page_config(page_title="PDF Chat")
# Sidebar with API key input
api_key = create_sidebar()
if not api_key:
st.warning("Please enter your OpenAI API key")
return
st.title("Chat with PDF")
# File uploader
papers = st.file_uploader("Upload PDFs", type=["pdf"], accept_multiple_files=True)
# Process PDFs button
if papers:
if st.button("Process PDFs"):
process_pdfs(papers, api_key)
# Display chat messages from history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Ask about your PDFs"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate and display assistant response
with st.chat_message("assistant"):
if st.session_state.chain is None:
response = "Please upload and process a PDF first."
else:
with st.spinner("Thinking..."):
# Get response with source documents
result = st.session_state.chain({"question": prompt})
response = result["answer"]
# Optionally show sources
if "source_documents" in result:
sources = result["source_documents"]
if sources:
response += "\n\nSources:"
for i, doc in enumerate(sources, 1):
# Add page numbers if available
page_info = f" (Page {doc.metadata['page'] + 1})" if 'page' in doc.metadata else ""
response += f"\n{i}.{page_info} {doc.page_content[:200]}..."
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()