mgbam commited on
Commit
113401c
·
verified ·
1 Parent(s): 718c260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -286
app.py CHANGED
@@ -1,346 +1,147 @@
1
- import os
2
- import re
3
- import json
4
- import math
5
- import requests
6
- import threading
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
-
9
  import streamlit as st
10
- import pandas as pd
11
-
12
- # Set page config FIRST, before any other Streamlit calls:
13
- st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
14
-
15
- # NLP
16
- import nltk
17
- nltk.download('punkt')
18
- from nltk.tokenize import sent_tokenize
19
-
20
- # Transformers for summarization
21
- from transformers import pipeline
22
-
23
- # Optional: OpenAI and Google Generative AI
24
- import openai
25
- import google.generativeai as genai
26
-
27
- ###############################################################################
28
- # CONFIG & ENV #
29
- ###############################################################################
30
-
31
-
32
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
33
- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
34
- MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
35
-
36
- # Configure OpenAI if key is provided
37
- if OPENAI_API_KEY:
38
- openai.api_key = OPENAI_API_KEY
39
 
40
- # Configure Google PaLM / Gemini if key is provided
41
- if GEMINI_API_KEY:
42
- genai.configure(api_key=GEMINI_API_KEY)
 
 
 
 
 
 
 
 
 
43
 
44
  ###############################################################################
45
- # SUMMARIZATION PIPELINE #
46
  ###############################################################################
47
- @st.cache_resource
48
- def load_summarizer():
49
- """
50
- Load a summarization model (e.g., BART, PEGASUS, T5).
51
- For a more concise summarization, consider 'google/pegasus-xsum'.
52
- For a balanced approach, 'facebook/bart-large-cnn' is popular.
53
- """
54
- return pipeline(
55
- "summarization",
56
- model="facebook/bart-large-cnn",
57
- tokenizer="facebook/bart-large-cnn"
58
- )
59
-
60
- summarizer = load_summarizer()
61
 
62
  ###############################################################################
63
- # PUBMED RETRIEVAL (NCBI E-utilities) #
64
  ###############################################################################
65
- def search_pubmed(query, max_results=3):
66
- """
67
- Searches PubMed for PMIDs matching the query.
68
- Includes recommended 'tool' and 'email' in the request.
69
- """
70
- base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
71
- params = {
72
- "db": "pubmed",
73
- "term": query,
74
- "retmax": max_results,
75
- "retmode": "json",
76
- "tool": "ElysiumRAG",
77
- "email": MY_PUBMED_EMAIL
78
- }
79
- resp = requests.get(base_url, params=params)
80
- resp.raise_for_status()
81
- data = resp.json()
82
- id_list = data.get("esearchresult", {}).get("idlist", [])
83
- return id_list
84
 
85
- def fetch_one_abstract(pmid):
86
  """
87
- Fetches a single abstract for a given PMID using EFetch.
88
- Returns (pmid, text).
89
  """
90
- base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
91
- params = {
92
- "db": "pubmed",
93
- "retmode": "text",
94
- "rettype": "abstract",
95
- "id": pmid,
96
- "tool": "ElysiumRAG",
97
- "email": MY_PUBMED_EMAIL
98
- }
99
- resp = requests.get(base_url, params=params)
100
- resp.raise_for_status()
101
- raw_text = resp.text.strip()
102
 
103
- if not raw_text:
104
- return (pmid, "No abstract text found.")
105
- return (pmid, raw_text)
106
-
107
- def fetch_pubmed_abstracts(pmids):
108
- """
109
- Parallel fetching of multiple PMIDs to reduce overall latency.
110
- Returns {pmid: abstract_text}.
111
- """
112
- abstracts_map = {}
113
- if not pmids:
114
- return abstracts_map
115
-
116
- with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
117
- future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
118
- for future in as_completed(future_to_pmid):
119
- pmid = future_to_pmid[future]
120
- try:
121
- pmid_result, text = future.result()
122
- abstracts_map[pmid_result] = text
123
- except Exception as e:
124
- abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
125
- return abstracts_map
126
 
127
  ###############################################################################
128
- # ABSTRACT CHUNKING + SUMMARIZATION LOGIC #
129
  ###############################################################################
130
- def chunk_and_summarize(abstract_text, chunk_size=512):
131
  """
132
- Splits a large abstract into manageable chunks (by sentences),
133
- then summarizes each chunk with the Hugging Face pipeline.
134
- Returns a combined summary for the entire abstract.
135
  """
