mgbam commited on
Commit
718c260
·
verified ·
1 Parent(s): 8225d31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -52
app.py CHANGED
@@ -9,12 +9,15 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import streamlit as st
10
  import pandas as pd
11
 
 
 
 
12
  # NLP
13
  import nltk
14
  nltk.download('punkt')
15
  from nltk.tokenize import sent_tokenize
16
 
17
- # Hugging Face Transformers
18
  from transformers import pipeline
19
 
20
  # Optional: OpenAI and Google Generative AI
@@ -24,30 +27,17 @@ import google.generativeai as genai
24
  ###############################################################################
25
  # CONFIG & ENV #
26
  ###############################################################################
27
- """
28
- In your Hugging Face Space:
29
- 1. Add environment secrets:
30
- - OPENAI_API_KEY (if using OpenAI)
31
- - GEMINI_API_KEY (if using Google PaLM/Gemini)
32
- - MY_PUBMED_EMAIL (to identify yourself to NCBI)
33
- 2. In requirements.txt, install:
34
- - streamlit
35
- - requests
36
- - nltk
37
- - transformers
38
- - torch
39
- - openai (if using OpenAI)
40
- - google-generativeai (if using Gemini)
41
- - pandas
42
- """
43
 
44
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
45
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
46
  MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
47
 
 
48
  if OPENAI_API_KEY:
49
  openai.api_key = OPENAI_API_KEY
50
 
 
51
  if GEMINI_API_KEY:
52
  genai.configure(api_key=GEMINI_API_KEY)
53
 
@@ -58,12 +48,12 @@ if GEMINI_API_KEY:
58
  def load_summarizer():
59
  """
60
  Load a summarization model (e.g., BART, PEGASUS, T5).
61
- For a more concise summarization, consider: 'google/pegasus-xsum'
62
  For a balanced approach, 'facebook/bart-large-cnn' is popular.
63
  """
64
  return pipeline(
65
- "summarization",
66
- model="facebook/bart-large-cnn",
67
  tokenizer="facebook/bart-large-cnn"
68
  )
69
 
@@ -109,11 +99,9 @@ def fetch_one_abstract(pmid):
109
  resp = requests.get(base_url, params=params)
110
  resp.raise_for_status()
111
  raw_text = resp.text.strip()
112
-
113
- # If there's no clear text returned, mark as empty
114
  if not raw_text:
115
  return (pmid, "No abstract text found.")
116
-
117
  return (pmid, raw_text)
118
 
119
  def fetch_pubmed_abstracts(pmids):
@@ -122,6 +110,9 @@ def fetch_pubmed_abstracts(pmids):
122
  Returns {pmid: abstract_text}.
123
  """
124
  abstracts_map = {}
 
 
 
125
  with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
126
  future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
127
  for future in as_completed(future_to_pmid):
@@ -142,10 +133,9 @@ def chunk_and_summarize(abstract_text, chunk_size=512):
142
  then summarizes each chunk with the Hugging Face pipeline.
143
  Returns a combined summary for the entire abstract.
144
  """
145
- # We first split by sentences
146
  sentences = sent_tokenize(abstract_text)
147
  chunks = []
148
-
149
  current_chunk = []
150
  current_length = 0
151
  for sent in sentences:
@@ -155,6 +145,7 @@ def chunk_and_summarize(abstract_text, chunk_size=512):
155
  chunks.append(" ".join(current_chunk))
156
  current_chunk = []
157
  current_length = 0
 
158
  current_chunk.append(sent)
159
  current_length += tokens_in_sent
160
 
@@ -162,18 +153,16 @@ def chunk_and_summarize(abstract_text, chunk_size=512):
162
  if current_chunk:
163
  chunks.append(" ".join(current_chunk))
164
 
165
- # Summarize each chunk to avoid hitting token or length constraints
166
  summarized_pieces = []
167
  for c in chunks:
168
  summary_out = summarizer(
169
  c,
170
- max_length=100, # tweak for desired summary length
171
  min_length=30,
172
  do_sample=False
173
  )
174
  summarized_pieces.append(summary_out[0]['summary_text'])
175
-
176
- # Combine partial summaries into one final text
177
  final_summary = " ".join(summarized_pieces)
178
  return final_summary.strip()
179
 
@@ -218,17 +207,17 @@ def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001",
218
  ###############################################################################
219
  def build_system_prompt_with_refs(pmids, summarized_map):
220
  """
221
- Creates a system prompt that includes the summarized abstracts alongside
222
- labeled references. This allows the LLM to quote or cite specific references.
223
  """
224
- # Example of labeling references: [Ref1], [Ref2], etc.
225
  system_context = (
226
  "You have access to the following summarized PubMed articles. "
227
- "When relevant, cite them in your final answer using their reference label.\n\n"
228
  )
