Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pdfplumber | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
from transformers import pipeline | |
# Load models | |
embedding_model = SentenceTransformer('all-mpnet-base-v2') # Better embedding model | |
qa_pipeline = pipeline("question-answering", model="deepset/roberta-large-squad2") # Larger QA model | |
# Initialize FAISS index | |
dimension = 768 # Dimension of the embedding model | |
index = faiss.IndexFlatL2(dimension) | |
# Store text chunks and their embeddings | |
text_chunks = [] | |
def extract_text_from_pdf(pdf_file): | |
"""Extract text from a PDF file.""" | |
text = "" | |
try: | |
with pdfplumber.open(pdf_file) as pdf: | |
for page in pdf.pages: | |
text += page.extract_text() or "" # Handle empty pages | |
except Exception as e: | |
return f"Error extracting text: {e}" | |
return text | |
def index_text_chunks(pdf_file): | |
"""Split text into chunks, generate embeddings, and index them.""" | |
global text_chunks, index | |
# Extract text from the uploaded PDF | |
text = extract_text_from_pdf(pdf_file) | |
if not text: | |
return "No text extracted from the PDF. Please upload a valid PDF file." | |
# Split text into chunks (e.g., paragraphs) | |
text_chunks = [chunk for chunk in text.split("\n\n") if chunk.strip()] | |
# Generate embeddings for the chunks | |
embeddings = embedding_model.encode(text_chunks) | |
# Build the FAISS index | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(embeddings)) | |
return f"Paper uploaded and indexed successfully! Found {len(text_chunks)} chunks." | |
def answer_question(question): | |
"""Retrieve relevant chunks and generate an answer.""" | |
global text_chunks, index | |
if not text_chunks: | |
return "Please upload a paper first." | |
# Embed the question | |
question_embedding = embedding_model.encode([question]) | |
# Retrieve top-k relevant chunks (increase k for more context) | |
k = 5 # Retrieve more chunks for better context | |
distances, indices = index.search(question_embedding, k=k) | |
relevant_chunks = [text_chunks[i] for i in indices[0]] | |
# Use the QA model to generate an answer | |
context = " ".join(relevant_chunks) | |
result = qa_pipeline(question=question, context=context) | |
# Post-process the answer | |
answer = result['answer'] | |
if answer.strip() == "": | |
return "The paper does not provide enough information to answer this question." | |
return answer | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat with Your Paper π") | |
gr.Markdown("Upload a PDF of your research paper and ask questions about it.") | |
with gr.Row(): | |
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
with gr.Row(): | |
question_input = gr.Textbox(label="Ask a Question", placeholder="What is the main contribution of this paper?") | |
answer_output = gr.Textbox(label="Answer", interactive=False) | |
# Buttons | |
upload_button = gr.Button("Upload and Index Paper") | |
ask_button = gr.Button("Ask Question") | |
# Define actions | |
upload_button.click( | |
fn=index_text_chunks, | |
inputs=pdf_input, | |
outputs=upload_status | |
) | |
ask_button.click( | |
fn=answer_question, | |
inputs=question_input, | |
outputs=answer_output | |
) | |
# Launch the app | |
demo.launch() |