mgbam commited on
Commit
9c2f1fc
·
verified ·
1 Parent(s): 1f7d1c0

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +56 -27
mcp/orchestrator.py CHANGED
@@ -1,9 +1,11 @@
1
  # mcp/orchestrator.py
2
  import asyncio
3
- from typing import Dict, Any
 
4
  from mcp.arxiv import fetch_arxiv
5
  from mcp.pubmed import fetch_pubmed
6
  from mcp.nlp import extract_umls_concepts
 
7
  from mcp.umls_rel import fetch_relations
8
  from mcp.openfda import fetch_drug_safety
9
  from mcp.ncbi import search_gene, get_mesh_definition
@@ -14,37 +16,49 @@ from mcp.opentargets import ot
14
  from mcp.cbio import cbio
15
  from mcp.openai_utils import ai_summarize, ai_qa
16
  from mcp.gemini import gemini_summarize, gemini_qa
 
 
17
 
18
  def _get_llm(llm: str):
19
- return (gemini_summarize, gemini_qa) if llm.lower() == "gemini" else (ai_summarize, ai_qa)
 
 
 
 
 
 
20
 
21
  async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
22
- # 1) Parallel literature pulls
23
- arxiv_t, pubmed_t = fetch_arxiv(query), fetch_pubmed(query)
24
- papers = []
25
- for res in await asyncio.gather(arxiv_t, pubmed_t, return_exceptions=True):
 
 
 
 
 
 
26
  if isinstance(res, list):
27
  papers.extend(res)
28
 
29
- # 2) SpaCy→UMLS concept linking
30
- blob = " ".join(p.get("summary","") for p in papers)
31
  umls = await extract_umls_concepts(blob)
32
 
33
- # 3) Fetch UMLS relations in parallel
34
- rels = await asyncio.gather(
35
- *[fetch_relations(c["cui"]) for c in umls],
36
- return_exceptions=True
37
- )
38
 
39
- # 4) Enrich: OpenFDA, NCBI, DisGeNET, Trials, OpenTargets, cBioPortal
40
- keys = [c["name"] for c in umls]
41
- fda_tasks = [fetch_drug_safety(k) for k in keys]
42
- gene_task = search_gene(keys[0]) if keys else asyncio.sleep(0, result=[])
43
- mesh_task = get_mesh_definition(keys[0]) if keys else asyncio.sleep(0, result="")
44
- dis_task = disease_to_genes(keys[0]) if keys else asyncio.sleep(0, result=[])
45
- trials_task = search_trials(query)
46
- ot_task = ot.fetch(keys[0]) if keys else asyncio.sleep(0, result=[])
47
- cbio_task = cbio.fetch_variants(keys[0]) if keys else asyncio.sleep(0, result=[])
48
 
49
  fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
50
  asyncio.gather(*fda_tasks, return_exceptions=True),
@@ -53,17 +67,26 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
53
  return_exceptions=False
54
  )
55
 
56
- # 5) AI summary
57
- summarize, _ = _get_llm(llm)
 
 
 
 
 
 
 
 
 
58
  try:
59
- ai_summary = await summarize(blob)
60
  except Exception:
61
  ai_summary = "LLM summary failed."
62
 
63
  return {
64
  "papers": papers,
65
  "umls": umls,
66
- "umls_relations": rels,
67
  "drug_safety": fda,
68
  "genes": [gene],
69
  "mesh_defs": [mesh],
@@ -71,11 +94,17 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
71
  "clinical_trials": trials,
72
  "ot_associations": ot_assoc,
73
  "variants": variants,
 
 
74
  "ai_summary": ai_summary,
75
  "llm_used": llm.lower()
76
  }
77
 
78
- async def answer_ai_question(question: str, context: str = "", llm: str = "openai"):
 
 
 
 
79
  _, qa_fn = _get_llm(llm)
80
  try:
81
  answer = await qa_fn(question, context)
 
1
  # mcp/orchestrator.py
2
  import asyncio
3
+ from typing import Any, Dict, List
4
+
5
  from mcp.arxiv import fetch_arxiv
6
  from mcp.pubmed import fetch_pubmed
