mgbam commited on
Commit
2b2ae99
·
verified ·
1 Parent(s): d55cbab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -33
app.py CHANGED
@@ -2,15 +2,9 @@
2
  # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
3
 
4
  import os, pathlib
5
-
6
- # ── Streamlit telemetry dir fix ───────────────────────────────────────
7
- os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
8
- os.environ["XDG_STATE_HOME"] = "/tmp"
9
- os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
10
- pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
11
-
12
  import asyncio, re
13
  from pathlib import Path
 
14
  import streamlit as st
15
  import pandas as pd
16
  import plotly.express as px
@@ -18,10 +12,16 @@ from fpdf import FPDF
18
  from streamlit_agraph import agraph
19
 
20
  from mcp.orchestrator import orchestrate_search, answer_ai_question
21
- from mcp.workspace import get_workspace, save_query
22
  from mcp.knowledge_graph import build_agraph
23
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
24
- from mcp.alerts import check_alerts
 
 
 
 
 
 
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
@@ -59,9 +59,17 @@ def _workspace_sidebar():
59
  def render_ui():
60
  st.set_page_config("MedGenesis AI", layout="wide")
61
 
62
- # Persist follow-up input
 
 
63
  if "followup_input" not in st.session_state:
64
  st.session_state.followup_input = ""
 
 
 
 
 
 
65
 
66
  _workspace_sidebar()
67
 
@@ -76,6 +84,7 @@ def render_ui():
76
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
77
  query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
78
 
 
79
  if get_workspace():
80
  try:
81
  news = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
@@ -87,20 +96,23 @@ def render_ui():
87
  except Exception:
88
  pass
89
 
90
- # Trigger search
91
  if st.button("Run Search 🚀") and query:
92
  with st.spinner("Collecting literature & biomedical data …"):
93
  res = asyncio.run(orchestrate_search(query, llm=llm))
94
  st.success(f"Completed with **{res['llm_used'].title()}**")
95
  st.session_state.query_result = res
 
 
96
  st.session_state.followup_input = ""
97
- else:
98
- res = st.session_state.get("query_result", None)
 
99
 
100
  if res:
101
  tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
102
 
103
- with tabs[0]: # Results
104
  for i, p in enumerate(res["papers"], 1):
105
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
106
  st.write(p["summary"])
@@ -111,22 +123,25 @@ def render_ui():
111
  with col2:
112
  st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
113
  if st.button("💾 Save"):
114
- save_query(query, res)
115
  st.success("Saved to workspace")
 
116
  st.subheader("UMLS concepts")
117
  for c in res["umls"]:
118
  if c.get("cui"):
119
  st.write(f"- **{c['name']}** ({c['cui']})")
 
120
  st.subheader("OpenFDA safety")
121
  for d in res["drug_safety"]:
122
  st.json(d)
 
123
  st.subheader("AI summary")
124
  st.info(res["ai_summary"])
125
 
126
- with tabs[1]: # Genes
127
  st.header("Gene / Variant signals")
128
  for g in res["genes"]:
129
- st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description', '')}")
130
  if res["gene_disease"]:
131
  st.markdown("### DisGeNET links")
132
  st.json(res["gene_disease"][:15])
@@ -136,7 +151,7 @@ def render_ui():
136
  if d:
137
  st.write("-", d)
138
 
139
- with tabs[2]: # Trials
140
  st.header("Clinical trials")
141
  if not res["clinical_trials"]:
142
  st.info("No trials (rate-limited or none found).")
@@ -144,7 +159,7 @@ def render_ui():
144
  st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
145
  st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
146
 
147
- with tabs[3]: # Graph
148
  nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
149
  hl = st.text_input("Highlight node:", key="hl")
150
  if hl:
@@ -153,7 +168,7 @@ def render_ui():
153
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
154
  agraph(nodes, edges, cfg)
155
 
156
- with tabs[4]: # Metrics
157
  nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
158
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
159
  st.metric("Density", f"{get_density(G):.3f}")
@@ -162,26 +177,30 @@ def render_ui():
162
  lab = next((n.label for n in nodes if n.id == nid), nid)
163
  st.write(f"- {lab} {sc:.3f}")
164
 
165
- with tabs[5]: # Visuals
166
  years = [p["published"] for p in res["papers"] if p.get("published")]
167
  if years:
168
  st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
169
 
 
170
  st.markdown("---")
171
- follow = st.text_input(
172
- "Ask follow‑up question:",
173
- value=st.session_state.followup_input,
174
- key="followup_input"
175
- )
176
- if st.button("Ask AI"):
177
- st.session_state.followup_input = follow
178
  if follow.strip():
179
- with st.spinner("Generating AI response..."):
180
- ans = asyncio.run(answer_ai_question(
181
- follow, context=query, llm=llm))
182
- st.write(ans["answer"])
 
183
  else:
184
- st.warning("Please type a follow-up question before submitting.")
 
 
 
 
 
 
185
  else:
186
  st.info("Enter a question and press **Run Search 🚀**")
