mgbam commited on
Commit
bc40121
Β·
verified Β·
1 Parent(s): e5ff04a

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +91 -115
mcp/orchestrator.py CHANGED
@@ -1,138 +1,114 @@
1
  """
2
- MedGenesis – multi-API orchestrator
3
- ──────────────────────────────────
4
- β€’ Supports OpenAI or Gemini (pass llm="openai" | "gemini")
5
- β€’ Falls back between redundant data sources whenever possible
6
- β€’ All network I/O is async & individually time-bounded
 
7
  """
8
 
9
- from __future__ import annotations
10
- import asyncio, textwrap
11
- from typing import Any, Dict, List, Tuple
12
 
13
- # ── 1. Literature helpers ────────────────────────────────────────────
14
  from mcp.arxiv import fetch_arxiv
15
  from mcp.pubmed import fetch_pubmed
 
16
 
17
- # ── 2. Gene / disease / expression helpers ───────────────────────────
18
- from mcp.gene_hub import resolve_gene # smart dispatcher
19
- from mcp.mygene import fetch_gene_info
20
- from mcp.ensembl import fetch_ensembl
21
- from mcp.opentargets import fetch_ot # tractability, constraint
22
- from mcp.cbio import fetch_cbio
23
-
24
- # ── 3. Safety, trials, concepts ──────────────────────────────────────
25
- from mcp.openfda import fetch_drug_safety
26
- from mcp.clinicaltrials import search_trials
27
  from mcp.umls import lookup_umls
 
 
28
  from mcp.disgenet import disease_to_genes
 
 
 
 
 
29
 
30
- # ── 4. Chem & drug metadata ──────────────────────────────────────────
31
- from mcp.drugcentral_ext import fetch_drugcentral
32
- from mcp.pubchem_ext import fetch_compound
33
-
34
- # ── 5. LLM utils (OpenAI & Gemini) ───────────────────────────────────
35
  from mcp.openai_utils import ai_summarize, ai_qa
36
  from mcp.gemini import gemini_summarize, gemini_qa
37
 
