Update mcp/orchestrator.py
Browse files- mcp/orchestrator.py +157 -104
mcp/orchestrator.py
CHANGED
@@ -1,127 +1,180 @@
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
-
MedGenesis β dual-LLM orchestrator
|
3 |
-
|
4 |
-
β’ Accepts llm
|
5 |
-
β’
|
|
|
|
|
6 |
"""
|
|
|
7 |
from __future__ import annotations
|
8 |
import asyncio, itertools, logging
|
9 |
-
from typing import Dict, Any, List
|
10 |
-
|
11 |
-
from mcp.arxiv
|
12 |
-
from mcp.pubmed
|
13 |
-
from mcp.ncbi
|
14 |
-
from mcp.mygene
|
15 |
-
from mcp.ensembl
|
16 |
-
from mcp.opentargets
|
17 |
-
from mcp.umls
|
18 |
-
from mcp.openfda
|
19 |
-
from mcp.disgenet
|
20 |
-
from mcp.clinicaltrials
|
21 |
-
from mcp.cbio
|
22 |
-
from mcp.openai_utils
|
23 |
-
from mcp.gemini
|
24 |
|
25 |
log = logging.getLogger(__name__)
|
26 |
-
|
27 |
|
28 |
|
29 |
-
|
30 |
-
|
31 |
if engine.lower() == "gemini":
|
32 |
return gemini_summarize, gemini_qa, "gemini"
|
33 |
return ai_summarize, ai_qa, "openai"
|
34 |
|
35 |
-
async def _gather_safely(*aws, as_list: bool = True):
|
36 |
-
"""await gather() that converts Exception β RuntimeError placeholder"""
|
37 |
-
out = await asyncio.gather(*aws, return_exceptions=True)
|
38 |
-
if as_list:
|
39 |
-
# filter exceptions β keep structure but drop failures
|
40 |
-
return [x for x in out if not isinstance(x, Exception)]
|
41 |
-
return out
|
42 |
|
43 |
-
async def
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return {
|
58 |
-
"
|
59 |
-
"
|
60 |
-
"
|
61 |
-
"
|
62 |
-
"
|
|
|
|
|
|
|
|
|
|
|
63 |
}
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
umls, fda, gene_dat, trials, variants = await asyncio.gather(
|
87 |
-
_gather_safely(*umls_f),
|
88 |
-
_gather_safely(*fda_f),
|
89 |
-
gene_bundle,
|
90 |
-
trials_task,
|
91 |
-
cbio_task,
|
92 |
-
)
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
97 |
|
98 |
return {
|
99 |
-
"
|
100 |
-
"
|
101 |
-
"
|
102 |
-
"
|
103 |
-
"
|
104 |
-
"genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
|
105 |
-
"mesh_defs" : gene_dat["mesh"],
|
106 |
-
"gene_disease" : gene_dat["ot_assoc"],
|
107 |
-
"clinical_trials" : trials,
|
108 |
-
"variants" : variants or [],
|
109 |
}
|
110 |
|
111 |
-
|
112 |
-
async def answer_ai_question(question: str,
|
113 |
-
"""
|
|
|
|
|
114 |
_, qa_fn, _ = _llm_router(llm)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
""
|
121 |
-
async def _wrapper():
|
122 |
-
try:
|
123 |
-
return await fn(*args)
|
124 |
-
except Exception as exc:
|
125 |
-
log.warning("background task %s failed: %s", fn.__name__, exc)
|
126 |
-
return RuntimeError(str(exc))
|
127 |
-
return asyncio.create_task(_wrapper())
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# mcp/orchestrator.py
|
3 |
+
|
4 |
"""
|
5 |
+
MedGenesis β dual-LLM orchestrator (v4)
|
6 |
+
---------------------------------------
|
7 |
+
β’ Accepts llm="openai" | "gemini" (defaults to OpenAI)
|
8 |
+
β’ Safely runs all data-source calls in parallel
|
9 |
+
β’ Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
|
10 |
+
β’ Returns one dict that the Streamlit UI can rely on
|
11 |
"""
|
12 |
+
|
13 |
from __future__ import annotations
|
14 |
import asyncio, itertools, logging
|
15 |
+
from typing import Dict, Any, List
|
16 |
+
|
17 |
+
from mcp.arxiv import fetch_arxiv
|
18 |
+
from mcp.pubmed import fetch_pubmed
|
19 |
+
from mcp.ncbi import search_gene, get_mesh_definition
|
20 |
+
from mcp.mygene import fetch_gene_info
|
21 |
+
from mcp.ensembl import fetch_ensembl
|
22 |
+
from mcp.opentargets import fetch_ot
|
23 |
+
from mcp.umls import lookup_umls
|
24 |
+
from mcp.openfda import fetch_drug_safety
|
25 |
+
from mcp.disgenet import disease_to_genes
|
26 |
+
from mcp.clinicaltrials import fetch_clinical_trials
|
27 |
+
from mcp.cbio import fetch_cbio_variants
|
28 |
+
from mcp.openai_utils import ai_summarize, ai_qa
|
29 |
+
from mcp.gemini import gemini_summarize, gemini_qa
|
30 |
|
31 |
log = logging.getLogger(__name__)
|
32 |
+
_DEFAULT_LLM = "openai"
|
33 |
|
34 |
|
35 |
+
def _llm_router(engine: str = _DEFAULT_LLM):
|
36 |
+
"""Returns (summarize_fn, qa_fn, engine_name)."""
|
37 |
if engine.lower() == "gemini":
|
38 |
return gemini_summarize, gemini_qa, "gemini"
|
39 |
return ai_summarize, ai_qa, "openai"
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
async def _safe_gather(*tasks, return_exceptions: bool = False):
|
43 |
+
"""
|
44 |
+
Wrapper around asyncio.gather that logs failures
|
45 |
+
and optionally returns exceptions as results.
|
46 |
+
"""
|
47 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
48 |
+
cleaned = []
|
49 |
+
for idx, res in enumerate(results):
|
50 |
+
if isinstance(res, Exception):
|
51 |
+
log.warning("Task %d failed: %s", idx, res)
|
52 |
+
if return_exceptions:
|
53 |
+
cleaned.append(res)
|
54 |
+
else:
|
55 |
+
cleaned.append(res)
|
56 |
+
return cleaned
|
57 |
+
|
58 |
+
|
59 |
+
async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
|
60 |
+
"""
|
61 |
+
Main entry point for MedGenesis UI.
|
62 |
+
Returns a dict with:
|
63 |
+
- papers, umls, drug_safety, clinical_trials, variants
|
64 |
+
- genes, mesh_defs, gene_disease
|
65 |
+
- ai_summary, llm_used
|
66 |
+
"""
|
67 |
+
# 1) Literature (PubMed + arXiv in parallel)
|
68 |
+
pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
|
69 |
+
arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
|
70 |
+
papers_raw = await _safe_gather(pubmed_t, arxiv_t)
|
71 |
+
papers = list(itertools.chain.from_iterable(papers_raw))[:30]
|
72 |
+
|
73 |
+
# 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
|
74 |
+
seeds = {
|
75 |
+
w.strip()
|
76 |
+
for p in papers
|
77 |
+
for w in p.get("summary", "")[:500].split()
|
78 |
+
if w.isalpha()
|
79 |
+
}
|
80 |
+
seeds = list(seeds)[:10]
|
81 |
+
|
82 |
+
# 3) Fan-out all bio-enrichment tasks safely
|
83 |
+
umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
|
84 |
+
fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
|
85 |
+
gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
|
86 |
+
trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
|
87 |
+
cbio_t = asyncio.create_task(
|
88 |
+
fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
|
89 |
+
)
|
90 |
+
|
91 |
+
umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
|
92 |
+
_safe_gather(*umls_tasks, return_exceptions=True),
|
93 |
+
_safe_gather(*fda_tasks, return_exceptions=True),
|
94 |
+
gene_enrich_t,
|
95 |
+
trials_t,
|
96 |
+
cbio_t,
|
97 |
+
)
|
98 |
+
|
99 |
+
# 4) Deduplicate and flatten genes
|
100 |
+
genes = {
|
101 |
+
g["symbol"]
|
102 |
+
for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
|
103 |
+
for g in source if isinstance(g, dict) and g.get("symbol")
|
104 |
+
}
|
105 |
+
genes = list(genes)
|
106 |
+
|
107 |
+
# 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
|
108 |
+
seen = set()
|
109 |
+
unique_vars: List[dict] = []
|
110 |
+
for var in variants or []:
|
111 |
+
key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
|
112 |
+
if key not in seen:
|
113 |
+
seen.add(key)
|
114 |
+
unique_vars.append(var)
|
115 |
+
|
116 |
+
# 6) LLM summary
|
117 |
+
summarize_fn, _, engine_used = _llm_router(llm)
|
118 |
+
long_text = " ".join(p.get("summary", "") for p in papers)
|
119 |
+
ai_summary = await summarize_fn(long_text[:12000])
|
120 |
+
|
121 |
return {
|
122 |
+
"papers": papers,
|
123 |
+
"umls": [u for u in umls_list if not isinstance(u, Exception)],
|
124 |
+
"drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
|
125 |
+
"clinical_trials": trials or [],
|
126 |
+
"variants": unique_vars,
|
127 |
+
"genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
|
128 |
+
"mesh_defs": gene_data["mesh"],
|
129 |
+
"gene_disease": gene_data["ot_assoc"],
|
130 |
+
"ai_summary": ai_summary,
|
131 |
+
"llm_used": engine_used,
|
132 |
}
|
133 |
|
134 |
|
135 |
+
async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
|
136 |
+
"""
|
137 |
+
Fan-out gene-related tasks for each seed key:
|
138 |
+
- NCBI gene lookup
|
139 |
+
- MeSH definition
|
140 |
+
- MyGene.info
|
141 |
+
- Ensembl xrefs
|
142 |
+
- OpenTargets associations
|
143 |
+
Returns a dict of lists.
|
144 |
+
"""
|
145 |
+
jobs = []
|
146 |
+
for k in keys:
|
147 |
+
jobs.extend([
|
148 |
+
asyncio.create_task(search_gene(k)),
|
149 |
+
asyncio.create_task(get_mesh_definition(k)),
|
150 |
+
asyncio.create_task(fetch_gene_info(k)),
|
151 |
+
asyncio.create_task(fetch_ensembl(k)),
|
152 |
+
asyncio.create_task(fetch_ot(k)),
|
153 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
results = await _safe_gather(*jobs, return_exceptions=True)
|
156 |
+
|
157 |
+
# Group back into 5 buckets
|
158 |
+
def bucket(idx: int):
|
159 |
+
return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]
|
160 |
|
161 |
return {
|
162 |
+
"ncbi": bucket(0),
|
163 |
+
"mesh": bucket(1),
|
164 |
+
"mygene": bucket(2),
|
165 |
+
"ensembl": bucket(3),
|
166 |
+
"ot_assoc": bucket(4),
|
|
|
|
|
|
|
|
|
|
|
167 |
}
|
168 |
|
169 |
+
|
170 |
+
async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
|
171 |
+
"""
|
172 |
+
Follow-up QA: wraps the chosen LLMβs QA function.
|
173 |
+
"""
|
174 |
_, qa_fn, _ = _llm_router(llm)
|
175 |
+
prompt = f"Q: {question}\nContext: {context}\nA:"
|
176 |
+
try:
|
177 |
+
answer = await qa_fn(prompt)
|
178 |
+
except Exception as e:
|
179 |
+
answer = f"LLM error: {e}"
|
180 |
+
return {"answer": answer}
|
|
|
|
|
|
|
|
|
|
|
|
|
|