File size: 3,450 Bytes
9965499
12007d6
 
 
 
9965499
3d539ca
12007d6
9965499
 
12007d6
 
 
 
 
 
 
9965499
12007d6
 
3d539ca
12007d6
 
 
 
 
3637999
12007d6
 
 
 
 
 
 
 
 
 
9965499
12007d6
 
 
 
9965499
 
12007d6
 
 
 
 
 
 
 
 
9965499
12007d6
 
 
9965499
12007d6
 
 
 
 
 
 
 
 
 
 
9965499
 
12007d6
 
 
9965499
3d539ca
9965499
 
 
 
12007d6
 
 
 
9965499
3d539ca
 
12007d6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
MedGenesis – dual-LLM orchestrator
----------------------------------
β€’ Accepts `llm` arg ("openai" | "gemini")
β€’ Defaults to "openai" if arg omitted
"""

import asyncio, httpx
from typing import Dict, Any, List

from mcp.arxiv          import fetch_arxiv
from mcp.pubmed         import fetch_pubmed
from mcp.nlp            import extract_keywords
from mcp.umls           import lookup_umls
from mcp.openfda        import fetch_drug_safety
from mcp.ncbi           import search_gene, get_mesh_definition
from mcp.disgenet       import disease_to_genes
from mcp.clinicaltrials import search_trials
from mcp.openai_utils   import ai_summarize, ai_qa
from mcp.gemini         import gemini_summarize, gemini_qa   # make sure gemini.py exists

# ---------------- LLM router ----------------
def _get_llm(llm: str):
    if llm.lower() == "gemini":
        return gemini_summarize, gemini_qa
    return ai_summarize, ai_qa                         # default β†’ OpenAI


async def _enrich_genes_mesh_disg(keys: List[str]) -> Dict[str, Any]:
    jobs = []
    for k in keys:
        jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
    res = await asyncio.gather(*jobs, return_exceptions=True)

    genes, meshes, disg = [], [], []
    for i, r in enumerate(res):
        if isinstance(r, Exception):  # skip failures quietly
            continue
        if   i % 3 == 0: genes.extend(r)
        elif i % 3 == 1: meshes.append(r)
        else:            disg.extend(r)
    return {"genes": genes, "meshes": meshes, "disgenet": disg}


# ------------------------------------------------------------------
async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
    """
    Main orchestrator – returns dict for UI.
    """
    # 1) Literature ---------------------------------------------------
    arxiv_f  = asyncio.create_task(fetch_arxiv(query))
    pubmed_f = asyncio.create_task(fetch_pubmed(query))
    papers   = sum(await asyncio.gather(arxiv_f, pubmed_f), [])

    # 2) Keywords -----------------------------------------------------
    blob = " ".join(p["summary"] for p in papers)
    keys = extract_keywords(blob)[:8]

    # 3) Enrichment ---------------------------------------------------
    umls_f   = [lookup_umls(k)       for k in keys]
    fda_f    = [fetch_drug_safety(k) for k in keys]
    genes_f  = asyncio.create_task(_enrich_genes_mesh_disg(keys))
    trials_f = asyncio.create_task(search_trials(query, max_studies=10))

    umls, fda, genes, trials = await asyncio.gather(
        asyncio.gather(*umls_f,  return_exceptions=True),
        asyncio.gather(*fda_f,   return_exceptions=True),
        genes_f,
        trials_f,
    )

    # 4) AI summary ---------------------------------------------------
    summarize, _ = _get_llm(llm)
    summary = await summarize(blob)

    return {
        "papers"          : papers,
        "umls"            : umls,
        "drug_safety"     : fda,
        "ai_summary"      : summary,
        "llm_used"        : llm.lower(),
        "genes"           : genes["genes"],
        "mesh_defs"       : genes["meshes"],
        "gene_disease"    : genes["disgenet"],
        "clinical_trials" : trials,
    }


async def answer_ai_question(question: str, context: str, llm: str = "openai") -> Dict[str, str]:
    """One-shot follow-up Q-A via chosen engine."""
    _, qa = _get_llm(llm)
    return {"answer": await qa(question, context)}