136
- sentences = sent_tokenize(abstract_text)
137
- chunks = []
138
-
139
- current_chunk = []
140
- current_length = 0
141
- for sent in sentences:
142
- tokens_in_sent = len(sent.split())
143
- # If adding this sentence exceeds the chunk_size limit, finalize the chunk
144
- if current_length + tokens_in_sent > chunk_size:
145
- chunks.append(" ".join(current_chunk))
146
- current_chunk = []
147
- current_length = 0
148
-
149
- current_chunk.append(sent)
150
- current_length += tokens_in_sent
151
-
152
- # Final chunk if it exists
153
- if current_chunk:
154
- chunks.append(" ".join(current_chunk))
155
-
156
- summarized_pieces = []
157
- for c in chunks:
158
- summary_out = summarizer(
159
- c,
160
- max_length=100, # Tweak for desired summary length
161
- min_length=30,
162
- do_sample=False
163
- )
164
- summarized_pieces.append(summary_out[0]['summary_text'])
165
-
166
- final_summary = " ".join(summarized_pieces)
167
- return final_summary.strip()
168
-
169
- ###############################################################################
170
- # LLM CALLS (OpenAI / Gemini) #
171
- ###############################################################################
172
- def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
173
- """
174
- Basic ChatCompletion with a system + user role for OpenAI.
175
- """
176
- if not OPENAI_API_KEY:
177
- return "Error: OpenAI API key not provided."
178
- try:
179
- response = openai.ChatCompletion.create(
180
- model=model,
181
- messages=[
182
- {"role": "system", "content": system_prompt},
183
- {"role": "user", "content": user_message}
184
- ],
185
- temperature=temperature
186
- )
187
- return response.choices[0].message["content"].strip()
188
- except Exception as e:
189
- return f"Error calling OpenAI: {str(e)}"
190
-
191
- def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
192
- """
193
- Basic PaLM2/Gemini chat call using google.generativeai.
194
- """
195
- if not GEMINI_API_KEY:
196
- return "Error: Gemini API key not provided."
197
- try:
198
- model = genai.GenerativeModel(model_name=model_name)
199
- chat_session = model.start_chat(history=[("system", system_prompt)])
200
- reply = chat_session.send_message(user_message, temperature=temperature)
201
- return reply.text
202
- except Exception as e:
203
- return f"Error calling Gemini: {str(e)}"
204
-
205
- ###############################################################################
206
- # BUILD REFERENCES FOR ANSWER #
207
- ###############################################################################
208
- def build_system_prompt_with_refs(pmids, summarized_map):
209
- """
210
- Creates a system prompt that includes the summarized abstracts alongside
211
- labeled references (e.g., [Ref1]) so the LLM can cite them in the final answer.
212
- """
213
- system_context = (
214
- "You have access to the following summarized PubMed articles. "
215
- "When relevant, cite them using their reference label.\n\n"
216
- )
217
  for idx, pmid in enumerate(pmids, start=1):
218
  ref_label = f"[Ref{idx}]"
219
- system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
220
-
221
- system_context += "Use this contextual info to provide a concise, evidence-based answer."
 
222
  return system_context
223
 
224
  ###############################################################################
225
- # STREAMLIT APP #
226
  ###############################################################################
227
  def main():
228
- # From here on, we do NOT call st.set_page_config() again (to avoid the error).
229
- st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
230
 
231
  st.markdown("""
232
- **Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
233
- using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls
234
- (OpenAI or Gemini).
235
-
236
- This version includes:
237
- - **Parallel** fetching for multiple PMIDs
238
- - Advanced **chunking & summarization** of large abstracts
239
- - **Reference labeling** in the final answer
240
- - Clear disclaimers & best-practice structures
241
-
242
- ---
243
- **Disclaimer**: This is a demonstration prototype for educational or research purposes.
244
- It is *not* a substitute for professional medical advice. Always consult a qualified
245
- healthcare provider for personal health decisions.
246
- """)
247
 
248
- user_query = st.text_area(
249
- "Enter your medical question or topic:",
250
- placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
251
- height=120
252
- )
253
 
254
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  with col1:
256
- max_papers = st.slider(
257
- "Number of PubMed Articles to Retrieve",
258
- min_value=1,
259
- max_value=10,
260
- value=3,
261
- help="Number of articles to fetch & summarize."
262
- )
263
  with col2:
