fhmsf commited on
Commit
df2b51a
·
verified ·
1 Parent(s): a2161e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -103
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import os
2
  import faiss
 
3
  import numpy as np
4
  import requests
5
- import streamlit as st
6
 
7
  from pypdf import PdfReader
8
  from sentence_transformers import SentenceTransformer
9
 
10
- ###############################################################################
11
  # 1. PDF Parsing and Chunking
12
- ###############################################################################
13
  def extract_pdf_text(pdf_file) -> str:
14
- """
15
- Read and extract text from each page of an uploaded PDF file.
16
- """
17
  reader = PdfReader(pdf_file)
18
  all_text = []
19
  for page in reader.pages:
@@ -22,10 +19,6 @@ def extract_pdf_text(pdf_file) -> str:
22
  return "\n".join(all_text)
23
 
24
  def chunk_text(text, chunk_size=300, overlap=50):
25
- """
26
- Splits text into overlapping chunks, each approx. 'chunk_size' tokens.
27
- 'overlap' is how many tokens from the previous chunk to include again.
28
- """
29
  words = text.split()
30
  chunks = []
31
  start = 0
@@ -36,48 +29,36 @@ def chunk_text(text, chunk_size=300, overlap=50):
36
  start += (chunk_size - overlap)
37
  return chunks
38
 
39
- ###############################################################################
40
  # 2. Embedding Model
41
- ###############################################################################
42
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
43
 
44
- ###############################################################################
45
  # 3. Build FAISS Index
46
- ###############################################################################
47
  def build_faiss_index(chunks):
48
- """
49
- Creates a FAISS index from embedded chunks.
50
- Returns (index, chunk_embeddings).
51
- """
52
  chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
53
  chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
54
-
55
  dimension = chunk_embeddings.shape[1]
56
  index = faiss.IndexFlatL2(dimension)
57
  index.add(chunk_embeddings)
58
-
59
  return index, chunk_embeddings
60
 
61
- ###############################################################################
62
  # 4. Retrieval Function
63
- ###############################################################################
64
  def retrieve_chunks(query, index, chunks, top_k=3):
65
- """
66
- Embeds 'query' and retrieves the top_k most relevant chunks from 'index'.
67
- """
68
  query_embedding = embedding_model.encode([query], show_progress_bar=False)
69
  query_embedding = np.array(query_embedding, dtype='float32')
70
 
71
  distances, indices = index.search(query_embedding, top_k)
72
  return [chunks[i] for i in indices[0]]
73
 
74
- ###############################################################################
75
  # 5. Gemini LLM Integration
76
- ###############################################################################
77
  def gemini_generate(prompt):
78
- """
79
- Calls Google's Gemini API with the environment variable GEMINI_API_KEY.
80
- """
81
  gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
82
  if not gemini_api_key:
83
  return "Error: No GEMINI_API_KEY found in environment variables."
@@ -87,7 +68,8 @@ def gemini_generate(prompt):
87
  "v1beta/models/gemini-1.5-flash:generateContent"
88
  f"?key={gemini_api_key}"
89
  )
