mgbam commited on
Commit
12007d6
·
verified ·
1 Parent(s): a4f7e5c

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +68 -66
mcp/orchestrator.py CHANGED
@@ -1,90 +1,92 @@
1
- # mcp/orchestrator.py
2
  """
3
- Orchestrates retrieval, enrichment, and AI synthesis for a user query.
 
 
 
4
  """
5
 
6
- import asyncio
7
  from typing import Dict, Any, List
8
 
9
- from mcp.arxiv import fetch_arxiv
10
- from mcp.pubmed import fetch_pubmed
11
- from mcp.nlp import extract_keywords
12
- from mcp.umls import lookup_umls
13
- from mcp.openfda import fetch_drug_safety
14
- from mcp.ncbi import search_gene, get_mesh_definition
15
- from mcp.disgenet import disease_to_genes
16
  from mcp.clinicaltrials import search_trials
17
- from mcp.openai_utils import ai_summarize, ai_qa
 
18
 
19
- # ---------------------------------------------------------------------
20
- async def _gene_and_mesh_enrichment(keywords: List[str]) -> Dict[str, Any]:
21
- """Run NCBI and DisGeNET on keywords in parallel."""
22
- tasks = []
23
- for kw in keywords:
24
- tasks.append(search_gene(kw))
25
- tasks.append(get_mesh_definition(kw))
26
- tasks.append(disease_to_genes(kw))
27
- results = await asyncio.gather(*tasks, return_exceptions=True)
28
 
29
- genes, meshes, disgen = [], [], []
30
- for i, res in enumerate(results):
31
- if isinstance(res, Exception):
 
 
 
 
 
 
 
32
  continue
33
- # Cycle: 0 gene, 1 mesh, 2 disgenet, repeat …
34
- mod = i % 3
35
- if mod == 0:
36
- genes.extend(res)
37
- elif mod == 1:
38
- meshes.append(res)
39
- else:
40
- disgen.extend(res)
41
- return {"genes": genes, "meshes": meshes, "disgenet": disgen}
42
 
43
- # ---------------------------------------------------------------------
44
- async def orchestrate_search(query: str) -> Dict[str, Any]:
45
- """Main entry—returns a rich result dict for app UI."""
46
- # -------- literature retrieval in parallel --------
47
- arxiv_task = asyncio.create_task(fetch_arxiv(query))
48
- pubmed_task = asyncio.create_task(fetch_pubmed(query))
49
- arxiv_results, pubmed_results = await asyncio.gather(arxiv_task, pubmed_task)
50
- papers = arxiv_results + pubmed_results
51
 
52
- # -------- keyword extraction --------
53
- paper_text = " ".join(p["summary"] for p in papers)
54
- keywords = extract_keywords(paper_text)[:8]
 
 
 
 
 
 
55
 
56
- # -------- enrichment tasks in parallel --------
57
- umls_tasks = [lookup_umls(k) for k in keywords]
58
- fda_tasks = [fetch_drug_safety(k) for k in keywords]
59
- enrich_task = asyncio.create_task(_gene_and_mesh_enrichment(keywords))
60
- trials_task = asyncio.create_task(search_trials(query, max_studies=10))
61
 
62
- umls, fda, enrich, trials = await asyncio.gather(
63
- asyncio.gather(*umls_tasks),
64
- asyncio.gather(*fda_tasks),
65
- enrich_task,
66
- trials_task,
 
 
 
 
 
 
67
  )
68
 
69
- # -------- AI summary --------
70
- summary = await ai_summarize(paper_text)
71
- links = [p["link"] for p in papers[:3]]
72
 
73
  return {
74
  "papers" : papers,
75
  "umls" : umls,
76
  "drug_safety" : fda,
77
  "ai_summary" : summary,
78
- "suggested_reading": links,
79
- # new fields
80
- "genes" : enrich["genes"],
81
- "mesh_definitions": enrich["meshes"],
82
- "gene_disease" : enrich["disgenet"],
83
  "clinical_trials" : trials,
84
  }
85
 
