Spaces:
Sleeping
Sleeping
Commit
·
ceaeef3
1
Parent(s):
7934a8e
update: FigureAnnotatorFromPageImage
Browse files
medrag_multi_modal/assistant/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from .figure_annotation import
|
2 |
from .llm_client import ClientType, LLMClient
|
3 |
from .medqa_assistant import MedQAAssistant
|
4 |
|
5 |
-
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "
|
|
|
1 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
2 |
from .llm_client import ClientType, LLMClient
|
3 |
from .medqa_assistant import MedQAAssistant
|
4 |
|
5 |
+
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
|
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
@@ -10,7 +10,7 @@ from ..utils import get_wandb_artifact, read_jsonl_file
|
|
10 |
from .llm_client import LLMClient
|
11 |
|
12 |
|
13 |
-
class
|
14 |
llm_client: LLMClient
|
15 |
|
16 |
@weave.op()
|
@@ -24,6 +24,7 @@ You are presented with a page from a scientific textbook.
|
|
24 |
You are to first identify the number of figures in the image.
|
25 |
Then you are to identify the figure IDs associated with each figure in the image.
|
26 |
Then, you are to extract the exact figure descriptions from the image.
|
|
|
27 |
|
28 |
Here are some clues you need to follow:
|
29 |
1. Figure IDs are unique identifiers for each figure in the image.
|
@@ -33,6 +34,8 @@ Here are some clues you need to follow:
|
|
33 |
5. The text in the image is written in English and is present in a two-column format.
|
34 |
6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
|
35 |
7. There might be multiple figures present in the image.
|
|
|
|
|
36 |
""",
|
37 |
user_prompt=[page_image],
|
38 |
)
|
|
|
10 |
from .llm_client import LLMClient
|
11 |
|
12 |
|
13 |
+
class FigureAnnotatorFromPageImage(weave.Model):
|
14 |
llm_client: LLMClient
|
15 |
|
16 |
@weave.op()
|
|
|
24 |
You are to first identify the number of figures in the image.
|
25 |
Then you are to identify the figure IDs associated with each figure in the image.
|
26 |
Then, you are to extract the exact figure descriptions from the image.
|
27 |
+
You need to output the figure IDs and descriptions in a structured manner as a JSON object.
|
28 |
|
29 |
Here are some clues you need to follow:
|
30 |
1. Figure IDs are unique identifiers for each figure in the image.
|
|
|
34 |
5. The text in the image is written in English and is present in a two-column format.
|
35 |
6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
|
36 |
7. There might be multiple figures present in the image.
|
37 |
+
8. The figures may or may not have a distinct border against a white background.
|
38 |
+
9. There might be multiple figures present in the image. You are to carefully identify all the figures in the image.
|
39 |
""",
|
40 |
user_prompt=[page_image],
|
41 |
)
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
@@ -14,11 +14,59 @@ class ClientType(str, Enum):
|
|
14 |
MISTRAL = "mistral"
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
class LLMClient(weave.Model):
|
18 |
model_name: str
|
19 |
-
client_type: ClientType
|
20 |
|
21 |
-
def __init__(self, model_name: str, client_type: ClientType):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
super().__init__(model_name=model_name, client_type=client_type)
|
23 |
|
24 |
@weave.op()
|
|
|
14 |
MISTRAL = "mistral"
|
15 |
|
16 |
|
17 |
+
GOOGLE_MODELS = [
|
18 |
+
"gemini-1.0-pro-latest",
|
19 |
+
"gemini-1.0-pro",
|
20 |
+
"gemini-pro",
|
21 |
+
"gemini-1.0-pro-001",
|
22 |
+
"gemini-1.0-pro-vision-latest",
|
23 |
+
"gemini-pro-vision",
|
24 |
+
"gemini-1.5-pro-latest",
|
25 |
+
"gemini-1.5-pro-001",
|
26 |
+
"gemini-1.5-pro-002",
|
27 |
+
"gemini-1.5-pro",
|
28 |
+
"gemini-1.5-pro-exp-0801",
|
29 |
+
"gemini-1.5-pro-exp-0827",
|
30 |
+
"gemini-1.5-flash-latest",
|
31 |
+
"gemini-1.5-flash-001",
|
32 |
+
"gemini-1.5-flash-001-tuning",
|
33 |
+
"gemini-1.5-flash",
|
34 |
+
"gemini-1.5-flash-exp-0827",
|
35 |
+
"gemini-1.5-flash-002",
|
36 |
+
"gemini-1.5-flash-8b",
|
37 |
+
"gemini-1.5-flash-8b-001",
|
38 |
+
"gemini-1.5-flash-8b-latest",
|
39 |
+
"gemini-1.5-flash-8b-exp-0827",
|
40 |
+
"gemini-1.5-flash-8b-exp-0924",
|
41 |
+
]
|
42 |
+
|
43 |
+
MISTRAL_MODELS = [
|
44 |
+
"ministral-3b-latest",
|
45 |
+
"ministral-8b-latest",
|
46 |
+
"mistral-large-latest",
|
47 |
+
"mistral-small-latest",
|
48 |
+
"codestral-latest",
|
49 |
+
"pixtral-12b-2409",
|
50 |
+
"open-mistral-nemo",
|
51 |
+
"open-codestral-mamba",
|
52 |
+
"open-mistral-7b",
|
53 |
+
"open-mixtral-8x7b",
|
54 |
+
"open-mixtral-8x22b",
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
class LLMClient(weave.Model):
|
59 |
model_name: str
|
60 |
+
client_type: Optional[ClientType]
|
61 |
|
62 |
+
def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
|
63 |
+
if client_type is None:
|
64 |
+
if model_name in GOOGLE_MODELS:
|
65 |
+
client_type = ClientType.GEMINI
|
66 |
+
elif model_name in MISTRAL_MODELS:
|
67 |
+
client_type = ClientType.MISTRAL
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Invalid model name: {model_name}")
|
70 |
super().__init__(model_name=model_name, client_type=client_type)
|
71 |
|
72 |
@weave.op()
|