|
from transformers import pipeline |
|
from retrieval import get_relevant_pubmed_docs |
|
from visualization import create_medical_graph |
|
|
|
|
|
model_name = "microsoft/BioGPT-Large-PubMedQA" |
|
qa_pipeline = pipeline("text-generation", model=model_name) |
|
|
|
|
|
docs_cache = {} |
|
|
|
def process_medical_query(query: str): |
|
|
|
relevant_docs = get_relevant_pubmed_docs(query) |
|
docs_cache[query] = relevant_docs |
|
|
|
|
|
context_text = "\n\n".join(relevant_docs) |
|
prompt = f"Question: {query}\nContext: {context_text}\nAnswer:" |
|
|
|
|
|
generation = qa_pipeline(prompt, max_new_tokens=100, truncation=True) |
|
if generation and isinstance(generation, list): |
|
answer = generation[0]["generated_text"] |
|
else: |
|
answer = "No answer found." |
|
return answer |
|
|
|
def get_graph_html(query: str): |
|
relevant_docs = docs_cache.get(query, []) |
|
return create_medical_graph(query, relevant_docs) |
|
|