90
- payload = {
 
91
  "contents": [
92
  {
93
  "parts": [
@@ -97,32 +79,27 @@ def gemini_generate(prompt):
97
  ]
98
  }
99
  headers = {"Content-Type": "application/json"}
 
 
 
 
100
 
 
101
  try:
102
- response = requests.post(url, headers=headers, json=payload)
103
- response.raise_for_status()
104
- r_data = response.json()
105
- # Extract the text from the 'candidates' structure:
106
  return r_data["candidates"][0]["content"]["parts"][0]["text"]
107
- except requests.exceptions.RequestException as e:
108
- return f"Error calling Gemini API: {e}"
109
- except KeyError:
110
- return f"Parsing error or unexpected response format: {response.text}"
111
 
112
- ###############################################################################
113
  # 6. RAG QA Function
114
- ###############################################################################
115
  def answer_question_with_RAG(user_question, index, chunks):
116
- """
117
- Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate().
118
- """
119
  relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
120
  context = "\n\n".join(relevant_chunks)
121
-
122
  prompt = f"""
123
  You are an AI assistant that knows the details from the uploaded research paper.
124
  Answer the user's question accurately using the context below.
125
- If something is not in the context, say 'I don't know'.
126
 
127
  Context:
128
  {context}
@@ -133,62 +110,122 @@ def answer_question_with_RAG(user_question, index, chunks):
133
  """
134
  return gemini_generate(prompt)
135
 
136
- ###############################################################################
137
- # Streamlit Application
138
- ###############################################################################
139
- def main():
140
- # Basic page config (optional):
141
- st.set_page_config(
142
- page_title="AI-Powered Personal Research Assistant",
143
- layout="centered"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- # Title and Subheader
147
- st.title("AI-Powered Personal Research Assistant")
148
- st.write("Welcome! How may I help you?")
149
-
150
- # Store the FAISS index + chunks in session_state to persist across reruns
151
- if "faiss_index" not in st.session_state:
152
- st.session_state.faiss_index = None
153
- if "chunks" not in st.session_state:
154
- st.session_state.chunks = None
155
-
156
- # Step 1: Upload and Process PDF
157
- uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"])
158
- if st.button("Process PDF"):
159
- if uploaded_pdf is None:
160
- st.warning("Please upload a PDF file first.")
161
- else:
162
- # Read and chunk
163
- raw_text = extract_pdf_text(uploaded_pdf)
164
- if not raw_text.strip():
165
- st.error("No text found in PDF.")
166
- return
167
- chunks = chunk_text(raw_text, chunk_size=300, overlap=50)
168
- if not chunks:
169
- st.error("No valid text to chunk.")
170
- return
171
- # Build index
172
- faiss_index, _ = build_faiss_index(chunks)
173
- st.session_state.faiss_index = faiss_index
174
- st.session_state.chunks = chunks
175
- st.success("PDF processed successfully!")
176
-
177
- # Step 2: Ask a Question
178
- user_question = st.text_input("Ask a question about your research paper:")
179
- if st.button("Get Answer"):
180
- if not st.session_state.faiss_index or not st.session_state.chunks:
181
- st.warning("Please upload and process a PDF first.")
182
- elif not user_question.strip():
183
- st.warning("Please enter a valid question.")
184
- else:
185
- answer = answer_question_with_RAG(
186
- user_question,
187
- st.session_state.faiss_index,
188
- st.session_state.chunks
189
- )
190
- st.write("### Answer:")
191
- st.write(answer)
192
-
193
- if __name__ == "__main__":
194
- main()
 
1
  import os
2
  import faiss
3
+ import gradio as gr
4
  import numpy as np
5
  import requests
 
6
 
7
  from pypdf import PdfReader
8
  from sentence_transformers import SentenceTransformer
9
 
10
+ ################################################################################
11
  # 1. PDF Parsing and Chunking
12
+ ################################################################################
13
  def extract_pdf_text(pdf_file) -> str:
 
 
 
14
  reader = PdfReader(pdf_file)
15
  all_text = []
16
  for page in reader.pages:
 
19
  return "\n".join(all_text)
20
 
21
  def chunk_text(text, chunk_size=300, overlap=50):
 
 
 
 
22
  words = text.split()
23
  chunks = []
24
  start = 0
 
29
  start += (chunk_size - overlap)
30
  return chunks
31
 
32
+ ################################################################################
33
  # 2. Embedding Model
34
+ ################################################################################
35
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
36
 
37
+ ################################################################################
38
  # 3. Build FAISS Index
39
+ ################################################################################
40
  def build_faiss_index(chunks):
 
 
 
 
41
  chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
42
  chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
 
43
  dimension = chunk_embeddings.shape[1]
44
  index = faiss.IndexFlatL2(dimension)
45
  index.add(chunk_embeddings)
 
46
  return index, chunk_embeddings
47
 
48
+ ################################################################################
49
  # 4. Retrieval Function
50
+ ################################################################################
51
  def retrieve_chunks(query, index, chunks, top_k=3):
 
 
 
52
  query_embedding = embedding_model.encode([query], show_progress_bar=False)
53
  query_embedding = np.array(query_embedding, dtype='float32')
54
 
55
  distances, indices = index.search(query_embedding, top_k)
56
  return [chunks[i] for i in indices[0]]
57
 
58
+ ################################################################################
59
  # 5. Gemini LLM Integration
60
+ ################################################################################
61
  def gemini_generate(prompt):
 
 
 
62
  gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
63
  if not gemini_api_key:
64
  return "Error: No GEMINI_API_KEY found in environment variables."
 
68
  "v1beta/models/gemini-1.5-flash:generateContent"
69
  f"?key={gemini_api_key}"
70
  )
71
+
72
+ data = {
73
  "contents": [
74
  {
75
  "parts": [
 
79
  ]
80
  }
81
  headers = {"Content-Type": "application/json"}
82
+ response = requests.post(url, headers=headers, json=data)
83
+
84
+ if response.status_code != 200:
85
+ return f"Error {response.status_code}: {response.text}"
86
 
87
+ r_data = response.json()
88
  try:
 
 
 
 
89
  return r_data["candidates"][0]["content"]["parts"][0]["text"]
90
+ except Exception:
91
+ return f"Parsing error or unexpected response structure: {r_data}"
 
 
92
 
93
+ ################################################################################
94
  # 6. RAG QA Function
95
+ ################################################################################
96
  def answer_question_with_RAG(user_question, index, chunks):
 
 
 
97
  relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
98
  context = "\n\n".join(relevant_chunks)
 
99
  prompt = f"""
100
  You are an AI assistant that knows the details from the uploaded research paper.
101
  Answer the user's question accurately using the context below.
102
+ If something is not in the context, say you don't know.
103
 
104
  Context:
105
  {context}
 
110
  """
111
  return gemini_generate(prompt)
112
 
113
+ ################################################################################
114
+ # 7. Gradio Interface
115
+ ################################################################################
116
+ def process_pdf(pdf_file):
117
+ if pdf_file is None:
118
+ return None, "Please upload a PDF file."
119
+
120
+ text = extract_pdf_text(pdf_file.name)
121
+ if not text:
122
+ return None, "No text found in PDF."
123
+
124
+ chunks = chunk_text(text, chunk_size=300, overlap=50)
125
+ if not chunks:
126
+ return None, "No valid text to chunk."
127
+
128
+ faiss_index, _ = build_faiss_index(chunks)
129
+ return (faiss_index, chunks), "PDF processed successfully!"
130
+
131
+ def chat_with_paper(query, state):
132
+ if not state:
133
+ return "Please upload and process a PDF first."
134
+ faiss_index, doc_chunks = state
135
+ if not query or not query.strip():
136
+ return "Please enter a valid question."
137
+ return answer_question_with_RAG(query, faiss_index, doc_chunks)
138
+
139
+ demo_theme = gr.themes.Soft(primary_hue="slate")
140
+
141
+ css_code = """
142
+ body {
143
+ background-color: #E6F7FF !important; /* Lightest blue */
144
+ margin: 0;
145
+ padding: 0;
146
+ }
147
+
148
+ .block > .inside {
149
+ margin: auto !important;
150
+ max-width: 900px !important;
151
+ border: 4px solid black !important;
152
+ border-radius: 10px !important;
153
+ background-color: #FFFFFF !important;
154
+ padding: 20px !important;
155
+ }
156
+
157
+ #icon-container {
158
+ text-align: center !important;
159
+ margin-top: 1rem !important;
160
+ margin-bottom: 1rem !important;
161
+ }
162
+
163
+ #app-title {
164
+ text-align: center !important;
165
+ font-size: 3rem !important;
166
+ font-weight: 900 !important;
167
+ margin-bottom: 0.5rem !important;
168
+ margin-top: 0.5rem !important;
169
+ }
170
+
171
+ #app-welcome {
172
+ text-align: center !important;
173
+ font-size: 1.5rem !important;
174
+ color: #444 !important;
175
+ margin-bottom: 25px !important;
176
+ font-weight: 700 !important;
177
+ }
178
+
179
+ button {
180
+ background-color: #3CB371 !important;
181
+ color: #ffffff !important;
182
+ border: none !important;
183
+ font-weight: 600 !important;
184
+ cursor: pointer;
185
+ }
186
+
187
+ button:hover {
188
+ background-color: #2E8B57 !important;
189
+ }
190
+
191
+ textarea, input[type="text"] {
192
+ text-align: center !important;
193
+ }
194
+ """
195
+
196
+ with gr.Blocks(theme=demo_theme, css=css_code) as demo:
197
+ gr.Markdown("""
198
+ <div id="icon-container">
199
+ <img src="https://i.ibb.co/3Wp3yBZ/ai-icon.png" alt="AI icon" style="width:100px;">
200
+ </div>
201
+ """)
202
+
203
+ gr.Markdown("<div id='app-title'>AI-Powered Personal Research Assistant</div>")
204
+ gr.Markdown("<div id='app-welcome'>Welcome! How may I help you?</div>")
205
+
206
+ state = gr.State()
207
+
208
+ with gr.Row():
209
+ pdf_input = gr.File(label="Upload your research paper (PDF)", file_types=[".pdf"])
210
+ process_button = gr.Button("Process PDF")
211
+
212
+ status_output = gr.Textbox(label="Status", interactive=False)
213
+
214
+ process_button.click(
215
+ fn=process_pdf,
216
+ inputs=pdf_input,
217
+ outputs=[state, status_output]
218
+ )
219
+
220
+ with gr.Row():
221
+ user_query = gr.Textbox(label="Ask a question about your research paper:")
222
+ ask_button = gr.Button("Get Answer")
223
+ answer_output = gr.Textbox(label="Answer")
224
+
225
+ ask_button.click(
226
+ fn=chat_with_paper,
227
+ inputs=[user_query, state],
228
+ outputs=answer_output
229
  )
230
 
231
+ demo.launch()