mgbam commited on
Commit
2417938
·
verified ·
1 Parent(s): 4a6179c

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +24 -22
mcp/orchestrator.py CHANGED
@@ -21,7 +21,7 @@ from mcp.embeddings import embed_texts, cluster_embeddings
21
 
22
  def _get_llm(llm: str):
23
  """
24
- Router for LLM engines: returns (summarize_fn, qa_fn).
25
  """
26
  if llm.lower() == "gemini":
27
  return gemini_summarize, gemini_qa
@@ -30,27 +30,27 @@ def _get_llm(llm: str):
30
 
31
  async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
32
  """
33
- Main orchestrator: fetch literature, concepts, enrichments,
34
- embeddings, clusters, and AI summary.
35
  """
36
- # 1) Literature fetch
37
  arxiv_task = fetch_arxiv(query)
38
  pubmed_task = fetch_pubmed(query)
39
- results = await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True)
40
  papers: List[Dict] = []
41
- for res in results:
42
  if isinstance(res, list):
43
  papers.extend(res)
44
 
45
- # 2) UMLS concept linking via spaCy
46
  blob = " ".join(p.get("summary", "") for p in papers)
47
  umls = await extract_umls_concepts(blob)
48
 
49
- # 3) Fetch UMLS relations
50
- rels_tasks = [fetch_relations(c["cui"]) for c in umls]
51
- umls_relations = await asyncio.gather(*rels_tasks, return_exceptions=True)
52
 
53
- # 4) Data enrichment
54
  names = [c["name"] for c in umls]
55
  fda_tasks = [fetch_drug_safety(n) for n in names]
56
  gene_task = search_gene(names[0]) if names else asyncio.sleep(0, result=[])
@@ -60,6 +60,7 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
60
  ot_task = ot.fetch(names[0]) if names else asyncio.sleep(0, result=[])
61
  cbio_task = cbio.fetch_variants(names[0]) if names else asyncio.sleep(0, result=[])
62
 
 
63
  fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
64
  asyncio.gather(*fda_tasks, return_exceptions=True),
65
  gene_task, mesh_task, dis_task,
@@ -67,16 +68,17 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
67
  return_exceptions=False
68
  )
69
 
70
- # 5) Embeddings & clustering
71
  summaries = [p.get("summary", "") for p in papers]
72
  if summaries:
73
- embs = await embed_texts(summaries)
74
- clusters = await cluster_embeddings(embs, n_clusters=max(2, min(10, len(embs)//2)))
 
 
75
  else:
76
- embs = []
77
- clusters = []
78
 
79
- # 6) AI summary
80
  summarize_fn, _ = _get_llm(llm)
81
  try:
82
  ai_summary = await summarize_fn(blob)
@@ -94,7 +96,7 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
94
  "clinical_trials": trials,
95
  "ot_associations": ot_assoc,
96
  "variants": variants,
97
- "embeddings": embs,
98
  "clusters": clusters,
99
  "ai_summary": ai_summary,
100
  "llm_used": llm.lower()
@@ -103,11 +105,11 @@ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
103
 
104
  async def answer_ai_question(question: str, context: str = "", llm: str = "openai") -> Dict[str, str]:
105
  """
106
- Follow-up Q&A using chosen LLM engine.
107
  """
108
  _, qa_fn = _get_llm(llm)
109
  try:
110
- answer = await qa_fn(question, context)
111
  except Exception:
112
- answer = "LLM follow-up failed."
113
- return {"answer": answer}
 
21
 
22
  def _get_llm(llm: str):
23
  """
24
+ Route summarization and QA to the chosen engine.
25
  """
26
  if llm.lower() == "gemini":
27
  return gemini_summarize, gemini_qa
 
30
 
31
  async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
32
  """
33
+ Fetch papers, extract concepts & relations, enrich data,
34
+ compute embeddings+clusters, and run LLM summary.
35
  """
36
+ # Gather literature
37
  arxiv_task = fetch_arxiv(query)
38
  pubmed_task = fetch_pubmed(query)
39
+ lit_results = await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True)
40
  papers: List[Dict] = []
41
+ for res in lit_results:
42
  if isinstance(res, list):
43
  papers.extend(res)
44
 
45
+ # Concept extraction
46
  blob = " ".join(p.get("summary", "") for p in papers)
47
  umls = await extract_umls_concepts(blob)
48
 
49
+ # Fetch UMLS relations
50
+ rel_tasks = [fetch_relations(c["cui"]) for c in umls]
51
+ umls_relations = await asyncio.gather(*rel_tasks, return_exceptions=True)
52
 
53
+ # Data enrichment tasks
54
  names = [c["name"] for c in umls]
55
  fda_tasks = [fetch_drug_safety(n) for n in names]
56
  gene_task = search_gene(names[0]) if names else asyncio.sleep(0, result=[])
 
60
  ot_task = ot.fetch(names[0]) if names else asyncio.sleep(0, result=[])
61
  cbio_task = cbio.fetch_variants(names[0]) if names else asyncio.sleep(0, result=[])
62
 
63
+ # Run enrichment
64
  fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
65
  asyncio.gather(*fda_tasks, return_exceptions=True),
66
  gene_task, mesh_task, dis_task,
 
68
  return_exceptions=False
69
  )
70
 
71
+ # Embeddings & clustering
72
  summaries = [p.get("summary", "") for p in papers]
73
  if summaries:
74
+ embeddings = await embed_texts(summaries)
75
+ clusters = await cluster_embeddings(
76
+ embeddings, n_clusters = max(2, min(10, len(embeddings)//2))
77
+ )
78
  else:
79
+ embeddings, clusters = [], []
 
80
 
81
+ # LLM summary
82
  summarize_fn, _ = _get_llm(llm)
83
  try:
84
  ai_summary = await summarize_fn(blob)
 
96
  "clinical_trials": trials,
97
  "ot_associations": ot_assoc,
98
  "variants": variants,
99
+ "embeddings": embeddings,
100
  "clusters": clusters,
101
  "ai_summary": ai_summary,
102
  "llm_used": llm.lower()
 
105
 
106
  async def answer_ai_question(question: str, context: str = "", llm: str = "openai") -> Dict[str, str]:
107
  """
108
+ Follow-up Q&A via chosen LLM.
109
  """
110
  _, qa_fn = _get_llm(llm)
111
  try:
112
+ ans = await qa_fn(question, context)
113
  except Exception:
114
+ ans = "LLM follow-up failed."
115
+ return {"answer": ans}