mgbam commited on
Commit
2a8cf8d
·
verified ·
1 Parent(s): 4f7b321

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +109 -56
mcp/orchestrator.py CHANGED
@@ -1,6 +1,6 @@
1
- # mcp/orchestrator.py
2
-
3
  import asyncio
 
 
4
  from mcp.pubmed import fetch_pubmed
5
  from mcp.arxiv import fetch_arxiv
6
  from mcp.umls import extract_umls_concepts
@@ -11,70 +11,123 @@ from mcp.ensembl import fetch_ensembl
11
  from mcp.opentargets import fetch_ot
12
  from mcp.clinicaltrials import search_trials
13
  from mcp.cbio import fetch_cbio
 
14
  from mcp.gemini import gemini_summarize, gemini_qa
15
  from mcp.openai_utils import ai_summarize, ai_qa
16
- from mcp.disgenet import disease_to_genes
17
 
18
- async def orchestrate_search(query, llm="openai"):
19
- # --- Literature: PubMed + arXiv
20
- pubmed_task = asyncio.create_task(fetch_pubmed(query, max_results=7))
21
- arxiv_task = asyncio.create_task(fetch_arxiv(query, max_results=7))
22
- # --- UMLS, OpenFDA, Gene, Mesh
23
- umls_task = asyncio.create_task(extract_umls_concepts(query))
24
- fda_task = asyncio.create_task(fetch_drug_safety(query))
25
- gene_ncbi_task = asyncio.create_task(search_gene(query))
26
- mygene_task = asyncio.create_task(fetch_gene_info(query))
27
- ensembl_task = asyncio.create_task(fetch_ensembl(query))
28
- ot_task = asyncio.create_task(fetch_ot(query))
29
- mesh_task = asyncio.create_task(get_mesh_definition(query))
30
- # --- Trials, cBio, DisGeNET
31
- trials_task = asyncio.create_task(search_trials(query, max_studies=10))
32
- cbio_task = asyncio.create_task(fetch_cbio(query))
33
- disgenet_task = asyncio.create_task(disease_to_genes(query))
34
-
35
- # Run
36
- pubmed, arxiv, umls, fda, ncbi, mygene, ensembl, ot, mesh, trials, cbio, disgenet = await asyncio.gather(
37
- pubmed_task, arxiv_task, umls_task, fda_task, gene_ncbi_task,
38
- mygene_task, ensembl_task, ot_task, mesh_task, trials_task, cbio_task, disgenet_task
39
- )
40
- # Genes: flatten and deduplicate
41
- genes = []
42
- for g in (ncbi, mygene, ensembl, ot):
43
- if isinstance(g, list):
44
- genes.extend(g)
45
- elif isinstance(g, dict) and g:
46
- genes.append(g)
47
- genes = [g for i, g in enumerate(genes) if g and genes.index(g) == i] # dedup
48
-
49
- # --- AI summary (LLM engine select)
50
- papers = (pubmed or []) + (arxiv or [])
51
- if llm == "gemini":
52
- ai_summary = await gemini_summarize(" ".join([p.get("summary", "") for p in papers]))
53
- llm_used = "gemini"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
- ai_summary = await ai_summarize(" ".join([p.get("summary", "") for p in papers]))
56
- llm_used = "openai"
57
 
58
  return {
59
- "papers": papers,
60
- "genes": genes,
61
- "umls": umls or [],
62
- "gene_disease": disgenet if isinstance(disgenet, list) else [],
63
- "mesh_defs": [mesh] if isinstance(mesh, str) and mesh else [],
64
- "drug_safety": fda or [],
65
- "clinical_trials": trials or [],
66
- "variants": cbio if isinstance(cbio, list) else [],
67
- "ai_summary": ai_summary,
68
- "llm_used": llm_used
69
  }
70
 
71
- async def answer_ai_question(question, context="", llm="openai"):
72
- # Gemini fallback if OpenAI quota is exceeded
 
 
 
 
 
 
 
 
 
73
  try:
74
- if llm == "gemini":
75
  answer = await gemini_qa(question, context)
76
  else:
77
  answer = await ai_qa(question, context)
78
  except Exception as e:
79
- answer = f"LLM unavailable or quota exceeded. ({e})"
80
- return {"answer": answer}
 
 
 
1
  import asyncio
2
+ from typing import Any, Dict, List, Literal, Union
3
+
4
  from mcp.pubmed import fetch_pubmed
5
  from mcp.arxiv import fetch_arxiv
6
  from mcp.umls import extract_umls_concepts
 
11
  from mcp.opentargets import fetch_ot
12
  from mcp.clinicaltrials import search_trials
13
  from mcp.cbio import fetch_cbio
14
+ from mcp.disgenet import disease_to_genes
15
  from mcp.gemini import gemini_summarize, gemini_qa