187
 
 
2
  # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
3
 
4
  import os, pathlib
 
 
 
 
 
 
 
5
  import asyncio, re
6
  from pathlib import Path
7
+
8
  import streamlit as st
9
  import pandas as pd
10
  import plotly.express as px
 
12
  from streamlit_agraph import agraph
13
 
14
  from mcp.orchestrator import orchestrate_search, answer_ai_question
15
+ from mcp.workspace import get_workspace, save_query
16
  from mcp.knowledge_graph import build_agraph
17
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
18
+ from mcp.alerts import check_alerts
19
+
20
+ # ── Streamlit telemetry dir fix ───────────────────────────────────────
21
+ os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
22
+ os.environ["XDG_STATE_HOME"] = "/tmp"
23
+ os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
24
+ pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
 
59
  def render_ui():
60
  st.set_page_config("MedGenesis AI", layout="wide")
61
 
62
+ # Initialize session state keys if missing
63
+ if "query_result" not in st.session_state:
64
+ st.session_state.query_result = None
65
  if "followup_input" not in st.session_state:
66
  st.session_state.followup_input = ""
67
+ if "followup_response" not in st.session_state:
68
+ st.session_state.followup_response = None
69
+ if "last_query" not in st.session_state:
70
+ st.session_state.last_query = ""
71
+ if "last_llm" not in st.session_state:
72
+ st.session_state.last_llm = ""
73
 
74
  _workspace_sidebar()
75
 
 
84
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
85
  query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
86
 
87
+ # Alerts
88
  if get_workspace():
89
  try:
90
  news = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
 
96
  except Exception:
97
  pass
98
 
99
+ # Run search button
100
  if st.button("Run Search 🚀") and query:
101
  with st.spinner("Collecting literature & biomedical data …"):
102
  res = asyncio.run(orchestrate_search(query, llm=llm))
103
  st.success(f"Completed with **{res['llm_used'].title()}**")
104
  st.session_state.query_result = res
105
+ st.session_state.last_query = query
106
+ st.session_state.last_llm = llm
107
  st.session_state.followup_input = ""
108
+ st.session_state.followup_response = None
109
+
110
+ res = st.session_state.query_result
111
 
112
  if res:
113
  tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
114
 
115
+ with tabs[0]:
116
  for i, p in enumerate(res["papers"], 1):
117
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
118
  st.write(p["summary"])
 
123
  with col2:
124
  st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
125
  if st.button("💾 Save"):
126
+ save_query(st.session_state.last_query, res)
127
  st.success("Saved to workspace")
128
+
129
  st.subheader("UMLS concepts")
130
  for c in res["umls"]:
131
  if c.get("cui"):
132
  st.write(f"- **{c['name']}** ({c['cui']})")
133
+
134
  st.subheader("OpenFDA safety")
135
  for d in res["drug_safety"]:
136
  st.json(d)
137
+
138
  st.subheader("AI summary")
139
  st.info(res["ai_summary"])
140
 
141
+ with tabs[1]:
142
  st.header("Gene / Variant signals")
143
  for g in res["genes"]:
144
+ st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description','')}")
145
  if res["gene_disease"]:
146
  st.markdown("### DisGeNET links")
147
  st.json(res["gene_disease"][:15])
 
151
  if d:
152
  st.write("-", d)
153
 
154
+ with tabs[2]:
155
  st.header("Clinical trials")
156
  if not res["clinical_trials"]:
157
  st.info("No trials (rate-limited or none found).")
 
159
  st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
160
  st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
161
 
162
+ with tabs[3]:
163
  nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
164
  hl = st.text_input("Highlight node:", key="hl")
165
  if hl:
 
168
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
169
  agraph(nodes, edges, cfg)
170
 
171
+ with tabs[4]:
172
  nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
173
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
174
  st.metric("Density", f"{get_density(G):.3f}")
 
177
  lab = next((n.label for n in nodes if n.id == nid), nid)
178
  st.write(f"- {lab} {sc:.3f}")
179
 
180
+ with tabs[5]:
181
  years = [p["published"] for p in res["papers"] if p.get("published")]
182
  if years:
183
  st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
184
 
185
+ # Follow-up Q&A block with callback
186
  st.markdown("---")
187
+ st.text_input("Ask follow‑up question:", key="followup_input")
188
+ def handle_followup():
189
+ follow = st.session_state.followup_input
 
 
 
 
190
  if follow.strip():
191
+ ans = asyncio.run(answer_ai_question(
192
+ follow,
193
+ context=st.session_state.last_query,
194
+ llm=st.session_state.last_llm))
195
+ st.session_state.followup_response = ans["answer"]
196
  else:
197
+ st.session_state.followup_response = None
198
+
199
+ st.button("Ask AI", on_click=handle_followup)
200
+
201
+ if st.session_state.followup_response:
202
+ st.write(st.session_state.followup_response)
203
+
204
  else:
205
  st.info("Enter a question and press **Run Search 🚀**")
206