Spaces:
Running
Running
Commit
·
e197ad0
1
Parent(s):
b123ef7
update: integrate FigureAnnotatorFromPageImage into MedQAAssistant
Browse files
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
@@ -92,44 +92,48 @@ Here are some clues you need to follow:
|
|
92 |
)
|
93 |
|
94 |
@weave.op()
|
95 |
-
def predict(self, image_artifact_address: str):
|
96 |
"""
|
97 |
-
Predicts figure annotations for
|
98 |
|
99 |
-
This function retrieves
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
`extract_structured_output` method and
|
105 |
|
106 |
Args:
|
107 |
-
|
|
|
108 |
|
109 |
Returns:
|
110 |
-
|
|
|
111 |
"""
|
112 |
artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
113 |
metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
|
114 |
-
annotations =
|
115 |
for item in track(metadata, description="Annotating images:"):
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
)
|
120 |
-
if len(figure_image_files) > 0:
|
121 |
-
page_image = cv2.imread(page_image_file)
|
122 |
-
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
123 |
-
page_image = Image.fromarray(page_image)
|
124 |
-
figure_extracted_annotations = self.annotate_figures(
|
125 |
-
page_image=page_image
|
126 |
)
|
127 |
-
|
128 |
-
{
|
129 |
-
"page_idx": item["page_idx"],
|
130 |
-
"annotations": self.extract_structured_output(
|
131 |
-
figure_extracted_annotations["annotations"]
|
132 |
-
).model_dump(),
|
133 |
-
}
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return annotations
|
|
|
92 |
)
|
93 |
|
94 |
@weave.op()
|
95 |
+
def predict(self, page_idx: int, image_artifact_address: str):
|
96 |
"""
|
97 |
+
Predicts figure annotations for a specific page in a document.
|
98 |
|
99 |
+
This function retrieves the artifact directory from the given image artifact address,
|
100 |
+
reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
|
101 |
+
to find the specified page index. If the page index matches, it reads the page image
|
102 |
+
and associated figure images, and then uses the `annotate_figures` method to extract
|
103 |
+
figure annotations from the page image. The extracted annotations are then structured
|
104 |
+
using the `extract_structured_output` method and returned as a dictionary.
|
105 |
|
106 |
Args:
|
107 |
+
page_idx (int): The index of the page to annotate.
|
108 |
+
image_artifact_address (str): The address of the image artifact containing the page images.
|
109 |
|
110 |
Returns:
|
111 |
+
dict: A dictionary containing the page index as the key and the extracted figure annotations
|
112 |
+
as the value.
|
113 |
"""
|
114 |
artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
115 |
metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
|
116 |
+
annotations = {}
|
117 |
for item in track(metadata, description="Annotating images:"):
|
118 |
+
if item["page_idx"] == page_idx:
|
119 |
+
page_image_file = os.path.join(
|
120 |
+
artifact_dir, f"page{item['page_idx']}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
)
|
122 |
+
figure_image_files = glob(
|
123 |
+
os.path.join(artifact_dir, f"page{item['page_idx']}_fig*.png")
|
|
|
|
|
|
|
|
|
|
|
124 |
)
|
125 |
+
if len(figure_image_files) > 0:
|
126 |
+
page_image = cv2.imread(page_image_file)
|
127 |
+
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
128 |
+
page_image = Image.fromarray(page_image)
|
129 |
+
figure_extracted_annotations = self.annotate_figures(
|
130 |
+
page_image=page_image
|
131 |
+
)
|
132 |
+
figure_extracted_annotations = self.extract_structured_output(
|
133 |
+
figure_extracted_annotations["annotations"]
|
134 |
+
).model_dump()
|
135 |
+
annotations[item["page_idx"]] = figure_extracted_annotations[
|
136 |
+
"annotations"
|
137 |
+
]
|
138 |
+
break
|
139 |
return annotations
|
medrag_multi_modal/assistant/medqa_assistant.py
CHANGED
@@ -1,6 +1,9 @@
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
from ..retrieval import SimilarityMetric
|
|
|
4 |
from .llm_client import LLMClient
|
5 |
|
6 |
|
@@ -9,11 +12,12 @@ class MedQAAssistant(weave.Model):
|
|
9 |
|
10 |
llm_client: LLMClient
|
11 |
retriever: weave.Model
|
|
|
12 |
top_k_chunks: int = 2
|
13 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
14 |
|
15 |
@weave.op()
|
16 |
-
def predict(self, query: str) -> str:
|
17 |
retrieved_chunks = self.retriever.predict(
|
18 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
19 |
)
|
@@ -23,13 +27,24 @@ class MedQAAssistant(weave.Model):
|
|
23 |
for chunk in retrieved_chunks:
|
24 |
retrieved_chunk_texts.append(chunk["text"])
|
25 |
page_indices.add(int(chunk["page_idx"]))
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
system_prompt = """
|
29 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
30 |
"""
|
31 |
response = self.llm_client.predict(
|
32 |
-
system_prompt=system_prompt,
|
|
|
33 |
)
|
|
|
34 |
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
35 |
return response
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
import weave
|
4 |
|
5 |
from ..retrieval import SimilarityMetric
|
6 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
7 |
from .llm_client import LLMClient
|
8 |
|
9 |
|
|
|
12 |
|
13 |
llm_client: LLMClient
|
14 |
retriever: weave.Model
|
15 |
+
figure_annotator: FigureAnnotatorFromPageImage
|
16 |
top_k_chunks: int = 2
|
17 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
18 |
|
19 |
@weave.op()
|
20 |
+
def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str:
|
21 |
retrieved_chunks = self.retriever.predict(
|
22 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
23 |
)
|
|
|
27 |
for chunk in retrieved_chunks:
|
28 |
retrieved_chunk_texts.append(chunk["text"])
|
29 |
page_indices.add(int(chunk["page_idx"]))
|
30 |
+
|
31 |
+
figure_descriptions = []
|
32 |
+
if image_artifact_address is not None:
|
33 |
+
for page_idx in page_indices:
|
34 |
+
figure_annotations = self.figure_annotator.predict(
|
35 |
+
page_idx=page_idx, image_artifact_address=image_artifact_address
|
36 |
+
)
|
37 |
+
figure_descriptions += [
|
38 |
+
item["figure_description"] for item in figure_annotations[page_idx]
|
39 |
+
]
|
40 |
|
41 |
system_prompt = """
|
42 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
43 |
"""
|
44 |
response = self.llm_client.predict(
|
45 |
+
system_prompt=system_prompt,
|
46 |
+
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
47 |
)
|
48 |
+
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
49 |
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
50 |
return response
|