""" MedGenesis – dual-LLM orchestrator ---------------------------------- • Accepts `llm` arg ("openai" | "gemini") • Defaults to "openai" if arg omitted """ import asyncio, httpx from typing import Dict, Any, List from mcp.arxiv import fetch_arxiv from mcp.pubmed import fetch_pubmed from mcp.nlp import extract_keywords from mcp.umls import lookup_umls from mcp.openfda import fetch_drug_safety from mcp.ncbi import search_gene, get_mesh_definition from mcp.disgenet import disease_to_genes from mcp.clinicaltrials import search_trials from mcp.openai_utils import ai_summarize, ai_qa from mcp.gemini import gemini_summarize, gemini_qa # make sure gemini.py exists # ---------------- LLM router ---------------- def _get_llm(llm: str): if llm.lower() == "gemini": return gemini_summarize, gemini_qa return ai_summarize, ai_qa # default → OpenAI async def _enrich_genes_mesh_disg(keys: List[str]) -> Dict[str, Any]: jobs = [] for k in keys: jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)] res = await asyncio.gather(*jobs, return_exceptions=True) genes, meshes, disg = [], [], [] for i, r in enumerate(res): if isinstance(r, Exception): # skip failures quietly continue if i % 3 == 0: genes.extend(r) elif i % 3 == 1: meshes.append(r) else: disg.extend(r) return {"genes": genes, "meshes": meshes, "disgenet": disg} # ------------------------------------------------------------------ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]: """ Main orchestrator – returns dict for UI. """ # 1) Literature --------------------------------------------------- arxiv_f = asyncio.create_task(fetch_arxiv(query)) pubmed_f = asyncio.create_task(fetch_pubmed(query)) papers = sum(await asyncio.gather(arxiv_f, pubmed_f), []) # 2) Keywords ----------------------------------------------------- blob = " ".join(p["summary"] for p in papers) keys = extract_keywords(blob)[:8] # 3) Enrichment --------------------------------------------------- umls_f = [lookup_umls(k) for k in keys] fda_f = [fetch_drug_safety(k) for k in keys] genes_f = asyncio.create_task(_enrich_genes_mesh_disg(keys)) trials_f = asyncio.create_task(search_trials(query, max_studies=10)) umls, fda, genes, trials = await asyncio.gather( asyncio.gather(*umls_f, return_exceptions=True), asyncio.gather(*fda_f, return_exceptions=True), genes_f, trials_f, ) # 4) AI summary --------------------------------------------------- summarize, _ = _get_llm(llm) summary = await summarize(blob) return { "papers" : papers, "umls" : umls, "drug_safety" : fda, "ai_summary" : summary, "llm_used" : llm.lower(), "genes" : genes["genes"], "mesh_defs" : genes["meshes"], "gene_disease" : genes["disgenet"], "clinical_trials" : trials, } async def answer_ai_question(question: str, context: str, llm: str = "openai") -> Dict[str, str]: """One-shot follow-up Q-A via chosen engine.""" _, qa = _get_llm(llm) return {"answer": await qa(question, context)}