16
  from mcp.openai_utils import ai_summarize, ai_qa
 
17
 
18
+
19
+ async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
20
+ """
21
+ Await a list of asyncio.Tasks and return their results in order.
22
+ """
23
+ return await asyncio.gather(*tasks)
24
+
25
+
26
+ def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
27
+ """
28
+ Flatten a list of items where elements may be lists or single values,
29
+ then deduplicate preserving order.
30
+ """
31
+ flat: List[Any] = []
32
+ seen = set()
33
+ for elem in items:
34
+ if isinstance(elem, list):
35
+ for x in elem:
36
+ if x not in seen:
37
+ seen.add(x)
38
+ flat.append(x)
39
+ elif elem is not None:
40
+ if elem not in seen:
41
+ seen.add(elem)
42
+ flat.append(elem)
43
+ return flat
44
+
45
+
46
+ async def orchestrate_search(
47
+ query: str,
48
+ llm: Literal['openai', 'gemini'] = 'openai',
49
+ max_papers: int = 7,
50
+ max_trials: int = 10,
51
+ ) -> Dict[str, Any]:
52
+ """
53
+ Perform a comprehensive biomedical search pipeline:
54
+ - Literature (PubMed + arXiv)
55
+ - Entity extraction (UMLS)
56
+ - Drug safety, gene & variant info, disease-gene mapping
57
+ - Clinical trials, cBioPortal data
58
+ - AI-driven summary
59
+
60
+ Returns a dict with keys:
61
+ papers, genes, umls, gene_disease, mesh_defs,
62
+ drug_safety, clinical_trials, variants, ai_summary, llm_used
63
+ """
64
+ # Launch parallel tasks
65
+ tasks = {
66
+ 'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
67
+ 'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
68
+ 'umls': asyncio.create_task(
69
+ asyncio.to_thread(extract_umls_concepts, query)
70
+ ),
71
+ 'drug_safety': asyncio.create_task(fetch_drug_safety(query)),
72
+ 'ncbi_gene': asyncio.create_task(search_gene(query)),
73
+ 'mygene': asyncio.create_task(fetch_gene_info(query)),
74
+ 'ensembl': asyncio.create_task(fetch_ensembl(query)),
75
+ 'opentargets': asyncio.create_task(fetch_ot(query)),
76
+ 'mesh': asyncio.create_task(get_mesh_definition(query)),
77
+ 'trials': asyncio.create_task(search_trials(query, max_studies=max_trials)),
78
+ 'cbio': asyncio.create_task(fetch_cbio(query)),
79
+ 'disgenet': asyncio.create_task(disease_to_genes(query)),
80
+ }
81
+
82
+ # Await all
83
+ results = await _gather_tasks(list(tasks.values()))
84
+ data = dict(zip(tasks.keys(), results))
85
+
86
+ # Process gene sources
87
+ gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
88
+ genes = _flatten_unique(gene_sources)
89
+
90
+ # Combine literature
91
+ papers = (data['pubmed'] or []) + (data['arxiv'] or [])
92
+
93
+ # AI-driven summary selection
94
+ summaries = " ".join(p.get('summary', '') for p in papers)
95
+ if llm == 'gemini':
96
+ ai_summary = await gemini_summarize(summaries)
97
+ llm_used = 'gemini'
98
  else:
99
+ ai_summary = await ai_summarize(summaries)
100
+ llm_used = 'openai'
101
 
102
  return {
103
+ 'papers': papers,
104
+ 'genes': genes,
105
+ 'umls': data['umls'] or [],
106
+ 'gene_disease': data['disgenet'] if isinstance(data['disgenet'], list) else [],
107
+ 'mesh_defs': [data['mesh']] if isinstance(data['mesh'], str) and data['mesh'] else [],
108
+ 'drug_safety': data['drug_safety'] or [],
109
+ 'clinical_trials': data['trials'] or [],
110
+ 'variants': data['cbio'] if isinstance(data['cbio'], list) else [],
111
+ 'ai_summary': ai_summary,
112
+ 'llm_used': llm_used,
113
  }
114
 
115
+
116
+ async def answer_ai_question(
117
+ question: str,
118
+ context: str = "",
119
+ llm: Literal['openai', 'gemini'] = 'openai',
120
+ ) -> Dict[str, str]:
121
+ """
122
+ Answer a free-text question using the specified LLM, with fallback.
123
+
124
+ Returns a dict {'answer': <text>}.
125
+ """
126
  try:
127
+ if llm == 'gemini':
128
  answer = await gemini_qa(question, context)
129
  else:
130
  answer = await ai_qa(question, context)
131
  except Exception as e:
132
+ answer = f"LLM unavailable or quota exceeded: {e}"
133
+ return {'answer': answer}