mgbam commited on
Commit
24a46bd
·
verified ·
1 Parent(s): 2a8cf8d

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +39 -23
mcp/orchestrator.py CHANGED
@@ -1,4 +1,5 @@
1
  import asyncio
 
2
  from typing import Any, Dict, List, Literal, Union
3
 
4
  from mcp.pubmed import fetch_pubmed
@@ -16,6 +17,23 @@ from mcp.gemini import gemini_summarize, gemini_qa
16
  from mcp.openai_utils import ai_summarize, ai_qa
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
20
  """
21
  Await a list of asyncio.Tasks and return their results in order.
@@ -50,47 +68,46 @@ async def orchestrate_search(
50
  max_trials: int = 10,
51
  ) -> Dict[str, Any]:
52
  """
53
- Perform a comprehensive biomedical search pipeline:
54
  - Literature (PubMed + arXiv)
55
  - Entity extraction (UMLS)
56
  - Drug safety, gene & variant info, disease-gene mapping
57
  - Clinical trials, cBioPortal data
58
  - AI-driven summary
59
 
60
- Returns a dict with keys:
61
- papers, genes, umls, gene_disease, mesh_defs,
62
- drug_safety, clinical_trials, variants, ai_summary, llm_used
63
  """
64
- # Launch parallel tasks
65
  tasks = {
66
  'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
67
  'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
68
  'umls': asyncio.create_task(
69
  asyncio.to_thread(extract_umls_concepts, query)
70
  ),
71
- 'drug_safety': asyncio.create_task(fetch_drug_safety(query)),
72
- 'ncbi_gene': asyncio.create_task(search_gene(query)),
73
- 'mygene': asyncio.create_task(fetch_gene_info(query)),
74
- 'ensembl': asyncio.create_task(fetch_ensembl(query)),
75
- 'opentargets': asyncio.create_task(fetch_ot(query)),
76
- 'mesh': asyncio.create_task(get_mesh_definition(query)),
77
- 'trials': asyncio.create_task(search_trials(query, max_studies=max_trials)),
78
- 'cbio': asyncio.create_task(fetch_cbio(query)),
79
- 'disgenet': asyncio.create_task(disease_to_genes(query)),
80
  }
81
 
82
- # Await all
83
  results = await _gather_tasks(list(tasks.values()))
84
  data = dict(zip(tasks.keys(), results))
85
 
86
- # Process gene sources
87
  gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
88
  genes = _flatten_unique(gene_sources)
89
 
90
- # Combine literature
91
  papers = (data['pubmed'] or []) + (data['arxiv'] or [])
92
 
93
- # AI-driven summary selection
94
  summaries = " ".join(p.get('summary', '') for p in papers)
95
  if llm == 'gemini':
96
  ai_summary = await gemini_summarize(summaries)
@@ -103,11 +120,11 @@ async def orchestrate_search(
103
  'papers': papers,
104
  'genes': genes,
105
  'umls': data['umls'] or [],
106
- 'gene_disease': data['disgenet'] if isinstance(data['disgenet'], list) else [],
107
- 'mesh_defs': [data['mesh']] if isinstance(data['mesh'], str) and data['mesh'] else [],
108
  'drug_safety': data['drug_safety'] or [],
109
  'clinical_trials': data['trials'] or [],
110
- 'variants': data['cbio'] if isinstance(data['cbio'], list) else [],
111
  'ai_summary': ai_summary,
112
  'llm_used': llm_used,
113
  }
@@ -120,8 +137,7 @@ async def answer_ai_question(
120
  ) -> Dict[str, str]:
121
  """
122
  Answer a free-text question using the specified LLM, with fallback.
123
-
124
- Returns a dict {'answer': <text>}.
125
  """
126
  try:
127
  if llm == 'gemini':
 
1
  import asyncio
2
+ import httpx
3
  from typing import Any, Dict, List, Literal, Union
4
 
5
  from mcp.pubmed import fetch_pubmed
 
17
  from mcp.openai_utils import ai_summarize, ai_qa
18
 
19
 
20
+ def _safe_call(
21
+ func: Any,
22
+ *args,
23
+ default: Any = None,
24
+ **kwargs,
25
+ ) -> Any:
26
+ """
27
+ Safely call an async function, returning a default on HTTP or other failures.
28
+ """
29
+ try:
30
+ return await func(*args, **kwargs) # type: ignore
31
+ except httpx.HTTPStatusError:
32
+ return default
33
+ except Exception:
34
+ return default
35
+
36
+
37
  async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
38
  """
39
  Await a list of asyncio.Tasks and return their results in order.
 
68
  max_trials: int = 10,
69
  ) -> Dict[str, Any]:
70
  """
71
+ Perform a comprehensive biomedical search pipeline with fault tolerance:
72
  - Literature (PubMed + arXiv)
73
  - Entity extraction (UMLS)
74
  - Drug safety, gene & variant info, disease-gene mapping
75
  - Clinical trials, cBioPortal data
76
  - AI-driven summary
77
 
78
+ Individual fetch functions that fail with an HTTP error will return an empty default,
79
+ ensuring the pipeline always completes.
 
80
  """
81
+ # Launch parallel tasks with safe wrapper for potential HTTP errors
82
  tasks = {
83
  'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
84
  'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
85
  'umls': asyncio.create_task(
86
  asyncio.to_thread(extract_umls_concepts, query)
87
  ),
88
+ 'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
89
+ 'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
90
+ 'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
91
+ 'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
92
+ 'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
93
+ 'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
94
+ 'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
95
+ 'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
96
+ 'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
97
  }
98
 
99
+ # Await all tasks
100
  results = await _gather_tasks(list(tasks.values()))
101
  data = dict(zip(tasks.keys(), results))
102
 
103
+ # Consolidate gene sources
104
  gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
105
  genes = _flatten_unique(gene_sources)
106
 
107
+ # Merge literature results
108
  papers = (data['pubmed'] or []) + (data['arxiv'] or [])
109
 
110
+ # AI-driven summary
111
  summaries = " ".join(p.get('summary', '') for p in papers)
112
  if llm == 'gemini':
113
  ai_summary = await gemini_summarize(summaries)
 
120
  'papers': papers,
121
  'genes': genes,
122
  'umls': data['umls'] or [],
123
+ 'gene_disease': data['disgenet'] or [],
124
+ 'mesh_defs': [data['mesh']] if data['mesh'] else [],
125
  'drug_safety': data['drug_safety'] or [],
126
  'clinical_trials': data['trials'] or [],
127
+ 'variants': data['cbio'] or [],
128
  'ai_summary': ai_summary,
129
  'llm_used': llm_used,
130
  }
 
137
  ) -> Dict[str, str]:
138
  """
139
  Answer a free-text question using the specified LLM, with fallback.
140
+ Returns {'answer': text}.
 
141
  """
142
  try:
143
  if llm == 'gemini':