RAG / app.py
Alimubariz124's picture
Update app.py
63f5111 verified
import os
import streamlit as st
import PyPDF2
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from transformers import pipeline
# Load embedding model
@st.cache_resource
def load_embedding_model():
return SentenceTransformer("all-MiniLM-L6-v2")
# Parse PDF file
def parse_pdf(file):
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Split text into chunks
def split_text(text, chunk_size=500, chunk_overlap=100):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_text(text)
# Create FAISS index
def create_faiss_index(texts, embedding_model):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_texts(texts, embeddings)
return vectorstore
# Generate response from the model
def generate_response(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=512, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Main Streamlit app
def main():
st.title("Advanced Chat with Your Document")
# Initialize session state for conversation history and documents
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = []
if "vectorstore" not in st.session_state:
st.session_state.vectorstore = None
# Step 1: Upload multiple PDF files
uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
if uploaded_files:
st.write(f"{len(uploaded_files)} file(s) uploaded successfully!")
# Process PDFs
with st.spinner("Processing PDFs..."):
all_texts = []
for uploaded_file in uploaded_files:
pdf_text = parse_pdf(uploaded_file)
chunks = split_text(pdf_text)
all_texts.extend(chunks)
# Create a unified vector database
embedding_model = load_embedding_model()
st.session_state.vectorstore = create_faiss_index(all_texts, embedding_model)
st.success("PDFs processed! You can now ask questions.")
# Step 2: Chat interface
user_input = st.text_input("Ask a question about the document(s):")
if user_input:
if st.session_state.vectorstore is None:
st.error("Please upload and process documents first.")
return
with st.spinner("Generating response..."):
# Load the LLM
model_name = "meta-llama/Llama-2-7b-chat-hf" # Replace with your local path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Set up LangChain components
retriever = st.session_state.vectorstore.as_retriever()
llm = HuggingFacePipeline(pipeline=pipeline("text-generation", model=model, tokenizer=tokenizer))
# Define a custom prompt template for Chain-of-Thought reasoning
prompt_template = """
Answer the following question based ONLY on the provided context.
If the question requires multi-step reasoning, break it down step by step.
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# Create a conversational retrieval chain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": prompt},
return_source_documents=True
)
# Add conversation history
chat_history = st.session_state.conversation_history[-3:] # Last 3 interactions
result = qa_chain({"question": user_input, "chat_history": chat_history})
# Extract response and update conversation history
response = result["answer"]
st.session_state.conversation_history.append(f"User: {user_input}")
st.session_state.conversation_history.append(f"Bot: {response}")
st.write(f"**Response:** {response}")
# Display source documents (optional)
if "source_documents" in result:
st.subheader("Source Documents")
for doc in result["source_documents"]:
st.write(doc.page_content)
# Display conversation history
st.subheader("Conversation History")
for line in st.session_state.conversation_history:
st.write(line)
if __name__ == "__main__":
main()