File size: 3,487 Bytes
e10d69e
 
 
 
 
 
 
 
c58cb45
 
e10d69e
 
c58cb45
e10d69e
 
 
 
 
 
 
 
bb13b3d
 
 
 
 
 
e10d69e
 
bb13b3d
e10d69e
 
bb13b3d
 
 
 
 
 
 
 
 
 
e10d69e
bb13b3d
 
e10d69e
 
bb13b3d
 
e10d69e
 
 
bb13b3d
 
e10d69e
 
 
 
 
 
c58cb45
 
 
e10d69e
 
 
 
 
c58cb45
 
 
 
 
 
 
e10d69e
 
 
 
 
 
 
bb13b3d
e10d69e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import pdfplumber
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import pipeline

# Load models
embedding_model = SentenceTransformer('all-mpnet-base-v2')  # Better embedding model
qa_pipeline = pipeline("question-answering", model="deepset/roberta-large-squad2")  # Larger QA model

# Initialize FAISS index
dimension = 768  # Dimension of the embedding model
index = faiss.IndexFlatL2(dimension)

# Store text chunks and their embeddings
text_chunks = []

def extract_text_from_pdf(pdf_file):
    """Extract text from a PDF file."""
    text = ""
    try:
        with pdfplumber.open(pdf_file) as pdf:
            for page in pdf.pages:
                text += page.extract_text() or ""  # Handle empty pages
    except Exception as e:
        return f"Error extracting text: {e}"
    return text

def index_text_chunks(pdf_file):
    """Split text into chunks, generate embeddings, and index them."""
    global text_chunks, index
    
    # Extract text from the uploaded PDF
    text = extract_text_from_pdf(pdf_file)
    if not text:
        return "No text extracted from the PDF. Please upload a valid PDF file."
    
    # Split text into chunks (e.g., paragraphs)
    text_chunks = [chunk for chunk in text.split("\n\n") if chunk.strip()]
    
    # Generate embeddings for the chunks
    embeddings = embedding_model.encode(text_chunks)
    
    # Build the FAISS index
    index = faiss.IndexFlatL2(dimension)
    index.add(np.array(embeddings))
    
    return f"Paper uploaded and indexed successfully! Found {len(text_chunks)} chunks."

def answer_question(question):
    """Retrieve relevant chunks and generate an answer."""
    global text_chunks, index
    
    if not text_chunks:
        return "Please upload a paper first."
    
    # Embed the question
    question_embedding = embedding_model.encode([question])
    
    # Retrieve top-k relevant chunks (increase k for more context)
    k = 5  # Retrieve more chunks for better context
    distances, indices = index.search(question_embedding, k=k)
    relevant_chunks = [text_chunks[i] for i in indices[0]]
    
    # Use the QA model to generate an answer
    context = " ".join(relevant_chunks)
    result = qa_pipeline(question=question, context=context)
    
    # Post-process the answer
    answer = result['answer']
    if answer.strip() == "":
        return "The paper does not provide enough information to answer this question."
    
    return answer

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Chat with Your Paper 📄")
    gr.Markdown("Upload a PDF of your research paper and ask questions about it.")
    
    with gr.Row():
        pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_status = gr.Textbox(label="Upload Status", interactive=False)
    
    with gr.Row():
        question_input = gr.Textbox(label="Ask a Question", placeholder="What is the main contribution of this paper?")
        answer_output = gr.Textbox(label="Answer", interactive=False)
    
    # Buttons
    upload_button = gr.Button("Upload and Index Paper")
    ask_button = gr.Button("Ask Question")
    
    # Define actions
    upload_button.click(
        fn=index_text_chunks,
        inputs=pdf_input,
        outputs=upload_status
    )
    ask_button.click(
        fn=answer_question,
        inputs=question_input,
        outputs=answer_output
    )

# Launch the app
demo.launch()