|
|
|
""" |
|
Orchestrates retrieval, enrichment, and AI synthesis for a user query. |
|
""" |
|
|
|
import asyncio |
|
from typing import Dict, Any, List |
|
|
|
from mcp.arxiv import fetch_arxiv |
|
from mcp.pubmed import fetch_pubmed |
|
from mcp.nlp import extract_keywords |
|
from mcp.umls import lookup_umls |
|
from mcp.openfda import fetch_drug_safety |
|
from mcp.ncbi import search_gene, get_mesh_definition |
|
from mcp.disgenet import disease_to_genes |
|
from mcp.clinicaltrials import search_trials |
|
from mcp.openai_utils import ai_summarize, ai_qa |
|
|
|
|
|
async def _gene_and_mesh_enrichment(keywords: List[str]) -> Dict[str, Any]: |
|
"""Run NCBI and DisGeNET on keywords in parallel.""" |
|
tasks = [] |
|
for kw in keywords: |
|
tasks.append(search_gene(kw)) |
|
tasks.append(get_mesh_definition(kw)) |
|
tasks.append(disease_to_genes(kw)) |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
genes, meshes, disgen = [], [], [] |
|
for i, res in enumerate(results): |
|
if isinstance(res, Exception): |
|
continue |
|
|
|
mod = i % 3 |
|
if mod == 0: |
|
genes.extend(res) |
|
elif mod == 1: |
|
meshes.append(res) |
|
else: |
|
disgen.extend(res) |
|
return {"genes": genes, "meshes": meshes, "disgenet": disgen} |
|
|
|
|
|
async def orchestrate_search(query: str) -> Dict[str, Any]: |
|
"""Main entry—returns a rich result dict for app UI.""" |
|
|
|
arxiv_task = asyncio.create_task(fetch_arxiv(query)) |
|
pubmed_task = asyncio.create_task(fetch_pubmed(query)) |
|
arxiv_results, pubmed_results = await asyncio.gather(arxiv_task, pubmed_task) |
|
papers = arxiv_results + pubmed_results |
|
|
|
|
|
paper_text = " ".join(p["summary"] for p in papers) |
|
keywords = extract_keywords(paper_text)[:8] |
|
|
|
|
|
umls_tasks = [lookup_umls(k) for k in keywords] |
|
fda_tasks = [fetch_drug_safety(k) for k in keywords] |
|
enrich_task = asyncio.create_task(_gene_and_mesh_enrichment(keywords)) |
|
trials_task = asyncio.create_task(search_trials(query, max_studies=10)) |
|
|
|
umls, fda, enrich, trials = await asyncio.gather( |
|
asyncio.gather(*umls_tasks), |
|
asyncio.gather(*fda_tasks), |
|
enrich_task, |
|
trials_task, |
|
) |
|
|
|
|
|
summary = await ai_summarize(paper_text) |
|
links = [p["link"] for p in papers[:3]] |
|
|
|
return { |
|
"papers" : papers, |
|
"umls" : umls, |
|
"drug_safety" : fda, |
|
"ai_summary" : summary, |
|
"suggested_reading": links, |
|
|
|
"genes" : enrich["genes"], |
|
"mesh_definitions": enrich["meshes"], |
|
"gene_disease" : enrich["disgenet"], |
|
"clinical_trials" : trials, |
|
} |
|
|
|
|
|
async def answer_ai_question(question: str, context: str = "") -> Dict[str, str]: |
|
"""Free-form Q&A using OpenAI.""" |
|
answer = await ai_qa(question, context) |
|
return {"answer": answer} |
|
|