#!/usr/bin/env python3 # ────────────────────────────────────────────────────────────────────── # MedGenesis AI – Streamlit UI (OpenAI + Gemini, CPU-only) # ────────────────────────────────────────────────────────────────────── import os, pathlib, asyncio, re from pathlib import Path from datetime import datetime import streamlit as st import pandas as pd import plotly.express as px from fpdf import FPDF from streamlit_agraph import agraph # ── internal helpers -------------------------------------------------- from mcp.orchestrator import orchestrate_search, answer_ai_question from mcp.workspace import get_workspace, save_query from mcp.knowledge_graph import build_agraph from mcp.graph_metrics import build_nx, get_top_hubs, get_density from mcp.alerts import check_alerts # ── Streamlit telemetry dir fix (HF Spaces sandbox quirks) ------------ os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit" os.environ["XDG_STATE_HOME"] = "/tmp" os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false" pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True) ROOT = Path(__file__).parent LOGO = ROOT / "assets" / "logo.png" # ══════════════════════════════════════════════════════════════════════ # Small util helpers # ══════════════════════════════════════════════════════════════════════ def _latin1_safe(txt: str) -> str: """Replace non-Latin-1 chars – keeps FPDF happy.""" return txt.encode("latin-1", "replace").decode("latin-1") def _pdf(papers: list[dict]) -> bytes: pdf = FPDF() pdf.set_auto_page_break(auto=True, margin=15) pdf.add_page() pdf.set_font("Helvetica", size=11) pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Literature results"), ln=True, align="C") pdf.ln(3) for i, p in enumerate(papers, 1): pdf.set_font("Helvetica", "B", 11) pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p['title']}")) pdf.set_font("Helvetica", "", 9) body = ( f"{p['authors']}\n" f"{p['summary']}\n" f"{p['link']}\n" ) pdf.multi_cell(0, 6, _latin1_safe(body)) pdf.ln(1) # FPDF already returns latin-1 bytes – no extra encode needed return pdf.output(dest="S").encode("latin-1", "replace") def _workspace_sidebar() -> None: with st.sidebar: st.header("🗂 Workspace") ws = get_workspace() if not ws: st.info("Run a search then press **Save** to populate this list.") return for i, item in enumerate(ws, 1): with st.expander(f"{i}. {item['query']}"): st.write(item["result"]["ai_summary"]) # ══════════════════════════════════════════════════════════════════════ # Main Streamlit UI # ══════════════════════════════════════════════════════════════════════ def render_ui() -> None: st.set_page_config("MedGenesis AI", layout="wide") # ── Session-state defaults ──────────────────────────────────────── for k, v in { "query_result": None, "followup_input": "", "followup_response": None, "last_query": "", "last_llm": "", }.items(): st.session_state.setdefault(k, v) _workspace_sidebar() col_logo, col_title = st.columns([0.15, 0.85]) with col_logo: if LOGO.exists(): st.image(LOGO, width=110) with col_title: st.markdown("## 🧬 **MedGenesis AI**") st.caption("Multi-source biomedical assistant · OpenAI / Gemini") llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True) query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy") # ── alert notifications (async) ─────────────────────────────────── saved_qs = [w["query"] for w in get_workspace()] if saved_qs: try: news = asyncio.run(check_alerts(saved_qs)) if news: with st.sidebar: st.subheader("🔔 New papers") for q, lnks in news.items(): st.write(f"**{q}** – {len(lnks)} new") except Exception: pass # network hiccups – silent # ── Run Search ---------------------------------------------------- if st.button("Run Search 🚀") and query.strip(): with st.spinner("Collecting literature & biomedical data …"): res = asyncio.run(orchestrate_search(query, llm=llm)) # store in session st.session_state.update( query_result=res, last_query=query, last_llm=llm, followup_input="", followup_response=None, ) st.success(f"Completed with **{res['llm_used'].title()}**") res = st.session_state.query_result if not res: st.info("Enter a biomedical question and press **Run Search 🚀**") return # ── Tabs ---------------------------------------------------------- tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"]) # 1) Results ------------------------------------------------------- with tabs[0]: for i, p in enumerate(res["papers"], 1): st.markdown( f"**{i}. [{p['title']}]({p['link']})** " f"*{p['authors']}*" ) st.write(p["summary"]) c_csv, c_pdf = st.columns(2) with c_csv: st.download_button( "CSV", pd.DataFrame(res["papers"]).to_csv(index=False), "papers.csv", "text/csv", ) with c_pdf: st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf") if st.button("💾 Save"): save_query(st.session_state.last_query, res) st.success("Saved to workspace") st.subheader("UMLS concepts") for c in (res["umls"] or []): if isinstance(c, dict) and c.get("cui"): st.write(f"- **{c['name']}** ({c['cui']})") st.subheader("OpenFDA safety signals") for d in (res["drug_safety"] or []): st.json(d) st.subheader("AI summary") st.info(res["ai_summary"]) # 2) Genes --------------------------------------------------------- with tabs[1]: st.header("Gene / Variant signals") genes_list = [ g for g in res["genes"] if isinstance(g, dict) and (g.get("symbol") or g.get("name")) ] if not genes_list: st.info("No gene hits (rate-limited or none found).") for g in genes_list: st.write(f"- **{g.get('symbol') or g.get('name')}** " f"{g.get('description','')}") if res["gene_disease"]: st.markdown("### DisGeNET associations") ok = [d for d in res["gene_disease"] if isinstance(d, dict)] if ok: st.json(ok[:15]) defs = [d for d in res["mesh_defs"] if isinstance(d, str) and d] if defs: st.markdown("### MeSH definitions") for d in defs: st.write("-", d) # 3) Trials -------------------------------------------------------- with tabs[2]: st.header("Clinical trials") ct = res["clinical_trials"] if not ct: st.info("No trials (rate-limited or none found).") for t in ct: nct = t.get("NCTId", [""])[0] bttl = t.get("BriefTitle", [""])[0] phase= t.get("Phase", [""])[0] stat = t.get("OverallStatus", [""])[0] st.markdown(f"**{nct}** – {bttl}") st.write(f"Phase {phase} | Status {stat}") # 4) Graph --------------------------------------------------------- with tabs[3]: nodes, edges, cfg = build_agraph( res["papers"], res["umls"], res["drug_safety"] ) hl = st.text_input("Highlight node:", key="hl") if hl: pat = re.compile(re.escape(hl), re.I) for n in nodes: n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3" agraph(nodes, edges, cfg) # 5) Metrics ------------------------------------------------------- with tabs[4]: G = build_nx( [n.__dict__ for n in nodes], [e.__dict__ for e in edges], ) st.metric("Density", f"{get_density(G):.3f}") st.markdown("**Top hubs**") for nid, sc in get_top_hubs(G, k=5): label = next((n.label for n in nodes if n.id == nid), nid) st.write(f"- {label} {sc:.3f}") # 6) Visuals ------------------------------------------------------- with tabs[5]: years = [ p["published"][:4] for p in res["papers"] if p.get("published") and len(p["published"]) >= 4 ] if years: st.plotly_chart( px.histogram( years, nbins=min(15, len(set(years))), title="Publication Year" ) ) # ── Follow-up Q-A ------------------------------------------------- st.markdown("---") st.text_input("Ask follow-up question:", key="followup_input", placeholder="e.g. Any Phase III trials recruiting now?") def _on_ask(): q = st.session_state.followup_input.strip() if not q: st.warning("Please type a question first.") return with st.spinner("Querying LLM …"): ans = asyncio.run( answer_ai_question( q, context=st.session_state.last_query, llm=st.session_state.last_llm) ) st.session_state.followup_response = ( ans.get("answer") or "LLM unavailable or quota exceeded." ) st.button("Ask AI", on_click=_on_ask) if st.session_state.followup_response: st.write(st.session_state.followup_response) # ── entry-point ─────────────────────────────────────────────────────── if __name__ == "__main__": render_ui()