Update app.py
Browse files
app.py
CHANGED
@@ -1,346 +1,147 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import json
|
4 |
-
import math
|
5 |
-
import requests
|
6 |
-
import threading
|
7 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
-
|
9 |
import streamlit as st
|
10 |
-
import
|
11 |
-
|
12 |
-
# Set page config FIRST, before any other Streamlit calls:
|
13 |
-
st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
|
14 |
-
|
15 |
-
# NLP
|
16 |
-
import nltk
|
17 |
-
nltk.download('punkt')
|
18 |
-
from nltk.tokenize import sent_tokenize
|
19 |
-
|
20 |
-
# Transformers for summarization
|
21 |
-
from transformers import pipeline
|
22 |
-
|
23 |
-
# Optional: OpenAI and Google Generative AI
|
24 |
-
import openai
|
25 |
-
import google.generativeai as genai
|
26 |
-
|
27 |
-
###############################################################################
|
28 |
-
# CONFIG & ENV #
|
29 |
-
###############################################################################
|
30 |
-
|
31 |
-
|
32 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
33 |
-
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
34 |
-
MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
|
35 |
-
|
36 |
-
# Configure OpenAI if key is provided
|
37 |
-
if OPENAI_API_KEY:
|
38 |
-
openai.api_key = OPENAI_API_KEY
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
###############################################################################
|
45 |
-
#
|
46 |
###############################################################################
|
47 |
-
|
48 |
-
def load_summarizer():
|
49 |
-
"""
|
50 |
-
Load a summarization model (e.g., BART, PEGASUS, T5).
|
51 |
-
For a more concise summarization, consider 'google/pegasus-xsum'.
|
52 |
-
For a balanced approach, 'facebook/bart-large-cnn' is popular.
|
53 |
-
"""
|
54 |
-
return pipeline(
|
55 |
-
"summarization",
|
56 |
-
model="facebook/bart-large-cnn",
|
57 |
-
tokenizer="facebook/bart-large-cnn"
|
58 |
-
)
|
59 |
-
|
60 |
-
summarizer = load_summarizer()
|
61 |
|
62 |
###############################################################################
|
63 |
-
#
|
64 |
###############################################################################
|
65 |
-
def search_pubmed(query, max_results=3):
|
66 |
-
"""
|
67 |
-
Searches PubMed for PMIDs matching the query.
|
68 |
-
Includes recommended 'tool' and 'email' in the request.
|
69 |
-
"""
|
70 |
-
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
71 |
-
params = {
|
72 |
-
"db": "pubmed",
|
73 |
-
"term": query,
|
74 |
-
"retmax": max_results,
|
75 |
-
"retmode": "json",
|
76 |
-
"tool": "ElysiumRAG",
|
77 |
-
"email": MY_PUBMED_EMAIL
|
78 |
-
}
|
79 |
-
resp = requests.get(base_url, params=params)
|
80 |
-
resp.raise_for_status()
|
81 |
-
data = resp.json()
|
82 |
-
id_list = data.get("esearchresult", {}).get("idlist", [])
|
83 |
-
return id_list
|
84 |
|
85 |
-
def
|
86 |
"""
|
87 |
-
|
88 |
-
|
89 |
"""
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
"retmode": "text",
|
94 |
-
"rettype": "abstract",
|
95 |
-
"id": pmid,
|
96 |
-
"tool": "ElysiumRAG",
|
97 |
-
"email": MY_PUBMED_EMAIL
|
98 |
-
}
|
99 |
-
resp = requests.get(base_url, params=params)
|
100 |
-
resp.raise_for_status()
|
101 |
-
raw_text = resp.text.strip()
|
102 |
|
103 |
-
|
104 |
-
return (pmid, "No abstract text found.")
|
105 |
-
return (pmid, raw_text)
|
106 |
-
|
107 |
-
def fetch_pubmed_abstracts(pmids):
|
108 |
-
"""
|
109 |
-
Parallel fetching of multiple PMIDs to reduce overall latency.
|
110 |
-
Returns {pmid: abstract_text}.
|
111 |
-
"""
|
112 |
-
abstracts_map = {}
|
113 |
-
if not pmids:
|
114 |
-
return abstracts_map
|
115 |
-
|
116 |
-
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
|
117 |
-
future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
|
118 |
-
for future in as_completed(future_to_pmid):
|
119 |
-
pmid = future_to_pmid[future]
|
120 |
-
try:
|
121 |
-
pmid_result, text = future.result()
|
122 |
-
abstracts_map[pmid_result] = text
|
123 |
-
except Exception as e:
|
124 |
-
abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
|
125 |
-
return abstracts_map
|
126 |
|
127 |
###############################################################################
|
128 |
-
#
|
129 |
###############################################################################
|
130 |
-
def
|
131 |
"""
|
132 |
-
|
133 |
-
then summarizes each chunk with the Hugging Face pipeline.
|
134 |
-
Returns a combined summary for the entire abstract.
|
135 |
"""
|
136 |
-
|
137 |
-
chunks = []
|
138 |
-
|
139 |
-
current_chunk = []
|
140 |
-
current_length = 0
|
141 |
-
for sent in sentences:
|
142 |
-
tokens_in_sent = len(sent.split())
|
143 |
-
# If adding this sentence exceeds the chunk_size limit, finalize the chunk
|
144 |
-
if current_length + tokens_in_sent > chunk_size:
|
145 |
-
chunks.append(" ".join(current_chunk))
|
146 |
-
current_chunk = []
|
147 |
-
current_length = 0
|
148 |
-
|
149 |
-
current_chunk.append(sent)
|
150 |
-
current_length += tokens_in_sent
|
151 |
-
|
152 |
-
# Final chunk if it exists
|
153 |
-
if current_chunk:
|
154 |
-
chunks.append(" ".join(current_chunk))
|
155 |
-
|
156 |
-
summarized_pieces = []
|
157 |
-
for c in chunks:
|
158 |
-
summary_out = summarizer(
|
159 |
-
c,
|
160 |
-
max_length=100, # Tweak for desired summary length
|
161 |
-
min_length=30,
|
162 |
-
do_sample=False
|
163 |
-
)
|
164 |
-
summarized_pieces.append(summary_out[0]['summary_text'])
|
165 |
-
|
166 |
-
final_summary = " ".join(summarized_pieces)
|
167 |
-
return final_summary.strip()
|
168 |
-
|
169 |
-
###############################################################################
|
170 |
-
# LLM CALLS (OpenAI / Gemini) #
|
171 |
-
###############################################################################
|
172 |
-
def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
|
173 |
-
"""
|
174 |
-
Basic ChatCompletion with a system + user role for OpenAI.
|
175 |
-
"""
|
176 |
-
if not OPENAI_API_KEY:
|
177 |
-
return "Error: OpenAI API key not provided."
|
178 |
-
try:
|
179 |
-
response = openai.ChatCompletion.create(
|
180 |
-
model=model,
|
181 |
-
messages=[
|
182 |
-
{"role": "system", "content": system_prompt},
|
183 |
-
{"role": "user", "content": user_message}
|
184 |
-
],
|
185 |
-
temperature=temperature
|
186 |
-
)
|
187 |
-
return response.choices[0].message["content"].strip()
|
188 |
-
except Exception as e:
|
189 |
-
return f"Error calling OpenAI: {str(e)}"
|
190 |
-
|
191 |
-
def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
|
192 |
-
"""
|
193 |
-
Basic PaLM2/Gemini chat call using google.generativeai.
|
194 |
-
"""
|
195 |
-
if not GEMINI_API_KEY:
|
196 |
-
return "Error: Gemini API key not provided."
|
197 |
-
try:
|
198 |
-
model = genai.GenerativeModel(model_name=model_name)
|
199 |
-
chat_session = model.start_chat(history=[("system", system_prompt)])
|
200 |
-
reply = chat_session.send_message(user_message, temperature=temperature)
|
201 |
-
return reply.text
|
202 |
-
except Exception as e:
|
203 |
-
return f"Error calling Gemini: {str(e)}"
|
204 |
-
|
205 |
-
###############################################################################
|
206 |
-
# BUILD REFERENCES FOR ANSWER #
|
207 |
-
###############################################################################
|
208 |
-
def build_system_prompt_with_refs(pmids, summarized_map):
|
209 |
-
"""
|
210 |
-
Creates a system prompt that includes the summarized abstracts alongside
|
211 |
-
labeled references (e.g., [Ref1]) so the LLM can cite them in the final answer.
|
212 |
-
"""
|
213 |
-
system_context = (
|
214 |
-
"You have access to the following summarized PubMed articles. "
|
215 |
-
"When relevant, cite them using their reference label.\n\n"
|
216 |
-
)
|
217 |
for idx, pmid in enumerate(pmids, start=1):
|
218 |
ref_label = f"[Ref{idx}]"
|
219 |
-
system_context += f"{ref_label} (PMID {pmid}): {
|
220 |
-
|
221 |
-
|
|
|
222 |
return system_context
|
223 |
|
224 |
###############################################################################
|
225 |
-
#
|
226 |
###############################################################################
|
227 |
def main():
|
228 |
-
|
229 |
-
st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
|
230 |
|
231 |
st.markdown("""
|
232 |
-
**
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
This version includes:
|
237 |
-
- **Parallel** fetching for multiple PMIDs
|
238 |
-
- Advanced **chunking & summarization** of large abstracts
|
239 |
-
- **Reference labeling** in the final answer
|
240 |
-
- Clear disclaimers & best-practice structures
|
241 |
-
|
242 |
-
---
|
243 |
-
**Disclaimer**: This is a demonstration prototype for educational or research purposes.
|
244 |
-
It is *not* a substitute for professional medical advice. Always consult a qualified
|
245 |
-
healthcare provider for personal health decisions.
|
246 |
-
""")
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
|
251 |
-
height=120
|
252 |
-
)
|
253 |
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
with col1:
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
max_value=10,
|
260 |
-
value=3,
|
261 |
-
help="Number of articles to fetch & summarize."
|
262 |
-
)
|
263 |
with col2:
|
264 |
-
selected_llm = st.selectbox(
|
265 |
-
|
266 |
-
|
267 |
-
help="Choose which large language model to finalize the answer."
|
268 |
-
)
|
269 |
-
|
270 |
-
# Additional advanced parameter: chunk size
|
271 |
-
chunk_size = st.slider(
|
272 |
-
"Summarization Chunk Size (words)",
|
273 |
-
min_value=256,
|
274 |
-
max_value=1024,
|
275 |
-
value=512,
|
276 |
-
help=(
|
277 |
-
"Larger chunks produce fewer summarization calls, but risk token limits. "
|
278 |
-
"Smaller chunks produce more robust summaries."
|
279 |
-
)
|
280 |
-
)
|
281 |
|
282 |
-
if st.button("Run
|
283 |
if not user_query.strip():
|
284 |
-
st.warning("Please enter a
|
285 |
return
|
286 |
|
287 |
-
# 1
|
288 |
with st.spinner("Searching PubMed..."):
|
289 |
-
pmids = search_pubmed(
|
290 |
|
291 |
if not pmids:
|
292 |
-
st.error("No
|
293 |
return
|
294 |
|
295 |
-
# 2
|
296 |
-
with st.spinner("Fetching
|
297 |
-
|
298 |
summarized_map = {}
|
299 |
-
for pmid,
|
300 |
-
if "Error
|
301 |
-
summarized_map[pmid] =
|
302 |
else:
|
303 |
-
summarized_map[pmid] = chunk_and_summarize(
|
304 |
|
305 |
-
# 3
|
306 |
st.subheader("Retrieved & Summarized PubMed Articles")
|
307 |
for idx, pmid in enumerate(pmids, start=1):
|
308 |
-
|
309 |
-
st.markdown(f"**{ref_label} PMID {pmid}**")
|
310 |
st.write(summarized_map[pmid])
|
311 |
st.write("---")
|
312 |
|
313 |
-
# 4
|
314 |
st.subheader("RAG-Enhanced Final Answer")
|
315 |
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
|
320 |
else:
|
321 |
-
answer = gemini_chat(system_prompt
|
322 |
|
323 |
st.write(answer)
|
324 |
st.success("RAG Pipeline Complete.")
|
325 |
|
326 |
-
# Production
|
327 |
st.markdown("---")
|
328 |
st.markdown("""
|
329 |
-
### Production
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
4. **Rate Limiting**
|
337 |
-
- Respect NCBI's ~3 requests/sec guideline if scaling up usage.
|
338 |
-
5. **Logging & Monitoring**
|
339 |
-
- In production, set up robust logging/observability for success/failure rates.
|
340 |
-
6. **Security & Privacy**
|
341 |
-
- Currently only uses public info. If patient data is included, ensure HIPAA/GDPR compliance.
|
342 |
""")
|
343 |
|
344 |
-
|
345 |
if __name__ == "__main__":
|
346 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
from config import (
|
5 |
+
OPENAI_API_KEY,
|
6 |
+
GEMINI_API_KEY,
|
7 |
+
DEFAULT_CHUNK_SIZE
|
8 |
+
)
|
9 |
+
from models import configure_llms, openai_chat, gemini_chat
|
10 |
+
from pubmed_utils import (
|
11 |
+
search_pubmed,
|
12 |
+
fetch_pubmed_abstracts,
|
13 |
+
chunk_and_summarize
|
14 |
+
)
|
15 |
+
from image_pipeline import load_image_model, analyze_image
|
16 |
|
17 |
###############################################################################
|
18 |
+
# PAGE CONFIG FIRST #
|
19 |
###############################################################################
|
20 |
+
st.set_page_config(page_title="RAG + Image: Production Scenario", layout="wide")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
###############################################################################
|
23 |
+
# INITIALIZE & LOAD MODELS #
|
24 |
###############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
def initialize_app():
|
27 |
"""
|
28 |
+
Configures LLMs, loads image model, etc.
|
29 |
+
Cache these calls for performance in HF Spaces.
|
30 |
"""
|
31 |
+
configure_llms() # sets openai.api_key and genai.configure if keys are present
|
32 |
+
image_model = load_image_model()
|
33 |
+
return image_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
image_model = initialize_app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
###############################################################################
|
38 |
+
# HELPER: BUILD SYSTEM PROMPT WITH REFERENCES #
|
39 |
###############################################################################
|
40 |
+
def build_system_prompt_with_refs(pmids, summaries):
|
41 |
"""
|
42 |
+
Creates a system prompt that includes references [Ref1], [Ref2], etc.
|
|
|
|
|
43 |
"""
|
44 |
+
system_context = "You have access to the following summarized PubMed articles:\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
for idx, pmid in enumerate(pmids, start=1):
|
46 |
ref_label = f"[Ref{idx}]"
|
47 |
+
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
|
48 |
+
system_context += (
|
49 |
+
"Use this info to answer the user's question. Cite references as needed."
|
50 |
+
)
|
51 |
return system_context
|
52 |
|
53 |
###############################################################################
|
54 |
+
# MAIN APP #
|
55 |
###############################################################################
|
56 |
def main():
|
57 |
+
st.title("RAG + Image: Production-Ready Medical AI")
|
|
|
58 |
|
59 |
st.markdown("""
|
60 |
+
**Features**:
|
61 |
+
1. *PubMed RAG Pipeline*: Search, fetch, summarize, then generate a final answer with LLM.
|
62 |
+
2. *Optional Image Analysis*: Upload an image for a simple caption or interpretive text.
|
63 |
+
3. *Separation of Concerns*: Each major function is in its own module for maintainability.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
**Disclaimer**: Not a substitute for professional medical advice.
|
66 |
+
""")
|
|
|
|
|
|
|
67 |
|
68 |
+
# Section A: Image pipeline
|
69 |
+
st.subheader("Image Analysis")
|
70 |
+
uploaded_image = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
|
71 |
+
if uploaded_image:
|
72 |
+
with st.spinner("Analyzing image..."):
|
73 |
+
caption = analyze_image(uploaded_image, image_model)
|
74 |
+
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
75 |
+
st.write("**Model Output:**", caption)
|
76 |
+
st.write("---")
|
77 |
+
|
78 |
+
# Section B: PubMed-based RAG
|
79 |
+
st.subheader("PubMed Retrieval & Summarization")
|
80 |
+
user_query = st.text_input("Enter your medical question:", "What are the latest treatments for type 2 diabetes complications?")
|
81 |
+
|
82 |
+
col1, col2, col3 = st.columns([2, 1, 1])
|
83 |
with col1:
|
84 |
+
st.markdown("**Set Pipeline Params**")
|
85 |
+
max_papers = st.slider("PubMed Articles to Retrieve", 1, 10, 3)
|
86 |
+
chunk_size = st.slider("Summarization Chunk Size", 256, 1024, DEFAULT_CHUNK_SIZE)
|
|
|
|
|
|
|
|
|
87 |
with col2:
|
88 |
+
selected_llm = st.selectbox("Select LLM", ["OpenAI GPT-3.5", "Gemini PaLM2"])
|
89 |
+
with col3:
|
90 |
+
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
if st.button("Run RAG Pipeline"):
|
93 |
if not user_query.strip():
|
94 |
+
st.warning("Please enter a question.")
|
95 |
return
|
96 |
|
97 |
+
# 1) PubMed retrieval
|
98 |
with st.spinner("Searching PubMed..."):
|
99 |
+
pmids = search_pubmed(user_query, max_results=max_papers)
|
100 |
|
101 |
if not pmids:
|
102 |
+
st.error("No relevant results found. Try a different query.")
|
103 |
return
|
104 |
|
105 |
+
# 2) Fetch & Summarize
|
106 |
+
with st.spinner("Fetching & Summarizing abstracts..."):
|
107 |
+
abs_map = fetch_pubmed_abstracts(pmids)
|
108 |
summarized_map = {}
|
109 |
+
for pmid, text in abs_map.items():
|
110 |
+
if text.startswith("Error:"):
|
111 |
+
summarized_map[pmid] = text
|
112 |
else:
|
113 |
+
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
|
114 |
|
115 |
+
# 3) Display Summaries
|
116 |
st.subheader("Retrieved & Summarized PubMed Articles")
|
117 |
for idx, pmid in enumerate(pmids, start=1):
|
118 |
+
st.markdown(f"**[Ref{idx}] PMID {pmid}**")
|
|
|
119 |
st.write(summarized_map[pmid])
|
120 |
st.write("---")
|
121 |
|
122 |
+
# 4) Final LLM Answer
|
123 |
st.subheader("RAG-Enhanced Final Answer")
|
124 |
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
|
125 |
+
with st.spinner("Generating answer..."):
|
126 |
+
if selected_llm == "OpenAI GPT-3.5":
|
127 |
+
answer = openai_chat(system_prompt, user_query, temperature=temperature)
|
|
|
128 |
else:
|
129 |
+
answer = gemini_chat(system_prompt, user_query, temperature=temperature)
|
130 |
|
131 |
st.write(answer)
|
132 |
st.success("RAG Pipeline Complete.")
|
133 |
|
134 |
+
# Production tips
|
135 |
st.markdown("---")
|
136 |
st.markdown("""
|
137 |
+
### Production Enhancements
|
138 |
+
- **Vector Database** for advanced retrieval
|
139 |
+
- **Citation Parsing** for accurate referencing
|
140 |
+
- **Multi-Lingual** expansions
|
141 |
+
- **Rate Limiting** for PubMed (max ~3 requests/sec)
|
142 |
+
- **Robust Logging / Monitoring**
|
143 |
+
- **Security & Privacy** if patient data is integrated
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
""")
|
145 |
|
|
|
146 |
if __name__ == "__main__":
|
147 |
main()
|