38
- ###############################################################################
39
- # Internal routing helpers
40
- ###############################################################################
41
- _DEFAULT_LLM = "openai"
42
-
43
- def _llm_router(choice: str) -> Tuple:
44
- """
45
- Return (summary_fn, qa_fn, tag) for the requested engine.
46
- """
47
- if str(choice).lower() == "gemini":
48
- return gemini_summarize, gemini_qa, "gemini"
49
- return ai_summarize, ai_qa, "openai"
50
-
51
- ###############################################################################
52
- # High-level enrichment helpers
53
- ###############################################################################
54
- async def _keyword_enrichment(keywords: List[str]) -> Dict[str, Any]:
55
- """
56
- Fan-out to UMLS, Drug Safety, and probes gene/Disease APIs in parallel.
57
- """
58
- umls_tasks = [lookup_umls(k) for k in keywords]
59
- fda_tasks = [fetch_drug_safety(k) for k in keywords]
60
- gene_tasks = [resolve_gene(k) for k in keywords]
61
-
62
- # gather protects against individual failures
63
- umls, fda, genes = await asyncio.gather(
64
- asyncio.gather(*umls_tasks, return_exceptions=True),
65
- asyncio.gather(*fda_tasks, return_exceptions=True),
66
- asyncio.gather(*gene_tasks, return_exceptions=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
- # flatten & sanitise
69
- return {
70
- "umls" : [u for u in umls if not isinstance(u, Exception)],
71
- "fda" : [d for d in fda if not isinstance(d, Exception)],
72
- "genes": [g for g in genes if not isinstance(g, Exception)],
73
- }
74
 
75
- ###############################################################################
76
- # Public orchestration entry-points
77
- ###############################################################################
78
- async def orchestrate_search(query: str, *, llm: str=_DEFAULT_LLM,
79
- max_papers: int = 25,
80
- max_trials: int = 20) -> Dict[str, Any]:
81
- """
82
- Full pipeline:
83
- 1. Fetch literature (arXiv + PubMed)
84
- 2. Derive keywords (simple TF filtering)
85
- 3. Multi-API enrich (UMLS, safety, gene, trials, chem)
86
- 4. Summarise with LLM
87
- """
88
-
89
- # ── 1 literature (parallel) ───────────────────────────────────────
90
- arxiv_task = asyncio.create_task(fetch_arxiv(query, max_results=max_papers//2))
91
- pubmed_task = asyncio.create_task(fetch_pubmed(query, max_results=max_papers//2))
92
- papers = sum(await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=False), [])
93
-
94
- # ── 2 keywords (top-8 by naive word-freq) ─────────────────────────
95
- joined = " ".join(p["summary"] for p in papers)
96
- tokens = [w for w in joined.split() if len(w) > 4]
97
- freq = {}
98
- for t in tokens: freq[t] = freq.get(t, 0) + 1
99
- keywords = sorted(freq, key=freq.get, reverse=True)[:8]
100
-
101
- # ── 3 enrichment ──────────────────────────────────────────────────
102
- enrich_task = asyncio.create_task(_keyword_enrichment(keywords))
103
- trials_task = asyncio.create_task(search_trials(query, max_studies=max_trials))
104
- gene_dis_gen = asyncio.create_task(disease_to_genes(query)) # coarse disease string
105
-
106
- enrich, trials, gene_dis = await asyncio.gather(enrich_task, trials_task, gene_dis_gen)
107
-
108
- # ── 4 LLM summary & return ────────────────────────────────────────
109
- summarise_fn, _, engine_tag = _llm_router(llm)
110
  try:
111
- ai_summary = await summarise_fn(joined[:15000])
112
  except Exception:
113
- ai_summary = "LLM unavailable or quota exceeded."
 
 
 
 
114
 
115
  return {
116
  "papers" : papers,
117
- "keywords" : keywords,
118
- "umls" : enrich["umls"],
119
- "drug_safety" : enrich["fda"],
120
- "genes" : enrich["genes"],
121
- "gene_disease" : gene_dis,
 
122
  "clinical_trials" : trials,
123
- "ai_summary" : ai_summary,
124
- "llm_used" : engine_tag,
 
 
125
  }
126
 
127
-
128
- async def answer_ai_question(question: str, *, context: str,
129
- llm: str=_DEFAULT_LLM) -> Dict[str, str]:
130
- """
131
- Follow-up Q-A on demand.
132
- """
133
- _, qa_fn, _ = _llm_router(llm)
134
- try:
135
- answer = await qa_fn(question, context)
136
- except Exception:
137
- answer = "LLM unavailable or quota exceeded."
138
- return {"answer": answer}
 
1
  """
2
+ MedGenesis – dual-LLM orchestrator (OpenAI + Gemini)
3
+ ----------------------------------------------------
4
+ Returns a single dict the UI expects. New keys:
5
+
6
+ β€’ variants – mutation summaries from cBioPortal
7
+ β€’ variant_count – quick count for empty-tab logic
8
  """
9
 
10
+ import asyncio
11
+ from typing import Dict, Any, List
 
12
 
13
+ # literature + NLP
14
  from mcp.arxiv import fetch_arxiv
15
  from mcp.pubmed import fetch_pubmed
16
+ from mcp.nlp import extract_keywords
17
 
18
+ # enrichment
 
 
 
 
 
 
 
 
 
19
  from mcp.umls import lookup_umls
20
+ from mcp.openfda import fetch_drug_safety
21
+ from mcp.ncbi import search_gene, get_mesh_definition
22
  from mcp.disgenet import disease_to_genes
23
+ from mcp.clinicaltrials import search_trials
24
+ from mcp.mygene import fetch_gene_info
25
+ from mcp.ensembl import fetch_ensembl
26
+ from mcp.opentargets import fetch_ot
27
+ from mcp.cbio import fetch_cbio # NEW
28
 
29
+ # LLMs
 
 
 
 
30
  from mcp.openai_utils import ai_summarize, ai_qa
31
  from mcp.gemini import gemini_summarize, gemini_qa
32
 
33
+ _DEF = "openai"
34
+
35
+ def _llm_router(llm: str):
36
+ llm = (llm or _DEF).lower()
37
+ if llm == "gemini":
38
+ return ("gemini", gemini_summarize, gemini_qa)
39
+ return ("openai", ai_summarize, ai_qa)
40
+
41
+ # ---------------- gene meta helper ----------------
42
+ async def _resolve_gene(sym: str) -> Dict[str, Any]:
43
+ for fn in (fetch_gene_info, fetch_ensembl, fetch_ot):
44
+ try:
45
+ data = await fn(sym)
46
+ if data:
47
+ return data
48
+ except Exception:
49
+ continue
50
+ return {}
51
+
52
+ # ---------------- orchestrator --------------------
53
+ async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
54
+ # 1 literature ---------------------------------------------------
55
+ arxiv_f = asyncio.create_task(fetch_arxiv(query))
56
+ pubmed_f = asyncio.create_task(fetch_pubmed(query))
57
+ papers = sum(await asyncio.gather(arxiv_f, pubmed_f), [])
58
+
59
+ # 2 keywords -----------------------------------------------------
60
+ blob = " ".join(p["summary"] for p in papers)
61
+ keys = extract_keywords(blob)[:8] if blob else []
62
+
63
+ # 3 parallel enrichment -----------------------------------------
64
+ umls_f = [lookup_umls(k) for k in keys]
65
+ fda_f = [fetch_drug_safety(k) for k in keys]
66
+ ncbi_f = [search_gene(k) for k in keys]
67
+ mesh_f = [get_mesh_definition(k) for k in keys]
68
+ gene_meta= [ _resolve_gene(k) for k in keys[:3] ] # cheap
69
+ trials_f = asyncio.create_task(search_trials(query, max_studies=20))
70
+
71
+ # primary await
72
+ (
73
+ umls, fda, ncbi, meshes, gmeta, trials
74
+ ) = await asyncio.gather(
75
+ asyncio.gather(*umls_f, return_exceptions=True),
76
+ asyncio.gather(*fda_f, return_exceptions=True),
77
+ asyncio.gather(*ncbi_f, return_exceptions=True),
78
+ asyncio.gather(*mesh_f, return_exceptions=True),
79
+ asyncio.gather(*gene_meta, return_exceptions=True),
80
+ trials_f,
81
  )
 
 
 
 
 
 
82
 
83
+ # 4 variants (fire & forget; don’t fail whole run) --------------
84
+ var_jobs = [fetch_cbio(g.get("symbol") or k)
85
+ for g, k in zip(gmeta, keys[:len(gmeta)])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  try:
87
+ variants = sum(await asyncio.gather(*var_jobs), [])
88
  except Exception:
89
+ variants = []
90
+
91
+ # 5 LLM summary -------------------------------------------------
92
+ _, summarise, _ = _llm_router(llm)
93
+ summary = await summarise(blob) if blob else "No abstracts found."
94
 
95
  return {
96
  "papers" : papers,
97
+ "umls" : umls,
98
+ "drug_safety" : fda,
99
+ "genes" : sum(ncbi, []),
100
+ "mesh_defs" : meshes,
101
+ "gene_meta" : gmeta,
102
+ "gene_disease" : await disease_to_genes(query) or [],
103
  "clinical_trials" : trials,
104
+ "variants" : variants,
105
+ "variant_count" : len(variants),
106
+ "ai_summary" : summary,
107
+ "llm_used" : llm.lower(),
108
  }
109
 
110
+ # ---------------- follow-up QA --------------------
111
+ async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
112
+ _, _, qa_fn = _llm_router(llm)
113
+ ans = await qa_fn(question, context)
114
+ return {"answer": ans}