MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
12007d6 verified
"""
MedGenesis – dual-LLM orchestrator
----------------------------------
β€’ Accepts `llm` arg ("openai" | "gemini")
β€’ Defaults to "openai" if arg omitted
"""
import asyncio, httpx
from typing import Dict, Any, List
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
from mcp.nlp import extract_keywords
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.disgenet import disease_to_genes
from mcp.clinicaltrials import search_trials
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa # make sure gemini.py exists
# ---------------- LLM router ----------------
def _get_llm(llm: str):
if llm.lower() == "gemini":
return gemini_summarize, gemini_qa
return ai_summarize, ai_qa # default β†’ OpenAI
async def _enrich_genes_mesh_disg(keys: List[str]) -> Dict[str, Any]:
jobs = []
for k in keys:
jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
res = await asyncio.gather(*jobs, return_exceptions=True)
genes, meshes, disg = [], [], []
for i, r in enumerate(res):
if isinstance(r, Exception): # skip failures quietly
continue
if i % 3 == 0: genes.extend(r)
elif i % 3 == 1: meshes.append(r)
else: disg.extend(r)
return {"genes": genes, "meshes": meshes, "disgenet": disg}
# ------------------------------------------------------------------
async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
"""
Main orchestrator – returns dict for UI.
"""
# 1) Literature ---------------------------------------------------
arxiv_f = asyncio.create_task(fetch_arxiv(query))
pubmed_f = asyncio.create_task(fetch_pubmed(query))
papers = sum(await asyncio.gather(arxiv_f, pubmed_f), [])
# 2) Keywords -----------------------------------------------------
blob = " ".join(p["summary"] for p in papers)
keys = extract_keywords(blob)[:8]
# 3) Enrichment ---------------------------------------------------
umls_f = [lookup_umls(k) for k in keys]
fda_f = [fetch_drug_safety(k) for k in keys]
genes_f = asyncio.create_task(_enrich_genes_mesh_disg(keys))
trials_f = asyncio.create_task(search_trials(query, max_studies=10))
umls, fda, genes, trials = await asyncio.gather(
asyncio.gather(*umls_f, return_exceptions=True),
asyncio.gather(*fda_f, return_exceptions=True),
genes_f,
trials_f,
)
# 4) AI summary ---------------------------------------------------
summarize, _ = _get_llm(llm)
summary = await summarize(blob)
return {
"papers" : papers,
"umls" : umls,
"drug_safety" : fda,
"ai_summary" : summary,
"llm_used" : llm.lower(),
"genes" : genes["genes"],
"mesh_defs" : genes["meshes"],
"gene_disease" : genes["disgenet"],
"clinical_trials" : trials,
}
async def answer_ai_question(question: str, context: str, llm: str = "openai") -> Dict[str, str]:
"""One-shot follow-up Q-A via chosen engine."""
_, qa = _get_llm(llm)
return {"answer": await qa(question, context)}