mgbam commited on
Commit
76418d6
Β·
verified Β·
1 Parent(s): 25baf98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -176
app.py CHANGED
@@ -1,13 +1,15 @@
1
  #!/usr/bin/env python3
2
- # app.py – MedGenesis AI Β· Streamlit front-end (v3)
3
- # ---------------------------------------------------
4
- # β€’ Dual-LLM selector (OpenAI | Gemini)
5
- # β€’ Robust PDF export (all Unicode β†’ Latin-1 safe)
6
- # β€’ Lazy session-state handling so a failed background
7
- # request never kills the whole app.
8
- # β€’ New β€œVariants” tab (cBioPortal) + null-safe β€œGraph”
9
- # and β€œMetrics” using the patched helpers.
10
 
 
 
 
 
 
 
 
 
11
  import os, pathlib, asyncio, re
12
  from pathlib import Path
13
 
@@ -24,35 +26,30 @@ from mcp.graph_metrics import build_nx, get_top_hubs, get_density
24
 
25
  # ── Streamlit telemetry dir fix ─────────────────────────────────────
26
  os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
27
- os.environ["XDG_STATE_HOME"] = "/tmp"
28
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
29
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
30
 
31
  ROOT = Path(__file__).parent
32
  LOGO = ROOT / "assets" / "logo.png"
33
 
34
- # ── PDF export helper (robust to ALL Unicode) ───────────────────────
35
- def _latin1_safe(txt: str) -> str:
36
  return txt.encode("latin-1", "replace").decode("latin-1")
37
 
38
- def _pdf(papers):
39
  pdf = FPDF()
40
  pdf.set_auto_page_break(auto=True, margin=15)
41
  pdf.add_page()
42
  pdf.set_font("Helvetica", size=11)
43
- pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Results"), ln=True, align="C")
44
  pdf.ln(3)
45
-
46
  for i, p in enumerate(papers, 1):
47
  pdf.set_font("Helvetica", "B", 11)
48
- pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p['title']}"))
49
  pdf.set_font("Helvetica", "", 9)
50
- body = (
51
- f"{p['authors']}\n"
52
- f"{p['summary']}\n"
53
- f"{p['link']}\n"
54
- )
55
- pdf.multi_cell(0, 6, _latin1_safe(body))
56
  pdf.ln(1)
57
  return pdf.output(dest="S").encode("latin-1", "replace")
58
 
@@ -68,189 +65,192 @@ def _workspace_sidebar():
68
  with st.expander(f"{i}. {item['query']}"):
69
  st.write(item["result"]["ai_summary"])
70
 
71
- # ── UI main routine ─────────────────────────────────────────────────
72
- def render_ui():
73
  st.set_page_config("MedGenesis AI", layout="wide")
74
 
75
- # Session-state defaults
76
- for key, default in {
77
- "query_result" : None,
78
- "last_query" : "",
79
- "last_llm" : "openai",
80
- "followup_input" : "",
81
  "followup_response": None,
82
- }.items():
83
- if key not in st.session_state:
84
- st.session_state[key] = default
85
 
86
  _workspace_sidebar()
87
 
88
- # Header block
89
  c1, c2 = st.columns([0.15, 0.85])
90
  with c1:
91
  if LOGO.exists():
92
  st.image(str(LOGO), width=105)
93
  with c2:
94
  st.markdown("## 🧬 **MedGenesis AI**")
95
- st.caption("Multi-source biomedical assistant – OpenAI / Gemini")
96
 
97
  # Controls
98
- llm = st.radio("LLM engine", ["openai", "gemini"],
99
- horizontal=True, index=0)
100
  query = st.text_input("Enter biomedical question",
101
  placeholder="e.g. CRISPR glioblastoma therapy")
102
 
103
- # Run search
104
  if st.button("Run Search πŸš€") and query:
105
  with st.spinner("Collecting literature & biomedical data …"):
106
  res = asyncio.run(orchestrate_search(query, llm=llm))
