Spaces:
Sleeping
Sleeping
import weave | |
from ..retrieval import SimilarityMetric | |
from .figure_annotation import FigureAnnotatorFromPageImage | |
from .llm_client import LLMClient | |
class MedQAAssistant(weave.Model): | |
""" | |
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a | |
language model client, a retriever model, and a figure annotator. | |
!!! example "Usage Example" | |
```python | |
import weave | |
from dotenv import load_dotenv | |
from medrag_multi_modal.assistant import ( | |
FigureAnnotatorFromPageImage, | |
LLMClient, | |
MedQAAssistant, | |
) | |
from medrag_multi_modal.retrieval import MedCPTRetriever | |
load_dotenv() | |
weave.init(project_name="ml-colabs/medrag-multi-modal") | |
llm_client = LLMClient(model_name="gemini-1.5-flash") | |
retriever=MedCPTRetriever.from_wandb_artifact( | |
chunk_dataset_name="grays-anatomy-chunks:v0", | |
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", | |
) | |
figure_annotator=FigureAnnotatorFromPageImage( | |
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), | |
structured_output_llm_client=LLMClient(model_name="gpt-4o"), | |
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", | |
) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator | |
) | |
medqa_assistant.predict(query="What is ribosome?") | |
``` | |
Args: | |
llm_client (LLMClient): The language model client used to generate responses. | |
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. | |
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. | |
top_k_chunks (int): The number of top chunks to retrieve based on similarity metric. | |
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. | |
""" | |
llm_client: LLMClient | |
retriever: weave.Model | |
figure_annotator: FigureAnnotatorFromPageImage | |
top_k_chunks: int = 2 | |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE | |
def predict(self, query: str) -> str: | |
""" | |
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions | |
from a medical document and using a language model to generate the final response. | |
This function performs the following steps: | |
1. Retrieves relevant text chunks from the medical document based on the query using the retriever model. | |
2. Extracts the text and page indices from the retrieved chunks. | |
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. | |
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions. | |
5. Uses the language model client to generate a response based on the constructed prompts. | |
6. Appends the source information (page numbers) to the generated response. | |
Args: | |
query (str): The medical query to be answered. | |
Returns: | |
str: The generated response to the query, including source information. | |
""" | |
retrieved_chunks = self.retriever.predict( | |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric | |
) | |
retrieved_chunk_texts = [] | |
page_indices = set() | |
for chunk in retrieved_chunks: | |
retrieved_chunk_texts.append(chunk["text"]) | |
page_indices.add(int(chunk["page_idx"])) | |
figure_descriptions = [] | |
for page_idx in page_indices: | |
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ | |
page_idx | |
] | |
figure_descriptions += [ | |
item["figure_description"] for item in figure_annotations | |
] | |
system_prompt = """ | |
You are an expert in medical science. You are given a query and a list of chunks from a medical document. | |
""" | |
response = self.llm_client.predict( | |
system_prompt=system_prompt, | |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], | |
) | |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices]) | |
response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy" | |
return response | |