mgbam commited on
Commit
8ab7297
Β·
verified Β·
1 Parent(s): 2c2342d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -186
app.py CHANGED
@@ -1,10 +1,7 @@
1
- #!/usr/bin/env python3
2
- # ──────────────────────────────────────────────────────────────────────
3
- # MedGenesis AI – Streamlit UI (OpenAI + Gemini, CPU-only)
4
- # ──────────────────────────────────────────────────────────────────────
5
  import os, pathlib, asyncio, re
6
  from pathlib import Path
7
- from datetime import datetime
8
 
9
  import streamlit as st
10
  import pandas as pd
@@ -12,214 +9,174 @@ import plotly.express as px
12
  from fpdf import FPDF
13
  from streamlit_agraph import agraph
14
 
15
- # ── internal helpers --------------------------------------------------
16
- from mcp.orchestrator import orchestrate_search, answer_ai_question
17
- from mcp.workspace import get_workspace, save_query
18
  from mcp.knowledge_graph import build_agraph
19
- from mcp.graph_metrics import build_nx, get_top_hubs, get_density
20
- from mcp.alerts import check_alerts
21
 
22
- # ── Streamlit telemetry dir fix (HF Spaces sandbox quirks) ------------
23
- os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
24
- os.environ["XDG_STATE_HOME"] = "/tmp"
25
- os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
26
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
27
 
28
  ROOT = Path(__file__).parent
29
  LOGO = ROOT / "assets" / "logo.png"
30
 
31
- # ══════════════════════════════════════════════════════════════════════
32
- # Small util helpers
33
- # ══════════════════════════════════════════════════════════════════════
34
  def _latin1_safe(txt: str) -> str:
35
- """Replace non-Latin-1 chars – keeps FPDF happy."""
36
  return txt.encode("latin-1", "replace").decode("latin-1")
37
 
38
-
39
- def _pdf(papers: list[dict]) -> bytes:
40
  pdf = FPDF()
41
  pdf.set_auto_page_break(auto=True, margin=15)
42
  pdf.add_page()
43
  pdf.set_font("Helvetica", size=11)
44
- pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Literature results"),
45
- ln=True, align="C")
46
  pdf.ln(3)
47
-
48
  for i, p in enumerate(papers, 1):
49
  pdf.set_font("Helvetica", "B", 11)
50
- pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p['title']}"))
51
  pdf.set_font("Helvetica", "", 9)
52
- body = (
53
- f"{p['authors']}\n"
54
- f"{p['summary']}\n"
55
- f"{p['link']}\n"
56
- )
57
  pdf.multi_cell(0, 6, _latin1_safe(body))
58
  pdf.ln(1)
59
-
60
- # FPDF already returns latin-1 bytes – no extra encode needed
61
  return pdf.output(dest="S").encode("latin-1", "replace")
62
 
63
-
64
- def _workspace_sidebar() -> None:
65
  with st.sidebar:
66
- st.header("πŸ—‚ Workspace")
67
  ws = get_workspace()
68
  if not ws:
69
  st.info("Run a search then press **Save** to populate this list.")
70
  return
71
  for i, item in enumerate(ws, 1):
72
  with st.expander(f"{i}. {item['query']}"):
73
- st.write(item["result"]["ai_summary"])
74
-
75
 
76
- # ══════════════════════════════════════════════════════════════════════
77
- # Main Streamlit UI
78
- # ══════════════════════════════════════════════════════════════════════
79
- def render_ui() -> None:
80
  st.set_page_config("MedGenesis AI", layout="wide")
81
 
82
- # ── Session-state defaults ────────────────────────────────────────
83
- for k, v in {
84
- "query_result": None,
85
- "followup_input": "",
86
- "followup_response": None,
87
- "last_query": "",
88
- "last_llm": "",
89
- }.items():
90
- st.session_state.setdefault(k, v)
91
 
92
  _workspace_sidebar()
93
-
94
- col_logo, col_title = st.columns([0.15, 0.85])
95
- with col_logo:
96
  if LOGO.exists():
97
- st.image(LOGO, width=110)
98
- with col_title:
99
  st.markdown("## 🧬 **MedGenesis AI**")
100
  st.caption("Multi-source biomedical assistant Β· OpenAI / Gemini")
101
 