86
- # ---------------------------------------------------------------------
87
- async def answer_ai_question(question: str, context: str = "") -> Dict[str, str]:
88
- """Free-form Q&A using OpenAI."""
89
- answer = await ai_qa(question, context)
90
- return {"answer": answer}
 
 
1
  """
2
+ MedGenesis dual-LLM orchestrator
3
+ ----------------------------------
4
+ • Accepts `llm` arg ("openai" | "gemini")
5
+ • Defaults to "openai" if arg omitted
6
  """
7
 
8
+ import asyncio, httpx
9
  from typing import Dict, Any, List
10
 
11
+ from mcp.arxiv import fetch_arxiv
12
+ from mcp.pubmed import fetch_pubmed
13
+ from mcp.nlp import extract_keywords
14
+ from mcp.umls import lookup_umls
15
+ from mcp.openfda import fetch_drug_safety
16
+ from mcp.ncbi import search_gene, get_mesh_definition
17
+ from mcp.disgenet import disease_to_genes
18
  from mcp.clinicaltrials import search_trials
19
+ from mcp.openai_utils import ai_summarize, ai_qa
20
+ from mcp.gemini import gemini_summarize, gemini_qa # make sure gemini.py exists
21
 
22
+ # ---------------- LLM router ----------------
23
+ def _get_llm(llm: str):
24
+ if llm.lower() == "gemini":
25
+ return gemini_summarize, gemini_qa
26
+ return ai_summarize, ai_qa # default → OpenAI
 
 
 
 
27
 
28
+
29
+ async def _enrich_genes_mesh_disg(keys: List[str]) -> Dict[str, Any]:
30
+ jobs = []
31
+ for k in keys:
32
+ jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
33
+ res = await asyncio.gather(*jobs, return_exceptions=True)
34
+
35
+ genes, meshes, disg = [], [], []
36
+ for i, r in enumerate(res):
37
+ if isinstance(r, Exception): # skip failures quietly
38
  continue
39
+ if i % 3 == 0: genes.extend(r)
40
+ elif i % 3 == 1: meshes.append(r)
41
+ else: disg.extend(r)
42
+ return {"genes": genes, "meshes": meshes, "disgenet": disg}
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
44
 
45
+ # ------------------------------------------------------------------
46
+ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
47
+ """
48
+ Main orchestrator – returns dict for UI.
49
+ """
50
+ # 1) Literature ---------------------------------------------------
51
+ arxiv_f = asyncio.create_task(fetch_arxiv(query))
52
+ pubmed_f = asyncio.create_task(fetch_pubmed(query))
53
+ papers = sum(await asyncio.gather(arxiv_f, pubmed_f), [])
54
 
55
+ # 2) Keywords -----------------------------------------------------
56
+ blob = " ".join(p["summary"] for p in papers)
57
+ keys = extract_keywords(blob)[:8]
 
 
58
 
59
+ # 3) Enrichment ---------------------------------------------------
60
+ umls_f = [lookup_umls(k) for k in keys]
61
+ fda_f = [fetch_drug_safety(k) for k in keys]
62
+ genes_f = asyncio.create_task(_enrich_genes_mesh_disg(keys))
63
+ trials_f = asyncio.create_task(search_trials(query, max_studies=10))
64
+
65
+ umls, fda, genes, trials = await asyncio.gather(
66
+ asyncio.gather(*umls_f, return_exceptions=True),
67
+ asyncio.gather(*fda_f, return_exceptions=True),
68
+ genes_f,
69
+ trials_f,
70
  )
71
 
72
+ # 4) AI summary ---------------------------------------------------
73
+ summarize, _ = _get_llm(llm)
74
+ summary = await summarize(blob)
75
 
76
  return {
77
  "papers" : papers,
78
  "umls" : umls,
79
  "drug_safety" : fda,
80
  "ai_summary" : summary,
81
+ "llm_used" : llm.lower(),
82
+ "genes" : genes["genes"],
83
+ "mesh_defs" : genes["meshes"],
84
+ "gene_disease" : genes["disgenet"],
 
85
  "clinical_trials" : trials,
86
  }
87
 
88
+
89
+ async def answer_ai_question(question: str, context: str, llm: str = "openai") -> Dict[str, str]:
90
+ """One-shot follow-up Q-A via chosen engine."""
91
+ _, qa = _get_llm(llm)
92
+ return {"answer": await qa(question, context)}