107
- st.session_state.query_result = res
108
- st.session_state.last_query = query
109
- st.session_state.last_llm = llm
110
- st.session_state.followup_input = ""
111
- st.session_state.followup_response = None
112
-
113
- res = st.session_state.query_result
114
- if res:
115
- # Guard against missing keys
116
- for key in (
117
- "papers", "umls", "drug_safety", "genes", "mesh_defs",
118
- "gene_disease", "clinical_trials", "variants"
119
- ):
120
- res.setdefault(key, [])
121
-
122
- # -------------- TABS -------------------------------------------------
123
- tabs = st.tabs([
124
- "Results", "Genes", "Trials", "Variants",
125
- "Graph", "Metrics", "Visuals"
126
- ])
127
-
128
- # ── Results tab ─────────────────────────────────────────────────────
129
- with tabs[0]:
130
- st.subheader("Literature")
131
- for i, p in enumerate(res["papers"], 1):
132
- st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
133
- st.write(p["summary"])
134
- col1, col2 = st.columns(2)
135
- with col1:
136
- st.download_button(
137
- "CSV",
138
- pd.DataFrame(res["papers"]).to_csv(index=False),
139
- "papers.csv",
140
- "text/csv",
141
- )
142
- with col2:
143
- st.download_button(
144
- "PDF",
145
- _pdf(res["papers"]),
146
- "papers.pdf",
147
- "application/pdf",
148
- )
149
- if st.button("πŸ’Ύ Save"):
150
- save_query(st.session_state.last_query, res)
151
- st.success("Saved to workspace")
152
-
153
- st.subheader("UMLS concepts")
154
- for c in res["umls"]:
155
- if c.get("cui"):
156
- st.write(f"- **{c['name']}** ({c['cui']})")
157
-
158
- st.subheader("OpenFDA safety signals")
159
- for d in res["drug_safety"]:
160
- st.json(d)
161
-
162
- st.subheader("AI summary")
163
- st.info(res["ai_summary"])
164
-
165
- # ── Genes tab ───────────────────────────────────────────────────────
166
- with tabs[1]:
167
- st.header("Gene / Variant signals")
168
- for g in res["genes"]:
169
- lab = g.get("name") or g.get("symbol") or g.get("geneid")
 
 
 
 
 
170
  st.write(f"- **{lab}**")
171
- if res["gene_disease"]:
172
- st.markdown("### DisGeNET associations")
173
- st.json(res["gene_disease"][:15])
174
- if res["mesh_defs"]:
175
- st.markdown("### MeSH definitions")
176
- for d in res["mesh_defs"]:
177
- if d:
178
- st.write("-", d)
179
-
180
- # ── Trials tab ──────────────────────────────────────────────────────
181
- with tabs[2]:
182
- st.header("Clinical trials")
183
- if not res["clinical_trials"]:
184
- st.info("No trials (rate-limited or none found).")
 
 
 
185
  for t in res["clinical_trials"]:
186
  st.markdown(f"**{t['nctId']}** – {t['briefTitle']}")
187
  st.write(f"Phase {t.get('phase')} | Status {t.get('status')}")
188
 
189
- # ── Variants tab ────────────────────────────────────────────────────
190
- with tabs[3]:
191
- st.header("Cancer variants (cBioPortal)")
192
- if not res["variants"]:
193
- st.info("No variant data.")
194
- else:
195
- st.json(res["variants"][:50])
196
-
197
- # ── Graph tab ───────────────────────────────────────────────────────
198
- with tabs[4]:
199
- nodes, edges, cfg = build_agraph(
200
- res["papers"], res["umls"], res["drug_safety"]
201
- )
202
- hl = st.text_input("Highlight node:", key="hl")
203
- if hl:
204
- pat = re.compile(re.escape(hl), re.I)
205
- for n in nodes:
206
- n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
207
- agraph(nodes, edges, cfg)
208
-
209
- # ── Metrics tab ─────────────────────────────────────────────────────
210
- with tabs[5]:
211
- G = build_nx(
212
- [n.__dict__ for n in nodes],
213
- [e.__dict__ for e in edges],
214
- )
215
- st.metric("Density", f"{get_density(G):.3f}")
216
- st.markdown("**Top hubs**")
217
- for nid, sc in get_top_hubs(G):
218
- lab = next((n.label for n in nodes if n.id == nid), nid)
219
- st.write(f"- {lab} {sc:.3f}")
220
-
221
- # ── Visuals tab ────────────────────────────────────────────────────
222
- with tabs[6]:
223
- years = [p.get("published", "")[:4] for p in res["papers"] if p.get("published")]
224
- if years:
225
- st.plotly_chart(px.histogram(years, nbins=12,
226
- title="Publication Year"))
227
-
228
- # ── Follow-up Q-A block ────────────────────────────────────────────
229
- st.markdown("---")
230
- st.text_input("Ask follow-up question:", key="followup_input")
231
-
232
- def _on_ask():
233
- q = st.session_state.followup_input.strip()
234
- if not q:
235
- st.warning("Please type a question first.")
236
- return
237
- with st.spinner("Querying LLM …"):
238
- ans = asyncio.run(
239
- answer_ai_question(
240
- q,
241
- context=st.session_state.last_query,
242
- llm=st.session_state.last_llm,
243
- )
244
  )
245
- st.session_state.followup_response = ans["answer"]
246
-
247
- st.button("Ask AI", on_click=_on_ask)
248
 
249
- if st.session_state.followup_response:
250
- st.write(st.session_state.followup_response)
251
 
252
- else:
253
- st.info("Enter a question and press **Run Search πŸš€**")
254
 
255
 