102
- llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
103
- query = st.text_input("Enter biomedical question",
104
- placeholder="e.g. CRISPR glioblastoma therapy")
105
 
106
- # ── alert notifications (async) ───────────────────────────────────
107
- saved_qs = [w["query"] for w in get_workspace()]
108
- if saved_qs:
109
  try:
110
- news = asyncio.run(check_alerts(saved_qs))
111
  if news:
112
  with st.sidebar:
113
  st.subheader("πŸ”” New papers")
114
  for q, lnks in news.items():
115
  st.write(f"**{q}** – {len(lnks)} new")
116
  except Exception:
117
- pass # network hiccups – silent
118
 
119
- # ── Run Search ----------------------------------------------------
120
- if st.button("Run Search πŸš€") and query.strip():
121
  with st.spinner("Collecting literature & biomedical data …"):
122
  res = asyncio.run(orchestrate_search(query, llm=llm))
123
-
124
- # store in session
125
- st.session_state.update(
126
- query_result=res,
127
- last_query=query,
128
- last_llm=llm,
129
- followup_input="",
130
- followup_response=None,
131
- )
132
- st.success(f"Completed with **{res['llm_used'].title()}**")
133
 
134
  res = st.session_state.query_result
135
  if not res:
136
- st.info("Enter a biomedical question and press **Run Search πŸš€**")
137
  return
138
 
139
- # ── Tabs ----------------------------------------------------------
140
- tabs = st.tabs(["Results", "Genes", "Trials",
141
- "Graph", "Metrics", "Visuals"])
142
-
143
- # 1) Results -------------------------------------------------------
144
  with tabs[0]:
145
- for i, p in enumerate(res["papers"], 1):
146
- st.markdown(
147
- f"**{i}. [{p['title']}]({p['link']})** "
148
- f"*{p['authors']}*"
149
- )
150
- st.write(p["summary"])
151
-
152
- c_csv, c_pdf = st.columns(2)
153
- with c_csv:
154
- st.download_button(
155
- "CSV",
156
- pd.DataFrame(res["papers"]).to_csv(index=False),
157
- "papers.csv",
158
- "text/csv",
159
- )
160
- with c_pdf:
161
- st.download_button("PDF", _pdf(res["papers"]),
162
- "papers.pdf", "application/pdf")
163
-
164
  if st.button("πŸ’Ύ Save"):
165
  save_query(st.session_state.last_query, res)
166
  st.success("Saved to workspace")
167
-
168
  st.subheader("UMLS concepts")
169
- for c in (res["umls"] or []):
170
  if isinstance(c, dict) and c.get("cui"):
171
- st.write(f"- **{c['name']}** ({c['cui']})")
172
-
173
  st.subheader("OpenFDA safety signals")
174
- for d in (res["drug_safety"] or []):
175
- st.json(d)
176
-
177
  st.subheader("AI summary")
178
- st.info(res["ai_summary"])
179
 
180
- # 2) Genes ---------------------------------------------------------
181
  with tabs[1]:
182
  st.header("Gene / Variant signals")
183
- genes_list = [
184
- g for g in res["genes"]
185
- if isinstance(g, dict) and (g.get("symbol") or g.get("name"))
186
- ]
187
- if not genes_list:
188
  st.info("No gene hits (rate-limited or none found).")
189
- for g in genes_list:
190
- st.write(f"- **{g.get('symbol') or g.get('name')}** "
191
- f"{g.get('description','')}")
192
- if res["gene_disease"]:
 
 
193
  st.markdown("### DisGeNET associations")
194
- ok = [d for d in res["gene_disease"] if isinstance(d, dict)]
195
- if ok:
196
- st.json(ok[:15])
197
-
198
- defs = [d for d in res["mesh_defs"] if isinstance(d, str) and d]
199
- if defs:
200
  st.markdown("### MeSH definitions")
201
- for d in defs:
202
- st.write("-", d)
 
203
 
204
- # 3) Trials --------------------------------------------------------
205
  with tabs[2]:
206
  st.header("Clinical trials")
207
- ct = res["clinical_trials"]
208
- if not ct:
209
  st.info("No trials (rate-limited or none found).")
210
- for t in ct:
211
- nct = t.get("NCTId", [""])[0]
212
- bttl = t.get("BriefTitle", [""])[0]
213
- phase= t.get("Phase", [""])[0]
214
- stat = t.get("OverallStatus", [""])[0]
215
- st.markdown(f"**{nct}** – {bttl}")
216
- st.write(f"Phase {phase} | Status {stat}")
217
-
218
- # 4) Graph ---------------------------------------------------------
 
