med / backend.py
mgbam's picture
Update backend.py
0f546f7 verified
from transformers import pipeline
from retrieval import get_relevant_pubmed_docs
from visualization import create_medical_graph
# Use Microsoft BioGPT-Large-PubMedQA
model_name = "microsoft/BioGPT-Large-PubMedQA"
qa_pipeline = pipeline("text-generation", model=model_name)
# Simple in-memory cache for docs so the graph route can fetch them
docs_cache = {}
def process_medical_query(query: str):
# Retrieve relevant PubMed documents
relevant_docs = get_relevant_pubmed_docs(query)
docs_cache[query] = relevant_docs
# Combine docs as context for the prompt
context_text = "\n\n".join(relevant_docs)
prompt = f"Question: {query}\nContext: {context_text}\nAnswer:"
# Use max_new_tokens to generate additional tokens, and enable truncation.
generation = qa_pipeline(prompt, max_new_tokens=100, truncation=True)
if generation and isinstance(generation, list):
answer = generation[0]["generated_text"]
else:
answer = "No answer found."
return answer
def get_graph_html(query: str):
relevant_docs = docs_cache.get(query, [])
return create_medical_graph(query, relevant_docs)