mgbam commited on
Commit
9958236
Β·
verified Β·
1 Parent(s): ac3658e

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +157 -104
mcp/orchestrator.py CHANGED
@@ -1,127 +1,180 @@
 
 
 
1
  """
2
- MedGenesis – dual-LLM orchestrator
3
- ----------------------------------
4
- β€’ Accepts llm = "openai" | "gemini" (falls back to OpenAI)
5
- β€’ Returns one unified dict the UI can rely on.
 
 
6
  """
 
7
  from __future__ import annotations
8
  import asyncio, itertools, logging
9
- from typing import Dict, Any, List, Tuple
10
-
11
- from mcp.arxiv import fetch_arxiv
12
- from mcp.pubmed import fetch_pubmed
13
- from mcp.ncbi import search_gene, get_mesh_definition
14
- from mcp.mygene import fetch_gene_info
15
- from mcp.ensembl import fetch_ensembl
16
- from mcp.opentargets import fetch_ot
17
- from mcp.umls import lookup_umls
18
- from mcp.openfda import fetch_drug_safety
19
- from mcp.disgenet import disease_to_genes
20
- from mcp.clinicaltrials import search_trials
21
- from mcp.cbio import fetch_cbio
22
- from mcp.openai_utils import ai_summarize, ai_qa
23
- from mcp.gemini import gemini_summarize, gemini_qa
24
 
25
  log = logging.getLogger(__name__)
26
- _DEF = "openai" # default engine
27
 
28
 
29
- # ─────────────────────────────────── helpers ───────────────────────────────────
30
- def _llm_router(engine: str = _DEF) -> Tuple:
31
  if engine.lower() == "gemini":
32
  return gemini_summarize, gemini_qa, "gemini"
33
  return ai_summarize, ai_qa, "openai"
34
 
35
- async def _gather_safely(*aws, as_list: bool = True):
36
- """await gather() that converts Exception β†’ RuntimeError placeholder"""
37
- out = await asyncio.gather(*aws, return_exceptions=True)
38
- if as_list:
39
- # filter exceptions – keep structure but drop failures
40
- return [x for x in out if not isinstance(x, Exception)]
41
- return out
42
 
43
- async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
44
- jobs = []
45
- for k in keys:
46
- jobs += [
47
- search_gene(k), # basic gene info
48
- get_mesh_definition(k), # MeSH definitions
49
- fetch_gene_info(k), # MyGene
50
- fetch_ensembl(k), # Ensembl x-refs
51
- fetch_ot(k), # Open Targets associations
52
- ]
53
- res = await _gather_safely(*jobs, as_list=False)
54
-
55
- # slice & compress five-way fan-out
56
- combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return {
58
- "ncbi" : combo(0),
59
- "mesh" : combo(1),
60
- "mygene" : combo(2),
61
- "ensembl" : combo(3),
62
- "ot_assoc" : combo(4),
 
 
 
 
 
63
  }
64
 
65
 
66
- # ───────────────────────────────── orchestrator ────────────────────────────────
67
- async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
68
- """Main entry – returns dict for the Streamlit UI"""
69
- # 1 Literature – run in parallel
70
- arxiv_task = asyncio.create_task(fetch_arxiv(query))
71
- pubmed_task = asyncio.create_task(fetch_pubmed(query))
72
- papers_raw = await _gather_safely(arxiv_task, pubmed_task)
73
- papers = list(itertools.chain.from_iterable(papers_raw))[:30] # keep ≀30
74
-
75
- # 2 Keyword extraction (very light – only from abstracts)
76
- kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()}
77
- kws = list(kws)[:10] # coarse, fast -> 10 seeds
78
-
79
- # 3 Bio-enrichment fan-out
80
- umls_f = [_safe_task(lookup_umls, k) for k in kws]
81
- fda_f = [_safe_task(fetch_drug_safety, k) for k in kws]
82
- gene_bundle = asyncio.create_task(_gene_enrichment(kws))
83
- trials_task = asyncio.create_task(search_trials(query, max_studies=20))
84
- cbio_task = asyncio.create_task(fetch_cbio(kws[0] if kws else ""))
85
-
86
- umls, fda, gene_dat, trials, variants = await asyncio.gather(
87
- _gather_safely(*umls_f),
88
- _gather_safely(*fda_f),
89
- gene_bundle,
90
- trials_task,
91
- cbio_task,
92
- )
93
 