256
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
+ """
3
+ MedGenesis AI – Streamlit UI (v3, June 2025)
 
 
 
 
 
 
4
 
5
+ β€’ Dual-LLM selector (OpenAI | Gemini)
6
+ β€’ Tabs: Results | Genes | Trials | Variants | Graph | Metrics | Visuals
7
+ β€’ Robust PDF export (all Unicode β†’ Latin-1 safe)
8
+ β€’ Null-safe handling of any RuntimeError / HTTPStatusError objects that
9
+ slip through the async pipeline.
10
+ """
11
+
12
+ from __future__ import annotations
13
  import os, pathlib, asyncio, re
14
  from pathlib import Path
15
 
 
26
 
27
  # ── Streamlit telemetry dir fix ─────────────────────────────────────
28
  os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
29
+ os.environ["XDG_STATE_HOME"] = "/tmp"
30
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
31
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
32
 
33
  ROOT = Path(__file__).parent
34
  LOGO = ROOT / "assets" / "logo.png"
35
 
36
+ # ── PDF helper ──────────────────────────────────────────────────────
37
+ def _latin1(txt: str) -> str:
38
  return txt.encode("latin-1", "replace").decode("latin-1")
39
 
40
+ def _pdf(papers: list[dict]) -> bytes:
41
  pdf = FPDF()
42
  pdf.set_auto_page_break(auto=True, margin=15)
43
  pdf.add_page()
44
  pdf.set_font("Helvetica", size=11)
45
+ pdf.cell(200, 8, _latin1("MedGenesis AI – Results"), ln=True, align="C")
46
  pdf.ln(3)
 
47
  for i, p in enumerate(papers, 1):
48
  pdf.set_font("Helvetica", "B", 11)
49
+ pdf.multi_cell(0, 7, _latin1(f"{i}. {p['title']}"))
50
  pdf.set_font("Helvetica", "", 9)
51
+ body = f"{p['authors']}\n{p['summary']}\n{p['link']}\n"
52
+ pdf.multi_cell(0, 6, _latin1(body))
 
 
 
 
53
  pdf.ln(1)
54
  return pdf.output(dest="S").encode("latin-1", "replace")
55
 
 
65
  with st.expander(f"{i}. {item['query']}"):
66
  st.write(item["result"]["ai_summary"])
67
 
68
+ # ── Main UI ─────────────────────────────────────────────────────────
69
+ def render_ui() -> None:
70
  st.set_page_config("MedGenesis AI", layout="wide")
71
 
72
+ # Session defaults
73
+ defaults = {
74
+ "query_result": None,
75
+ "last_query": "",
76
+ "last_llm": "openai",
77
+ "followup_input": "",
78
  "followup_response": None,
79
+ }
80
+ for k, v in defaults.items():
81
+ st.session_state.setdefault(k, v)
82
 
83
  _workspace_sidebar()
84
 
85
+ # Header
86
  c1, c2 = st.columns([0.15, 0.85])
87
  with c1:
88
  if LOGO.exists():
89
  st.image(str(LOGO), width=105)
90
  with c2:
91
  st.markdown("## 🧬 **MedGenesis AI**")
92
+ st.caption("Multi-source biomedical assistant Β· OpenAI / Gemini")
93
 
94
  # Controls
95
+ llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
 
96
  query = st.text_input("Enter biomedical question",
97
  placeholder="e.g. CRISPR glioblastoma therapy")
98
 
 
99
  if st.button("Run Search πŸš€") and query:
100
  with st.spinner("Collecting literature & biomedical data …"):
101
  res = asyncio.run(orchestrate_search(query, llm=llm))
