File size: 3,957 Bytes
57eeff5
 
9c2f1fc
 
bd2d9e0
 
 
9c2f1fc
bd2d9e0
 
 
 
 
 
 
 
 
 
9c2f1fc
 
bd2d9e0
 
9c2f1fc
2417938
9c2f1fc
 
 
 
 
bd2d9e0
bc178da
9c2f1fc
2417938
 
9c2f1fc
2417938
9c2f1fc
 
2417938
9c2f1fc
2417938
bd2d9e0
 
 
2417938
9c2f1fc
bc178da
bd2d9e0
2417938
 
 
bc178da
2417938
9c2f1fc
 
 
 
 
 
 
 
bd2d9e0
2417938
bc178da
 
 
 
bd2d9e0
9958236
 
2417938
9c2f1fc
 
2417938
 
 
 
9c2f1fc
2417938
9c2f1fc
2417938
9c2f1fc
bd2d9e0
9c2f1fc
bc178da
 
9958236
0bd4f6b
9958236
bd2d9e0
9c2f1fc
bd2d9e0
bc178da
bd2d9e0
 
 
 
 
2417938
9c2f1fc
bc178da
bd2d9e0
2a8cf8d
bd2d9e0
9c2f1fc
 
 
2417938
9c2f1fc
bc178da
bd2d9e0
2417938
bc178da
2417938
d7bf01e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# mcp/orchestrator.py
import asyncio
from typing import Any, Dict, List

from mcp.arxiv          import fetch_arxiv
from mcp.pubmed         import fetch_pubmed
from mcp.nlp            import extract_umls_concepts
from mcp.umls           import lookup_umls
from mcp.umls_rel       import fetch_relations
from mcp.openfda        import fetch_drug_safety
from mcp.ncbi           import search_gene, get_mesh_definition
from mcp.disgenet       import disease_to_genes
from mcp.clinicaltrials import search_trials
from mcp.mygene         import mygene
from mcp.opentargets    import ot
from mcp.cbio           import cbio
from mcp.openai_utils   import ai_summarize, ai_qa
from mcp.gemini         import gemini_summarize, gemini_qa
from mcp.embeddings     import embed_texts, cluster_embeddings


def _get_llm(llm: str):
    """
    Route summarization and QA to the chosen engine.
    """
    if llm.lower() == "gemini":
        return gemini_summarize, gemini_qa
    return ai_summarize, ai_qa


async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
    """
    Fetch papers, extract concepts & relations, enrich data,
    compute embeddings+clusters, and run LLM summary.
    """
    # Gather literature
    arxiv_task  = fetch_arxiv(query)
    pubmed_task = fetch_pubmed(query)
    lit_results = await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True)
    papers: List[Dict] = []
    for res in lit_results:
        if isinstance(res, list):
            papers.extend(res)

    # Concept extraction
    blob = " ".join(p.get("summary", "") for p in papers)
    umls = await extract_umls_concepts(blob)

    # Fetch UMLS relations
    rel_tasks = [fetch_relations(c["cui"]) for c in umls]
    umls_relations = await asyncio.gather(*rel_tasks, return_exceptions=True)

    # Data enrichment tasks
    names = [c["name"] for c in umls]
    fda_tasks   = [fetch_drug_safety(n) for n in names]
    gene_task   = search_gene(names[0]) if names else asyncio.sleep(0, result=[])
    mesh_task   = get_mesh_definition(names[0]) if names else asyncio.sleep(0, result="")
    dis_task    = disease_to_genes(names[0]) if names else asyncio.sleep(0, result=[])
    trials_task = search_trials(query)
    ot_task     = ot.fetch(names[0]) if names else asyncio.sleep(0, result=[])
    cbio_task   = cbio.fetch_variants(names[0]) if names else asyncio.sleep(0, result=[])

    # Run enrichment
    fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
        asyncio.gather(*fda_tasks, return_exceptions=True),
        gene_task, mesh_task, dis_task,
        trials_task, ot_task, cbio_task,
        return_exceptions=False
    )

    # Embeddings & clustering
    summaries = [p.get("summary", "") for p in papers]
    if summaries:
        embeddings = await embed_texts(summaries)
        clusters = await cluster_embeddings(
            embeddings, n_clusters = max(2, min(10, len(embeddings)//2))
        )
    else:
        embeddings, clusters = [], []

    # LLM summary
    summarize_fn, _ = _get_llm(llm)
    try:
        ai_summary = await summarize_fn(blob)
    except Exception:
        ai_summary = "LLM summary failed."

    return {
        "papers": papers,
        "umls": umls,
        "umls_relations": umls_relations,
        "drug_safety": fda,
        "genes": [gene],
        "mesh_defs": [mesh],
        "gene_disease": dis,
        "clinical_trials": trials,
        "ot_associations": ot_assoc,
        "variants": variants,
        "embeddings": embeddings,
        "clusters": clusters,
        "ai_summary": ai_summary,
        "llm_used": llm.lower()
    }


async def answer_ai_question(question: str, context: str = "", llm: str = "openai") -> Dict[str, str]:
    """
    Follow-up Q&A via chosen LLM.
    """
    _, qa_fn = _get_llm(llm)
    try:
        ans = await qa_fn(question, context)
    except Exception:
        ans = "LLM follow-up failed."
    return {"answer": ans}