94
- # 4 LLM summary
95
- summarise_fn, _, engine = _llm_router(llm)
96
- summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000])
 
 
97
 
98
  return {
99
- "papers" : papers,
100
- "umls" : umls,
101
- "drug_safety" : fda,
102
- "ai_summary" : summary,
103
- "llm_used" : engine,
104
- "genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
105
- "mesh_defs" : gene_dat["mesh"],
106
- "gene_disease" : gene_dat["ot_assoc"],
107
- "clinical_trials" : trials,
108
- "variants" : variants or [],
109
  }
110
 
111
- # ─────────────────────────────── follow-up QA ─────────────────────────────────
112
- async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
113
- """Follow-up QA using chosen LLM."""
 
 
114
  _, qa_fn, _ = _llm_router(llm)
115
- return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")}
116
-
117
-
118
- # ─────────────────────────── internal util ───────────────────────────────────
119
- def _safe_task(fn, *args):
120
- """Helper to wrap callable β†’ Task returning RuntimeError on exception."""
121
- async def _wrapper():
122
- try:
123
- return await fn(*args)
124
- except Exception as exc:
125
- log.warning("background task %s failed: %s", fn.__name__, exc)
126
- return RuntimeError(str(exc))
127
- return asyncio.create_task(_wrapper())
 
1
+ #!/usr/bin/env python3
2
+ # mcp/orchestrator.py
3
+
4
  """
5
+ MedGenesis – dual-LLM orchestrator (v4)
6
+ ---------------------------------------
7
+ β€’ Accepts llm="openai" | "gemini" (defaults to OpenAI)
8
+ β€’ Safely runs all data-source calls in parallel
9
+ β€’ Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
10
+ β€’ Returns one dict that the Streamlit UI can rely on
11
  """
12
+
13
  from __future__ import annotations
14
  import asyncio, itertools, logging
15
+ from typing import Dict, Any, List
16
+
17
+ from mcp.arxiv import fetch_arxiv
18
+ from mcp.pubmed import fetch_pubmed
19
+ from mcp.ncbi import search_gene, get_mesh_definition
20
+ from mcp.mygene import fetch_gene_info
21
+ from mcp.ensembl import fetch_ensembl
22
+ from mcp.opentargets import fetch_ot
23
+ from mcp.umls import lookup_umls
24
+ from mcp.openfda import fetch_drug_safety
25
+ from mcp.disgenet import disease_to_genes
26
+ from mcp.clinicaltrials import fetch_clinical_trials
27
+ from mcp.cbio import fetch_cbio_variants
28
+ from mcp.openai_utils import ai_summarize, ai_qa
29
+ from mcp.gemini import gemini_summarize, gemini_qa
30
 
31
  log = logging.getLogger(__name__)
32
+ _DEFAULT_LLM = "openai"
33
 
34
 
35
+ def _llm_router(engine: str = _DEFAULT_LLM):
36
+ """Returns (summarize_fn, qa_fn, engine_name)."""
37
  if engine.lower() == "gemini":
38
  return gemini_summarize, gemini_qa, "gemini"
39
  return ai_summarize, ai_qa, "openai"
40
 
 
 
 
 
 
 
 
41
 