219
  with tabs[3]:
220
- nodes, edges, cfg = build_agraph(
221
- res["papers"], res["umls"], res["drug_safety"]
222
- )
 
 
 
 
 
 
 
 
223
  hl = st.text_input("Highlight node:", key="hl")
224
  if hl:
225
  pat = re.compile(re.escape(hl), re.I)
@@ -227,60 +184,38 @@ def render_ui() -> None:
227
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
228
  agraph(nodes, edges, cfg)
229
 
230
- # 5) Metrics -------------------------------------------------------
231
- with tabs[4]:
232
- G = build_nx(
233
- [n.__dict__ for n in nodes],
234
- [e.__dict__ for e in edges],
235
- )
236
  st.metric("Density", f"{get_density(G):.3f}")
237
  st.markdown("**Top hubs**")
238
- for nid, sc in get_top_hubs(G, k=5):
239
- label = next((n.label for n in nodes if n.id == nid), nid)
240
- st.write(f"- {label} {sc:.3f}")
241
 
242
- # 6) Visuals -------------------------------------------------------
243
- with tabs[5]:
244
- years = [
245
- p["published"][:4] for p in res["papers"]
246
- if p.get("published") and len(p["published"]) >= 4
247
- ]
248
  if years:
249
- st.plotly_chart(
250
- px.histogram(
251
- years, nbins=min(15, len(set(years))),
252
- title="Publication Year"
253
- )
254
- )
255
 
256
- # ── Follow-up Q-A -------------------------------------------------
257
  st.markdown("---")
258
- st.text_input("Ask follow-up question:",
259
- key="followup_input",
260
- placeholder="e.g. Any Phase III trials recruiting now?")
261
-
262
- def _on_ask():
263
- q = st.session_state.followup_input.strip()
264
- if not q:
265
- st.warning("Please type a question first.")
266
- return
267
- with st.spinner("Querying LLM …"):
268
- ans = asyncio.run(
269
- answer_ai_question(
270
- q,
271
- context=st.session_state.last_query,
272
- llm=st.session_state.last_llm)
273
- )
274
- st.session_state.followup_response = (
275
- ans.get("answer") or "LLM unavailable or quota exceeded."
276
- )
277
-
278
- st.button("Ask AI", on_click=_on_ask)
279
-
280
  if st.session_state.followup_response:
281
  st.write(st.session_state.followup_response)
282
 
283
-
284
- # ── entry-point ───────────────────────────────────────────────────────
285
  if __name__ == "__main__":
286
  render_ui()
 
1
+ # app.py - MedGenesis AI Streamlit app (OpenAI/Gemini)
2
+
 
 
3
  import os, pathlib, asyncio, re
4
  from pathlib import Path
 
5
 
6
  import streamlit as st
7
  import pandas as pd
 
9
  from fpdf import FPDF
10
  from streamlit_agraph import agraph
11
 
12
+ from mcp.orchestrator import orchestrate_search, answer_ai_question
13
+ from mcp.workspace import get_workspace, save_query
 
14
  from mcp.knowledge_graph import build_agraph
15
+ from mcp.graph_metrics import build_nx, get_top_hubs, get_density
16
+ from mcp.alerts import check_alerts
17
 
18
+ # --- Fix Streamlit temp dir ---
19
+ os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
20
+ os.environ["XDG_STATE_HOME"] = "/tmp"
21
+ os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
22
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
23
 
24
  ROOT = Path(__file__).parent
25
  LOGO = ROOT / "assets" / "logo.png"
26
 
 
 
 
27
  def _latin1_safe(txt: str) -> str:
 
28
  return txt.encode("latin-1", "replace").decode("latin-1")
29
 
30
+ def _pdf(papers):
 
31
  pdf = FPDF()
32
  pdf.set_auto_page_break(auto=True, margin=15)
33
  pdf.add_page()
34
  pdf.set_font("Helvetica", size=11)
35
+ pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Results"), ln=True, align="C")
 
36
  pdf.ln(3)
 
37
  for i, p in enumerate(papers, 1):
38
  pdf.set_font("Helvetica", "B", 11)
