#!/usr/bin/env python3 """ MedGenesis – dual-LLM orchestrator (v5) --------------------------------------- • No external 'pytrials' dependency. • Uses direct HTTP for clinical trials. • Clean async fan-out, dual-LLM support. """ from __future__ import annotations import asyncio, itertools, logging from typing import Dict, Any, List, Tuple from mcp.arxiv import fetch_arxiv from mcp.pubmed import fetch_pubmed from mcp.ncbi import search_gene, get_mesh_definition from mcp.mygene import fetch_gene_info from mcp.ensembl import fetch_ensembl from mcp.opentargets import fetch_ot from mcp.umls import lookup_umls from mcp.openfda import fetch_drug_safety from mcp.disgenet import disease_to_genes from mcp.clinicaltrials import fetch_clinical_trials from mcp.cbio import fetch_cbio from mcp.openai_utils import ai_summarize, ai_qa from mcp.gemini import gemini_summarize, gemini_qa log = logging.getLogger(__name__) _DEFAULT_LLM = "openai" def _llm_router(engine: str = _DEFAULT_LLM) -> Tuple: """Choose summarization and QA functions based on engine name.""" if engine.lower() == "gemini": return gemini_summarize, gemini_qa, "gemini" return ai_summarize, ai_qa, "openai" async def _safe_gather(*tasks, return_exceptions: bool = False): """ Await multiple coroutines, log any exceptions, and optionally return them. """ results = await asyncio.gather(*tasks, return_exceptions=True) cleaned: List[Any] = [] for r in results: if isinstance(r, Exception): log.warning("Task failed: %s", r) if return_exceptions: cleaned.append(r) else: cleaned.append(r) return cleaned async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]: """ Fan-out gene-related endpoints for each seed keyword: - NCBI gene lookup - MeSH definition - MyGene.info - Ensembl cross-refs - OpenTargets associations Returns a dict of results. """ jobs: List[asyncio.Task] = [] for k in keys: jobs.extend([ asyncio.create_task(search_gene(k)), asyncio.create_task(get_mesh_definition(k)), asyncio.create_task(fetch_gene_info(k)), asyncio.create_task(fetch_ensembl(k)), asyncio.create_task(fetch_ot(k)), ]) results = await _safe_gather(*jobs, return_exceptions=True) def bucket(idx: int) -> List[Any]: return [res for i, res in enumerate(results) if i % 5 == idx and not isinstance(res, Exception)] return { "ncbi": bucket(0), "mesh": bucket(1), "mygene": bucket(2), "ensembl": bucket(3), "ot_assoc": bucket(4), } async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]: """ Main entry point. Performs: 1. Literature fetch (PubMed + arXiv) 2. Keyword seed extraction 3. Bio-enrichment (UMLS, OpenFDA, gene services) 4. Clinical trials lookup 5. cBioPortal variants 6. AI LLM summary Returns a unified dict for the UI. """ # 1) Literature pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7)) arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7)) papers_raw = await _safe_gather(pubmed_t, arxiv_t) papers = list(itertools.chain.from_iterable(papers_raw))[:30] # 2) Seed keywords seeds = { w.strip() for p in papers for w in p.get("summary", "")[:500].split() if w.isalpha() } seeds = list(seeds)[:10] # 3) Bio-enrichment fan-out umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds] fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds] gene_task = asyncio.create_task(_gene_enrichment(seeds)) trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10)) cbio_t = asyncio.create_task( fetch_cbio(seeds[0]) if seeds else asyncio.sleep(0, result=[]) ) umls_list, fda_list, gene_data, trials, variants = await asyncio.gather( _safe_gather(*umls_tasks, return_exceptions=True), _safe_gather(*fda_tasks, return_exceptions=True), gene_task, trials_t, cbio_t, ) # 4) Deduplicate gene symbols from enrichment genes = { g["symbol"] for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"]) for g in src if isinstance(g, dict) and g.get("symbol") } genes = list(genes) # 5) Deduplicate variants by genomic coordinates seen: set = set() unique_vars: List[dict] = [] for v in variants or []: key = ( v.get("chromosome"), v.get("startPosition"), v.get("referenceAllele"), v.get("variantAllele"), ) if key not in seen: seen.add(key) unique_vars.append(v) # 6) LLM-driven summary summarize_fn, _, engine_used = _llm_router(llm) combined = " ".join(p.get("summary", "") for p in papers) ai_summary = await summarize_fn(combined[:12000]) return { "papers": papers, "umls": [u for u in umls_list if not isinstance(u, Exception)], "drug_safety": list( itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list)) ), "clinical_trials": trials or [], "variants": unique_vars, "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"], "mesh_defs": gene_data["mesh"], "gene_disease": gene_data["ot_assoc"], "ai_summary": ai_summary, "llm_used": engine_used, } async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]: """ Follow-up QA: uses the chosen LLM’s QA function. """ _, qa_fn, _ = _llm_router(llm) prompt = f"Q: {question}\nContext: {context}\nA:" try: answer = await qa_fn(prompt) except Exception as e: answer = f"LLM error: {e}" return {"answer": answer}