Spaces:
Sleeping
Sleeping
Commit
·
6c6905f
1
Parent(s):
e197ad0
update: MedQAAssistant + FigureAnnotatorFromPageImage
Browse files
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import os
|
2 |
from glob import glob
|
3 |
-
from typing import Union
|
4 |
|
5 |
import cv2
|
6 |
import weave
|
7 |
from PIL import Image
|
8 |
from pydantic import BaseModel
|
9 |
-
from rich.progress import track
|
10 |
|
11 |
from ..utils import get_wandb_artifact, read_jsonl_file
|
12 |
from .llm_client import LLMClient
|
@@ -23,7 +22,8 @@ class FigureAnnotations(BaseModel):
|
|
23 |
|
24 |
class FigureAnnotatorFromPageImage(weave.Model):
|
25 |
"""
|
26 |
-
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
|
|
27 |
|
28 |
!!! example "Example Usage"
|
29 |
```python
|
@@ -39,19 +39,35 @@ class FigureAnnotatorFromPageImage(weave.Model):
|
|
39 |
figure_annotator = FigureAnnotatorFromPageImage(
|
40 |
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
41 |
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
|
|
42 |
)
|
43 |
-
annotations = figure_annotator.predict(
|
44 |
-
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6"
|
45 |
-
)
|
46 |
```
|
47 |
|
48 |
-
|
49 |
-
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
"""
|
52 |
|
53 |
figure_extraction_llm_client: LLMClient
|
54 |
structured_output_llm_client: LLMClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
@weave.op()
|
57 |
def annotate_figures(
|
@@ -92,7 +108,7 @@ Here are some clues you need to follow:
|
|
92 |
)
|
93 |
|
94 |
@weave.op()
|
95 |
-
def predict(self, page_idx: int,
|
96 |
"""
|
97 |
Predicts figure annotations for a specific page in a document.
|
98 |
|
@@ -105,22 +121,23 @@ Here are some clues you need to follow:
|
|
105 |
|
106 |
Args:
|
107 |
page_idx (int): The index of the page to annotate.
|
108 |
-
image_artifact_address (str): The address of the image artifact containing the
|
|
|
109 |
|
110 |
Returns:
|
111 |
-
dict: A dictionary containing the page index as the key and the extracted figure
|
112 |
-
|
113 |
"""
|
114 |
-
|
115 |
-
metadata = read_jsonl_file(os.path.join(
|
116 |
annotations = {}
|
117 |
-
for item in
|
118 |
if item["page_idx"] == page_idx:
|
119 |
page_image_file = os.path.join(
|
120 |
-
|
121 |
)
|
122 |
figure_image_files = glob(
|
123 |
-
os.path.join(
|
124 |
)
|
125 |
if len(figure_image_files) > 0:
|
126 |
page_image = cv2.imread(page_image_file)
|
|
|
1 |
import os
|
2 |
from glob import glob
|
3 |
+
from typing import Optional, Union
|
4 |
|
5 |
import cv2
|
6 |
import weave
|
7 |
from PIL import Image
|
8 |
from pydantic import BaseModel
|
|
|
9 |
|
10 |
from ..utils import get_wandb_artifact, read_jsonl_file
|
11 |
from .llm_client import LLMClient
|
|
|
22 |
|
23 |
class FigureAnnotatorFromPageImage(weave.Model):
|
24 |
"""
|
25 |
+
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
26 |
+
figures from a page image of a scientific textbook.
|
27 |
|
28 |
!!! example "Example Usage"
|
29 |
```python
|
|
|
39 |
figure_annotator = FigureAnnotatorFromPageImage(
|
40 |
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
41 |
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
42 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
43 |
)
|
44 |
+
annotations = figure_annotator.predict(page_idx=34)
|
|
|
|
|
45 |
```
|
46 |
|
47 |
+
Args:
|
48 |
+
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
49 |
+
from the page image.
|
50 |
+
structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
|
51 |
+
annotations into a structured format.
|
52 |
+
image_artifact_address (Optional[str]): The address of the image artifact containing the
|
53 |
+
page images.
|
54 |
"""
|
55 |
|
56 |
figure_extraction_llm_client: LLMClient
|
57 |
structured_output_llm_client: LLMClient
|
58 |
+
_artifact_dir: str
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
figure_extraction_llm_client: LLMClient,
|
63 |
+
structured_output_llm_client: LLMClient,
|
64 |
+
image_artifact_address: Optional[str] = None,
|
65 |
+
):
|
66 |
+
super().__init__(
|
67 |
+
figure_extraction_llm_client=figure_extraction_llm_client,
|
68 |
+
structured_output_llm_client=structured_output_llm_client,
|
69 |
+
)
|
70 |
+
self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
71 |
|
72 |
@weave.op()
|
73 |
def annotate_figures(
|
|
|
108 |
)
|
109 |
|
110 |
@weave.op()
|
111 |
+
def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
|
112 |
"""
|
113 |
Predicts figure annotations for a specific page in a document.
|
114 |
|
|
|
121 |
|
122 |
Args:
|
123 |
page_idx (int): The index of the page to annotate.
|
124 |
+
image_artifact_address (str): The address of the image artifact containing the
|
125 |
+
page images.
|
126 |
|
127 |
Returns:
|
128 |
+
dict: A dictionary containing the page index as the key and the extracted figure
|
129 |
+
annotations as the value.
|
130 |
"""
|
131 |
+
|
132 |
+
metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
|
133 |
annotations = {}
|
134 |
+
for item in metadata:
|
135 |
if item["page_idx"] == page_idx:
|
136 |
page_image_file = os.path.join(
|
137 |
+
self._artifact_dir, f"page{item['page_idx']}.png"
|
138 |
)
|
139 |
figure_image_files = glob(
|
140 |
+
os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
|
141 |
)
|
142 |
if len(figure_image_files) > 0:
|
143 |
page_image = cv2.imread(page_image_file)
|
medrag_multi_modal/assistant/medqa_assistant.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
import weave
|
4 |
|
5 |
from ..retrieval import SimilarityMetric
|
@@ -8,7 +6,50 @@ from .llm_client import LLMClient
|
|
8 |
|
9 |
|
10 |
class MedQAAssistant(weave.Model):
|
11 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
llm_client: LLMClient
|
14 |
retriever: weave.Model
|
@@ -17,7 +58,25 @@ class MedQAAssistant(weave.Model):
|
|
17 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
18 |
|
19 |
@weave.op()
|
20 |
-
def predict(self, query: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
retrieved_chunks = self.retriever.predict(
|
22 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
23 |
)
|
@@ -29,14 +88,13 @@ class MedQAAssistant(weave.Model):
|
|
29 |
page_indices.add(int(chunk["page_idx"]))
|
30 |
|
31 |
figure_descriptions = []
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
]
|
40 |
|
41 |
system_prompt = """
|
42 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
@@ -46,5 +104,5 @@ class MedQAAssistant(weave.Model):
|
|
46 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
47 |
)
|
48 |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
49 |
-
response += f"\n\n**Source:** {'Pages' if len(
|
50 |
return response
|
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
from ..retrieval import SimilarityMetric
|
|
|
6 |
|
7 |
|
8 |
class MedQAAssistant(weave.Model):
|
9 |
+
"""
|
10 |
+
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
|
11 |
+
language model client, a retriever model, and a figure annotator.
|
12 |
+
|
13 |
+
!!! example "Usage Example"
|
14 |
+
```python
|
15 |
+
import weave
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
|
18 |
+
from medrag_multi_modal.assistant import (
|
19 |
+
FigureAnnotatorFromPageImage,
|
20 |
+
LLMClient,
|
21 |
+
MedQAAssistant,
|
22 |
+
)
|
23 |
+
from medrag_multi_modal.retrieval import MedCPTRetriever
|
24 |
+
|
25 |
+
load_dotenv()
|
26 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
27 |
+
|
28 |
+
llm_client = LLMClient(model_name="gemini-1.5-flash")
|
29 |
+
|
30 |
+
retriever=MedCPTRetriever.from_wandb_artifact(
|
31 |
+
chunk_dataset_name="grays-anatomy-chunks:v0",
|
32 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
|
33 |
+
)
|
34 |
+
|
35 |
+
figure_annotator=FigureAnnotatorFromPageImage(
|
36 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
37 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
38 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
39 |
+
)
|
40 |
+
medqa_assistant = MedQAAssistant(
|
41 |
+
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
|
42 |
+
)
|
43 |
+
medqa_assistant.predict(query="What is ribosome?")
|
44 |
+
```
|
45 |
+
|
46 |
+
Args:
|
47 |
+
llm_client (LLMClient): The language model client used to generate responses.
|
48 |
+
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
49 |
+
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
50 |
+
top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
|
51 |
+
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
52 |
+
"""
|
53 |
|
54 |
llm_client: LLMClient
|
55 |
retriever: weave.Model
|
|
|
58 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
59 |
|
60 |
@weave.op()
|
61 |
+
def predict(self, query: str) -> str:
|
62 |
+
"""
|
63 |
+
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
64 |
+
from a medical document and using a language model to generate the final response.
|
65 |
+
|
66 |
+
This function performs the following steps:
|
67 |
+
1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
|
68 |
+
2. Extracts the text and page indices from the retrieved chunks.
|
69 |
+
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
70 |
+
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
|
71 |
+
5. Uses the language model client to generate a response based on the constructed prompts.
|
72 |
+
6. Appends the source information (page numbers) to the generated response.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
query (str): The medical query to be answered.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
str: The generated response to the query, including source information.
|
79 |
+
"""
|
80 |
retrieved_chunks = self.retriever.predict(
|
81 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
82 |
)
|
|
|
88 |
page_indices.add(int(chunk["page_idx"]))
|
89 |
|
90 |
figure_descriptions = []
|
91 |
+
for page_idx in page_indices:
|
92 |
+
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
|
93 |
+
page_idx
|
94 |
+
]
|
95 |
+
figure_descriptions += [
|
96 |
+
item["figure_description"] for item in figure_annotations
|
97 |
+
]
|
|
|
98 |
|
99 |
system_prompt = """
|
100 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
|
|
104 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
105 |
)
|
106 |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
107 |
+
response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
108 |
return response
|