mgbam commited on
Commit
73fc0d7
·
verified ·
1 Parent(s): d7bf01e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -100
app.py CHANGED
@@ -14,20 +14,25 @@ from mcp.knowledge_graph import build_agraph
14
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
15
  from mcp.protocols import draft_protocol
16
 
17
- # Streamlit config
18
  st.set_page_config(page_title="MedGenesis AI", layout="wide")
 
 
19
  if "res" not in st.session_state:
20
  st.session_state.res = None
21
 
 
22
  st.title("🧬 MedGenesis AI")
23
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
24
  query = st.text_input("Enter biomedical question")
25
 
26
- # PDF generator
27
  def _make_pdf(papers):
28
  pdf = FPDF()
29
- pdf.add_page(); pdf.set_font("Helvetica", size=12)
30
- pdf.cell(0, 10, "MedGenesis AI – Results", ln=True, align="C"); pdf.ln(5)
 
 
31
  for i, p in enumerate(papers, 1):
32
  pdf.set_font("Helvetica", "B", 11)
33
  pdf.multi_cell(0, 7, f"{i}. {p.get('title','')}")
@@ -35,105 +40,101 @@ def _make_pdf(papers):
35
  body = f"{p.get('authors','')}
36
  {p.get('summary','')}
37
  {p.get('link','')}"
38
- pdf.multi_cell(0, 6, body); pdf.ln(3)
 
39
  return pdf.output(dest="S").encode("latin-1", errors="replace")
40
 
41
- # Run search
42
- enabled = st.button("Run Search 🚀") and query.strip()
43
- if enabled:
44
  with st.spinner("Gathering data…"):
45
  st.session_state.res = asyncio.run(orchestrate_search(query, llm))
 
 
46
  res = st.session_state.res
 
 
47
  if not res:
48
- st.info("Enter a query and press Run Search")
49
- st.stop()
50
-
51
- # Tabs
52
- tabs = st.tabs([
53
- "Results", "Graph", "Clusters", "Variants",
54
- "Trials", "Metrics", "Visuals", "Protocols"
55
- ])
56
-
57
- # Results
58
- title_tab, graph_tab, clust_tab, var_tab, trial_tab, met_tab, vis_tab, proto_tab = tabs
59
-
60
- with title_tab:
61
- for i, p in enumerate(res["papers"], 1):
62
- st.markdown(f"**{i}. [{p['title']}]({p['link']})**")
63
- st.write(p["summary"])
64
- c1, c2 = st.columns(2)
65
- c1.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
66
- "papers.csv", "text/csv")
67
- c2.download_button("PDF", _make_pdf(res["papers"]),
68
- "papers.pdf", "application/pdf")
69
- st.subheader("AI summary"); st.info(res["ai_summary"])
70
-
71
- # Graph
72
- with graph_tab:
73
- nodes, edges, cfg = build_agraph(
74
- res["papers"], res["umls"], res["drug_safety"], res["umls_relations"]
75
- )
76
- hl = st.text_input("Highlight node:", key="hl")
77
- if hl:
78
- pat = re.compile(re.escape(hl), re.I)
79
- for n in nodes:
80
- n.color = "#f1c40f" if pat.search(n.label) else n.color
81
- agraph(nodes, edges, cfg)
82
-
83
- # Clusters
84
- with clust_tab:
85
- clusters = res.get("clusters", [])
86
- if clusters:
87
- df = pd.DataFrame({
88
- "title": [p['title'] for p in res['papers']],
89
- "cluster": clusters
90
- })
91
- st.write("### Paper Clusters")
92
- for c in sorted(set(clusters)):
93
- st.write(f"**Cluster {c}**")
94
- for t in df[df['cluster']==c]['title']:
95
- st.write(f"- {t}")
96
- else:
97
- st.info("No clusters to show.")
98
-
99
- # Variants
100
- with var_tab:
101
- if res.get("variants"):
102
- st.json(res["variants"])
103
- else:
104
- st.warning("No variants found. Try 'TP53' or 'BRCA1'.")
105
-
106
- # Trials
107
- with trial_tab:
108
- if res.get("clinical_trials"):
109
- st.json(res["clinical_trials"])
110
- else:
111
- st.warning("No trials found. Try a disease or drug.")
112
-
113
- # Metrics
114
- with met_tab:
115
- G = build_nx(
116
- [n.__dict__ for n in nodes], [e.__dict__ for e in edges]
117
- )
118
- st.metric("Density", f"{get_density(G):.3f}")
119
- st.markdown("**Top hubs**")
120
- for nid, score in get_top_hubs(G):
121
- lbl = next((n.label for n in nodes if n.id==nid), nid)
122
- st.write(f"- {lbl}: {score:.3f}")
123
-
124
- # Visuals
125
- with vis_tab:
126
- years = [p.get("published","")[:4] for p in res["papers"] if p.get("published")]
127
- if years:
128
- st.plotly_chart(px.histogram(years, nbins=10, title="Publication Year"))
129
-
130
- # Protocols
131
- with proto_tab:
132
- hyp = st.text_input("Enter hypothesis for protocol:", key="proto_q")
133
- if st.button("Draft Protocol") and hyp.strip():
134
- with st.spinner("Generating protocol…"):
135
- doc = asyncio.run(draft_protocol(
136
- hyp, context=res["ai_summary"], llm=llm
137
- ))
138
- st.subheader("Experimental Protocol")
139
- st.write(doc)
 
