Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import faiss
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
|
7 |
+
from pypdf import PdfReader
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
|
10 |
+
################################################################################
|
11 |
+
# 1. PDF Parsing and Chunking
|
12 |
+
################################################################################
|
13 |
+
|
14 |
+
def extract_pdf_text(pdf_file) -> str:
|
15 |
+
"""
|
16 |
+
Extracts text from each page of the uploaded PDF, then concatenates them.
|
17 |
+
"""
|
18 |
+
reader = PdfReader(pdf_file)
|
19 |
+
all_text = []
|
20 |
+
for page in reader.pages:
|
21 |
+
text = page.extract_text() or ""
|
22 |
+
all_text.append(text.strip())
|
23 |
+
return "\n".join(all_text)
|
24 |
+
|
25 |
+
def chunk_text(text, chunk_size=300, overlap=50):
|
26 |
+
"""
|
27 |
+
Splits text into overlapping chunks of size ~chunk_size tokens.
|
28 |
+
overlap indicates how many tokens from the previous chunk are included again.
|
29 |
+
"""
|
30 |
+
words = text.split()
|
31 |
+
chunks = []
|
32 |
+
start = 0
|
33 |
+
while start < len(words):
|
34 |
+
end = start + chunk_size
|
35 |
+
chunk = words[start:end]
|
36 |
+
chunks.append(" ".join(chunk))
|
37 |
+
start += (chunk_size - overlap)
|
38 |
+
return chunks
|
39 |
+
|
40 |
+
################################################################################
|
41 |
+
# 2. Embedding Model
|
42 |
+
################################################################################
|
43 |
+
|
44 |
+
# Use a SentenceTransformer from Hugging Face to embed text
|
45 |
+
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
46 |
+
|
47 |
+
################################################################################
|
48 |
+
# 3. Building the FAISS Index
|
49 |
+
################################################################################
|
50 |
+
|
51 |
+
def build_faiss_index(chunks):
|
52 |
+
"""
|
53 |
+
Creates a FAISS index from the text chunks. Returns (index, chunk_embeddings).
|
54 |
+
"""
|
55 |
+
chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
|
56 |
+
chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
|
57 |
+
dimension = chunk_embeddings.shape[1]
|
58 |
+
|
59 |
+
index = faiss.IndexFlatL2(dimension) # L2 distance
|
60 |
+
index.add(chunk_embeddings)
|
61 |
+
return index, chunk_embeddings
|
62 |
+
|
63 |
+
################################################################################
|
64 |
+
# 4. Retrieval Function
|
65 |
+
################################################################################
|
66 |
+
|
67 |
+
def retrieve_chunks(query, index, chunks, top_k=3):
|
68 |
+
"""
|
69 |
+
Embeds the user query and retrieves top_k most relevant chunks via FAISS.
|
70 |
+
"""
|
71 |
+
query_embedding = embedding_model.encode([query], show_progress_bar=False)
|
72 |
+
query_embedding = np.array(query_embedding, dtype='float32')
|
73 |
+
|
74 |
+
distances, indices = index.search(query_embedding, top_k)
|
75 |
+
relevant_chunks = [chunks[i] for i in indices[0]]
|
76 |
+
return relevant_chunks
|
77 |
+
|
78 |
+
################################################################################
|
79 |
+
# 5. Gemini LLM Integration (Parsing 'candidates')
|
80 |
+
################################################################################
|
81 |
+
|
82 |
+
def gemini_generate(prompt):
|
83 |
+
"""
|
84 |
+
Calls Google's Gemini API using the environment variable GEMINI_API_KEY.
|
85 |
+
Assumes the 'generateContent' endpoint returns text under:
|
86 |
+
r_data["candidates"][0]["content"]["parts"][0]["text"]
|
87 |
+
"""
|
88 |
+
gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
|
89 |
+
if not gemini_api_key:
|
90 |
+
return "Error: No GEMINI_API_KEY found in environment variables."
|
91 |
+
|
92 |
+
url = (
|
93 |
+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
|
94 |
+
f"?key={gemini_api_key}"
|
95 |
+
)
|
96 |
+
|
97 |
+
data = {
|
98 |
+
"contents": [
|
99 |
+
{
|
100 |
+
"parts": [
|
101 |
+
{"text": prompt}
|
102 |
+
]
|
103 |
+
}
|
104 |
+
]
|
105 |
+
}
|
106 |
+
headers = {"Content-Type": "application/json"}
|
107 |
+
|
108 |
+
response = requests.post(url, headers=headers, json=data)
|
109 |
+
if response.status_code != 200:
|
110 |
+
return f"Error {response.status_code}: {response.text}"
|
111 |
+
|
112 |
+
r_data = response.json()
|
113 |
+
try:
|
114 |
+
generated_text = r_data["candidates"][0]["content"]["parts"][0]["text"]
|
115 |
+
return generated_text
|
116 |
+
except Exception:
|
117 |
+
return f"Parsing error or unexpected response structure: {r_data}"
|
118 |
+
|
119 |
+
################################################################################
|
120 |
+
# 6. RAG QA Function
|
121 |
+
################################################################################
|
122 |
+
|
123 |
+
def answer_question_with_RAG(user_question, index, chunks):
|
124 |
+
"""
|
125 |
+
Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate.
|
126 |
+
"""
|
127 |
+
relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
|
128 |
+
context = "\n\n".join(relevant_chunks)
|
129 |
+
|
130 |
+
prompt = f"""
|
131 |
+
You are an AI assistant that knows the details from the uploaded research paper.
|
132 |
+
Answer the user's question accurately using the context below.
|
133 |
+
If something is not in the context, say you don't know.
|
134 |
+
|
135 |
+
Context:
|
136 |
+
{context}
|
137 |
+
|
138 |
+
User's question: {user_question}
|
139 |
+
|
140 |
+
Answer:
|
141 |
+
"""
|
142 |
+
return gemini_generate(prompt)
|
143 |
+
|
144 |
+
################################################################################
|
145 |
+
# 7. Gradio Interface
|
146 |
+
################################################################################
|
147 |
+
|
148 |
+
def process_pdf(pdf_file):
|
149 |
+
"""
|
150 |
+
Called after the user uploads a PDF and clicks 'Process PDF'.
|
151 |
+
Extracts text, chunks it, builds FAISS index, and returns the new state.
|
152 |
+
"""
|
153 |
+
if pdf_file is None:
|
154 |
+
return None, "Please upload a PDF file."
|
155 |
+
|
156 |
+
text = extract_pdf_text(pdf_file.name)
|
157 |
+
if not text:
|
158 |
+
return None, "No text found in PDF."
|
159 |
+
|
160 |
+
chunks = chunk_text(text, chunk_size=300, overlap=50)
|
161 |
+
if not chunks:
|
162 |
+
return None, "No valid text to chunk."
|
163 |
+
|
164 |
+
faiss_index, _ = build_faiss_index(chunks)
|
165 |
+
return (faiss_index, chunks), "PDF processed successfully!"
|
166 |
+
|
167 |
+
def chat_with_paper(query, state):
|
168 |
+
"""
|
169 |
+
Handles user queries after the PDF is processed.
|
170 |
+
'state' is a tuple: (faiss_index, doc_chunks).
|
171 |
+
"""
|
172 |
+
if not state:
|
173 |
+
return "Please upload and process a PDF first."
|
174 |
+
|
175 |
+
faiss_index, doc_chunks = state
|
176 |
+
if not query or not query.strip():
|
177 |
+
return "Please enter a valid question."
|
178 |
+
|
179 |
+
return answer_question_with_RAG(query, faiss_index, doc_chunks)
|
180 |
+
|
181 |
+
################################################################################
|
182 |
+
# 8. Gradio App with Sky-Blue Tiles
|
183 |
+
################################################################################
|
184 |
+
|
185 |
+
import gradio as gr
|
186 |
+
|
187 |
+
demo_theme = gr.themes.Soft(primary_hue="slate")
|
188 |
+
|
189 |
+
css_code = """
|
190 |
+
/* Tiled sky-blue background */
|
191 |
+
body {
|
192 |
+
background: url('https://i.ibb.co/gvrZQ1C/sky-blue-tile.png');
|
193 |
+
background-repeat: repeat;
|
194 |
+
background-size: 150px 150px;
|
195 |
+
}
|
196 |
+
/* Centered headings */
|
197 |
+
#title-heading {
|
198 |
+
text-align: center;
|
199 |
+
font-size: 2.5rem;
|
200 |
+
font-weight: 700;
|
201 |
+
margin-bottom: 10px;
|
202 |
+
}
|
203 |
+
#welcome-text {
|
204 |
+
text-align: center;
|
205 |
+
font-size: 1.2rem;
|
206 |
+
color: #444;
|
207 |
+
margin-bottom: 25px;
|
208 |
+
margin-top: 0.5rem;
|
209 |
+
}
|
210 |
+
"""
|
211 |
+
|
212 |
+
with gr.Blocks(theme=demo_theme, css=css_code) as demo:
|
213 |
+
gr.Markdown("<div id='title-heading'>AI-Powered Personalized Research Assistant</div>")
|
214 |
+
gr.Markdown("<div id='welcome-text'>Welcome! How may I help you?</div>")
|
215 |
+
|
216 |
+
# State to store (faiss_index, chunks)
|
217 |
+
state = gr.State()
|
218 |
+
|
219 |
+
with gr.Row():
|
220 |
+
pdf_input = gr.File(label="Upload your research paper (PDF)", file_types=[".pdf"])
|
221 |
+
process_button = gr.Button("Process PDF")
|
222 |
+
status_output = gr.Textbox(label="Status", interactive=False)
|
223 |
+
|
224 |
+
# When user clicks "Process PDF," we parse and build the index
|
225 |
+
process_button.click(
|
226 |
+
fn=process_pdf,
|
227 |
+
inputs=pdf_input,
|
228 |
+
outputs=[state, status_output]
|
229 |
+
)
|
230 |
+
|
231 |
+
with gr.Row():
|
232 |
+
user_query = gr.Textbox(label="Ask a question about your research paper:")
|
233 |
+
ask_button = gr.Button("Get Answer")
|
234 |
+
answer_output = gr.Textbox(label="Answer")
|
235 |
+
|
236 |
+
# When user clicks "Get Answer," we do an RAG-based query
|
237 |
+
ask_button.click(
|
238 |
+
fn=chat_with_paper,
|
239 |
+
inputs=[user_query, state],
|
240 |
+
outputs=answer_output
|
241 |
+
)
|
242 |
+
|
243 |
+
demo.launch()
|