229
  for idx, pmid in enumerate(pmids, start=1):
230
  ref_label = f"[Ref{idx}]"
231
  system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
 
232
  system_context += "Use this contextual info to provide a concise, evidence-based answer."
233
  return system_context
234
 
@@ -236,12 +225,13 @@ def build_system_prompt_with_refs(pmids, summarized_map):
236
  # STREAMLIT APP #
237
  ###############################################################################
238
  def main():
239
- st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
240
  st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
241
 
242
  st.markdown("""
243
- **Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
244
- using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini).
 
245
 
246
  This version includes:
247
  - **Parallel** fetching for multiple PMIDs
@@ -261,7 +251,6 @@ def main():
261
  height=120
262
  )
263
 
264
- # Sidebar or columns for parameters
265
  col1, col2 = st.columns(2)
266
  with col1:
267
  max_papers = st.slider(
@@ -284,7 +273,10 @@ def main():
284
  min_value=256,
285
  max_value=1024,
286
  value=512,
287
- help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries."
 
 
 
288
  )
289
 
290
  if st.button("Run Enhanced RAG Pipeline"):
@@ -295,12 +287,12 @@ def main():
295
  # 1. PubMed Search
296
  with st.spinner("Searching PubMed..."):
297
  pmids = search_pubmed(query=user_query, max_results=max_papers)
298
-
299
  if not pmids:
300
  st.error("No matching PubMed results. Try a different query.")
301
  return
302
 
303
- # 2. Fetch abstracts in parallel
304
  with st.spinner("Fetching and summarizing abstracts..."):
305
  abstracts_map = fetch_pubmed_abstracts(pmids)
306
  summarized_map = {}
@@ -318,8 +310,8 @@ def main():
318
  st.write(summarized_map[pmid])
319
  st.write("---")
320
 
321
- # 4. Build System Prompt
322
- st.subheader("Final Answer")
323
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
324
 
325
  with st.spinner("Generating final answer..."):
@@ -331,23 +323,24 @@ def main():
331
  st.write(answer)
332
  st.success("RAG Pipeline Complete.")
333
 
334
- # Production Considerations & Next Steps
335
  st.markdown("---")
336
  st.markdown("""
337
- ### Production-Ready Enhancements:
338
  1. **Vector Databases & Advanced Retrieval**
339
- - For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages.
340
  2. **Citation Parsing**
341
- - Automatically detect which abstract chunks contributed to each sentence.
342
  3. **Multi-Lingual**
343
- - Integrate translation pipelines for non-English queries or abstracts.
344
  4. **Rate Limiting**
345
- - Respect NCBI's ~3 requests/sec guideline if you're scaling out.
346
- 5. **Robust Logging & Error Handling**
347
- - Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing.
348
- 6. **Privacy & Security**
349
- - This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines.
350
  """)
351
 
 
352
  if __name__ == "__main__":
353
  main()
 
9
  import streamlit as st
10
  import pandas as pd
11
 
12
+ # Set page config FIRST, before any other Streamlit calls:
13
+ st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
14
+
15
  # NLP
16
  import nltk
17
  nltk.download('punkt')
18
  from nltk.tokenize import sent_tokenize
19
 
20
+ # Transformers for summarization
21
  from transformers import pipeline
22
 
23
  # Optional: OpenAI and Google Generative AI
 
27
  ###############################################################################
28
  # CONFIG & ENV #
29
  ###############################################################################
30
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
33
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
34
  MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
35
 
36
+ # Configure OpenAI if key is provided
37
  if OPENAI_API_KEY:
38
  openai.api_key = OPENAI_API_KEY
39
 
40
+ # Configure Google PaLM / Gemini if key is provided
41
  if GEMINI_API_KEY:
42
  genai.configure(api_key=GEMINI_API_KEY)
43
 
 
48
  def load_summarizer():
49
  """
50
  Load a summarization model (e.g., BART, PEGASUS, T5).
51
+ For a more concise summarization, consider 'google/pegasus-xsum'.
52
  For a balanced approach, 'facebook/bart-large-cnn' is popular.
53
  """
54
  return pipeline(
55
+ "summarization",
56
+ model="facebook/bart-large-cnn",
57
  tokenizer="facebook/bart-large-cnn"
58
  )
59
 
 
99
  resp = requests.get(base_url, params=params)
100
  resp.raise_for_status()
101
  raw_text = resp.text.strip()
102
+
 
103
  if not raw_text:
104
  return (pmid, "No abstract text found.")
 
105
  return (pmid, raw_text)
106
 
107
  def fetch_pubmed_abstracts(pmids):
 
110
  Returns {pmid: abstract_text}.
111
  """
112
  abstracts_map = {}
113
+ if not pmids:
114
+ return abstracts_map
115
+
116
  with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
117
  future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
118
  for future in as_completed(future_to_pmid):
 
133
  then summarizes each chunk with the Hugging Face pipeline.
134
  Returns a combined summary for the entire abstract.
135
  """
 
136
  sentences = sent_tokenize(abstract_text)
137
  chunks = []
138
+
139
  current_chunk = []
140
  current_length = 0
141
  for sent in sentences:
 
145
  chunks.append(" ".join(current_chunk))
146
  current_chunk = []
147
  current_length = 0
148
+
149
  current_chunk.append(sent)
150
  current_length += tokens_in_sent
151
 
 
153
  if current_chunk:
154
  chunks.append(" ".join(current_chunk))
155
 
 
156
  summarized_pieces = []
157
  for c in chunks:
158
  summary_out = summarizer(
159
  c,
160
+ max_length=100, # Tweak for desired summary length
161
  min_length=30,
162
  do_sample=False
163
  )
164
  summarized_pieces.append(summary_out[0]['summary_text'])
165
+
 
166
  final_summary = " ".join(summarized_pieces)
167
  return final_summary.strip()
168
 
 
207
  ###############################################################################
208
  def build_system_prompt_with_refs(pmids, summarized_map):
209
  """
210
+ Creates a system prompt that includes the summarized abstracts alongside
211
+ labeled references (e.g., [Ref1]) so the LLM can cite them in the final answer.
212
  """
 
213
  system_context = (
214
  "You have access to the following summarized PubMed articles. "
215
+ "When relevant, cite them using their reference label.\n\n"
216
  )
217
  for idx, pmid in enumerate(pmids, start=1):
218
  ref_label = f"[Ref{idx}]"
219
  system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
220
+
221
  system_context += "Use this contextual info to provide a concise, evidence-based answer."
222
  return system_context
223
 
 
225
  # STREAMLIT APP #
226
  ###############################################################################
227
  def main():
228
+ # From here on, we do NOT call st.set_page_config() again (to avoid the error).
229
  st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
230
 
231
  st.markdown("""
232
+ **Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
233
+ using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls
234
+ (OpenAI or Gemini).
235
 
236
  This version includes:
237
  - **Parallel** fetching for multiple PMIDs
 
251
  height=120
252
  )