264
- selected_llm = st.selectbox(
265
- "Select LLM for Final Generation",
266
- ["OpenAI: GPT-3.5", "Gemini: PaLM2"],
267
- help="Choose which large language model to finalize the answer."
268
- )
269
-
270
- # Additional advanced parameter: chunk size
271
- chunk_size = st.slider(
272
- "Summarization Chunk Size (words)",
273
- min_value=256,
274
- max_value=1024,
275
- value=512,
276
- help=(
277
- "Larger chunks produce fewer summarization calls, but risk token limits. "
278
- "Smaller chunks produce more robust summaries."
279
- )
280
- )
281
 
282
- if st.button("Run Enhanced RAG Pipeline"):
283
  if not user_query.strip():
284
- st.warning("Please enter a query before running RAG.")
285
  return
286
 
287
- # 1. PubMed Search
288
  with st.spinner("Searching PubMed..."):
289
- pmids = search_pubmed(query=user_query, max_results=max_papers)
290
 
291
  if not pmids:
292
- st.error("No matching PubMed results. Try a different query.")
293
  return
294
 
295
- # 2. Fetch & Summarize
296
- with st.spinner("Fetching and summarizing abstracts..."):
297
- abstracts_map = fetch_pubmed_abstracts(pmids)
298
  summarized_map = {}
299
- for pmid, abstract_text in abstracts_map.items():
300
- if "Error fetching" in abstract_text:
301
- summarized_map[pmid] = abstract_text
302
  else:
303
- summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size)
304
 
305
- # 3. Display Summaries
306
  st.subheader("Retrieved & Summarized PubMed Articles")
307
  for idx, pmid in enumerate(pmids, start=1):
308
- ref_label = f"[Ref{idx}]"
309
- st.markdown(f"**{ref_label} PMID {pmid}**")
310
  st.write(summarized_map[pmid])
311
  st.write("---")
312
 
313
- # 4. Build Prompt & Generate Final Answer
314
  st.subheader("RAG-Enhanced Final Answer")
315
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
316
-
317
- with st.spinner("Generating final answer..."):
318
- if selected_llm == "OpenAI: GPT-3.5":
319
- answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
320
  else:
321
- answer = gemini_chat(system_prompt=system_prompt, user_message=user_query)
322
 
323
  st.write(answer)
324
  st.success("RAG Pipeline Complete.")
325
 
326
- # Production notes:
327
  st.markdown("---")
328
  st.markdown("""
329
- ### Production-Ready Enhancements
330
- 1. **Vector Databases & Advanced Retrieval**
331
- - For large-scale usage, index PubMed articles in a vector DB to quickly retrieve relevant passages.
332
- 2. **Citation Parsing**
333
- - Automatically detect which chunk or article contributed to each sentence for more precise referencing.
334
- 3. **Multi-Lingual**
335
- - Integrate translation pipelines for non-English queries or abstracts to expand global reach.
336
- 4. **Rate Limiting**
337
- - Respect NCBI's ~3 requests/sec guideline if scaling up usage.
338
- 5. **Logging & Monitoring**
339
- - In production, set up robust logging/observability for success/failure rates.
340
- 6. **Security & Privacy**
341
- - Currently only uses public info. If patient data is included, ensure HIPAA/GDPR compliance.
342
  """)
343
 
344
-
345
  if __name__ == "__main__":
346
  main()
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from config import (
5
+ OPENAI_API_KEY,
6
+ GEMINI_API_KEY,
7
+ DEFAULT_CHUNK_SIZE
8
+ )
9
+ from models import configure_llms, openai_chat, gemini_chat
10
+ from pubmed_utils import (
11
+ search_pubmed,
12
+ fetch_pubmed_abstracts,
13
+ chunk_and_summarize
14
+ )
15
+ from image_pipeline import load_image_model, analyze_image
16
 
17
  ###############################################################################
18
+ # PAGE CONFIG FIRST #
19
  ###############################################################################
