File size: 1,142 Bytes
b8f604b
9108287
b8f604b
 
3feca5c
9108287
 
b8f604b
3feca5c
 
 
9108287
0f546f7
9108287
3feca5c
9108287
4360f69
9108287
 
4360f69
0f546f7
 
9108287
 
 
 
3feca5c
b8f604b
3feca5c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import pipeline
from retrieval import get_relevant_pubmed_docs
from visualization import create_medical_graph

# Use Microsoft BioGPT-Large-PubMedQA
model_name = "microsoft/BioGPT-Large-PubMedQA"
qa_pipeline = pipeline("text-generation", model=model_name)

# Simple in-memory cache for docs so the graph route can fetch them
docs_cache = {}

def process_medical_query(query: str):
    # Retrieve relevant PubMed documents
    relevant_docs = get_relevant_pubmed_docs(query)
    docs_cache[query] = relevant_docs

    # Combine docs as context for the prompt
    context_text = "\n\n".join(relevant_docs)
    prompt = f"Question: {query}\nContext: {context_text}\nAnswer:"

    # Use max_new_tokens to generate additional tokens, and enable truncation.
    generation = qa_pipeline(prompt, max_new_tokens=100, truncation=True)
    if generation and isinstance(generation, list):
        answer = generation[0]["generated_text"]
    else:
        answer = "No answer found."
    return answer

def get_graph_html(query: str):
    relevant_docs = docs_cache.get(query, [])
    return create_medical_graph(query, relevant_docs)