102
+ st.session_state.update(
103
+ query_result=res,
104
+ last_query=query,
105
+ last_llm=llm,
106
+ followup_input="",
107
+ followup_response=None,
108
+ )
109
+
110
+ res: dict | None = st.session_state.query_result
111
+ if not res:
112
+ st.info("Enter a question and press **Run Search πŸš€**")
113
+ return
114
+
115
+ # Guarantee all expected keys exist
116
+ for k in (
117
+ "papers", "umls", "drug_safety", "genes", "mesh_defs",
118
+ "gene_disease", "clinical_trials", "variants"
119
+ ):
120
+ res.setdefault(k, [])
121
+
122
+ # Tabs
123
+ tabs = st.tabs([
124
+ "Results", "Genes", "Trials", "Variants",
125
+ "Graph", "Metrics", "Visuals"
126
+ ])
127
+
128
+ # ---- Results ----------------------------------------------------
129
+ with tabs[0]:
130
+ st.subheader("Literature")
131
+ for i, p in enumerate(res["papers"], 1):
132
+ st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
133
+ st.write(p["summary"])
134
+ col1, col2 = st.columns(2)
135
+ with col1:
136
+ st.download_button(
137
+ "CSV",
138
+ pd.DataFrame(res["papers"]).to_csv(index=False),
139
+ "papers.csv",
140
+ "text/csv",
141
+ )
142
+ with col2:
143
+ st.download_button("PDF", _pdf(res["papers"]),
144
+ "papers.pdf", "application/pdf")
145
+ if st.button("πŸ’Ύ Save"):
146
+ save_query(st.session_state.last_query, res)
147
+ st.success("Saved to workspace")
148
+
149
+ st.subheader("UMLS concepts")
150
+ for c in res["umls"]:
151
+ if isinstance(c, dict) and c.get("cui"):
152
+ st.write(f"- **{c['name']}** ({c['cui']})")
153
+
154
+ st.subheader("OpenFDA safety signals")
155
+ for d in res["drug_safety"]:
156
+ st.json(d)
157
+
158
+ st.subheader("AI summary")
159
+ st.info(res["ai_summary"])
160
+
161
+ # ---- Genes ------------------------------------------------------
162
+ with tabs[1]:
163
+ st.header("Gene / Variant signals")
164
+ clean = [g for g in res["genes"] if isinstance(g, dict)]
165
+ if not clean:
166
+ st.info("No gene metadata (API may be rate-limited).")
167
+ else:
168
+ for g in clean:
169
+ lab = g.get("name") or g.get("symbol") or str(g.get("geneid", ""))
170
  st.write(f"- **{lab}**")
171
+
172
+ if res["gene_disease"]:
173
+ st.markdown("### DisGeNET associations")
174
+ st.json(res["gene_disease"][:15])
175
+
176
+ if res["mesh_defs"]:
177
+ st.markdown("### MeSH definitions")
178
+ for d in res["mesh_defs"]:
179
+ if d:
180
+ st.write("-", d)
181
+
182
+ # ---- Trials -----------------------------------------------------
183
+ with tabs[2]:
184
+ st.header("Clinical trials")
185
+ if not res["clinical_trials"]:
186
+ st.info("No trials (rate-limited or none found).")
187
+ else:
188
  for t in res["clinical_trials"]:
189
  st.markdown(f"**{t['nctId']}** – {t['briefTitle']}")
190
  st.write(f"Phase {t.get('phase')} | Status {t.get('status')}")
191
 
192
+ # ---- Variants ---------------------------------------------------
193
+ with tabs[3]:
194
+ st.header("Cancer variants (cBioPortal)")
195
+ if not res["variants"]:
196
+ st.info("No variant data.")
197
+ else:
198
+ st.json(res["variants"][:50])
199
+
200
+ # ---- Graph ------------------------------------------------------
201
+ with tabs[4]:
202
+ nodes, edges, cfg = build_agraph(
203
+ res["papers"], res["umls"], res["drug_safety"]
204
+ )
205
+ hl = st.text_input("Highlight node:", key="hl")
206
+ if hl:
207
+ pat = re.compile(re.escape(hl), re.I)
208
+ for n in nodes:
209
+ n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
210
+ agraph(nodes, edges, cfg)
211
+
212
+ # ---- Metrics ----------------------------------------------------
213
+ with tabs[5]:
214
+ G = build_nx(
215
+ [n.__dict__ for n in nodes],
216
+ [e.__dict__ for e in edges],
217
+ )
218
+ st.metric("Density", f"{get_density(G):.3f}")
219
+ st.markdown("**Top hubs**")
220
+ for nid, sc in get_top_hubs(G):
221
+ lab = next((n.label for n in nodes if n.id == nid), nid)
222
+ st.write(f"- {lab} {sc:.3f}")
223
+
224
+ # ---- Visuals ----------------------------------------------------
225
+ with tabs[6]:
226
+ years = [p.get("published", "")[:4] for p in res["papers"] if p.get("published")]
227
+ if years:
228
+ st.plotly_chart(px.histogram(years, nbins=12,
229
+ title="Publication Year"))
230
+
231
+ # ---- Follow-up QA ----------------------------------------------
232
+ st.markdown("---")
233
+ st.text_input("Ask follow-up question:", key="followup_input")
234
+
235
+ def _on_ask():
236
+ q = st.session_state.followup_input.strip()
237
+ if not q:
238
+ st.warning("Please type a question first.")
239
+ return
240
+ with st.spinner("Querying LLM …"):
241
+ ans = asyncio.run(
242
+ answer_ai_question(
243
+ q,
244
+ context=st.session_state.last_query,
245
+ llm=st.session_state.last_llm,
 
246
  )
247
+ )
248
+ st.session_state.followup_response = ans["answer"]
 
249
 
250
+ st.button("Ask AI", on_click=_on_ask)
 
251
 
252
+ if st.session_state.followup_response:
253
+ st.write(st.session_state.followup_response)
254
 
255
 
256
  if __name__ == "__main__":