mgbam commited on
Commit
fe00e4d
Β·
verified Β·
1 Parent(s): 85d5a4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -72
app.py CHANGED
@@ -1,6 +1,18 @@
1
- # ──────────────────────────── app.py ─────────────────────────────────
2
- """Streamlit UI – MedGenesis v2 with gene + variant + trial integration."""
3
- import os, pathlib, asyncio, re
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
5
 
6
  import streamlit as st
@@ -9,45 +21,42 @@ import plotly.express as px
9
  from fpdf import FPDF
10
  from streamlit_agraph import agraph
11
 
12
- from mcp.orchestrator import orchestrate_search, answer_ai_question
13
- from mcp.workspace import get_workspace, save_query
14
  from mcp.knowledge_graph import build_agraph
15
- from mcp.graph_utils import build_nx, get_top_hubs, get_density
16
- from mcp.alerts import check_alerts
17
 
18
- # ---- Streamlit telemetry patch -------------------------------------
19
- os.environ.update({
20
- "STREAMLIT_DATA_DIR": "/tmp/.streamlit",
21
- "XDG_STATE_HOME": "/tmp",
22
- "STREAMLIT_BROWSER_GATHERUSAGESTATS": "false",
23
- })
24
- pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
28
 
29
- # ---------------- helpers -------------------------------------------
 
30
 
31
- def _latin1_safe(t: str) -> str:
32
- return t.encode("latin-1", "replace").decode("latin-1")
33
-
34
- def _export_pdf(papers):
35
  pdf = FPDF()
36
  pdf.set_auto_page_break(auto=True, margin=15)
37
  pdf.add_page()
38
  pdf.set_font("Helvetica", size=11)
39
- pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Results"), ln=True, align="C")
40
  pdf.ln(3)
 
41
  for i, p in enumerate(papers, 1):
42
  pdf.set_font("Helvetica", "B", 11)
43
- pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p['title']}"))
44
  pdf.set_font("Helvetica", size=9)
45
  body = f"{p['authors']}\n{p['summary']}\n{p['link']}\n"
46
- pdf.multi_cell(0, 6, _latin1_safe(body))
47
  pdf.ln(1)
48
  return pdf.output(dest="S").encode("latin-1", "replace")
49
 
50
- # ---------------- sidebar -------------------------------------------
51
 
52
  def _workspace_sidebar():
53
  with st.sidebar:
@@ -60,107 +69,159 @@ def _workspace_sidebar():
60
  with st.expander(f"{i}. {item['query']}"):
61
  st.write(item["result"]["ai_summary"])
62
 
63
- # ---------------- main ----------------------------------------------
64
 
65
  def render_ui():
66
  st.set_page_config("MedGenesis AI", layout="wide")
 
 
 
 
 
 
67
  _workspace_sidebar()
68
 
69
- # header ---------------------------------------------------------
70
  c1, c2 = st.columns([0.15, 0.85])
71
- if LOGO.exists():
72
- with c1: st.image(str(LOGO), width=105)
 
73
  with c2:
74
  st.markdown("## 🧬 **MedGenesis AI**")
75
  st.caption("Multi‑source biomedical assistant Β· OpenAI / Gemini")
76
 
77
- llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
78
- query = st.text_input("Enter biomedical question", "CRISPR glioblastoma therapy")
79
-
80
- if st.button("Run Search πŸš€") and query:
81
- with st.spinner("Collecting literature & biomedical data …"):
82
- res = asyncio.run(orchestrate_search(query, llm=llm))
83
- st.success(f"Completed with **{res['llm_used'].title()}**")
84
- st.session_state.result = res
85
- st.session_state.last_query = query
86
- st.session_state.last_llm = llm
87
-
88
- res = st.session_state.get("result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if not res:
90
  st.info("Enter a question and press **Run Search πŸš€**")
91
  return
92
 
 
93
  tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
94
 
95
- # results --------------------------------------------------------
96
  with tabs[0]:
97
  for i, p in enumerate(res["papers"], 1):
98
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
99
  st.write(p["summary"])
100
- c1, c2 = st.columns(2)
101
- with c1:
102
- st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False), "papers.csv")
103
- with c2:
104
- st.download_button("PDF", _export_pdf(res["papers"]), "papers.pdf", mime="application/pdf")
105
  if st.button("πŸ’Ύ Save"):
