geekyrakshit commited on
Commit
e197ad0
·
1 Parent(s): b123ef7

update: integrate FigureAnnotatorFromPageImage into MedQAAssistant

Browse files
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -92,44 +92,48 @@ Here are some clues you need to follow:
92
  )
93
 
94
  @weave.op()
95
- def predict(self, image_artifact_address: str):
96
  """
97
- Predicts figure annotations for images in a given artifact directory.
98
 
99
- This function retrieves an artifact directory using the provided image artifact address.
100
- It reads metadata from a JSONL file in the artifact directory and iterates over each item in the metadata.
101
- For each item, it constructs the file path for the page image and checks for the presence of figure image files.
102
- If figure image files are found, it reads and converts the page image, then uses the `annotate_figures` method
103
- to extract figure annotations from the page image. The extracted annotations are then structured using the
104
- `extract_structured_output` method and appended to the annotations list.
105
 
106
  Args:
107
- image_artifact_address (str): The address of the image artifact.
 
108
 
109
  Returns:
110
- list: A list of dictionaries containing page indices and their corresponding figure annotations.
 
111
  """
112
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
113
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
114
- annotations = []
115
  for item in track(metadata, description="Annotating images:"):
116
- page_image_file = os.path.join(artifact_dir, f"page{item['page_idx']}.png")
117
- figure_image_files = glob(
118
- os.path.join(artifact_dir, f"page{item['page_idx']}_fig*.png")
119
- )
120
- if len(figure_image_files) > 0:
121
- page_image = cv2.imread(page_image_file)
122
- page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
123
- page_image = Image.fromarray(page_image)
124
- figure_extracted_annotations = self.annotate_figures(
125
- page_image=page_image
126
  )
127
- annotations.append(
128
- {
129
- "page_idx": item["page_idx"],
130
- "annotations": self.extract_structured_output(
131
- figure_extracted_annotations["annotations"]
132
- ).model_dump(),
133
- }
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return annotations
 
92
  )
93
 
94
  @weave.op()
95
+ def predict(self, page_idx: int, image_artifact_address: str):
96
  """
97
+ Predicts figure annotations for a specific page in a document.
98
 
99
+ This function retrieves the artifact directory from the given image artifact address,
100
+ reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
101
+ to find the specified page index. If the page index matches, it reads the page image
102
+ and associated figure images, and then uses the `annotate_figures` method to extract
103
+ figure annotations from the page image. The extracted annotations are then structured
104
+ using the `extract_structured_output` method and returned as a dictionary.
105
 
106
  Args:
107
+ page_idx (int): The index of the page to annotate.
108
+ image_artifact_address (str): The address of the image artifact containing the page images.
109
 
110
  Returns:
111
+ dict: A dictionary containing the page index as the key and the extracted figure annotations
112
+ as the value.
113
  """
114
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
115
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
116
+ annotations = {}
117
  for item in track(metadata, description="Annotating images:"):
118
+ if item["page_idx"] == page_idx:
119
+ page_image_file = os.path.join(
120
+ artifact_dir, f"page{item['page_idx']}.png"
 
 
 
 
 
 
 
121
  )
122
+ figure_image_files = glob(
123
+ os.path.join(artifact_dir, f"page{item['page_idx']}_fig*.png")
 
 
 
 
 
124
  )
125
+ if len(figure_image_files) > 0:
126
+ page_image = cv2.imread(page_image_file)
127
+ page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
128
+ page_image = Image.fromarray(page_image)
129
+ figure_extracted_annotations = self.annotate_figures(
130
+ page_image=page_image
131
+ )
132
+ figure_extracted_annotations = self.extract_structured_output(
133
+ figure_extracted_annotations["annotations"]
134
+ ).model_dump()
135
+ annotations[item["page_idx"]] = figure_extracted_annotations[
136
+ "annotations"
137
+ ]
138
+ break
139
  return annotations
medrag_multi_modal/assistant/medqa_assistant.py CHANGED
@@ -1,6 +1,9 @@
 
 
1
  import weave
2
 
3
  from ..retrieval import SimilarityMetric
 
4
  from .llm_client import LLMClient
5
 
6
 
@@ -9,11 +12,12 @@ class MedQAAssistant(weave.Model):
9
 
10
  llm_client: LLMClient
11
  retriever: weave.Model
 
12
  top_k_chunks: int = 2
13
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
14
 
15
  @weave.op()
16
- def predict(self, query: str) -> str:
17
  retrieved_chunks = self.retriever.predict(
18
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
19
  )
@@ -23,13 +27,24 @@ class MedQAAssistant(weave.Model):
23
  for chunk in retrieved_chunks:
24
  retrieved_chunk_texts.append(chunk["text"])
25
  page_indices.add(int(chunk["page_idx"]))
26
- page_numbers = ", ".join(map(str, page_indices))
 
 
 
 
 
 
 
 
 
27
 
28
  system_prompt = """
29
  You are an expert in medical science. You are given a query and a list of chunks from a medical document.
30
  """
31
  response = self.llm_client.predict(
32
- system_prompt=system_prompt, user_prompt=[query, *retrieved_chunk_texts]
 
33
  )
 
34
  response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
35
  return response
 
1
+ from typing import Optional
2
+
3
  import weave
4
 
5
  from ..retrieval import SimilarityMetric
6
+ from .figure_annotation import FigureAnnotatorFromPageImage
7
  from .llm_client import LLMClient
8
 
9
 
 
12
 
13
  llm_client: LLMClient
14
  retriever: weave.Model
15
+ figure_annotator: FigureAnnotatorFromPageImage
16
  top_k_chunks: int = 2
17
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
18
 
19
  @weave.op()
20
+ def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str:
21
  retrieved_chunks = self.retriever.predict(
22
  query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
23
  )
 
27
  for chunk in retrieved_chunks:
28
  retrieved_chunk_texts.append(chunk["text"])
29
  page_indices.add(int(chunk["page_idx"]))
30
+
31
+ figure_descriptions = []
32
+ if image_artifact_address is not None:
33
+ for page_idx in page_indices:
34
+ figure_annotations = self.figure_annotator.predict(
35
+ page_idx=page_idx, image_artifact_address=image_artifact_address
36
+ )
37
+ figure_descriptions += [
38
+ item["figure_description"] for item in figure_annotations[page_idx]
39
+ ]
40
 
41
  system_prompt = """
42
  You are an expert in medical science. You are given a query and a list of chunks from a medical document.
43
  """
44
  response = self.llm_client.predict(
45
+ system_prompt=system_prompt,
46
+ user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
47
  )
48
+ page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
49
  response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
50
  return response