Spaces:
Running
Running
import os | |
import streamlit as st | |
import PyPDF2 | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from transformers import pipeline | |
# Load embedding model | |
def load_embedding_model(): | |
return SentenceTransformer("all-MiniLM-L6-v2") | |
# Parse PDF file | |
def parse_pdf(file): | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() | |
return text | |
# Split text into chunks | |
def split_text(text, chunk_size=500, chunk_overlap=100): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
return text_splitter.split_text(text) | |
# Create FAISS index | |
def create_faiss_index(texts, embedding_model): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectorstore = FAISS.from_texts(texts, embeddings) | |
return vectorstore | |
# Generate response from the model | |
def generate_response(prompt, model, tokenizer): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_length=512, num_return_sequences=1) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Main Streamlit app | |
def main(): | |
st.title("Advanced Chat with Your Document") | |
# Initialize session state for conversation history and documents | |
if "conversation_history" not in st.session_state: | |
st.session_state.conversation_history = [] | |
if "vectorstore" not in st.session_state: | |
st.session_state.vectorstore = None | |
# Step 1: Upload multiple PDF files | |
uploaded_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True) | |
if uploaded_files: | |
st.write(f"{len(uploaded_files)} file(s) uploaded successfully!") | |
# Process PDFs | |
with st.spinner("Processing PDFs..."): | |
all_texts = [] | |
for uploaded_file in uploaded_files: | |
pdf_text = parse_pdf(uploaded_file) | |
chunks = split_text(pdf_text) | |
all_texts.extend(chunks) | |
# Create a unified vector database | |
embedding_model = load_embedding_model() | |
st.session_state.vectorstore = create_faiss_index(all_texts, embedding_model) | |
st.success("PDFs processed! You can now ask questions.") | |
# Step 2: Chat interface | |
user_input = st.text_input("Ask a question about the document(s):") | |
if user_input: | |
if st.session_state.vectorstore is None: | |
st.error("Please upload and process documents first.") | |
return | |
with st.spinner("Generating response..."): | |
# Load the LLM | |
model_name = "meta-llama/Llama-2-7b-chat-hf" # Replace with your local path | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) | |
# Set up LangChain components | |
retriever = st.session_state.vectorstore.as_retriever() | |
llm = HuggingFacePipeline(pipeline=pipeline("text-generation", model=model, tokenizer=tokenizer)) | |
# Define a custom prompt template for Chain-of-Thought reasoning | |
prompt_template = """ | |
Answer the following question based ONLY on the provided context. | |
If the question requires multi-step reasoning, break it down step by step. | |
Context: {context} | |
Question: {question} | |
Answer: | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
# Create a conversational retrieval chain | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
combine_docs_chain_kwargs={"prompt": prompt}, | |
return_source_documents=True | |
) | |
# Add conversation history | |
chat_history = st.session_state.conversation_history[-3:] # Last 3 interactions | |
result = qa_chain({"question": user_input, "chat_history": chat_history}) | |
# Extract response and update conversation history | |
response = result["answer"] | |
st.session_state.conversation_history.append(f"User: {user_input}") | |
st.session_state.conversation_history.append(f"Bot: {response}") | |
st.write(f"**Response:** {response}") | |
# Display source documents (optional) | |
if "source_documents" in result: | |
st.subheader("Source Documents") | |
for doc in result["source_documents"]: | |
st.write(doc.page_content) | |
# Display conversation history | |
st.subheader("Conversation History") | |
for line in st.session_state.conversation_history: | |
st.write(line) | |
if __name__ == "__main__": | |
main() |