106
- save_query(query, res)
107
  st.success("Saved to workspace")
 
 
 
 
 
 
 
 
 
 
108
  st.subheader("AI summary")
109
  st.info(res["ai_summary"])
110
 
111
- # gene tab -------------------------------------------------------
112
  with tabs[1]:
113
- if not res["genes"]:
 
 
114
  st.info("No gene hits (rate‑limited or none found).")
115
- for g in res["genes"]:
116
- st.json(g)
117
- if res["variants"]:
118
- st.markdown("### Tumour variants (cBioPortal)")
119
- for k, v in res["variants"].items():
120
- st.write(f"**{k}** – {len(v)} variants")
121
-
122
- # trials tab -----------------------------------------------------
 
 
 
 
 
 
123
  with tabs[2]:
124
  st.header("Clinical trials")
125
- if not res["clinical_trials"]:
 
126
  st.info("No trials (rate‑limited or none found).")
127
- for t in res["clinical_trials"]:
128
  st.markdown(f"**{t['nctId']}** – {t['briefTitle']}")
129
  st.write(f"Phase {t.get('phase')} | Status {t.get('status')}")
130
 
131
- # graph tab ------------------------------------------------------
132
  with tabs[3]:
133
  nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
134
- hl = st.text_input("Highlight node:")
135
  if hl:
136
  pat = re.compile(re.escape(hl), re.I)
137
  for n in nodes:
138
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
139
  agraph(nodes, edges, cfg)
140
 
141
- # metrics tab ----------------------------------------------------
142
  with tabs[4]:
 
143
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
144
  st.metric("Density", f"{get_density(G):.3f}")
 
145
  for nid, sc in get_top_hubs(G):
146
  lab = next((n.label for n in nodes if n.id == nid), nid)
147
  st.write(f"- {lab} {sc:.3f}")
148
 
149
- # visuals --------------------------------------------------------
150
  with tabs[5]:
151
- years = [p.get("published", "")[:4] for p in res["papers"] if p.get("published")]
152
  if years:
153
- fig = px.histogram(years, nbins=12, title="Publication Year")
154
- st.plotly_chart(fig)
155
 
156
- # follow‑up QA ---------------------------------------------------
157
  st.markdown("---")
158
- q = st.text_input("Ask follow‑up question:")
159
  if st.button("Ask AI"):
160
- with st.spinner("Querying LLM …"):
161
- ans = asyncio.run(answer_ai_question(q, context=st.session_state.last_query, llm=st.session_state.last_llm))
162
- st.write(ans["answer"])
163
-
 
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
- render_ui()
 