42
+ async def _safe_gather(*tasks, return_exceptions: bool = False):
43
+ """
44
+ Wrapper around asyncio.gather that logs failures
45
+ and optionally returns exceptions as results.
46
+ """
47
+ results = await asyncio.gather(*tasks, return_exceptions=True)
48
+ cleaned = []
49
+ for idx, res in enumerate(results):
50
+ if isinstance(res, Exception):
51
+ log.warning("Task %d failed: %s", idx, res)
52
+ if return_exceptions:
53
+ cleaned.append(res)
54
+ else:
55
+ cleaned.append(res)
56
+ return cleaned
57
+
58
+
59
+ async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
60
+ """
61
+ Main entry point for MedGenesis UI.
62
+ Returns a dict with:
63
+ - papers, umls, drug_safety, clinical_trials, variants
64
+ - genes, mesh_defs, gene_disease
65
+ - ai_summary, llm_used
66
+ """
67
+ # 1) Literature (PubMed + arXiv in parallel)
68
+ pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
69
+ arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
70
+ papers_raw = await _safe_gather(pubmed_t, arxiv_t)
71
+ papers = list(itertools.chain.from_iterable(papers_raw))[:30]
72
+
73
+ # 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
74
+ seeds = {
75
+ w.strip()
76
+ for p in papers
77
+ for w in p.get("summary", "")[:500].split()
78
+ if w.isalpha()
79
+ }
80
+ seeds = list(seeds)[:10]
81
+
82
+ # 3) Fan-out all bio-enrichment tasks safely
83
+ umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
84
+ fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
85
+ gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
86
+ trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
87
+ cbio_t = asyncio.create_task(
88
+ fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
89
+ )
90
+
91
+ umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
92
+ _safe_gather(*umls_tasks, return_exceptions=True),
93
+ _safe_gather(*fda_tasks, return_exceptions=True),
94
+ gene_enrich_t,
95
+ trials_t,
96
+ cbio_t,
97
+ )
98
+
99
+ # 4) Deduplicate and flatten genes
100
+ genes = {
101
+ g["symbol"]
102
+ for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
103
+ for g in source if isinstance(g, dict) and g.get("symbol")
104
+ }
105
+ genes = list(genes)
106
+
107
+ # 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
108
+ seen = set()
109
+ unique_vars: List[dict] = []
110
+ for var in variants or []:
111
+ key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
112
+ if key not in seen:
113
+ seen.add(key)
114
+ unique_vars.append(var)
115
+
116
+ # 6) LLM summary
117
+ summarize_fn, _, engine_used = _llm_router(llm)
118
+ long_text = " ".join(p.get("summary", "") for p in papers)
119
+ ai_summary = await summarize_fn(long_text[:12000])
120
+
121
  return {
122
+ "papers": papers,
123
+ "umls": [u for u in umls_list if not isinstance(u, Exception)],
124
+ "drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
125
+ "clinical_trials": trials or [],
126
+ "variants": unique_vars,
127
+ "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
128
+ "mesh_defs": gene_data["mesh"],
129
+ "gene_disease": gene_data["ot_assoc"],
130
+ "ai_summary": ai_summary,
131
+ "llm_used": engine_used,
132
  }
133
 
134
 
135
+ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
136
+ """
137
+ Fan-out gene-related tasks for each seed key:
138
+ - NCBI gene lookup
139
+ - MeSH definition
140
+ - MyGene.info
141
+ - Ensembl xrefs
142
+ - OpenTargets associations
143
+ Returns a dict of lists.
144
+ """
145
+ jobs = []
146
+ for k in keys:
147
+ jobs.extend([
148
+ asyncio.create_task(search_gene(k)),
149
+ asyncio.create_task(get_mesh_definition(k)),
150
+ asyncio.create_task(fetch_gene_info(k)),
151
+ asyncio.create_task(fetch_ensembl(k)),
152
+ asyncio.create_task(fetch_ot(k)),
153
+ ])
 
 
 
 
 
 
 
 
154
 
155
+ results = await _safe_gather(*jobs, return_exceptions=True)
156
+
157
+ # Group back into 5 buckets
158
+ def bucket(idx: int):
159
+ return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]
160
 
161
  return {
162
+ "ncbi": bucket(0),
163
+ "mesh": bucket(1),
164
+ "mygene": bucket(2),
165
+ "ensembl": bucket(3),
166
+ "ot_assoc": bucket(4),
 
 
 
 
 
167
  }
168
 
169
+
170
+ async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
171
+ """
172
+ Follow-up QA: wraps the chosen LLM’s QA function.
173
+ """
174
  _, qa_fn, _ = _llm_router(llm)
175
+ prompt = f"Q: {question}\nContext: {context}\nA:"
176
+ try:
177
+ answer = await qa_fn(prompt)
178
+ except Exception as e:
179
+ answer = f"LLM error: {e}"
180
+ return {"answer": answer}