medrag / medrag_multi_modal /assistant /medqa_assistant.py
geekyrakshit's picture
update: integrate FigureAnnotatorFromPageImage into MedQAAssistant
e197ad0
raw
history blame
1.91 kB
from typing import Optional
import weave
from ..retrieval import SimilarityMetric
from .figure_annotation import FigureAnnotatorFromPageImage
from .llm_client import LLMClient
class MedQAAssistant(weave.Model):
"""Cuming"""
llm_client: LLMClient
retriever: weave.Model
figure_annotator: FigureAnnotatorFromPageImage
top_k_chunks: int = 2
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
@weave.op()
def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str:
retrieved_chunks = self.retriever.predict(
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
)
retrieved_chunk_texts = []
page_indices = set()
for chunk in retrieved_chunks:
retrieved_chunk_texts.append(chunk["text"])
page_indices.add(int(chunk["page_idx"]))
figure_descriptions = []
if image_artifact_address is not None:
for page_idx in page_indices:
figure_annotations = self.figure_annotator.predict(
page_idx=page_idx, image_artifact_address=image_artifact_address
)
figure_descriptions += [
item["figure_description"] for item in figure_annotations[page_idx]
]
system_prompt = """
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
"""
response = self.llm_client.predict(
system_prompt=system_prompt,
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
)
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
return response