Spaces:
Running
Running
from typing import Optional | |
import weave | |
from ..retrieval import SimilarityMetric | |
from .figure_annotation import FigureAnnotatorFromPageImage | |
from .llm_client import LLMClient | |
class MedQAAssistant(weave.Model): | |
"""Cuming""" | |
llm_client: LLMClient | |
retriever: weave.Model | |
figure_annotator: FigureAnnotatorFromPageImage | |
top_k_chunks: int = 2 | |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE | |
def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str: | |
retrieved_chunks = self.retriever.predict( | |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric | |
) | |
retrieved_chunk_texts = [] | |
page_indices = set() | |
for chunk in retrieved_chunks: | |
retrieved_chunk_texts.append(chunk["text"]) | |
page_indices.add(int(chunk["page_idx"])) | |
figure_descriptions = [] | |
if image_artifact_address is not None: | |
for page_idx in page_indices: | |
figure_annotations = self.figure_annotator.predict( | |
page_idx=page_idx, image_artifact_address=image_artifact_address | |
) | |
figure_descriptions += [ | |
item["figure_description"] for item in figure_annotations[page_idx] | |
] | |
system_prompt = """ | |
You are an expert in medical science. You are given a query and a list of chunks from a medical document. | |
""" | |
response = self.llm_client.predict( | |
system_prompt=system_prompt, | |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], | |
) | |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices]) | |
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy" | |
return response | |