39
+ pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p.get('title', '')}"))
40
  pdf.set_font("Helvetica", "", 9)
41
+ body = f"{p.get('authors','')}\n{p.get('summary','')}\n{p.get('link','')}\n"
 
 
 
 
42
  pdf.multi_cell(0, 6, _latin1_safe(body))
43
  pdf.ln(1)
 
 
44
  return pdf.output(dest="S").encode("latin-1", "replace")
45
 
46
+ def _workspace_sidebar():
 
47
  with st.sidebar:
48
+ st.header("πŸ—‚οΈ Workspace")
49
  ws = get_workspace()
50
  if not ws:
51
  st.info("Run a search then press **Save** to populate this list.")
52
  return
53
  for i, item in enumerate(ws, 1):
54
  with st.expander(f"{i}. {item['query']}"):
55
+ st.write(item["result"].get("ai_summary", ""))
 
56
 
57
+ def render_ui():
 
 
 
58
  st.set_page_config("MedGenesis AI", layout="wide")
59
 
60
+ # Session state
61
+ for k, v in [
62
+ ("query_result", None), ("followup_input", ""),
63
+ ("followup_response", None), ("last_query", ""), ("last_llm", "")
64
+ ]:
65
+ if k not in st.session_state:
66
+ st.session_state[k] = v
 
 
67
 
68
  _workspace_sidebar()
69
+ c1, c2 = st.columns([0.15, 0.85])
70
+ with c1:
 
71
  if LOGO.exists():
72
+ st.image(str(LOGO), width=105)
73
+ with c2:
74
  st.markdown("## 🧬 **MedGenesis AI**")
75
  st.caption("Multi-source biomedical assistant Β· OpenAI / Gemini")
76
 
77
+ llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
78
+ query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
 
79
 
80
+ # Alerts
81
+ wsq = get_workspace()
82
+ if wsq:
83
  try:
84
+ news = asyncio.run(check_alerts([w["query"] for w in wsq]))
85
  if news:
86
  with st.sidebar:
87
  st.subheader("πŸ”” New papers")
88
  for q, lnks in news.items():
89
  st.write(f"**{q}** – {len(lnks)} new")
90
  except Exception:
91
+ pass
92
 
93
+ if st.button("Run Search πŸš€") and query:
 
94
  with st.spinner("Collecting literature & biomedical data …"):
95
  res = asyncio.run(orchestrate_search(query, llm=llm))
96
+ st.success(f"Completed with **{res.get('llm_used','LLM').title()}**")
97
+ st.session_state.query_result = res
98
+ st.session_state.last_query = query
99
+ st.session_state.last_llm = llm
100
+ st.session_state.followup_input = ""
101
+ st.session_state.followup_response = None
 
 
 
 
102
 
103
  res = st.session_state.query_result
104
  if not res:
105
+ st.info("Enter a question and press **Run Search πŸš€**")
106
  return
107
 
108
+ tabs = st.tabs(["Results", "Genes", "Trials", "Variants", "Graph", "Metrics", "Visuals"])
109
+ # --------------- Results Tab ---------------
 
 
 
110
  with tabs[0]:
111
+ for i, p in enumerate(res.get("papers", []), 1):
112
+ st.markdown(f"**{i}. [{p.get('title','')}]({p.get('link','')})** *{p.get('authors','')}*")
113
+ st.write(p.get("summary", ""))
114
+ col1, col2 = st.columns(2)
115
+ with col1:
116
+ st.download_button("CSV", pd.DataFrame(res.get("papers", [])).to_csv(index=False),
117
+ "papers.csv", "text/csv")
118
+ with col2:
119
+ st.download_button("PDF", _pdf(res.get("papers", [])), "papers.pdf", "application/pdf")
 
 
 
 
 
 
 
 
 
 
120
  if st.button("πŸ’Ύ Save"):
121
  save_query(st.session_state.last_query, res)
122
  st.success("Saved to workspace")
 
123
  st.subheader("UMLS concepts")
124
+ for c in res.get("umls", []):
125
  if isinstance(c, dict) and c.get("cui"):
126
+ st.write(f"- **{c.get('name','')}** ({c.get('cui')})")
 
127
  st.subheader("OpenFDA safety signals")
128
+ st.json(res.get("drug_safety", []))
 
 
129
  st.subheader("AI summary")
