|
|
|
|
|
import asyncio, re |
|
from pathlib import Path |
|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
from fpdf import FPDF |
|
from streamlit_agraph import agraph |
|
|
|
from mcp.orchestrator import orchestrate_search, answer_ai_question |
|
from mcp.workspace import get_workspace, save_query |
|
from mcp.knowledge_graph import build_agraph |
|
from mcp.graph_metrics import build_nx, get_top_hubs, get_density |
|
from mcp.alerts import check_alerts |
|
|
|
ROOT = Path(__file__).parent |
|
LOGO = ROOT / "assets" / "logo.png" |
|
|
|
|
|
def pdf_from_papers(papers): |
|
pdf = FPDF(); pdf.add_page(); pdf.set_font("Arial", size=12) |
|
pdf.cell(200, 10, "MedGenesis AI — Results", ln=True, align="C"); pdf.ln(8) |
|
for i, p in enumerate(papers, 1): |
|
pdf.set_font("Arial", "B", 12) |
|
pdf.multi_cell(0, 8, f"{i}. {p['title']}") |
|
pdf.set_font("Arial", "", 9) |
|
pdf.multi_cell(0, 6, f"{p['authors']}\n{p['summary']}\n{p['link']}\n") |
|
pdf.ln(2) |
|
return pdf.output(dest="S").encode("latin-1") |
|
|
|
|
|
def sidebar_workspace(): |
|
with st.sidebar: |
|
st.header("🗂️ Workspace") |
|
ws = get_workspace() |
|
if not ws: |
|
st.info("Run a search and click **Save** to build your workspace.") |
|
return |
|
for i, item in enumerate(ws, 1): |
|
with st.expander(f"{i}. {item['query']}"): |
|
st.write("**AI Summary**:", item["result"]["ai_summary"]) |
|
df = pd.DataFrame(item["result"]["papers"]) |
|
st.download_button("📥 CSV", df.to_csv(index=False), |
|
f"workspace_{i}.csv", "text/csv") |
|
|
|
|
|
def render_ui(): |
|
st.set_page_config(page_title="MedGenesis AI", layout="wide") |
|
|
|
|
|
saved_q = [q["query"] for q in get_workspace()] |
|
if saved_q: |
|
try: |
|
alerts = asyncio.run(check_alerts(saved_q)) |
|
if alerts: |
|
with st.sidebar: |
|
st.subheader("🔔 New Papers") |
|
for q, links in alerts.items(): |
|
st.write(f"**{q}** – {len(links)} new") |
|
except Exception as e: |
|
st.sidebar.warning(f"Alert check failed: {e}") |
|
|
|
sidebar_workspace() |
|
|
|
|
|
col1, col2 = st.columns([0.15, 0.85]) |
|
with col1: |
|
if LOGO.exists(): st.image(str(LOGO), width=100) |
|
with col2: |
|
st.markdown("## 🧬 **MedGenesis AI**") |
|
st.caption("PubMed·ArXiv·OpenFDA·UMLS·NCBI·DisGeNET·ClinicalTrials·GPT-4o") |
|
|
|
st.markdown("---") |
|
query = st.text_input("🔍 Ask your biomedical question:", |
|
placeholder="e.g. CRISPR for glioblastoma") |
|
|
|
|
|
if st.button("Run Search 🚀") and query: |
|
with st.spinner("Synthesizing multi-source biomedical intel…"): |
|
res = asyncio.run(orchestrate_search(query)) |
|
st.success("Ready!") |
|
|
|
tabs = st.tabs([ |
|
"Results", "Genes", "Trials", "Graph", "Metrics", "Visuals" |
|
]) |
|
|
|
|
|
with tabs[0]: |
|
st.header("📚 Literature") |
|
for i, p in enumerate(res["papers"], 1): |
|
st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*") |
|
st.markdown(f"<span style='color:gray'>{p['summary']}</span>", |
|
unsafe_allow_html=True) |
|
|
|
colA, colB = st.columns(2) |
|
with colA: |
|
if st.button("💾 Save to Workspace"): |
|
save_query(query, res); st.success("Saved!") |
|
with colB: |
|
st.download_button("📥 CSV", pd.DataFrame(res["papers"]).to_csv(index=False), |
|
"papers.csv", "text/csv") |
|
|
|
st.download_button("📄 PDF", pdf_from_papers(res["papers"]), |
|
"papers.pdf", "application/pdf") |
|
|
|
st.subheader("🧠 UMLS") |
|
for c in res["umls"]: |
|
if c.get("cui"): st.write(f"- **{c['name']}** ({c['cui']})") |
|
|
|
st.subheader("💊 OpenFDA Safety") |
|
for d in res["drug_safety"]: st.json(d) |
|
|
|
st.subheader("🤖 AI Summary") |
|
st.info(res["ai_summary"]) |
|
|
|
|
|
with tabs[1]: |
|
st.header("🧬 Gene Signals") |
|
for g in res["genes"]: |
|
st.write(f"- **{g.get('name', g.get('geneid'))}** – {g.get('description','')}") |
|
if res["gene_disease"]: |
|
st.markdown("### DisGeNET Links"); st.json(res["gene_disease"][:15]) |
|
if res["mesh_defs"]: |
|
st.markdown("### MeSH Definitions") |
|
for d in res["mesh_defs"]: st.write("-", d) |
|
|
|
|
|
with tabs[2]: |
|
st.header("💊 Clinical Trials") |
|
if not res["clinical_trials"]: |
|
st.info("No trials retrieved (rate-limited or none found).") |
|
for t in res["clinical_trials"]: |
|
st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}") |
|
st.write(f"Phase: {t.get('Phase',[''])[0]} | Status: {t['OverallStatus'][0]}") |
|
|
|
|
|
with tabs[3]: |
|
st.header("🗺️ Knowledge Graph") |
|
nodes, edges, cfg = build_agraph(res["papers"], |
|
res["umls"], |
|
res["drug_safety"]) |
|
hl = st.text_input("Highlight node:", key="hl") |
|
if hl: |
|
pat = re.compile(re.escape(hl), re.I) |
|
for n in nodes: |
|
if pat.search(n.label): n.color, n.size = "#f1c40f", 30 |
|
else: n.color = "#d3d3d3" |
|
agraph(nodes=nodes, edges=edges, config=cfg) |
|
|
|
|
|
with tabs[4]: |
|
st.header("📈 Graph Metrics") |
|
G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges]) |
|
st.metric("Density", f"{get_density(G):.3f}") |
|
st.markdown("#### Hub Nodes") |
|
for nid, sc in get_top_hubs(G): |
|
lab = next((n.label for n in nodes if n.id == nid), nid) |
|
st.write(f"- **{lab}** – {sc:.3f}") |
|
|
|
|
|
with tabs[5]: |
|
years = [p["published"] for p in res["papers"] if p.get("published")] |
|
if years: st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year")) |
|
|
|
|
|
st.markdown("---") |
|
follow = st.text_input("Ask follow-up question:") |
|
if st.button("Ask AI"): |
|
st.write(asyncio.run(answer_ai_question(follow, context=query))["answer"]) |
|
|
|
else: |
|
st.info("Enter a question and press **Run Search 🚀**") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
render_ui() |
|
|