fhmsf commited on
Commit
f615b93
·
verified ·
1 Parent(s): e44e7a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import gradio as gr
4
+ import numpy as np
5
+ import requests
6
+
7
+ from pypdf import PdfReader
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ ################################################################################
11
+ # 1. PDF Parsing and Chunking
12
+ ################################################################################
13
+
14
+ def extract_pdf_text(pdf_file) -> str:
15
+ """
16
+ Extracts text from each page of the uploaded PDF, then concatenates them.
17
+ """
18
+ reader = PdfReader(pdf_file)
19
+ all_text = []
20
+ for page in reader.pages:
21
+ text = page.extract_text() or ""
22
+ all_text.append(text.strip())
23
+ return "\n".join(all_text)
24
+
25
+ def chunk_text(text, chunk_size=300, overlap=50):
26
+ """
27
+ Splits text into overlapping chunks of size ~chunk_size tokens.
28
+ overlap indicates how many tokens from the previous chunk are included again.
29
+ """
30
+ words = text.split()
31
+ chunks = []
32
+ start = 0
33
+ while start < len(words):
34
+ end = start + chunk_size
35
+ chunk = words[start:end]
36
+ chunks.append(" ".join(chunk))
37
+ start += (chunk_size - overlap)
38
+ return chunks
39
+
40
+ ################################################################################
41
+ # 2. Embedding Model
42
+ ################################################################################
43
+
44
+ # Use a SentenceTransformer from Hugging Face to embed text
45
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
46
+
47
+ ################################################################################
48
+ # 3. Building the FAISS Index
49
+ ################################################################################
50
+
51
+ def build_faiss_index(chunks):
52
+ """
53
+ Creates a FAISS index from the text chunks. Returns (index, chunk_embeddings).
54
+ """
55
+ chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
56
+ chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
57
+ dimension = chunk_embeddings.shape[1]
58
+
59
+ index = faiss.IndexFlatL2(dimension) # L2 distance
60
+ index.add(chunk_embeddings)
61
+ return index, chunk_embeddings
62
+
63
+ ################################################################################
64
+ # 4. Retrieval Function
65
+ ################################################################################
66
+
67
+ def retrieve_chunks(query, index, chunks, top_k=3):
68
+ """
69
+ Embeds the user query and retrieves top_k most relevant chunks via FAISS.
70
+ """
71
+ query_embedding = embedding_model.encode([query], show_progress_bar=False)
72
+ query_embedding = np.array(query_embedding, dtype='float32')
73
+
74
+ distances, indices = index.search(query_embedding, top_k)
75
+ relevant_chunks = [chunks[i] for i in indices[0]]
76
+ return relevant_chunks
77
+
78
+ ################################################################################
79
+ # 5. Gemini LLM Integration (Parsing 'candidates')
80
+ ################################################################################
81
+
82
+ def gemini_generate(prompt):
83
+ """
84
+ Calls Google's Gemini API using the environment variable GEMINI_API_KEY.
85
+ Assumes the 'generateContent' endpoint returns text under:
86
+ r_data["candidates"][0]["content"]["parts"][0]["text"]
87
+ """
88
+ gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
89
+ if not gemini_api_key:
90
+ return "Error: No GEMINI_API_KEY found in environment variables."
91
+
92
+ url = (
93
+ "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
94
+ f"?key={gemini_api_key}"
95
+ )
96
+
97
+ data = {
98
+ "contents": [
99
+ {
100
+ "parts": [
101
+ {"text": prompt}
102
+ ]
103
+ }
104
+ ]
105
+ }
106
+ headers = {"Content-Type": "application/json"}
107
+
108
+ response = requests.post(url, headers=headers, json=data)
109
+ if response.status_code != 200:
110
+ return f"Error {response.status_code}: {response.text}"
111
+
112
+ r_data = response.json()
113
+ try:
114
+ generated_text = r_data["candidates"][0]["content"]["parts"][0]["text"]
115
+ return generated_text
116
+ except Exception:
117
+ return f"Parsing error or unexpected response structure: {r_data}"
118
+
119
+ ################################################################################
120
+ # 6. RAG QA Function
121
+ ################################################################################
122
+
123
+ def answer_question_with_RAG(user_question, index, chunks):
124
+ """
125
+ Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate.
126
+ """
127
+ relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
128
+ context = "\n\n".join(relevant_chunks)
129
+
130
+ prompt = f"""
131
+ You are an AI assistant that knows the details from the uploaded research paper.
132
+ Answer the user's question accurately using the context below.
133
+ If something is not in the context, say you don't know.
134
+
135
+ Context:
136
+ {context}
137
+
138
+ User's question: {user_question}
139
+
140
+ Answer:
141
+ """
142
+ return gemini_generate(prompt)
143
+
144
+ ################################################################################
145
+ # 7. Gradio Interface
146
+ ################################################################################
147
+
148
+ def process_pdf(pdf_file):
149
+ """
150
+ Called after the user uploads a PDF and clicks 'Process PDF'.
151
+ Extracts text, chunks it, builds FAISS index, and returns the new state.
152
+ """
153
+ if pdf_file is None:
154
+ return None, "Please upload a PDF file."
155
+
156
+ text = extract_pdf_text(pdf_file.name)
157
+ if not text:
158
+ return None, "No text found in PDF."
159
+
160
+ chunks = chunk_text(text, chunk_size=300, overlap=50)
161
+ if not chunks:
162
+ return None, "No valid text to chunk."
163
+
164
+ faiss_index, _ = build_faiss_index(chunks)
165
+ return (faiss_index, chunks), "PDF processed successfully!"
166
+
167
+ def chat_with_paper(query, state):
168
+ """
169
+ Handles user queries after the PDF is processed.
170
+ 'state' is a tuple: (faiss_index, doc_chunks).
171
+ """
172
+ if not state:
173
+ return "Please upload and process a PDF first."
174
+
175
+ faiss_index, doc_chunks = state
176
+ if not query or not query.strip():
177
+ return "Please enter a valid question."
178
+
179
+ return answer_question_with_RAG(query, faiss_index, doc_chunks)
180
+
181
+ ################################################################################
182
+ # 8. Gradio App with Sky-Blue Tiles
183
+ ################################################################################
184
+
185
+ import gradio as gr
186
+
187
+ demo_theme = gr.themes.Soft(primary_hue="slate")
188
+
189
+ css_code = """
190
+ /* Tiled sky-blue background */
191
+ body {
192
+ background: url('https://i.ibb.co/gvrZQ1C/sky-blue-tile.png');
193
+ background-repeat: repeat;
194
+ background-size: 150px 150px;
195
+ }
196
+ /* Centered headings */
197
+ #title-heading {
198
+ text-align: center;
199
+ font-size: 2.5rem;
200
+ font-weight: 700;
201
+ margin-bottom: 10px;
202
+ }
203
+ #welcome-text {
204
+ text-align: center;
205
+ font-size: 1.2rem;
206
+ color: #444;
207
+ margin-bottom: 25px;
208
+ margin-top: 0.5rem;
209
+ }
210
+ """
211
+
212
+ with gr.Blocks(theme=demo_theme, css=css_code) as demo:
213
+ gr.Markdown("<div id='title-heading'>AI-Powered Personalized Research Assistant</div>")
214
+ gr.Markdown("<div id='welcome-text'>Welcome! How may I help you?</div>")
215
+
216
+ # State to store (faiss_index, chunks)
217
+ state = gr.State()
218
+
219
+ with gr.Row():
220
+ pdf_input = gr.File(label="Upload your research paper (PDF)", file_types=[".pdf"])
221
+ process_button = gr.Button("Process PDF")
222
+ status_output = gr.Textbox(label="Status", interactive=False)
223
+
224
+ # When user clicks "Process PDF," we parse and build the index
225
+ process_button.click(
226
+ fn=process_pdf,
227
+ inputs=pdf_input,
228
+ outputs=[state, status_output]
229
+ )
230
+
231
+ with gr.Row():
232
+ user_query = gr.Textbox(label="Ask a question about your research paper:")
233
+ ask_button = gr.Button("Get Answer")
234
+ answer_output = gr.Textbox(label="Answer")
235
+
236
+ # When user clicks "Get Answer," we do an RAG-based query
237
+ ask_button.click(
238
+ fn=chat_with_paper,
239
+ inputs=[user_query, state],
240
+ outputs=answer_output
241
+ )
242
+
243
+ demo.launch()