253
 
 
254
  col1, col2 = st.columns(2)
255
  with col1:
256
  max_papers = st.slider(
 
273
  min_value=256,
274
  max_value=1024,
275
  value=512,
276
+ help=(
277
+ "Larger chunks produce fewer summarization calls, but risk token limits. "
278
+ "Smaller chunks produce more robust summaries."
279
+ )
280
  )
281
 
282
  if st.button("Run Enhanced RAG Pipeline"):
 
287
  # 1. PubMed Search
288
  with st.spinner("Searching PubMed..."):
289
  pmids = search_pubmed(query=user_query, max_results=max_papers)
290
+
291
  if not pmids:
292
  st.error("No matching PubMed results. Try a different query.")
293
  return
294
 
295
+ # 2. Fetch & Summarize
296
  with st.spinner("Fetching and summarizing abstracts..."):
297
  abstracts_map = fetch_pubmed_abstracts(pmids)
298
  summarized_map = {}
 
310
  st.write(summarized_map[pmid])
311
  st.write("---")
312
 
313
+ # 4. Build Prompt & Generate Final Answer
314
+ st.subheader("RAG-Enhanced Final Answer")
315
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
316
 
317
  with st.spinner("Generating final answer..."):
 
323
  st.write(answer)
324
  st.success("RAG Pipeline Complete.")
325
 
326
+ # Production notes:
327
  st.markdown("---")
328
  st.markdown("""
329
+ ### Production-Ready Enhancements
330
  1. **Vector Databases & Advanced Retrieval**
331
+ - For large-scale usage, index PubMed articles in a vector DB to quickly retrieve relevant passages.
332
  2. **Citation Parsing**
333
+ - Automatically detect which chunk or article contributed to each sentence for more precise referencing.
334
  3. **Multi-Lingual**
335
+ - Integrate translation pipelines for non-English queries or abstracts to expand global reach.
336
  4. **Rate Limiting**
337
+ - Respect NCBI's ~3 requests/sec guideline if scaling up usage.
338
+ 5. **Logging & Monitoring**
339
+ - In production, set up robust logging/observability for success/failure rates.
340
+ 6. **Security & Privacy**
341
+ - Currently only uses public info. If patient data is included, ensure HIPAA/GDPR compliance.
342
  """)
343
 
344
+
345
  if __name__ == "__main__":
346
  main()