Spaces:
Sleeping
Sleeping
File size: 4,681 Bytes
49cde8e e197ad0 49cde8e 6c6905f 2b64a07 49cde8e e197ad0 49cde8e 6c6905f 49cde8e b123ef7 e197ad0 6c6905f b123ef7 49cde8e b123ef7 49cde8e b123ef7 e197ad0 49cde8e e197ad0 6c6905f b123ef7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import weave
from ..retrieval import SimilarityMetric
from .figure_annotation import FigureAnnotatorFromPageImage
from .llm_client import LLMClient
class MedQAAssistant(weave.Model):
"""
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
language model client, a retriever model, and a figure annotator.
!!! example "Usage Example"
```python
import weave
from dotenv import load_dotenv
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
load_dotenv()
weave.init(project_name="ml-colabs/medrag-multi-modal")
llm_client = LLMClient(model_name="gemini-1.5-flash")
retriever=MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name="grays-anatomy-chunks:v0",
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
)
figure_annotator=FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
medqa_assistant.predict(query="What is ribosome?")
```
Args:
llm_client (LLMClient): The language model client used to generate responses.
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
"""
llm_client: LLMClient
retriever: weave.Model
figure_annotator: FigureAnnotatorFromPageImage
top_k_chunks: int = 2
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
@weave.op()
def predict(self, query: str) -> str:
"""
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
from a medical document and using a language model to generate the final response.
This function performs the following steps:
1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
2. Extracts the text and page indices from the retrieved chunks.
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
5. Uses the language model client to generate a response based on the constructed prompts.
6. Appends the source information (page numbers) to the generated response.
Args:
query (str): The medical query to be answered.
Returns:
str: The generated response to the query, including source information.
"""
retrieved_chunks = self.retriever.predict(
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
)
retrieved_chunk_texts = []
page_indices = set()
for chunk in retrieved_chunks:
retrieved_chunk_texts.append(chunk["text"])
page_indices.add(int(chunk["page_idx"]))
figure_descriptions = []
for page_idx in page_indices:
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
page_idx
]
figure_descriptions += [
item["figure_description"] for item in figure_annotations
]
system_prompt = """
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
"""
response = self.llm_client.predict(
system_prompt=system_prompt,
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
)
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
return response
|