thesnak commited on
Commit
e10d69e
·
verified ·
1 Parent(s): abc712f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pdfplumber
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ from transformers import pipeline
7
+
8
+ # Load models
9
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # For embedding text chunks
10
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") # For QA
11
+
12
+ # Initialize FAISS index
13
+ dimension = 384 # Dimension of the embedding model
14
+ index = faiss.IndexFlatL2(dimension)
15
+
16
+ # Store text chunks and their embeddings
17
+ text_chunks = []
18
+
19
+ def extract_text_from_pdf(pdf_file):
20
+ """Extract text from a PDF file."""
21
+ text = ""
22
+ with pdfplumber.open(pdf_file) as pdf:
23
+ for page in pdf.pages:
24
+ text += page.extract_text()
25
+ return text
26
+
27
+ def index_text_chunks(text):
28
+ """Split text into chunks, generate embeddings, and index them."""
29
+ global text_chunks, index
30
+ text_chunks = text.split("\n\n") # Split by paragraphs
31
+ embeddings = embedding_model.encode(text_chunks)
32
+ index = faiss.IndexFlatL2(dimension)
33
+ index.add(np.array(embeddings))
34
+ return "Paper uploaded and indexed successfully!"
35
+
36
+ def answer_question(question):
37
+ """Retrieve relevant chunks and generate an answer."""
38
+ if not text_chunks:
39
+ return "Please upload a paper first."
40
+
41
+ # Embed the question
42
+ question_embedding = embedding_model.encode([question])
43
+
44
+ # Retrieve top-k relevant chunks
45
+ distances, indices = index.search(question_embedding, k=2)
46
+ relevant_chunks = [text_chunks[i] for i in indices[0]]
47
+
48
+ # Use the QA model to generate an answer
49
+ context = " ".join(relevant_chunks)
50
+ result = qa_pipeline(question=question, context=context)
51
+ return result['answer']
52
+
53
+ # Gradio Interface
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("# Chat with Your Paper 📄")
56
+ gr.Markdown("Upload a PDF of your research paper and ask questions about it.")
57
+
58
+ with gr.Row():
59
+ pdf_input = gr.File(label="Upload PDF")
60
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
61
+
62
+ with gr.Row():
63
+ question_input = gr.Textbox(label="Ask a Question", placeholder="What is the main contribution of this paper?")
64
+ answer_output = gr.Textbox(label="Answer", interactive=False)
65
+
66
+ # Buttons
67
+ upload_button = gr.Button("Upload and Index Paper")
68
+ ask_button = gr.Button("Ask Question")
69
+
70
+ # Define actions
71
+ upload_button.click(
72
+ fn=index_text_chunks,
73
+ inputs=pdf_input,
74
+ outputs=upload_status
75
+ )
76
+ ask_button.click(
77
+ fn=answer_question,
78
+ inputs=question_input,
79
+ outputs=answer_output
80
+ )
81
+
82
+ # Launch the app
83
+ demo.launch()