Update mcp/orchestrator.py
Browse files- mcp/orchestrator.py +39 -23
mcp/orchestrator.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import asyncio
|
|
|
2 |
from typing import Any, Dict, List, Literal, Union
|
3 |
|
4 |
from mcp.pubmed import fetch_pubmed
|
@@ -16,6 +17,23 @@ from mcp.gemini import gemini_summarize, gemini_qa
|
|
16 |
from mcp.openai_utils import ai_summarize, ai_qa
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
|
20 |
"""
|
21 |
Await a list of asyncio.Tasks and return their results in order.
|
@@ -50,47 +68,46 @@ async def orchestrate_search(
|
|
50 |
max_trials: int = 10,
|
51 |
) -> Dict[str, Any]:
|
52 |
"""
|
53 |
-
Perform a comprehensive biomedical search pipeline:
|
54 |
- Literature (PubMed + arXiv)
|
55 |
- Entity extraction (UMLS)
|
56 |
- Drug safety, gene & variant info, disease-gene mapping
|
57 |
- Clinical trials, cBioPortal data
|
58 |
- AI-driven summary
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
drug_safety, clinical_trials, variants, ai_summary, llm_used
|
63 |
"""
|
64 |
-
# Launch parallel tasks
|
65 |
tasks = {
|
66 |
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
|
67 |
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
|
68 |
'umls': asyncio.create_task(
|
69 |
asyncio.to_thread(extract_umls_concepts, query)
|
70 |
),
|
71 |
-
'drug_safety': asyncio.create_task(fetch_drug_safety
|
72 |
-
'ncbi_gene': asyncio.create_task(search_gene
|
73 |
-
'mygene': asyncio.create_task(fetch_gene_info
|
74 |
-
'ensembl': asyncio.create_task(fetch_ensembl
|
75 |
-
'opentargets': asyncio.create_task(fetch_ot
|
76 |
-
'mesh': asyncio.create_task(get_mesh_definition
|
77 |
-
'trials': asyncio.create_task(search_trials
|
78 |
-
'cbio': asyncio.create_task(fetch_cbio
|
79 |
-
'disgenet': asyncio.create_task(disease_to_genes
|
80 |
}
|
81 |
|
82 |
-
# Await all
|
83 |
results = await _gather_tasks(list(tasks.values()))
|
84 |
data = dict(zip(tasks.keys(), results))
|
85 |
|
86 |
-
#
|
87 |
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
|
88 |
genes = _flatten_unique(gene_sources)
|
89 |
|
90 |
-
#
|
91 |
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
|
92 |
|
93 |
-
# AI-driven summary
|
94 |
summaries = " ".join(p.get('summary', '') for p in papers)
|
95 |
if llm == 'gemini':
|
96 |
ai_summary = await gemini_summarize(summaries)
|
@@ -103,11 +120,11 @@ async def orchestrate_search(
|
|
103 |
'papers': papers,
|
104 |
'genes': genes,
|
105 |
'umls': data['umls'] or [],
|
106 |
-
'gene_disease': data['disgenet']
|
107 |
-
'mesh_defs': [data['mesh']] if
|
108 |
'drug_safety': data['drug_safety'] or [],
|
109 |
'clinical_trials': data['trials'] or [],
|
110 |
-
'variants': data['cbio']
|
111 |
'ai_summary': ai_summary,
|
112 |
'llm_used': llm_used,
|
113 |
}
|
@@ -120,8 +137,7 @@ async def answer_ai_question(
|
|
120 |
) -> Dict[str, str]:
|
121 |
"""
|
122 |
Answer a free-text question using the specified LLM, with fallback.
|
123 |
-
|
124 |
-
Returns a dict {'answer': <text>}.
|
125 |
"""
|
126 |
try:
|
127 |
if llm == 'gemini':
|
|
|
1 |
import asyncio
|
2 |
+
import httpx
|
3 |
from typing import Any, Dict, List, Literal, Union
|
4 |
|
5 |
from mcp.pubmed import fetch_pubmed
|
|
|
17 |
from mcp.openai_utils import ai_summarize, ai_qa
|
18 |
|
19 |
|
20 |
+
def _safe_call(
|
21 |
+
func: Any,
|
22 |
+
*args,
|
23 |
+
default: Any = None,
|
24 |
+
**kwargs,
|
25 |
+
) -> Any:
|
26 |
+
"""
|
27 |
+
Safely call an async function, returning a default on HTTP or other failures.
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
return await func(*args, **kwargs) # type: ignore
|
31 |
+
except httpx.HTTPStatusError:
|
32 |
+
return default
|
33 |
+
except Exception:
|
34 |
+
return default
|
35 |
+
|
36 |
+
|
37 |
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
|
38 |
"""
|
39 |
Await a list of asyncio.Tasks and return their results in order.
|
|
|
68 |
max_trials: int = 10,
|
69 |
) -> Dict[str, Any]:
|
70 |
"""
|
71 |
+
Perform a comprehensive biomedical search pipeline with fault tolerance:
|
72 |
- Literature (PubMed + arXiv)
|
73 |
- Entity extraction (UMLS)
|
74 |
- Drug safety, gene & variant info, disease-gene mapping
|
75 |
- Clinical trials, cBioPortal data
|
76 |
- AI-driven summary
|
77 |
|
78 |
+
Individual fetch functions that fail with an HTTP error will return an empty default,
|
79 |
+
ensuring the pipeline always completes.
|
|
|
80 |
"""
|
81 |
+
# Launch parallel tasks with safe wrapper for potential HTTP errors
|
82 |
tasks = {
|
83 |
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
|
84 |
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
|
85 |
'umls': asyncio.create_task(
|
86 |
asyncio.to_thread(extract_umls_concepts, query)
|
87 |
),
|
88 |
+
'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
|
89 |
+
'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
|
90 |
+
'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
|
91 |
+
'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
|
92 |
+
'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
|
93 |
+
'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
|
94 |
+
'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
|
95 |
+
'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
|
96 |
+
'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
|
97 |
}
|
98 |
|
99 |
+
# Await all tasks
|
100 |
results = await _gather_tasks(list(tasks.values()))
|
101 |
data = dict(zip(tasks.keys(), results))
|
102 |
|
103 |
+
# Consolidate gene sources
|
104 |
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
|
105 |
genes = _flatten_unique(gene_sources)
|
106 |
|
107 |
+
# Merge literature results
|
108 |
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
|
109 |
|
110 |
+
# AI-driven summary
|
111 |
summaries = " ".join(p.get('summary', '') for p in papers)
|
112 |
if llm == 'gemini':
|
113 |
ai_summary = await gemini_summarize(summaries)
|
|
|
120 |
'papers': papers,
|
121 |
'genes': genes,
|
122 |
'umls': data['umls'] or [],
|
123 |
+
'gene_disease': data['disgenet'] or [],
|
124 |
+
'mesh_defs': [data['mesh']] if data['mesh'] else [],
|
125 |
'drug_safety': data['drug_safety'] or [],
|
126 |
'clinical_trials': data['trials'] or [],
|
127 |
+
'variants': data['cbio'] or [],
|
128 |
'ai_summary': ai_summary,
|
129 |
'llm_used': llm_used,
|
130 |
}
|
|
|
137 |
) -> Dict[str, str]:
|
138 |
"""
|
139 |
Answer a free-text question using the specified LLM, with fallback.
|
140 |
+
Returns {'answer': text}.
|
|
|
141 |
"""
|
142 |
try:
|
143 |
if llm == 'gemini':
|