thesnak's picture
Update app.py
c58cb45 verified
import gradio as gr
import pdfplumber
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import pipeline
# Load models
embedding_model = SentenceTransformer('all-mpnet-base-v2') # Better embedding model
qa_pipeline = pipeline("question-answering", model="deepset/roberta-large-squad2") # Larger QA model
# Initialize FAISS index
dimension = 768 # Dimension of the embedding model
index = faiss.IndexFlatL2(dimension)
# Store text chunks and their embeddings
text_chunks = []
def extract_text_from_pdf(pdf_file):
"""Extract text from a PDF file."""
text = ""
try:
with pdfplumber.open(pdf_file) as pdf:
for page in pdf.pages:
text += page.extract_text() or "" # Handle empty pages
except Exception as e:
return f"Error extracting text: {e}"
return text
def index_text_chunks(pdf_file):
"""Split text into chunks, generate embeddings, and index them."""
global text_chunks, index
# Extract text from the uploaded PDF
text = extract_text_from_pdf(pdf_file)
if not text:
return "No text extracted from the PDF. Please upload a valid PDF file."
# Split text into chunks (e.g., paragraphs)
text_chunks = [chunk for chunk in text.split("\n\n") if chunk.strip()]
# Generate embeddings for the chunks
embeddings = embedding_model.encode(text_chunks)
# Build the FAISS index
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
return f"Paper uploaded and indexed successfully! Found {len(text_chunks)} chunks."
def answer_question(question):
"""Retrieve relevant chunks and generate an answer."""
global text_chunks, index
if not text_chunks:
return "Please upload a paper first."
# Embed the question
question_embedding = embedding_model.encode([question])
# Retrieve top-k relevant chunks (increase k for more context)
k = 5 # Retrieve more chunks for better context
distances, indices = index.search(question_embedding, k=k)
relevant_chunks = [text_chunks[i] for i in indices[0]]
# Use the QA model to generate an answer
context = " ".join(relevant_chunks)
result = qa_pipeline(question=question, context=context)
# Post-process the answer
answer = result['answer']
if answer.strip() == "":
return "The paper does not provide enough information to answer this question."
return answer
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Chat with Your Paper πŸ“„")
gr.Markdown("Upload a PDF of your research paper and ask questions about it.")
with gr.Row():
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
upload_status = gr.Textbox(label="Upload Status", interactive=False)
with gr.Row():
question_input = gr.Textbox(label="Ask a Question", placeholder="What is the main contribution of this paper?")
answer_output = gr.Textbox(label="Answer", interactive=False)
# Buttons
upload_button = gr.Button("Upload and Index Paper")
ask_button = gr.Button("Ask Question")
# Define actions
upload_button.click(
fn=index_text_chunks,
inputs=pdf_input,
outputs=upload_status
)
ask_button.click(
fn=answer_question,
inputs=question_input,
outputs=answer_output
)
# Launch the app
demo.launch()