geekyrakshit commited on
Commit
fd0aa67
·
1 Parent(s): 7302c8f

add: FigureAnnotatorFromPageImage.extract_structured_output

Browse files
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -4,20 +4,31 @@ from typing import Union
4
  import cv2
5
  import weave
6
  from PIL import Image
 
7
  from rich.progress import track
8
 
9
  from ..utils import get_wandb_artifact, read_jsonl_file
10
  from .llm_client import LLMClient
11
 
12
 
 
 
 
 
 
 
 
 
 
13
  class FigureAnnotatorFromPageImage(weave.Model):
14
- llm_client: LLMClient
 
15
 
16
  @weave.op()
17
  def annotate_figures(
18
  self, page_image: Image.Image
19
  ) -> dict[str, Union[Image.Image, str]]:
20
- annotation = self.llm_client.predict(
21
  system_prompt="""
22
  You are an expert in the domain of scientific textbooks, especially medical texts.
23
  You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
@@ -43,16 +54,27 @@ Here are some clues you need to follow:
43
  )
44
  return {"page_image": page_image, "annotations": annotation}
45
 
 
 
 
 
 
 
 
 
46
  @weave.op()
47
  def predict(self, image_artifact_address: str):
48
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
49
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
50
  annotations = []
51
  for item in track(metadata, description="Annotating images:"):
52
- page_image = cv2.imread(
53
- os.path.join(artifact_dir, f"page{item['page_idx']}.png")
54
- )
55
  page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
56
  page_image = Image.fromarray(page_image)
57
- annotations.append(self.annotate_figures(page_image=page_image))
 
 
 
 
58
  return annotations
 
4
  import cv2
5
  import weave
6
  from PIL import Image
7
+ from pydantic import BaseModel
8
  from rich.progress import track
9
 
10
  from ..utils import get_wandb_artifact, read_jsonl_file
11
  from .llm_client import LLMClient
12
 
13
 
14
+ class FigureAnnotation(BaseModel):
15
+ figure_id: str
16
+ figure_description: str
17
+
18
+
19
+ class FigureAnnotations(BaseModel):
20
+ annotations: list[FigureAnnotation]
21
+
22
+
23
  class FigureAnnotatorFromPageImage(weave.Model):
24
+ figure_extraction_llm_client: LLMClient
25
+ structured_output_llm_client: LLMClient
26
 
27
  @weave.op()
28
  def annotate_figures(
29
  self, page_image: Image.Image
30
  ) -> dict[str, Union[Image.Image, str]]:
31
+ annotation = self.figure_extraction_llm_client.predict(
32
  system_prompt="""
33
  You are an expert in the domain of scientific textbooks, especially medical texts.
34
  You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
 
54
  )
55
  return {"page_image": page_image, "annotations": annotation}
56
 
57
+ @weave.op
58
+ def extract_structured_output(self, annotations: str) -> FigureAnnotations:
59
+ return self.structured_output_llm_client.predict(
60
+ system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
61
+ user_prompt=[annotations],
62
+ schema=FigureAnnotations,
63
+ )
64
+
65
  @weave.op()
66
  def predict(self, image_artifact_address: str):
67
  artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
68
  metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
69
  annotations = []
70
  for item in track(metadata, description="Annotating images:"):
71
+ page_image_file = os.path.join(artifact_dir, f"page{item['page_idx']}.png")
72
+ page_image = cv2.imread(page_image_file)
 
73
  page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
74
  page_image = Image.fromarray(page_image)
75
+ figure_extracted_annotations = self.annotate_figures(page_image=page_image)
76
+ figure_extracted_annotations["annotations"] = self.extract_structured_output(
77
+ figure_extracted_annotations["annotations"]
78
+ ).model_dump()
79
+ annotations.append(figure_extracted_annotations)
80
  return annotations
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -12,6 +12,7 @@ from ..utils import base64_encode_image
12
  class ClientType(str, Enum):
13
  GEMINI = "gemini"
14
  MISTRAL = "mistral"
 
15
 
16
 
17
  GOOGLE_MODELS = [
@@ -54,6 +55,8 @@ MISTRAL_MODELS = [
54
  "open-mixtral-8x22b",
55
  ]
56
 
 
 
57
 
58
  class LLMClient(weave.Model):
59
  model_name: str
@@ -65,6 +68,8 @@ class LLMClient(weave.Model):
65
  client_type = ClientType.GEMINI
66
  elif model_name in MISTRAL_MODELS:
67
  client_type = ClientType.MISTRAL
 
 
68
  else:
69
  raise ValueError(f"Invalid model name: {model_name}")
70
  super().__init__(model_name=model_name, client_type=client_type)
@@ -139,6 +144,51 @@ class LLMClient(weave.Model):
139
  )