1
+ # app.py – Streamlit front‑end for MedGenesis
2
+
3
+ """CPU‑only demo that can run on HF Spaces.
4
+ Requirements (environment variables / HF πŸŽ™ secrets):
5
+
6
+ OPENAI_API_KEY / GEMINI_KEY – LLMs
7
+ PUB_KEY / UMLS_KEY / DISGENET_KEY ... – data APIs (optional)
8
+ MYGENE_KEY / OT_KEY / CBIO_KEY – new APIs (optional)
9
+
10
+ Run locally:
11
+ streamlit run app.py --server.headless true --server.address 0.0.0.0
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import os, asyncio, re, pathlib
16
  from pathlib import Path
17
 
18
  import streamlit as st
 
21
  from fpdf import FPDF
22
  from streamlit_agraph import agraph
23
 
24
+ from mcp.orchestrator import orchestrate_search, answer_ai_question
25
+ from mcp.workspace import get_workspace, save_query
26
  from mcp.knowledge_graph import build_agraph
27
+ from mcp.graph_utils import build_nx, get_top_hubs, get_density
28
+ from mcp.alerts import check_alerts
29
 
30
+ # --- Streamlit telemetry dir fix ------------------------------------------
31
+ os.environ.setdefault("STREAMLIT_DATA_DIR", "/tmp/.streamlit")
32
+ os.environ.setdefault("XDG_STATE_HOME", "/tmp")
33
+ os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
34
+ pathlib.Path(os.environ["STREAMLIT_DATA_DIR"]).mkdir(parents=True, exist_ok=True)
 
 
35
 
36
  ROOT = Path(__file__).parent
37
  LOGO = ROOT / "assets" / "logo.png"
38
 
39
+ # --- helpers --------------------------------------------------------------
40
+ LATIN1 = str.maketrans({**{chr(i): "?" for i in range(256, 0x110000)}})
41
 
42
+ def _pdf(papers: list[dict]) -> bytes:
 
 
 
43
  pdf = FPDF()
44
  pdf.set_auto_page_break(auto=True, margin=15)
45
  pdf.add_page()
46
  pdf.set_font("Helvetica", size=11)
47
+ pdf.multi_cell(0, 8, "MedGenesis AI – Results", align="C")
48
  pdf.ln(3)
49
+
50
  for i, p in enumerate(papers, 1):
51
  pdf.set_font("Helvetica", "B", 11)
52
+ pdf.multi_cell(0, 7, f"{i}. {p['title']}".translate(LATIN1))
53
  pdf.set_font("Helvetica", size=9)
54
  body = f"{p['authors']}\n{p['summary']}\n{p['link']}\n"
55
+ pdf.multi_cell(0, 6, body.translate(LATIN1))
56
  pdf.ln(1)
57
  return pdf.output(dest="S").encode("latin-1", "replace")
58
 
59
+ # --- sidebar --------------------------------------------------------------
60
 
61
  def _workspace_sidebar():
62
  with st.sidebar:
 
69
  with st.expander(f"{i}. {item['query']}"):
70
  st.write(item["result"]["ai_summary"])
71
 
72
+ # --- UI -------------------------------------------------------------------
73
 
74
  def render_ui():
75
  st.set_page_config("MedGenesis AI", layout="wide")
76
+
77
+ # SessionΒ state --------------------------------------------------------
78
+ for k, v in {"query_result": None, "followup_input": "", "followup_response": None,
79
+ "last_query": "", "last_llm": "openai", "tab": 0}.items():
80
+ st.session_state.setdefault(k, v)
81
+
82
  _workspace_sidebar()
83
 
 
84
  c1, c2 = st.columns([0.15, 0.85])
85
+ with c1:
86
+ if LOGO.exists():
87
+ st.image(str(LOGO), width=105)
88
  with c2:
89
  st.markdown("## 🧬 **MedGenesis AI**")
90
  st.caption("Multi‑source biomedical assistant Β· OpenAI / Gemini")
91
 
92
+ llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True, index=(0 if st.session_state.last_llm=="openai" else 1))
93
+ query = st.text_input("Enter biomedical question", st.session_state.last_query or "e.g. CRISPR glioblastoma therapy")
94
+
95
+ # alerts ---------------------------------------------------------------
96
+ if work := get_workspace():
97
+ try:
98
+ news = asyncio.run(check_alerts([w["query"] for w in work]))
99
+ if news:
100
+ with st.sidebar:
101
+ st.subheader("πŸ”” New papers")
102
+ for q, lnks in news.items():
103
+ st.write(f"**{q}** – {len(lnks)} new")
104
+ except Exception:
105
+ pass
106
+
107
+ # run search -----------------------------------------------------------
108
+ if st.button("Run Search πŸš€"):
109
+ if not query.strip():
110
+ st.warning("Please enter a biomedical question first.")
111
+ else:
112
+ with st.spinner("Collecting literature & biomedical data …"):
113
+ res = asyncio.run(orchestrate_search(query, llm=llm))
114
+ st.session_state.update({
115
+ "query_result": res,
116
+ "last_query": query,
117
+ "last_llm": llm,
118
+ "followup_input": "",
119
+ "followup_response": None,
120
+ })
121
+ st.success(f"Completed with **{res['llm_used'].title()}**")
122
+
123
+ res = st.session_state.query_result
124
  if not res:
125
  st.info("Enter a question and press **Run Search πŸš€**")
126
  return
127
 
128
+ # --- tabs -------------------------------------------------------------
129
  tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
130
 
131
+ # Results --------------------------------------------------------------
132
  with tabs[0]:
133
  for i, p in enumerate(res["papers"], 1):
134
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
135
  st.write(p["summary"])
136
+ col1, col2 = st.columns(2)
137
+ with col1:
138
+ st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False), "papers.csv", "text/csv")
139
+ with col2:
140
+ st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
141
  if st.button("πŸ’Ύ Save"):
142
+ save_query(st.session_state.last_query, res)
143
  st.success("Saved to workspace")
144
+
145
+ st.subheader("UMLS concepts")
146
+ for c in res["umls"]:
147
+ if isinstance(c, dict) and c.get("cui"):
148
+ st.write(f"- **{c['name']}** ({c['cui']})")
149
+
150
+ st.subheader("OpenFDA safety")
151
+ for d in res["drug_safety"]:
152
+ st.json(d)
153
+
154
  st.subheader("AI summary")
155
  st.info(res["ai_summary"])
156
 
157
+ # Genes ----------------------------------------------------------------
158
  with tabs[1]:
159
+ st.header("Gene / Variant signals")
160
+ genes = res.get("genes") or []
161
+ if not genes:
162
  st.info("No gene hits (rate‑limited or none found).")
163
+ for g in genes:
164
+ sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name", "")
165
+ summ = g.get("summary") or g.get("description", "")
166
+ st.write(f"- **{sym}** {summ}")
167
+ if res["gene_disease"]:
168
+ st.markdown("### DisGeNET links")
169
+ st.json(res["gene_disease"][:15])
170
+ if res["mesh_defs"]:
171
+ st.markdown("### MeSH definitions")
172
+ for d in res["mesh_defs"]:
173
+ if d:
174
+ st.write("-", d)
175
+
176
+ # Trials ---------------------------------------------------------------
177
  with tabs[2]:
178
  st.header("Clinical trials")
179
+ trials = res.get("clinical_trials") or []
180
+ if not trials:
181
  st.info("No trials (rate‑limited or none found).")
182
+ for t in trials:
183
  st.markdown(f"**{t['nctId']}** – {t['briefTitle']}")
184
  st.write(f"Phase {t.get('phase')} | Status {t.get('status')}")
185
 
186
+ # Graph ---------------------------------------------------------------
187
  with tabs[3]:
188
  nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
189
+ hl = st.text_input("Highlight node:", key="hl")
190
  if hl:
191
  pat = re.compile(re.escape(hl), re.I)
192
  for n in nodes:
193
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
194
  agraph(nodes, edges, cfg)
195
 
196
+ # Metrics -------------------------------------------------------------
197
  with tabs[4]:
198
+ nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
199
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
200
  st.metric("Density", f"{get_density(G):.3f}")
201
+ st.markdown("**Top hubs**")
202
  for nid, sc in get_top_hubs(G):
203
  lab = next((n.label for n in nodes if n.id == nid), nid)
204
  st.write(f"- {lab} {sc:.3f}")
205
 
206
+ # Visuals -------------------------------------------------------------
207
  with tabs[5]:
208
+ years = [int(p["published"][:4]) for p in res["papers"] if p.get("published", "").isdigit()]
209
  if years:
210
+ st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
 
211
 
212
+ # Follow‑up QA --------------------------------------------------------
213
  st.markdown("---")
214
+ st.text_input("Ask follow‑up question:", key="followup_input", placeholder="e.g. Any phase III trials recruiting now?")
215
  if st.button("Ask AI"):
216
+ q = st.session_state.followup_input.strip()
217
+ if not q:
218
+ st.warning("Please type a question first.")
219
+ else:
220
+ with st.spinner("Querying LLM …"):
221
+ ans = asyncio.run(answer_ai_question(q, context=st.session_state.last_query, llm=st.session_state.last_llm))
222
+ st.session_state.followup_response = ans["answer"]
223
+ if st.session_state.followup_response:
224
+ st.write(st.session_state.followup_response)
225
 
226
  if __name__ == "__main__":
227
+ render_ui()