20
+ st.set_page_config(page_title="RAG + Image: Production Scenario", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ###############################################################################
23
+ # INITIALIZE & LOAD MODELS #
24
  ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def initialize_app():
27
  """
28
+ Configures LLMs, loads image model, etc.
29
+ Cache these calls for performance in HF Spaces.
30
  """
31
+ configure_llms() # sets openai.api_key and genai.configure if keys are present
32
+ image_model = load_image_model()
33
+ return image_model
 
 
 
 
 
 
 
 
 
34
 
35
+ image_model = initialize_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ###############################################################################
38
+ # HELPER: BUILD SYSTEM PROMPT WITH REFERENCES #
39
  ###############################################################################
40
+ def build_system_prompt_with_refs(pmids, summaries):
41
  """
42
+ Creates a system prompt that includes references [Ref1], [Ref2], etc.
 
 
43
  """
44
+ system_context = "You have access to the following summarized PubMed articles:\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  for idx, pmid in enumerate(pmids, start=1):
46
  ref_label = f"[Ref{idx}]"
47
+ system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
48
+ system_context += (
49
+ "Use this info to answer the user's question. Cite references as needed."
50
+ )
51
  return system_context
52
 
53
  ###############################################################################
54
+ # MAIN APP #
55
  ###############################################################################
56
  def main():
57
+ st.title("RAG + Image: Production-Ready Medical AI")
 
58
 
59
  st.markdown("""
60
+ **Features**:
61
+ 1. *PubMed RAG Pipeline*: Search, fetch, summarize, then generate a final answer with LLM.
62
+ 2. *Optional Image Analysis*: Upload an image for a simple caption or interpretive text.
63
+ 3. *Separation of Concerns*: Each major function is in its own module for maintainability.
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ **Disclaimer**: Not a substitute for professional medical advice.
66
+ """)
 
 
 
67
 
68
+ # Section A: Image pipeline
69
+ st.subheader("Image Analysis")
70
+ uploaded_image = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
71
+ if uploaded_image:
72
+ with st.spinner("Analyzing image..."):
73
+ caption = analyze_image(uploaded_image, image_model)
74
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
75
+ st.write("**Model Output:**", caption)
76
+ st.write("---")
77
+
78
+ # Section B: PubMed-based RAG
79
+ st.subheader("PubMed Retrieval & Summarization")
80
+ user_query = st.text_input("Enter your medical question:", "What are the latest treatments for type 2 diabetes complications?")
81
+
82
+ col1, col2, col3 = st.columns([2, 1, 1])
83
  with col1:
84
+ st.markdown("**Set Pipeline Params**")
85
+ max_papers = st.slider("PubMed Articles to Retrieve", 1, 10, 3)
86
+ chunk_size = st.slider("Summarization Chunk Size", 256, 1024, DEFAULT_CHUNK_SIZE)
 
 
 
 
87
  with col2:
88
+ selected_llm = st.selectbox("Select LLM", ["OpenAI GPT-3.5", "Gemini PaLM2"])
89
+ with col3:
90
+ temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ if st.button("Run RAG Pipeline"):
93
  if not user_query.strip():
94
+ st.warning("Please enter a question.")
95
  return
96
 
97
+ # 1) PubMed retrieval
98
  with st.spinner("Searching PubMed..."):
99
+ pmids = search_pubmed(user_query, max_results=max_papers)
100
 
101
  if not pmids:
102
+ st.error("No relevant results found. Try a different query.")
103
  return
104
 
105
+ # 2) Fetch & Summarize
106
+ with st.spinner("Fetching & Summarizing abstracts..."):
107
+ abs_map = fetch_pubmed_abstracts(pmids)
108
  summarized_map = {}
109
+ for pmid, text in abs_map.items():
110
+ if text.startswith("Error:"):
111
+ summarized_map[pmid] = text
112
  else:
113
+ summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
114
 
115
+ # 3) Display Summaries
116
  st.subheader("Retrieved & Summarized PubMed Articles")
117
  for idx, pmid in enumerate(pmids, start=1):
118
+ st.markdown(f"**[Ref{idx}] PMID {pmid}**")
 
119
  st.write(summarized_map[pmid])
120
  st.write("---")
121
 
122
+ # 4) Final LLM Answer
123
  st.subheader("RAG-Enhanced Final Answer")
124
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
125
+ with st.spinner("Generating answer..."):
126
+ if selected_llm == "OpenAI GPT-3.5":
127
+ answer = openai_chat(system_prompt, user_query, temperature=temperature)
 
128
  else:
129
+ answer = gemini_chat(system_prompt, user_query, temperature=temperature)
130
 
131
  st.write(answer)
132
  st.success("RAG Pipeline Complete.")
133
 
134
+ # Production tips
135
  st.markdown("---")
136
  st.markdown("""
137
+ ### Production Enhancements
138
+ - **Vector Database** for advanced retrieval
139
+ - **Citation Parsing** for accurate referencing
140
+ - **Multi-Lingual** expansions
141
+ - **Rate Limiting** for PubMed (max ~3 requests/sec)
142
+ - **Robust Logging / Monitoring**
143
+ - **Security & Privacy** if patient data is integrated
 
 
 
 
 
 
144
  """)
145
 
 
146
  if __name__ == "__main__":
147
  main()