140
  return response.choices[0].message.content
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  @weave.op()
143
  def predict(
144
  self,
@@ -150,5 +200,7 @@ class LLMClient(weave.Model):
150
  return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
151
  elif self.client_type == ClientType.MISTRAL:
152
  return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
 
 
153
  else:
154
  raise ValueError(f"Invalid client type: {self.client_type}")
 
12
  class ClientType(str, Enum):
13
  GEMINI = "gemini"
14
  MISTRAL = "mistral"
15
+ OPENAI = "openai"
16
 
17
 
18
  GOOGLE_MODELS = [
 
55
  "open-mixtral-8x22b",
56
  ]
57
 
58
+ OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
59
+
60
 
61
  class LLMClient(weave.Model):
62
  model_name: str
 
68
  client_type = ClientType.GEMINI
69
  elif model_name in MISTRAL_MODELS:
70
  client_type = ClientType.MISTRAL
71
+ elif model_name in OPENAI_MODELS:
72
+ client_type = ClientType.OPENAI
73
  else:
74
  raise ValueError(f"Invalid model name: {model_name}")
75
  super().__init__(model_name=model_name, client_type=client_type)
 
144
  )
145
  return response.choices[0].message.content
146
 
147
+ @weave.op()
148
+ def execute_openai_sdk(
149
+ self,
150
+ user_prompt: Union[str, list[str]],
151
+ system_prompt: Optional[Union[str, list[str]]] = None,
152
+ schema: Optional[Any] = None,
153
+ ) -> Union[str, Any]:
154
+ from openai import OpenAI
155
+
156
+ system_prompt = (
157
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
158
+ )
159
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
160
+
161
+ system_messages = [
162
+ {"role": "system", "content": prompt} for prompt in system_prompt
163
+ ]
164
+ user_messages = []
165
+ for prompt in user_prompt:
166
+ if isinstance(prompt, Image.Image):
167
+ user_messages.append(
168
+ {
169
+ "type": "image_url",
170
+ "image_url": {
171
+ "url": base64_encode_image(prompt, "image/png"),
172
+ },
173
+ },
174
+ )
175
+ else:
176
+ user_messages.append({"type": "text", "text": prompt})
177
+ messages = system_messages + [{"role": "user", "content": user_messages}]
178
+
179
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
180
+
181
+ if schema is None:
182
+ completion = client.chat.completions.create(
183
+ model=self.model_name, messages=messages
184
+ )
185
+ return completion.choices[0].message.content
186
+
187
+ completion = weave.op()(client.beta.chat.completions.parse)(
188
+ model=self.model_name, messages=messages, response_format=schema
189
+ )
190
+ return completion.choices[0].message.parsed
191
+
192
  @weave.op()
193
  def predict(
194
  self,
 
200
  return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
201
  elif self.client_type == ClientType.MISTRAL:
202
  return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
203
+ elif self.client_type == ClientType.OPENAI:
204
+ return self.execute_openai_sdk(user_prompt, system_prompt, schema)
205
  else:
206
  raise ValueError(f"Invalid client type: {self.client_type}")
pyproject.toml CHANGED
@@ -43,6 +43,7 @@ dependencies = [
43
  "instructor>=1.6.3",
44
  "jsonlines>=4.0.0",
45
  "opencv-python>=4.10.0.84",
 
46
  ]
47
 
48
  [project.optional-dependencies]
@@ -71,6 +72,7 @@ core = [
71
  "instructor>=1.6.3",
72
  "jsonlines>=4.0.0",
73
  "opencv-python>=4.10.0.84",
 
74
  ]
75
 
76
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
43
  "instructor>=1.6.3",
44
  "jsonlines>=4.0.0",
45
  "opencv-python>=4.10.0.84",
46
+ "openai>=1.52.2",
47
  ]
48
 
49
  [project.optional-dependencies]
 
72
  "instructor>=1.6.3",
73
  "jsonlines>=4.0.0",
74
  "opencv-python>=4.10.0.84",
75
+ "openai>=1.52.2",
76
  ]
77
 
78
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]