130
+ st.info(res.get("ai_summary", ""))
131
 
132
+ # --------------- Genes Tab ---------------
133
  with tabs[1]:
134
  st.header("Gene / Variant signals")
135
+ genes = res.get("genes", [])
136
+ if not genes:
 
 
 
137
  st.info("No gene hits (rate-limited or none found).")
138
+ else:
139
+ for g in genes:
140
+ if isinstance(g, dict):
141
+ lab = g.get("name") or g.get("symbol") or g.get("geneid")
142
+ st.write(f"- **{lab}** {g.get('description','')}")
143
+ if res.get("gene_disease"):
144
  st.markdown("### DisGeNET associations")
145
+ st.json(res.get("gene_disease")[:15])
146
+ if res.get("mesh_defs"):
 
 
 
 
147
  st.markdown("### MeSH definitions")
148
+ for d in res["mesh_defs"]:
149
+ if d:
150
+ st.write("-", d)
151
 
152
+ # --------------- Trials Tab ---------------
153
  with tabs[2]:
154
  st.header("Clinical trials")
155
+ trials = res.get("clinical_trials", [])
156
+ if not trials:
157
  st.info("No trials (rate-limited or none found).")
158
+ else:
159
+ for t in trials:
160
+ nct = t.get("nctId") or (t.get("NCTId", [""])[0] if isinstance(t.get("NCTId"), list) else "")
161
+ title = t.get("briefTitle") or (t.get("BriefTitle", [""])[0] if isinstance(t.get("BriefTitle"), list) else "")
162
+ phase = t.get("phase") or (t.get("Phase", [""])[0] if isinstance(t.get("Phase"), list) else "")
163
+ status = t.get("status") or (t.get("OverallStatus", [""])[0] if isinstance(t.get("OverallStatus"), list) else "")
164
+ st.markdown(f"**{nct}** – {title}")
165
+ st.write(f"Phase {phase} | Status {status}")
166
+
167
+ # --------------- Variants Tab ---------------
168
  with tabs[3]:
169
+ st.header("Cancer variants (cBioPortal)")
170
+ variants = res.get("variants", [])
171
+ if not variants:
172
+ st.info("No variant data.")
173
+ else:
174
+ for v in variants:
175
+ st.json(v)
176
+
177
+ # --------------- Graph Tab ---------------
178
+ with tabs[4]:
179
+ nodes, edges, cfg = build_agraph(res.get("papers", []), res.get("umls", []), res.get("drug_safety", []))
180
  hl = st.text_input("Highlight node:", key="hl")
181
  if hl:
182
  pat = re.compile(re.escape(hl), re.I)
 
184
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
185
  agraph(nodes, edges, cfg)
186
 
187
+ # --------------- Metrics Tab ---------------
188
+ with tabs[5]:
189
+ nodes, edges, _ = build_agraph(res.get("papers", []), res.get("umls", []), res.get("drug_safety", []))
190
+ G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
 
 
191
  st.metric("Density", f"{get_density(G):.3f}")
192
  st.markdown("**Top hubs**")
193
+ for nid, sc in get_top_hubs(G):
194
+ lab = next((n.label for n in nodes if n.id == nid), nid)
195
+ st.write(f"- {lab} {sc:.3f}")
196
 
197
+ # --------------- Visuals Tab ---------------
198
+ with tabs[6]:
199
+ years = [p.get("published", "") for p in res.get("papers", []) if p.get("published")]
 
 
 
200
  if years:
201
+ st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
 
 
 
 
 
202
 
203
+ # --------------- Follow-up Q&A ---------------
204
  st.markdown("---")
205
+ st.text_input("Ask follow‑up question:", key="followup_input")
206
+ def handle_followup():
207
+ follow = st.session_state.followup_input
208
+ if follow.strip():
209
+ ans = asyncio.run(answer_ai_question(
210
+ follow,
211
+ context=st.session_state.last_query,
212
+ llm=st.session_state.last_llm))
213
+ st.session_state.followup_response = ans.get("answer", "No answer.")
214
+ else:
215
+ st.session_state.followup_response = None
216
+ st.button("Ask AI", on_click=handle_followup)
 
 
 
 
 
 
 
 
 
 
217
  if st.session_state.followup_response:
218
  st.write(st.session_state.followup_response)
219
 
 
 
220
  if __name__ == "__main__":
221
  render_ui()