geekyrakshit commited on
Commit
ceaeef3
·
1 Parent(s): 7934a8e

update: FigureAnnotatorFromPageImage

Browse files
medrag_multi_modal/assistant/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from .figure_annotation import FigureAnnotator
2
  from .llm_client import ClientType, LLMClient
3
  from .medqa_assistant import MedQAAssistant
4
 
5
- __all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotator"]
 
1
+ from .figure_annotation import FigureAnnotatorFromPageImage
2
  from .llm_client import ClientType, LLMClient
3
  from .medqa_assistant import MedQAAssistant
4
 
5
+ __all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -10,7 +10,7 @@ from ..utils import get_wandb_artifact, read_jsonl_file
10
  from .llm_client import LLMClient
11
 
12
 
13
- class FigureAnnotator(weave.Model):
14
  llm_client: LLMClient
15
 
16
  @weave.op()
@@ -24,6 +24,7 @@ You are presented with a page from a scientific textbook.
24
  You are to first identify the number of figures in the image.
25
  Then you are to identify the figure IDs associated with each figure in the image.
26
  Then, you are to extract the exact figure descriptions from the image.
 
27
 
28
  Here are some clues you need to follow:
29
  1. Figure IDs are unique identifiers for each figure in the image.
@@ -33,6 +34,8 @@ Here are some clues you need to follow:
33
  5. The text in the image is written in English and is present in a two-column format.
34
  6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
35
  7. There might be multiple figures present in the image.
 
 
36
  """,
37
  user_prompt=[page_image],
38
  )
 
10
  from .llm_client import LLMClient
11
 
12
 
13
+ class FigureAnnotatorFromPageImage(weave.Model):
14
  llm_client: LLMClient
15
 
16
  @weave.op()
 
24
  You are to first identify the number of figures in the image.
25
  Then you are to identify the figure IDs associated with each figure in the image.
26
  Then, you are to extract the exact figure descriptions from the image.
27
+ You need to output the figure IDs and descriptions in a structured manner as a JSON object.
28
 
29
  Here are some clues you need to follow:
30
  1. Figure IDs are unique identifiers for each figure in the image.
 
34
  5. The text in the image is written in English and is present in a two-column format.
35
  6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
36
  7. There might be multiple figures present in the image.
37
+ 8. The figures may or may not have a distinct border against a white background.
38
+ 9. There might be multiple figures present in the image. You are to carefully identify all the figures in the image.
39
  """,
40
  user_prompt=[page_image],
41
  )
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -14,11 +14,59 @@ class ClientType(str, Enum):
14
  MISTRAL = "mistral"
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class LLMClient(weave.Model):
18
  model_name: str
19
- client_type: ClientType
20
 
21
- def __init__(self, model_name: str, client_type: ClientType):
 
 
 
 
 
 
 
22
  super().__init__(model_name=model_name, client_type=client_type)
23
 
24
  @weave.op()
 
14
  MISTRAL = "mistral"
15
 
16
 
17
+ GOOGLE_MODELS = [
18
+ "gemini-1.0-pro-latest",
19
+ "gemini-1.0-pro",
20
+ "gemini-pro",
21
+ "gemini-1.0-pro-001",
22
+ "gemini-1.0-pro-vision-latest",
23
+ "gemini-pro-vision",
24
+ "gemini-1.5-pro-latest",
25
+ "gemini-1.5-pro-001",
26
+ "gemini-1.5-pro-002",
27
+ "gemini-1.5-pro",
28
+ "gemini-1.5-pro-exp-0801",
29
+ "gemini-1.5-pro-exp-0827",
30
+ "gemini-1.5-flash-latest",
31
+ "gemini-1.5-flash-001",
32
+ "gemini-1.5-flash-001-tuning",
33
+ "gemini-1.5-flash",
34
+ "gemini-1.5-flash-exp-0827",
35
+ "gemini-1.5-flash-002",
36
+ "gemini-1.5-flash-8b",
37
+ "gemini-1.5-flash-8b-001",
38
+ "gemini-1.5-flash-8b-latest",
39
+ "gemini-1.5-flash-8b-exp-0827",
40
+ "gemini-1.5-flash-8b-exp-0924",
41
+ ]
42
+
43
+ MISTRAL_MODELS = [
44
+ "ministral-3b-latest",
45
+ "ministral-8b-latest",
46
+ "mistral-large-latest",
47
+ "mistral-small-latest",
48
+ "codestral-latest",
49
+ "pixtral-12b-2409",
50
+ "open-mistral-nemo",
51
+ "open-codestral-mamba",
52
+ "open-mistral-7b",
53
+ "open-mixtral-8x7b",
54
+ "open-mixtral-8x22b",
55
+ ]
56
+
57
+
58
  class LLMClient(weave.Model):
59
  model_name: str
60
+ client_type: Optional[ClientType]
61
 
62
+ def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
63
+ if client_type is None:
64
+ if model_name in GOOGLE_MODELS:
65
+ client_type = ClientType.GEMINI
66
+ elif model_name in MISTRAL_MODELS:
67
+ client_type = ClientType.MISTRAL
68
+ else:
69
+ raise ValueError(f"Invalid model name: {model_name}")
70
  super().__init__(model_name=model_name, client_type=client_type)
71
 
72
  @weave.op()