14
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
15
  from mcp.protocols import draft_protocol
16
 
17
+ # Streamlit configuration
18
  st.set_page_config(page_title="MedGenesis AI", layout="wide")
19
+
20
+ # Initialize session state
21
  if "res" not in st.session_state:
22
  st.session_state.res = None
23
 
24
+ # Header UI
25
  st.title("🧬 MedGenesis AI")
26
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
27
  query = st.text_input("Enter biomedical question")
28
 
29
+ # PDF generation helper
30
  def _make_pdf(papers):
31
  pdf = FPDF()
32
+ pdf.add_page()
33
+ pdf.set_font("Helvetica", size=12)
34
+ pdf.cell(0, 10, "MedGenesis AI – Results", ln=True, align="C")
35
+ pdf.ln(5)
36
  for i, p in enumerate(papers, 1):
37
  pdf.set_font("Helvetica", "B", 11)
38
  pdf.multi_cell(0, 7, f"{i}. {p.get('title','')}")
 
40
  body = f"{p.get('authors','')}
41
  {p.get('summary','')}
42
  {p.get('link','')}"
43
+ pdf.multi_cell(0, 6, body)
44
+ pdf.ln(3)
45
  return pdf.output(dest="S").encode("latin-1", errors="replace")
46
 
47
+ # Trigger search
48
+ if st.button("Run Search 🚀") and query.strip():
 
49
  with st.spinner("Gathering data…"):
50
  st.session_state.res = asyncio.run(orchestrate_search(query, llm))
51
+
52
+ # Retrieve results
53
  res = st.session_state.res
54
+
55
+ # If no results yet, prompt user
56
  if not res:
57
+ st.info("Enter a question and press **Run Search 🚀** to begin.")
58
+ else:
59
+ # Create tabs
60
+ tabs = st.tabs(["Results", "Graph", "Clusters", "Variants", "Trials", "Metrics", "Visuals", "Protocols"])
61
+ title_tab, graph_tab, clust_tab, var_tab, trial_tab, met_tab, vis_tab, proto_tab = tabs
62
+
63
+ # Results Tab
64
+ with title_tab:
65
+ for i, p in enumerate(res["papers"], 1):
66
+ st.markdown(f"**{i}. [{p['title']}]({p['link']})**")
67
+ st.write(p["summary"])
68
+ c1, c2 = st.columns(2)
69
+ c1.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
70
+ "papers.csv", "text/csv")
71
+ c2.download_button("PDF", _make_pdf(res["papers"]),
72
+ "papers.pdf", "application/pdf")
73
+ st.subheader("AI summary")
74
+ st.info(res["ai_summary"])
75
+
76
+ # Graph Tab
77
+ with graph_tab:
78
+ nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res.get("drug_safety", []), res.get("umls_relations", []))
79
+ hl = st.text_input("Highlight node:", key="hl")
80
+ if hl:
81
+ pat = re.compile(re.escape(hl), re.I)
82
+ for n in nodes:
83
+ n.color = "#f1c40f" if pat.search(n.label) else n.color
84
+ agraph(nodes, edges, cfg)
85
+
86
+ # Clusters Tab
87
+ with clust_tab:
88
+ clusters = res.get("clusters", [])
89
+ if clusters:
90
+ df = pd.DataFrame({
91
+ "title": [p['title'] for p in res['papers']],
92
+ "cluster": clusters
93
+ })
94
+ st.write("### Paper Clusters")
95
+ for c in sorted(set(clusters)):
96
+ st.write(f"**Cluster {c}**")
97
+ for t in df[df['cluster'] == c]['title']:
98
+ st.write(f"- {t}")
99
+ else:
100
+ st.info("No clusters to show.")
101
+
102
+ # Variants Tab
103
+ with var_tab:
104
+ variants = res.get("variants", [])
105
+ if variants:
106
+ st.json(variants)
107
+ else:
108
+ st.warning("No variants found. Try a well-known gene like 'TP53'.")
109
+
110
+ # Trials Tab
111
+ with trial_tab:
112
+ trials = res.get("clinical_trials", [])
113
+ if trials:
114
+ st.json(trials)
115
+ else:
116
+ st.warning("No trials found. Try a disease name or specific drug.")
117
+
118
+ # Metrics Tab
119
+ with met_tab:
120
+ G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
121
+ st.metric("Density", f"{get_density(G):.3f}")
122
+ st.markdown("**Top hubs**")
123
+ for nid, sc in get_top_hubs(G):
124
+ label = next((n.label for n in nodes if n.id == nid), nid)
125
+ st.write(f"- {label}: {sc:.3f}")
126
+
127
+ # Visuals Tab
128
+ with vis_tab:
129
+ years = [p.get("published", "")[:4] for p in res["papers"] if p.get("published")]
130
+ if years:
131
+ st.plotly_chart(px.histogram(years, nbins=10, title="Publication Year"))
132
+
133
+ # Protocols Tab
134
+ with proto_tab:
135
+ hyp = st.text_input("Enter hypothesis for protocol:", key="proto_q")
136
+ if st.button("Draft Protocol") and hyp.strip():
137
+ with st.spinner("Generating protocol…"):
138
+ doc = asyncio.run(draft_protocol(hyp, context=res["ai_summary"], llm=llm))
139
+ st.subheader("Experimental Protocol")
140
+ st.write(doc)