7
  from mcp.nlp import extract_umls_concepts
8
+ from mcp.umls import lookup_umls
9
  from mcp.umls_rel import fetch_relations
10
  from mcp.openfda import fetch_drug_safety
11
  from mcp.ncbi import search_gene, get_mesh_definition
 
16
  from mcp.cbio import cbio
17
  from mcp.openai_utils import ai_summarize, ai_qa
18
  from mcp.gemini import gemini_summarize, gemini_qa
19
+ from mcp.embeddings import embed_texts, cluster_embeddings
20
+
21
 
22
  def _get_llm(llm: str):
23
+ """
24
+ Router for LLM engines: returns (summarize_fn, qa_fn).
25
+ """
26
+ if llm.lower() == "gemini":
27
+ return gemini_summarize, gemini_qa
28
+ return ai_summarize, ai_qa
29
+
30
 
31
  async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
32
+ """
33
+ Main orchestrator: fetch literature, concepts, enrichments,
34
+ embeddings, clusters, and AI summary.
35
+ """
36
+ # 1) Literature fetch
37
+ arxiv_task = fetch_arxiv(query)
38
+ pubmed_task = fetch_pubmed(query)
39
+ results = await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True)
40
+ papers: List[Dict] = []
41
+ for res in results:
42
  if isinstance(res, list):
43
  papers.extend(res)
44
 
45
+ # 2) UMLS concept linking via spaCy
46
+ blob = " ".join(p.get("summary", "") for p in papers)
47
  umls = await extract_umls_concepts(blob)
48
 
49
+ # 3) Fetch UMLS relations
50
+ rels_tasks = [fetch_relations(c["cui"]) for c in umls]
51
+ umls_relations = await asyncio.gather(*rels_tasks, return_exceptions=True)
 
 
52
 
53
+ # 4) Data enrichment
54
+ names = [c["name"] for c in umls]
55
+ fda_tasks = [fetch_drug_safety(n) for n in names]
56
+ gene_task = search_gene(names[0]) if names else asyncio.sleep(0, result=[])
57
+ mesh_task = get_mesh_definition(names[0]) if names else asyncio.sleep(0, result="")
58
+ dis_task = disease_to_genes(names[0]) if names else asyncio.sleep(0, result=[])
59
+ trials_task = search_trials(query)
60
+ ot_task = ot.fetch(names[0]) if names else asyncio.sleep(0, result=[])
61
+ cbio_task = cbio.fetch_variants(names[0]) if names else asyncio.sleep(0, result=[])
62
 
63
  fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
64
  asyncio.gather(*fda_tasks, return_exceptions=True),
 
67
  return_exceptions=False
68
  )
69
 
70
+ # 5) Embeddings & clustering
71
+ summaries = [p.get("summary", "") for p in papers]
72
+ if summaries:
73
+ embs = await embed_texts(summaries)
74
+ clusters = await cluster_embeddings(embs, n_clusters=max(2, min(10, len(embs)//2)))
75
+ else:
76
+ embs = []
77
+ clusters = []
78
+
79
+ # 6) AI summary
80
+ summarize_fn, _ = _get_llm(llm)
81
  try:
82
+ ai_summary = await summarize_fn(blob)
83
  except Exception:
84
  ai_summary = "LLM summary failed."
85
 
86
  return {
87
  "papers": papers,
88
  "umls": umls,
89
+ "umls_relations": umls_relations,
90
  "drug_safety": fda,
91
  "genes": [gene],
92
  "mesh_defs": [mesh],
 
94
  "clinical_trials": trials,
95
  "ot_associations": ot_assoc,
96
  "variants": variants,
97
+ "embeddings": embs,
98
+ "clusters": clusters,
99
  "ai_summary": ai_summary,
100
  "llm_used": llm.lower()
101
  }
102
 
103
+
104
+ async def answer_ai_question(question: str, context: str = "", llm: str = "openai") -> Dict[str, str]:
105
+ """
106
+ Follow-up Q&A using chosen LLM engine.
107
+ """
108
  _, qa_fn = _get_llm(llm)
109
  try:
110
  answer = await qa_fn(question, context)