Spaces:
Running
on
Zero
Running
on
Zero
smellslikeml
commited on
Commit
•
3e7a2b7
1
Parent(s):
1afbcbd
update
Browse files- README.md +6 -5
- app.py +197 -126
- barchart.jpeg +0 -0
- examples/warehouse_rgb.jpg +0 -0
- models/conversation.py +450 -0
- models/mllava/__init__.py +4 -0
- models/mllava/configuration_llava.py +134 -0
- models/mllava/modeling_llava.py +770 -0
- models/mllava/processing_llava.py +381 -0
- models/mllava/utils.py +188 -0
- requirements.txt +5 -4
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: SpaceMantis
|
3 |
+
emoji: 🌌
|
4 |
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.24.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Multimodal Language Model specialized for spatial reasoning
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,132 +1,203 @@
|
|
1 |
-
"""SpaceLlama3.1 demo gradio app."""
|
2 |
-
|
3 |
-
import datetime
|
4 |
-
import logging
|
5 |
-
import os
|
6 |
-
|
7 |
import gradio as gr
|
8 |
-
import
|
9 |
-
import
|
10 |
-
|
11 |
-
from
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with gr.Blocks() as demo:
|
89 |
-
|
90 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
with gr.Row():
|
92 |
-
|
93 |
-
|
94 |
-
prompt = gr.Textbox(value="", label="Prompt", visible=True)
|
95 |
-
model_info = gr.Markdown(label="Model Info")
|
96 |
-
run = gr.Button("Run", variant="primary")
|
97 |
-
clear = gr.Button("Clear")
|
98 |
-
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)
|
99 |
-
|
100 |
-
# Button event handlers
|
101 |
-
run.click(
|
102 |
-
fn=compute,
|
103 |
-
inputs=[image, prompt],
|
104 |
-
outputs=highlighted_text, # Ensure this is the right output component
|
105 |
-
)
|
106 |
-
clear.click(fn=reset, inputs=None, outputs=[prompt, image])
|
107 |
-
|
108 |
-
# Status
|
109 |
-
status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
|
110 |
-
gpu_kind = gr.Markdown(f"GPU=?")
|
111 |
-
demo.load(
|
112 |
-
fn=lambda: f"Model `{MODEL_LOCATION}` loaded.", # Ensure the output is a string
|
113 |
-
inputs=None,
|
114 |
-
outputs=model_info,
|
115 |
-
)
|
116 |
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
-
|
121 |
-
|
122 |
-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
123 |
-
)
|
124 |
-
|
125 |
-
for k, v in os.environ.items():
|
126 |
-
logging.info('environ["%s"] = %r', k, v)
|
127 |
-
|
128 |
-
# Load the model once globally
|
129 |
-
load_model()
|
130 |
-
|
131 |
-
create_app().queue().launch()
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from PIL import Image
|
6 |
+
import functools
|
7 |
+
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration, chat_mllava
|
8 |
+
from models.conversation import conv_templates
|
9 |
+
from typing import List
|
10 |
+
processor = MLlavaProcessor.from_pretrained("remyxai/SpaceMantis")
|
11 |
+
model = LlavaForConditionalGeneration.from_pretrained("remyxai/SpaceMantis")
|
12 |
+
conv_template = conv_templates['llama_3']
|
13 |
+
|
14 |
+
@spaces.GPU
|
15 |
+
def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
|
16 |
+
global processor, model
|
17 |
+
model = model.to("cuda")
|
18 |
+
if not images:
|
19 |
+
images = None
|
20 |
+
for text, history in chat_mllava_stream(text, images, model, processor, history=history, **kwargs):
|
21 |
+
yield text
|
22 |
+
|
23 |
+
return text
|
24 |
+
|
25 |
+
@spaces.GPU
|
26 |
+
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
|
27 |
+
global processor, model
|
28 |
+
model = model.to("cuda")
|
29 |
+
if not images:
|
30 |
+
images = None
|
31 |
+
generated_text, history = chat_mllava(text, images, model, processor, history=history, **kwargs)
|
32 |
+
return generated_text
|
33 |
+
|
34 |
+
def enable_next_image(uploaded_images, image):
|
35 |
+
uploaded_images.append(image)
|
36 |
+
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
|
37 |
+
|
38 |
+
def add_message(history, message):
|
39 |
+
if message["files"]:
|
40 |
+
for file in message["files"]:
|
41 |
+
history.append([(file,), None])
|
42 |
+
if message["text"]:
|
43 |
+
history.append([message["text"], None])
|
44 |
+
return history, gr.MultimodalTextbox(value=None)
|
45 |
+
|
46 |
+
def print_like_dislike(x: gr.LikeData):
|
47 |
+
print(x.index, x.value, x.liked)
|
48 |
+
|
49 |
+
|
50 |
+
def get_chat_history(history):
|
51 |
+
chat_history = []
|
52 |
+
user_role = conv_template.roles[0]
|
53 |
+
assistant_role = conv_template.roles[1]
|
54 |
+
for i, message in enumerate(history):
|
55 |
+
if isinstance(message[0], str):
|
56 |
+
chat_history.append({"role": user_role, "text": message[0]})
|
57 |
+
if i != len(history) - 1:
|
58 |
+
assert message[1], "The bot message is not provided, internal error"
|
59 |
+
chat_history.append({"role": assistant_role, "text": message[1]})
|
60 |
+
else:
|
61 |
+
assert not message[1], "the bot message internal error, get: {}".format(message[1])
|
62 |
+
chat_history.append({"role": assistant_role, "text": ""})
|
63 |
+
return chat_history
|
64 |
+
|
65 |
+
|
66 |
+
def get_chat_images(history):
|
67 |
+
images = []
|
68 |
+
for message in history:
|
69 |
+
if isinstance(message[0], tuple):
|
70 |
+
images.extend(message[0])
|
71 |
+
return images
|
72 |
+
|
73 |
+
|
74 |
+
def bot(history):
|
75 |
+
print(history)
|
76 |
+
cur_messages = {"text": "", "images": []}
|
77 |
+
for message in history[::-1]:
|
78 |
+
if message[1]:
|
79 |
+
break
|
80 |
+
if isinstance(message[0], str):
|
81 |
+
cur_messages["text"] = message[0] + " " + cur_messages["text"]
|
82 |
+
elif isinstance(message[0], tuple):
|
83 |
+
cur_messages["images"].extend(message[0])
|
84 |
+
cur_messages["text"] = cur_messages["text"].strip()
|
85 |
+
cur_messages["images"] = cur_messages["images"][::-1]
|
86 |
+
if not cur_messages["text"]:
|
87 |
+
raise gr.Error("Please enter a message")
|
88 |
+
if cur_messages['text'].count("<image>") < len(cur_messages['images']):
|
89 |
+
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
|
90 |
+
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
|
91 |
+
history[-1][0] = cur_messages["text"]
|
92 |
+
if cur_messages['text'].count("<image>") > len(cur_messages['images']):
|
93 |
+
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
|
94 |
+
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
|
95 |
+
history[-1][0] = cur_messages["text"]
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
chat_history = get_chat_history(history)
|
100 |
+
chat_images = get_chat_images(history)
|
101 |
+
|
102 |
+
generation_kwargs = {
|
103 |
+
"max_new_tokens": 4096,
|
104 |
+
"num_beams": 1,
|
105 |
+
"do_sample": False
|
106 |
+
}
|
107 |
+
|
108 |
+
response = generate_stream(None, chat_images, chat_history, **generation_kwargs)
|
109 |
+
for _output in response:
|
110 |
+
history[-1][1] = _output
|
111 |
+
time.sleep(0.05)
|
112 |
+
yield history
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
def build_demo():
|
117 |
with gr.Blocks() as demo:
|
118 |
+
|
119 |
+
gr.Markdown(""" # SpaceMantis
|
120 |
+
Mantis is a multimodal conversational AI model fine-tuned from [Mantis-8B-siglip-llama3](https://huggingface.co/remyxai/SpaceMantis/blob/main/TIGER-Lab/Mantis-8B-siglip-llama3) for enhanced spatial reasoning. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.
|
121 |
+
|
122 |
+
### [Github](https://github.com/remyxai/VQASynth) | [Model](https://huggingface.co/remyxai/SpaceMantis) | [Dataset](https://huggingface.co/datasets/remyxai/mantis-spacellava)
|
123 |
+
""")
|
124 |
+
|
125 |
+
gr.Markdown("""## Chat with SpaceMantis
|
126 |
+
SpaceMantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
|
127 |
+
The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
|
128 |
+
(The model currently serving is [🤗 remyxai/SpaceMantis](https://huggingface.co/remyxai/SpaceMantis))
|
129 |
+
""")
|
130 |
+
|
131 |
+
chatbot = gr.Chatbot(line_breaks=True)
|
132 |
+
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
|
133 |
+
|
134 |
+
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
135 |
+
|
136 |
+
"""
|
137 |
+
with gr.Accordion(label='Advanced options', open=False):
|
138 |
+
temperature = gr.Slider(
|
139 |
+
label='Temperature',
|
140 |
+
minimum=0.1,
|
141 |
+
maximum=2.0,
|
142 |
+
step=0.1,
|
143 |
+
value=0.2,
|
144 |
+
interactive=True
|
145 |
+
)
|
146 |
+
top_p = gr.Slider(
|
147 |
+
label='Top-p',
|
148 |
+
minimum=0.05,
|
149 |
+
maximum=1.0,
|
150 |
+
step=0.05,
|
151 |
+
value=1.0,
|
152 |
+
interactive=True
|
153 |
+
)
|
154 |
+
"""
|
155 |
+
|
156 |
+
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
|
157 |
+
|
158 |
+
chatbot.like(print_like_dislike, None, None)
|
159 |
+
|
160 |
with gr.Row():
|
161 |
+
send_button = gr.Button("Send")
|
162 |
+
clear_button = gr.ClearButton([chatbot, chat_input])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
send_button.click(
|
165 |
+
add_message, [chatbot, chat_input], [chatbot, chat_input]
|
166 |
+
).then(
|
167 |
+
bot, chatbot, chatbot, api_name="bot_response"
|
168 |
+
)
|
169 |
+
|
170 |
+
gr.Examples(
|
171 |
+
examples=[
|
172 |
+
{
|
173 |
+
"text": "Give me the height of the man in the red hat in feet.",
|
174 |
+
"files": ["./examples/warehouse_rgb.jpg"]
|
175 |
+
},
|
176 |
+
],
|
177 |
+
inputs=[chat_input],
|
178 |
+
)
|
179 |
+
|
180 |
+
gr.Markdown("""
|
181 |
+
## Citation
|
182 |
+
```
|
183 |
+
@article{chen2024spatialvlm,
|
184 |
+
title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
|
185 |
+
author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
|
186 |
+
journal = {arXiv preprint arXiv:2401.12168},
|
187 |
+
year = {2024},
|
188 |
+
url = {https://arxiv.org/abs/2401.12168},
|
189 |
+
}
|
190 |
+
|
191 |
+
@article{jiang2024mantis,
|
192 |
+
title={MANTIS: Interleaved Multi-Image Instruction Tuning},
|
193 |
+
author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
|
194 |
+
journal={arXiv preprint arXiv:2405.01483},
|
195 |
+
year={2024}
|
196 |
+
}
|
197 |
+
```""")
|
198 |
+
return demo
|
199 |
+
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
+
demo = build_demo()
|
203 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
barchart.jpeg
ADDED
examples/warehouse_rgb.jpg
ADDED
models/conversation.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
LLAMA_3 = auto()
|
14 |
+
MFuyu = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
system: str
|
21 |
+
roles: List[str]
|
22 |
+
messages: List[List[str]]
|
23 |
+
offset: int
|
24 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
25 |
+
sep: str = "###"
|
26 |
+
sep2: str = None
|
27 |
+
version: str = "Unknown"
|
28 |
+
|
29 |
+
skip_next: bool = False
|
30 |
+
|
31 |
+
def get_prompt(self):
|
32 |
+
messages = self.messages
|
33 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
34 |
+
|
35 |
+
messages = self.messages.copy()
|
36 |
+
init_role, init_msg = messages[0].copy()
|
37 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
38 |
+
if 'mmtag' in self.version:
|
39 |
+
messages[0] = (init_role, init_msg)
|
40 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
41 |
+
messages.insert(1, (self.roles[1], "Received."))
|
42 |
+
else:
|
43 |
+
messages[0] = (init_role, "<image>" + init_msg)
|
44 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
45 |
+
ret = self.system + self.sep
|
46 |
+
for role, message in messages:
|
47 |
+
if message:
|
48 |
+
if type(message) is tuple:
|
49 |
+
message, _, _ = message
|
50 |
+
ret += role + ": " + message + self.sep
|
51 |
+
else:
|
52 |
+
ret += role + ":"
|
53 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
54 |
+
seps = [self.sep, self.sep2]
|
55 |
+
ret = self.system + seps[0]
|
56 |
+
for i, (role, message) in enumerate(messages):
|
57 |
+
if message:
|
58 |
+
if type(message) is tuple:
|
59 |
+
message, _, _ = message
|
60 |
+
ret += role + ": " + message + seps[i % 2]
|
61 |
+
else:
|
62 |
+
ret += role + ":"
|
63 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
64 |
+
ret = self.system + self.sep
|
65 |
+
for role, message in messages:
|
66 |
+
if message:
|
67 |
+
if type(message) is tuple:
|
68 |
+
message, _, _ = message
|
69 |
+
ret += role + message + self.sep
|
70 |
+
else:
|
71 |
+
ret += role
|
72 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
73 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
74 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
75 |
+
ret = ""
|
76 |
+
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if i == 0:
|
79 |
+
assert message, "first message should not be none"
|
80 |
+
assert role == self.roles[0], "first message should come from user"
|
81 |
+
if message:
|
82 |
+
if type(message) is tuple:
|
83 |
+
message, _, _ = message
|
84 |
+
if i == 0: message = wrap_sys(self.system) + message
|
85 |
+
if i % 2 == 0:
|
86 |
+
message = wrap_inst(message)
|
87 |
+
ret += self.sep + message
|
88 |
+
else:
|
89 |
+
ret += " " + message + " " + self.sep2
|
90 |
+
else:
|
91 |
+
ret += ""
|
92 |
+
ret = ret.lstrip(self.sep)
|
93 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
94 |
+
ret = self.system + self.sep
|
95 |
+
for role, message in messages:
|
96 |
+
if message:
|
97 |
+
if type(message) is tuple:
|
98 |
+
message, _, _ = message
|
99 |
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + message + self.sep
|
100 |
+
else:
|
101 |
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
102 |
+
elif self.sep_style == SeparatorStyle.MFuyu:
|
103 |
+
seps = [self.sep, self.sep2]
|
104 |
+
ret = self.system + "\n"
|
105 |
+
for i, (role, message) in enumerate(messages):
|
106 |
+
if message:
|
107 |
+
if type(message) is tuple:
|
108 |
+
message, _, _ = message
|
109 |
+
ret += role + ": " + message + seps[i % 2]
|
110 |
+
else:
|
111 |
+
ret += role + ":"
|
112 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
113 |
+
seps = [self.sep, self.sep2]
|
114 |
+
ret = self.system
|
115 |
+
for i, (role, message) in enumerate(messages):
|
116 |
+
if message:
|
117 |
+
if type(message) is tuple:
|
118 |
+
message, _, _ = message
|
119 |
+
ret += message + seps[i % 2]
|
120 |
+
else:
|
121 |
+
ret += ""
|
122 |
+
else:
|
123 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
124 |
+
|
125 |
+
return ret
|
126 |
+
|
127 |
+
def append_message(self, role, message):
|
128 |
+
self.messages.append([role, message])
|
129 |
+
|
130 |
+
def get_images(self, return_pil=False):
|
131 |
+
images = []
|
132 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
133 |
+
if i % 2 == 0:
|
134 |
+
if type(msg) is tuple:
|
135 |
+
import base64
|
136 |
+
from io import BytesIO
|
137 |
+
from PIL import Image
|
138 |
+
msg, image, image_process_mode = msg
|
139 |
+
if image_process_mode == "Pad":
|
140 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
141 |
+
width, height = pil_img.size
|
142 |
+
if width == height:
|
143 |
+
return pil_img
|
144 |
+
elif width > height:
|
145 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
146 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
147 |
+
return result
|
148 |
+
else:
|
149 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
150 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
151 |
+
return result
|
152 |
+
image = expand2square(image)
|
153 |
+
elif image_process_mode in ["Default", "Crop"]:
|
154 |
+
pass
|
155 |
+
elif image_process_mode == "Resize":
|
156 |
+
image = image.resize((336, 336))
|
157 |
+
else:
|
158 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
159 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
160 |
+
aspect_ratio = max_hw / min_hw
|
161 |
+
max_len, min_len = 800, 400
|
162 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
163 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
164 |
+
W, H = image.size
|
165 |
+
if longest_edge != max(image.size):
|
166 |
+
if H > W:
|
167 |
+
H, W = longest_edge, shortest_edge
|
168 |
+
else:
|
169 |
+
H, W = shortest_edge, longest_edge
|
170 |
+
image = image.resize((W, H))
|
171 |
+
if return_pil:
|
172 |
+
images.append(image)
|
173 |
+
else:
|
174 |
+
buffered = BytesIO()
|
175 |
+
image.save(buffered, format="PNG")
|
176 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
177 |
+
images.append(img_b64_str)
|
178 |
+
return images
|
179 |
+
|
180 |
+
def to_gradio_chatbot(self):
|
181 |
+
ret = []
|
182 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
183 |
+
if i % 2 == 0:
|
184 |
+
if type(msg) is tuple:
|
185 |
+
import base64
|
186 |
+
from io import BytesIO
|
187 |
+
msg, image, image_process_mode = msg
|
188 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
189 |
+
aspect_ratio = max_hw / min_hw
|
190 |
+
max_len, min_len = 800, 400
|
191 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
192 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
193 |
+
W, H = image.size
|
194 |
+
if H > W:
|
195 |
+
H, W = longest_edge, shortest_edge
|
196 |
+
else:
|
197 |
+
H, W = shortest_edge, longest_edge
|
198 |
+
image = image.resize((W, H))
|
199 |
+
buffered = BytesIO()
|
200 |
+
image.save(buffered, format="JPEG")
|
201 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
202 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
203 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
204 |
+
ret.append([msg, None])
|
205 |
+
else:
|
206 |
+
ret.append([msg, None])
|
207 |
+
else:
|
208 |
+
ret[-1][-1] = msg
|
209 |
+
return ret
|
210 |
+
|
211 |
+
def copy(self):
|
212 |
+
return Conversation(
|
213 |
+
system=self.system,
|
214 |
+
roles=self.roles,
|
215 |
+
messages=[[x, y] for x, y in self.messages],
|
216 |
+
offset=self.offset,
|
217 |
+
sep_style=self.sep_style,
|
218 |
+
sep=self.sep,
|
219 |
+
sep2=self.sep2,
|
220 |
+
version=self.version)
|
221 |
+
|
222 |
+
def dict(self):
|
223 |
+
if len(self.get_images()) > 0:
|
224 |
+
return {
|
225 |
+
"system": self.system,
|
226 |
+
"roles": self.roles,
|
227 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
228 |
+
"offset": self.offset,
|
229 |
+
"sep": self.sep,
|
230 |
+
"sep2": self.sep2,
|
231 |
+
}
|
232 |
+
return {
|
233 |
+
"system": self.system,
|
234 |
+
"roles": self.roles,
|
235 |
+
"messages": self.messages,
|
236 |
+
"offset": self.offset,
|
237 |
+
"sep": self.sep,
|
238 |
+
"sep2": self.sep2,
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
conv_vicuna_v0 = Conversation(
|
243 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
244 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
245 |
+
roles=("Human", "Assistant"),
|
246 |
+
messages=(
|
247 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
248 |
+
("Assistant",
|
249 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
250 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
251 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
252 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
253 |
+
"renewable and non-renewable energy sources:\n"
|
254 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
255 |
+
"energy sources are finite and will eventually run out.\n"
|
256 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
257 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
258 |
+
"and other negative effects.\n"
|
259 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
260 |
+
"have lower operational costs than non-renewable sources.\n"
|
261 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
262 |
+
"locations than non-renewable sources.\n"
|
263 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
264 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
265 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
266 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
267 |
+
),
|
268 |
+
offset=2,
|
269 |
+
sep_style=SeparatorStyle.SINGLE,
|
270 |
+
sep="###",
|
271 |
+
)
|
272 |
+
|
273 |
+
conv_vicuna_v1 = Conversation(
|
274 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
275 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
276 |
+
roles=("USER", "ASSISTANT"),
|
277 |
+
version="v1",
|
278 |
+
messages=(),
|
279 |
+
offset=0,
|
280 |
+
sep_style=SeparatorStyle.TWO,
|
281 |
+
sep=" ",
|
282 |
+
sep2="</s>",
|
283 |
+
)
|
284 |
+
|
285 |
+
conv_llama_2 = Conversation(
|
286 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
287 |
+
|
288 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
289 |
+
roles=("USER", "ASSISTANT"),
|
290 |
+
version="llama_v2",
|
291 |
+
messages=(),
|
292 |
+
offset=0,
|
293 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
294 |
+
sep="<s>",
|
295 |
+
sep2="</s>",
|
296 |
+
)
|
297 |
+
|
298 |
+
conv_llava_llama_2 = Conversation(
|
299 |
+
system="You are a helpful language and vision assistant. "
|
300 |
+
"You are able to understand the visual content that the user provides, "
|
301 |
+
"and assist the user with a variety of tasks using natural language.",
|
302 |
+
roles=("USER", "ASSISTANT"),
|
303 |
+
version="llama_v2",
|
304 |
+
messages=(),
|
305 |
+
offset=0,
|
306 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
307 |
+
sep="<s>",
|
308 |
+
sep2="</s>",
|
309 |
+
)
|
310 |
+
|
311 |
+
conv_mpt = Conversation(
|
312 |
+
system="""<|im_start|>system
|
313 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
314 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
315 |
+
version="mpt",
|
316 |
+
messages=(),
|
317 |
+
offset=0,
|
318 |
+
sep_style=SeparatorStyle.MPT,
|
319 |
+
sep="<|im_end|>",
|
320 |
+
)
|
321 |
+
|
322 |
+
conv_llava_plain = Conversation(
|
323 |
+
system="",
|
324 |
+
roles=("", ""),
|
325 |
+
messages=(
|
326 |
+
),
|
327 |
+
offset=0,
|
328 |
+
sep_style=SeparatorStyle.PLAIN,
|
329 |
+
sep="\n",
|
330 |
+
)
|
331 |
+
|
332 |
+
conv_llava_v0 = Conversation(
|
333 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
334 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
335 |
+
roles=("Human", "Assistant"),
|
336 |
+
messages=(
|
337 |
+
),
|
338 |
+
offset=0,
|
339 |
+
sep_style=SeparatorStyle.SINGLE,
|
340 |
+
sep="###",
|
341 |
+
)
|
342 |
+
|
343 |
+
conv_llava_v0_mmtag = Conversation(
|
344 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
345 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
346 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
347 |
+
roles=("Human", "Assistant"),
|
348 |
+
messages=(
|
349 |
+
),
|
350 |
+
offset=0,
|
351 |
+
sep_style=SeparatorStyle.SINGLE,
|
352 |
+
sep="###",
|
353 |
+
version="v0_mmtag",
|
354 |
+
)
|
355 |
+
|
356 |
+
conv_llava_v1 = Conversation(
|
357 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
358 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
359 |
+
roles=("USER", "ASSISTANT"),
|
360 |
+
version="v1",
|
361 |
+
messages=(),
|
362 |
+
offset=0,
|
363 |
+
sep_style=SeparatorStyle.TWO,
|
364 |
+
sep=" ",
|
365 |
+
sep2="</s>",
|
366 |
+
)
|
367 |
+
|
368 |
+
conv_llava_v1_mmtag = Conversation(
|
369 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
370 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
371 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
372 |
+
roles=("USER", "ASSISTANT"),
|
373 |
+
messages=(),
|
374 |
+
offset=0,
|
375 |
+
sep_style=SeparatorStyle.TWO,
|
376 |
+
sep=" ",
|
377 |
+
sep2="</s>",
|
378 |
+
version="v1_mmtag",
|
379 |
+
)
|
380 |
+
|
381 |
+
conv_mfuyu_v1 = Conversation(
|
382 |
+
system="You are a helpful language and vision assistant. "
|
383 |
+
"You are able to understand the visual content that the user provides, "
|
384 |
+
"and assist the user with a variety of tasks using natural language.",
|
385 |
+
roles=("USER", "ASSISTANT"),
|
386 |
+
version="v1",
|
387 |
+
messages=(),
|
388 |
+
offset=0,
|
389 |
+
sep_style=SeparatorStyle.MFuyu,
|
390 |
+
sep="<0x04>", # begin of answer token
|
391 |
+
sep2="|ENDOFTEXT|",
|
392 |
+
) # copied from conv_vicuna_v1
|
393 |
+
|
394 |
+
conv_mllava_v1_mmtag = Conversation(
|
395 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
396 |
+
"The assistant is able to understand the multiple visual contents that the user provides, and assist the user with a variety of tasks using natural language."
|
397 |
+
"Each visual content will be provided with the following format: <Image>visual content</Image>.",
|
398 |
+
roles=("USER", "ASSISTANT"),
|
399 |
+
messages=(),
|
400 |
+
offset=0,
|
401 |
+
sep_style=SeparatorStyle.SINGLE,
|
402 |
+
sep="</s>",
|
403 |
+
version="v1_mmtag",
|
404 |
+
)
|
405 |
+
|
406 |
+
conv_mllava_v1 = Conversation(
|
407 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
408 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
409 |
+
roles=("USER", "ASSISTANT"),
|
410 |
+
version="v1",
|
411 |
+
messages=(),
|
412 |
+
offset=0,
|
413 |
+
sep_style=SeparatorStyle.SINGLE,
|
414 |
+
sep="</s>",
|
415 |
+
)
|
416 |
+
|
417 |
+
conv_llama_3 = Conversation(
|
418 |
+
system="<|start_header_id|>system<|end_header_id|>\n\nYou are a pirate chatbot who always responds in pirate speak!",
|
419 |
+
roles=("user", "assistant"),
|
420 |
+
messages=(),
|
421 |
+
offset=0,
|
422 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
423 |
+
sep="<|eot_id|>",
|
424 |
+
)
|
425 |
+
|
426 |
+
default_conversation = conv_mfuyu_v1
|
427 |
+
conv_templates = {
|
428 |
+
"default": conv_vicuna_v0,
|
429 |
+
"v0": conv_vicuna_v0,
|
430 |
+
"v1": conv_vicuna_v1,
|
431 |
+
"vicuna_v1": conv_vicuna_v1,
|
432 |
+
"llama_2": conv_llama_2,
|
433 |
+
|
434 |
+
"plain": conv_llava_plain,
|
435 |
+
"v0_plain": conv_llava_plain,
|
436 |
+
"llava_v0": conv_llava_v0,
|
437 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
438 |
+
"llava_v1": conv_llava_v1,
|
439 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
440 |
+
"llava_llama_2": conv_llava_llama_2,
|
441 |
+
"llama_3": conv_llama_3,
|
442 |
+
"mllava_v1": conv_mllava_v1,
|
443 |
+
"mllava_v1_mmtag": conv_mllava_v1_mmtag,
|
444 |
+
|
445 |
+
"mpt": conv_mpt,
|
446 |
+
}
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
print(default_conversation.get_prompt())
|
models/mllava/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
|
2 |
+
from .processing_llava import MLlavaProcessor
|
3 |
+
from .configuration_llava import LlavaConfig
|
4 |
+
from .utils import chat_mllava, chat_mllava_stream
|
models/mllava/configuration_llava.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
""" Llava model configuration"""
|
15 |
+
|
16 |
+
|
17 |
+
# from ...configuration_utils import PretrainedConfig
|
18 |
+
# from ...utils import logging
|
19 |
+
# from ..auto import CONFIG_MAPPING
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
from transformers.models.auto import CONFIG_MAPPING
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
28 |
+
"llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class LlavaConfig(PretrainedConfig):
|
33 |
+
r"""
|
34 |
+
This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
|
35 |
+
Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
36 |
+
with the defaults will yield a similar configuration to that of the Llava-9B.
|
37 |
+
|
38 |
+
e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
|
39 |
+
|
40 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
41 |
+
documentation from [`PretrainedConfig`] for more information.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
vision_config (`LlavaVisionConfig`, *optional*):
|
45 |
+
Custom vision config or dict
|
46 |
+
text_config (`Union[AutoConfig, dict]`, *optional*):
|
47 |
+
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
|
48 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
49 |
+
The ignore index for the loss function.
|
50 |
+
image_token_index (`int`, *optional*, defaults to 32000):
|
51 |
+
The image token index to encode the image prompt.
|
52 |
+
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
53 |
+
The activation function used by the multimodal projector.
|
54 |
+
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
55 |
+
The feature selection strategy used to select the vision feature from the CLIP backbone.
|
56 |
+
vision_feature_layer (`int`, *optional*, defaults to -2):
|
57 |
+
The index of the layer to select the vision feature.
|
58 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
59 |
+
Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
|
60 |
+
`inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
|
61 |
+
|
62 |
+
Example:
|
63 |
+
|
64 |
+
```python
|
65 |
+
>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
|
66 |
+
|
67 |
+
>>> # Initializing a CLIP-vision config
|
68 |
+
>>> vision_config = CLIPVisionConfig()
|
69 |
+
|
70 |
+
>>> # Initializing a Llama config
|
71 |
+
>>> text_config = LlamaConfig()
|
72 |
+
|
73 |
+
>>> # Initializing a Llava llava-1.5-7b style configuration
|
74 |
+
>>> configuration = LlavaConfig(vision_config, text_config)
|
75 |
+
|
76 |
+
>>> # Initializing a model from the llava-1.5-7b style configuration
|
77 |
+
>>> model = LlavaForConditionalGeneration(configuration)
|
78 |
+
|
79 |
+
>>> # Accessing the model configuration
|
80 |
+
>>> configuration = model.config
|
81 |
+
```"""
|
82 |
+
|
83 |
+
model_type = "llava"
|
84 |
+
is_composition = False
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
vision_config=None,
|
89 |
+
text_config=None,
|
90 |
+
ignore_index=-100,
|
91 |
+
image_token_index=32000,
|
92 |
+
projector_hidden_act="gelu",
|
93 |
+
vision_feature_select_strategy="default",
|
94 |
+
vision_feature_layer=-2,
|
95 |
+
vocab_size=32000,
|
96 |
+
**kwargs,
|
97 |
+
):
|
98 |
+
self.ignore_index = ignore_index
|
99 |
+
self.image_token_index = image_token_index
|
100 |
+
self.projector_hidden_act = projector_hidden_act
|
101 |
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
102 |
+
self.vision_feature_layer = vision_feature_layer
|
103 |
+
self.vocab_size = vocab_size
|
104 |
+
|
105 |
+
self.vision_config = vision_config
|
106 |
+
|
107 |
+
if isinstance(self.vision_config, dict):
|
108 |
+
vision_config["model_type"] = (
|
109 |
+
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
110 |
+
)
|
111 |
+
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
112 |
+
elif vision_config is None:
|
113 |
+
self.vision_config = CONFIG_MAPPING["clip_vision_model"](
|
114 |
+
intermediate_size=4096,
|
115 |
+
hidden_size=1024,
|
116 |
+
patch_size=14,
|
117 |
+
image_size=336,
|
118 |
+
num_hidden_layers=24,
|
119 |
+
num_attention_heads=16,
|
120 |
+
vocab_size=32000,
|
121 |
+
projection_dim=768,
|
122 |
+
)
|
123 |
+
self.vocab_size = self.vocab_size
|
124 |
+
|
125 |
+
self.text_config = text_config
|
126 |
+
|
127 |
+
if isinstance(self.text_config, dict):
|
128 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
129 |
+
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
130 |
+
self.vocab_size = self.text_config.vocab_size
|
131 |
+
elif text_config is None:
|
132 |
+
self.text_config = CONFIG_MAPPING["llama"]()
|
133 |
+
|
134 |
+
super().__init__(**kwargs)
|
models/mllava/modeling_llava.py
ADDED
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch Llava model."""
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
# from ... import PreTrainedModel
|
24 |
+
# from ...activations import ACT2FN
|
25 |
+
# from ...cache_utils import Cache
|
26 |
+
# from ...modeling_outputs import ModelOutput
|
27 |
+
# from ...utils import (
|
28 |
+
# add_start_docstrings,
|
29 |
+
# add_start_docstrings_to_model_forward,
|
30 |
+
# logging,
|
31 |
+
# replace_return_docstrings,
|
32 |
+
# )
|
33 |
+
# from ..auto import AutoModel, AutoModelForCausalLM
|
34 |
+
|
35 |
+
from .configuration_llava import LlavaConfig
|
36 |
+
|
37 |
+
from transformers import PreTrainedModel
|
38 |
+
from transformers.activations import ACT2FN
|
39 |
+
from transformers.cache_utils import Cache
|
40 |
+
from transformers.modeling_outputs import ModelOutput
|
41 |
+
from transformers.utils import (
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
48 |
+
from .configuration_llava import LlavaConfig
|
49 |
+
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__)
|
52 |
+
|
53 |
+
_CONFIG_FOR_DOC = "LlavaConfig"
|
54 |
+
|
55 |
+
LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
56 |
+
"llava-hf/llava-1.5-7b-hf",
|
57 |
+
"llava-hf/llava-1.5-13b-hf",
|
58 |
+
"llava-hf/bakLlava-v1-hf",
|
59 |
+
# See all Llava models at https://huggingface.co/models?filter=llava
|
60 |
+
]
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
|
65 |
+
class LlavaCausalLMOutputWithPast(ModelOutput):
|
66 |
+
"""
|
67 |
+
Base class for Llava causal language model (or autoregressive) outputs.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
71 |
+
Language modeling loss (for next-token prediction).
|
72 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
73 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
74 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
75 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
76 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
77 |
+
|
78 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
79 |
+
`past_key_values` input) to speed up sequential decoding.
|
80 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
81 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
82 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
83 |
+
|
84 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
85 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
86 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
87 |
+
sequence_length)`.
|
88 |
+
|
89 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
90 |
+
heads.
|
91 |
+
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
92 |
+
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
93 |
+
sequence_length, hidden_size)`.
|
94 |
+
|
95 |
+
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
96 |
+
"""
|
97 |
+
|
98 |
+
loss: Optional[torch.FloatTensor] = None
|
99 |
+
logits: torch.FloatTensor = None
|
100 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
101 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
102 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
103 |
+
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
104 |
+
|
105 |
+
|
106 |
+
class LlavaMultiModalProjector(nn.Module):
|
107 |
+
def __init__(self, config: LlavaConfig):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
111 |
+
self.act = ACT2FN[config.projector_hidden_act]
|
112 |
+
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
113 |
+
|
114 |
+
def forward(self, image_features):
|
115 |
+
hidden_states = self.linear_1(image_features)
|
116 |
+
hidden_states = self.act(hidden_states)
|
117 |
+
hidden_states = self.linear_2(hidden_states)
|
118 |
+
return hidden_states
|
119 |
+
|
120 |
+
|
121 |
+
LLAVA_START_DOCSTRING = r"""
|
122 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
123 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
124 |
+
etc.)
|
125 |
+
|
126 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
127 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
128 |
+
and behavior.
|
129 |
+
|
130 |
+
Parameters:
|
131 |
+
config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
|
132 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
133 |
+
load the weights associated with the model, only the configuration. Check out the
|
134 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
135 |
+
"""
|
136 |
+
|
137 |
+
|
138 |
+
@add_start_docstrings(
|
139 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
140 |
+
LLAVA_START_DOCSTRING,
|
141 |
+
)
|
142 |
+
class LlavaPreTrainedModel(PreTrainedModel):
|
143 |
+
config_class = LlavaConfig
|
144 |
+
base_model_prefix = "model"
|
145 |
+
supports_gradient_checkpointing = True
|
146 |
+
_no_split_modules = ["LlavaVisionAttention"]
|
147 |
+
_skip_keys_device_placement = "past_key_values"
|
148 |
+
_supports_flash_attn_2 = True
|
149 |
+
|
150 |
+
def _init_weights(self, module):
|
151 |
+
# important: this ported version of Llava isn't meant for training from scratch - only
|
152 |
+
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
153 |
+
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
|
154 |
+
std = (
|
155 |
+
self.config.initializer_range
|
156 |
+
if hasattr(self.config, "initializer_range")
|
157 |
+
else self.config.text_config.initializer_range
|
158 |
+
)
|
159 |
+
|
160 |
+
if hasattr(module, "class_embedding"):
|
161 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
162 |
+
|
163 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
164 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
165 |
+
if module.bias is not None:
|
166 |
+
module.bias.data.zero_()
|
167 |
+
elif isinstance(module, nn.Embedding):
|
168 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
169 |
+
if module.padding_idx is not None:
|
170 |
+
module.weight.data[module.padding_idx].zero_()
|
171 |
+
|
172 |
+
@property
|
173 |
+
def _supports_sdpa(self):
|
174 |
+
"""
|
175 |
+
Retrieve language_model's attribute to check whether the model supports
|
176 |
+
SDPA or not.
|
177 |
+
"""
|
178 |
+
return self.language_model._supports_sdpa
|
179 |
+
|
180 |
+
|
181 |
+
LLAVA_INPUTS_DOCSTRING = r"""
|
182 |
+
Args:
|
183 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
184 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
185 |
+
it.
|
186 |
+
|
187 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
188 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
189 |
+
|
190 |
+
[What are input IDs?](../glossary#input-ids)
|
191 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
192 |
+
The tensors corresponding to the input images. Pixel values can be obtained using
|
193 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
194 |
+
[`CLIPImageProcessor`] for processing images).
|
195 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
196 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
197 |
+
|
198 |
+
- 1 for tokens that are **not masked**,
|
199 |
+
- 0 for tokens that are **masked**.
|
200 |
+
|
201 |
+
[What are attention masks?](../glossary#attention-mask)
|
202 |
+
|
203 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
204 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
205 |
+
|
206 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
207 |
+
`past_key_values`).
|
208 |
+
|
209 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
210 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
211 |
+
information on the default strategy.
|
212 |
+
|
213 |
+
- 1 indicates the head is **not masked**,
|
214 |
+
- 0 indicates the head is **masked**.
|
215 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
216 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
217 |
+
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
218 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
219 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
220 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
221 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
222 |
+
|
223 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
224 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
225 |
+
|
226 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
227 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
228 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
229 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
230 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
231 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
232 |
+
model's internal embedding lookup matrix.
|
233 |
+
use_cache (`bool`, *optional*):
|
234 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
235 |
+
`past_key_values`).
|
236 |
+
output_attentions (`bool`, *optional*):
|
237 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
238 |
+
tensors for more detail.
|
239 |
+
output_hidden_states (`bool`, *optional*):
|
240 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
241 |
+
more detail.
|
242 |
+
return_dict (`bool`, *optional*):
|
243 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
244 |
+
"""
|
245 |
+
|
246 |
+
|
247 |
+
@add_start_docstrings(
|
248 |
+
"""The LLAVA model which consists of a vision backbone and a language model.""",
|
249 |
+
LLAVA_START_DOCSTRING,
|
250 |
+
)
|
251 |
+
class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
252 |
+
def __init__(self, config: LlavaConfig, vision_tower=None, language_model=None):
|
253 |
+
super().__init__(config)
|
254 |
+
self.vision_tower = AutoModel.from_config(config.vision_config) if vision_tower is None else vision_tower
|
255 |
+
|
256 |
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
257 |
+
self.vocab_size = config.vocab_size
|
258 |
+
self.language_model = AutoModelForCausalLM.from_config(
|
259 |
+
config.text_config, attn_implementation=config._attn_implementation
|
260 |
+
) if language_model is None else language_model
|
261 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
262 |
+
self.post_init()
|
263 |
+
|
264 |
+
def get_input_embeddings(self):
|
265 |
+
return self.language_model.get_input_embeddings()
|
266 |
+
|
267 |
+
def set_input_embeddings(self, value):
|
268 |
+
self.language_model.set_input_embeddings(value)
|
269 |
+
|
270 |
+
def get_output_embeddings(self):
|
271 |
+
return self.language_model.get_output_embeddings()
|
272 |
+
|
273 |
+
def set_output_embeddings(self, new_embeddings):
|
274 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
275 |
+
|
276 |
+
def set_decoder(self, decoder):
|
277 |
+
self.language_model.set_decoder(decoder)
|
278 |
+
|
279 |
+
def get_decoder(self):
|
280 |
+
return self.language_model.get_decoder()
|
281 |
+
|
282 |
+
def tie_weights(self):
|
283 |
+
return self.language_model.tie_weights()
|
284 |
+
|
285 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
286 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
287 |
+
# update vocab size
|
288 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
289 |
+
self.config.vocab_size = model_embeds.num_embeddings
|
290 |
+
self.vocab_size = model_embeds.num_embeddings
|
291 |
+
return model_embeds
|
292 |
+
|
293 |
+
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
294 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
295 |
+
batch_size, sequence_length = input_ids.shape
|
296 |
+
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
297 |
+
# 1. Create a mask to know where special image tokens are
|
298 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
299 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
300 |
+
# Compute the maximum embed dimension
|
301 |
+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
302 |
+
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
303 |
+
|
304 |
+
# 2. Compute the positions where text should be written
|
305 |
+
# Calculate new positions for text tokens in merged image-text sequence.
|
306 |
+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
307 |
+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
308 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
309 |
+
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
310 |
+
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
311 |
+
if left_padding:
|
312 |
+
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
313 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
314 |
+
|
315 |
+
# 3. Create the full embedding, already padded to the maximum position
|
316 |
+
final_embedding = torch.zeros(
|
317 |
+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
318 |
+
)
|
319 |
+
final_attention_mask = torch.zeros(
|
320 |
+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
321 |
+
)
|
322 |
+
if labels is not None:
|
323 |
+
final_labels = torch.full(
|
324 |
+
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
325 |
+
)
|
326 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
327 |
+
# set the corresponding tensors into their correct target device.
|
328 |
+
target_device = inputs_embeds.device
|
329 |
+
batch_indices, non_image_indices, text_to_overwrite = (
|
330 |
+
batch_indices.to(target_device),
|
331 |
+
non_image_indices.to(target_device),
|
332 |
+
text_to_overwrite.to(target_device),
|
333 |
+
)
|
334 |
+
attention_mask = attention_mask.to(target_device)
|
335 |
+
|
336 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
337 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
338 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
339 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
340 |
+
if labels is not None:
|
341 |
+
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
342 |
+
|
343 |
+
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
344 |
+
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
345 |
+
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
346 |
+
|
347 |
+
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
348 |
+
raise ValueError(
|
349 |
+
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
350 |
+
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
351 |
+
)
|
352 |
+
|
353 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
354 |
+
final_attention_mask |= image_to_overwrite
|
355 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
356 |
+
|
357 |
+
if labels is None:
|
358 |
+
final_labels = None
|
359 |
+
|
360 |
+
return final_embedding, final_attention_mask, final_labels, position_ids
|
361 |
+
|
362 |
+
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
363 |
+
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
input_ids: torch.LongTensor = None,
|
367 |
+
pixel_values: torch.FloatTensor = None,
|
368 |
+
attention_mask: Optional[torch.Tensor] = None,
|
369 |
+
position_ids: Optional[torch.LongTensor] = None,
|
370 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
371 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
372 |
+
vision_feature_layer: Optional[int] = None,
|
373 |
+
vision_feature_select_strategy: Optional[str] = None,
|
374 |
+
labels: Optional[torch.LongTensor] = None,
|
375 |
+
use_cache: Optional[bool] = None,
|
376 |
+
output_attentions: Optional[bool] = None,
|
377 |
+
output_hidden_states: Optional[bool] = None,
|
378 |
+
return_dict: Optional[bool] = None,
|
379 |
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
380 |
+
r"""
|
381 |
+
Args:
|
382 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
383 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
384 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
385 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
|
389 |
+
Example:
|
390 |
+
|
391 |
+
```python
|
392 |
+
>>> from PIL import Image
|
393 |
+
>>> import requests
|
394 |
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
395 |
+
|
396 |
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
397 |
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
398 |
+
|
399 |
+
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
|
400 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
401 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
402 |
+
|
403 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
404 |
+
|
405 |
+
>>> # Generate
|
406 |
+
>>> generate_ids = model.generate(**inputs, max_length=30)
|
407 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
408 |
+
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
|
409 |
+
```"""
|
410 |
+
|
411 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
412 |
+
output_hidden_states = (
|
413 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
414 |
+
)
|
415 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
416 |
+
vision_feature_layer = (
|
417 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
418 |
+
)
|
419 |
+
vision_feature_select_strategy = (
|
420 |
+
vision_feature_select_strategy
|
421 |
+
if vision_feature_select_strategy is not None
|
422 |
+
else self.config.vision_feature_select_strategy
|
423 |
+
)
|
424 |
+
|
425 |
+
if inputs_embeds is None:
|
426 |
+
# 1. Extra the input embeddings
|
427 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
428 |
+
|
429 |
+
# 2. Merge text and images
|
430 |
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
431 |
+
if isinstance(pixel_values, list):
|
432 |
+
pixel_values = torch.cat([x for x in pixel_values if x is not None], dim=0)
|
433 |
+
# for siglip, need to transform the pixel_values to the right data type
|
434 |
+
if pixel_values.dtype != self.vision_tower.dtype:
|
435 |
+
pixel_values = pixel_values.type(self.vision_tower.dtype)
|
436 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
437 |
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
438 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
439 |
+
|
440 |
+
if vision_feature_select_strategy == "default":
|
441 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
442 |
+
elif vision_feature_select_strategy == "full":
|
443 |
+
selected_image_feature = selected_image_feature
|
444 |
+
else:
|
445 |
+
raise ValueError(
|
446 |
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
447 |
+
)
|
448 |
+
|
449 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
450 |
+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
451 |
+
image_features, inputs_embeds, input_ids, attention_mask, labels
|
452 |
+
)
|
453 |
+
if labels is None:
|
454 |
+
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
455 |
+
else:
|
456 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
457 |
+
# generation with cache
|
458 |
+
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
459 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
460 |
+
# that are set to 0
|
461 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
462 |
+
|
463 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
464 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
465 |
+
|
466 |
+
# Get the target length
|
467 |
+
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
468 |
+
|
469 |
+
extended_attention_mask = torch.ones(
|
470 |
+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
471 |
+
dtype=attention_mask.dtype,
|
472 |
+
device=attention_mask.device,
|
473 |
+
)
|
474 |
+
|
475 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
476 |
+
# if one uses Llava + Fused modules where the cache on the
|
477 |
+
# first iteration is already big enough, or if one passes custom cache
|
478 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
479 |
+
new_batch_index = batch_index[valid_indices]
|
480 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
481 |
+
|
482 |
+
# Zero-out the places where we don't need to attend
|
483 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
484 |
+
|
485 |
+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
486 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
487 |
+
|
488 |
+
outputs = self.language_model(
|
489 |
+
attention_mask=attention_mask,
|
490 |
+
position_ids=position_ids,
|
491 |
+
past_key_values=past_key_values,
|
492 |
+
inputs_embeds=inputs_embeds,
|
493 |
+
use_cache=use_cache,
|
494 |
+
output_attentions=output_attentions,
|
495 |
+
output_hidden_states=output_hidden_states,
|
496 |
+
return_dict=return_dict,
|
497 |
+
)
|
498 |
+
|
499 |
+
logits = outputs[0]
|
500 |
+
|
501 |
+
loss = None
|
502 |
+
if labels is not None:
|
503 |
+
# Shift so that tokens < n predict n
|
504 |
+
if attention_mask is not None:
|
505 |
+
shift_attention_mask = attention_mask[..., 1:]
|
506 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
507 |
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
508 |
+
else:
|
509 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
510 |
+
shift_labels = labels[..., 1:].contiguous()
|
511 |
+
# Flatten the tokens
|
512 |
+
loss_fct = nn.CrossEntropyLoss()
|
513 |
+
loss = loss_fct(
|
514 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
515 |
+
)
|
516 |
+
|
517 |
+
if not return_dict:
|
518 |
+
output = (logits,) + outputs[1:]
|
519 |
+
return (loss,) + output if loss is not None else output
|
520 |
+
|
521 |
+
return LlavaCausalLMOutputWithPast(
|
522 |
+
loss=loss,
|
523 |
+
logits=logits,
|
524 |
+
past_key_values=outputs.past_key_values,
|
525 |
+
hidden_states=outputs.hidden_states,
|
526 |
+
attentions=outputs.attentions,
|
527 |
+
)
|
528 |
+
|
529 |
+
def prepare_inputs_for_generation(
|
530 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
531 |
+
):
|
532 |
+
if past_key_values is not None:
|
533 |
+
if isinstance(past_key_values, Cache):
|
534 |
+
cache_length = past_key_values.get_seq_length()
|
535 |
+
past_length = past_key_values.seen_tokens
|
536 |
+
else:
|
537 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
538 |
+
|
539 |
+
# Keep only the unprocessed tokens:
|
540 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
541 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
542 |
+
# input)
|
543 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
544 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
545 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
546 |
+
# input_ids based on the past_length.
|
547 |
+
elif past_length < input_ids.shape[1]:
|
548 |
+
input_ids = input_ids[:, past_length:]
|
549 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
550 |
+
elif self.config.image_token_index in input_ids:
|
551 |
+
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
552 |
+
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
553 |
+
# older attention values, as their corresponding values are not part of the input.
|
554 |
+
if cache_length < past_length and attention_mask is not None:
|
555 |
+
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
556 |
+
|
557 |
+
position_ids = kwargs.get("position_ids", None)
|
558 |
+
if attention_mask is not None and position_ids is None:
|
559 |
+
# create position_ids on the fly for batch generation
|
560 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
561 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
562 |
+
if past_key_values:
|
563 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
564 |
+
|
565 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
566 |
+
if inputs_embeds is not None and past_key_values is None:
|
567 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
568 |
+
else:
|
569 |
+
model_inputs = {"input_ids": input_ids}
|
570 |
+
|
571 |
+
model_inputs.update(
|
572 |
+
{
|
573 |
+
"position_ids": position_ids,
|
574 |
+
"past_key_values": past_key_values,
|
575 |
+
"use_cache": kwargs.get("use_cache"),
|
576 |
+
"attention_mask": attention_mask,
|
577 |
+
"pixel_values": pixel_values,
|
578 |
+
}
|
579 |
+
)
|
580 |
+
return model_inputs
|
581 |
+
|
582 |
+
def _reorder_cache(self, *args, **kwargs):
|
583 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
584 |
+
|
585 |
+
|
586 |
+
|
587 |
+
|
588 |
+
from transformers.models.clip.modeling_clip import CLIPEncoderLayer, CLIPEncoder
|
589 |
+
@add_start_docstrings(
|
590 |
+
"""The MLLAVA model which consists of a vision backbone and a language model.""",
|
591 |
+
LLAVA_START_DOCSTRING,
|
592 |
+
)
|
593 |
+
class MLlavaForConditionalGeneration(LlavaForConditionalGeneration):
|
594 |
+
def __init__(self, config: LlavaConfig):
|
595 |
+
super().__init__(config)
|
596 |
+
config.vision_config.type_vocab_size = 144
|
597 |
+
self.image_type_embeddings = nn.Embedding(config.vision_config.type_vocab_size, config.vision_config.hidden_size)
|
598 |
+
# self.vision_xatten_layers = nn.ModuleList([CLIPEncoderLayer(config.vision_config) for _ in range(config.vision_config.num_hidden_layers)])
|
599 |
+
self.vision_xatten_layers = CLIPEncoder(config.vision_config)
|
600 |
+
|
601 |
+
|
602 |
+
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
603 |
+
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
604 |
+
def forward(
|
605 |
+
self,
|
606 |
+
input_ids: torch.LongTensor = None,
|
607 |
+
pixel_values: torch.FloatTensor = None,
|
608 |
+
attention_mask: Optional[torch.Tensor] = None,
|
609 |
+
position_ids: Optional[torch.LongTensor] = None,
|
610 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
611 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
612 |
+
vision_feature_layer: Optional[int] = None,
|
613 |
+
vision_feature_select_strategy: Optional[str] = None,
|
614 |
+
labels: Optional[torch.LongTensor] = None,
|
615 |
+
use_cache: Optional[bool] = None,
|
616 |
+
output_attentions: Optional[bool] = None,
|
617 |
+
output_hidden_states: Optional[bool] = None,
|
618 |
+
return_dict: Optional[bool] = None,
|
619 |
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
620 |
+
r"""
|
621 |
+
Args:
|
622 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
623 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
624 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
625 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
|
629 |
+
Example:
|
630 |
+
|
631 |
+
```python
|
632 |
+
>>> from PIL import Image
|
633 |
+
>>> import requests
|
634 |
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
635 |
+
|
636 |
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
637 |
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
638 |
+
|
639 |
+
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
|
640 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
641 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
642 |
+
|
643 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
644 |
+
|
645 |
+
>>> # Generate
|
646 |
+
>>> generate_ids = model.generate(**inputs, max_length=30)
|
647 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
648 |
+
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
|
649 |
+
```"""
|
650 |
+
|
651 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
652 |
+
output_hidden_states = (
|
653 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
654 |
+
)
|
655 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
656 |
+
vision_feature_layer = (
|
657 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
658 |
+
)
|
659 |
+
vision_feature_select_strategy = (
|
660 |
+
vision_feature_select_strategy
|
661 |
+
if vision_feature_select_strategy is not None
|
662 |
+
else self.config.vision_feature_select_strategy
|
663 |
+
)
|
664 |
+
|
665 |
+
if inputs_embeds is None:
|
666 |
+
# 1. Extra the input embeddings
|
667 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
668 |
+
|
669 |
+
# 2. Merge text and images
|
670 |
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
671 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
672 |
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
673 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
674 |
+
|
675 |
+
if vision_feature_select_strategy == "default":
|
676 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
677 |
+
elif vision_feature_select_strategy == "full":
|
678 |
+
selected_image_feature = selected_image_feature
|
679 |
+
else:
|
680 |
+
raise ValueError(
|
681 |
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
682 |
+
)
|
683 |
+
|
684 |
+
# added by Dongfu
|
685 |
+
num_images, num_image_patches, embed_dim = selected_image_feature.shape
|
686 |
+
image_type_embeddings = self.image_type_embeddings(torch.arange(num_images, device=selected_image_feature.device))
|
687 |
+
selected_image_feature += image_type_embeddings.unsqueeze(1)
|
688 |
+
xatten_output = self.vision_xatten_layers(selected_image_feature, attention_mask=None, causal_attention_mask=None)
|
689 |
+
selected_image_feature = xatten_output[0]
|
690 |
+
# end of added by Dongfu
|
691 |
+
|
692 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
693 |
+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
694 |
+
image_features, inputs_embeds, input_ids, attention_mask, labels
|
695 |
+
)
|
696 |
+
if labels is None:
|
697 |
+
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
698 |
+
else:
|
699 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
700 |
+
# generation with cache
|
701 |
+
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
702 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
703 |
+
# that are set to 0
|
704 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
705 |
+
|
706 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
707 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
708 |
+
|
709 |
+
# Get the target length
|
710 |
+
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
711 |
+
|
712 |
+
extended_attention_mask = torch.ones(
|
713 |
+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
714 |
+
dtype=attention_mask.dtype,
|
715 |
+
device=attention_mask.device,
|
716 |
+
)
|
717 |
+
|
718 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
719 |
+
# if one uses Llava + Fused modules where the cache on the
|
720 |
+
# first iteration is already big enough, or if one passes custom cache
|
721 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
722 |
+
new_batch_index = batch_index[valid_indices]
|
723 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
724 |
+
|
725 |
+
# Zero-out the places where we don't need to attend
|
726 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
727 |
+
|
728 |
+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
729 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
730 |
+
|
731 |
+
outputs = self.language_model(
|
732 |
+
attention_mask=attention_mask,
|
733 |
+
position_ids=position_ids,
|
734 |
+
past_key_values=past_key_values,
|
735 |
+
inputs_embeds=inputs_embeds,
|
736 |
+
use_cache=use_cache,
|
737 |
+
output_attentions=output_attentions,
|
738 |
+
output_hidden_states=output_hidden_states,
|
739 |
+
return_dict=return_dict,
|
740 |
+
)
|
741 |
+
|
742 |
+
logits = outputs[0]
|
743 |
+
|
744 |
+
loss = None
|
745 |
+
if labels is not None:
|
746 |
+
# Shift so that tokens < n predict n
|
747 |
+
if attention_mask is not None:
|
748 |
+
shift_attention_mask = attention_mask[..., 1:]
|
749 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
750 |
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
751 |
+
else:
|
752 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
753 |
+
shift_labels = labels[..., 1:].contiguous()
|
754 |
+
# Flatten the tokens
|
755 |
+
loss_fct = nn.CrossEntropyLoss()
|
756 |
+
loss = loss_fct(
|
757 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
758 |
+
)
|
759 |
+
|
760 |
+
if not return_dict:
|
761 |
+
output = (logits,) + outputs[1:]
|
762 |
+
return (loss,) + output if loss is not None else output
|
763 |
+
|
764 |
+
return LlavaCausalLMOutputWithPast(
|
765 |
+
loss=loss,
|
766 |
+
logits=logits,
|
767 |
+
past_key_values=outputs.past_key_values,
|
768 |
+
hidden_states=outputs.hidden_states,
|
769 |
+
attentions=outputs.attentions,
|
770 |
+
)
|
models/mllava/processing_llava.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for Llava.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
import json
|
21 |
+
from typing import List, Optional, Union, Dict
|
22 |
+
|
23 |
+
# from ...feature_extraction_utils import BatchFeature
|
24 |
+
# from ...image_utils import ImageInput
|
25 |
+
# from ...processing_utils import ProcessorMixin
|
26 |
+
# from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
27 |
+
# from ...utils import TensorType
|
28 |
+
|
29 |
+
from transformers.feature_extraction_sequence_utils import BatchFeature
|
30 |
+
from transformers.image_utils import ImageInput
|
31 |
+
from transformers.processing_utils import ProcessorMixin
|
32 |
+
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
33 |
+
from transformers.utils import TensorType
|
34 |
+
from transformers.processing_utils import transformers_module
|
35 |
+
from transformers.utils.hub import is_remote_url, download_url, cached_file, is_offline_mode
|
36 |
+
from transformers.utils import IMAGE_PROCESSOR_NAME
|
37 |
+
|
38 |
+
from PIL import Image
|
39 |
+
import logging
|
40 |
+
import torch
|
41 |
+
import numpy as np
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
class MLlavaProcessor(ProcessorMixin):
|
45 |
+
r"""
|
46 |
+
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
|
47 |
+
|
48 |
+
[`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
49 |
+
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image_processor ([`CLIPImageProcessor`], *optional*):
|
53 |
+
The image processor is a required input.
|
54 |
+
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
55 |
+
The tokenizer is a required input.
|
56 |
+
"""
|
57 |
+
|
58 |
+
attributes = ["image_processor", "tokenizer"]
|
59 |
+
image_processor_class = ("CLIPImageProcessor", "SiglipImageProcessor")
|
60 |
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast", "PreTrainedTokenizerFast")
|
61 |
+
|
62 |
+
def __init__(self, image_processor=None, tokenizer=None):
|
63 |
+
super().__init__(image_processor, tokenizer)
|
64 |
+
|
65 |
+
def preprocess_interleaved_images_and_text(
|
66 |
+
self,
|
67 |
+
text,
|
68 |
+
images=None,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
text (`str`, `List[str]`):
|
73 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
74 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
75 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
76 |
+
text can contain <image> tokens as the placeholder for the image(s) to be inserted.
|
77 |
+
images (`PIL.Image.Image`, `List[PIL.Image.Image]`, `List[List[PIL.Image.Image]]`):
|
78 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
79 |
+
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
80 |
+
number of channels, H and W are image height and width.
|
81 |
+
the number of the images should match the number of <image> tokens in the text.
|
82 |
+
|
83 |
+
"""
|
84 |
+
assert text is not None, "text cannot be None."
|
85 |
+
|
86 |
+
if images is not None:
|
87 |
+
if isinstance(images, Image.Image):
|
88 |
+
images = [images]
|
89 |
+
if isinstance(images, list) and isinstance(images[0], Image.Image):
|
90 |
+
if isinstance(text, str):
|
91 |
+
images = [images]
|
92 |
+
elif isinstance(text, list):
|
93 |
+
if len(text) != len(images):
|
94 |
+
raise ValueError("Invalid input text. Number of texts does not match number of images.")
|
95 |
+
images = [[image] for image in images]
|
96 |
+
if isinstance(text, str):
|
97 |
+
num_images = len(images[0])
|
98 |
+
num_image_tokens = text.count("<image>")
|
99 |
+
if num_image_tokens < num_images:
|
100 |
+
# prepend empty image tokens to text
|
101 |
+
if "USER:" in text:
|
102 |
+
text = text.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
|
103 |
+
elif "Human:" in text:
|
104 |
+
text = text.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
|
105 |
+
elif "HUMAN:" in text:
|
106 |
+
text = text.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
|
107 |
+
else:
|
108 |
+
text = "<image>" * (num_images - num_image_tokens) + text
|
109 |
+
# logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
|
110 |
+
elif num_image_tokens > num_images:
|
111 |
+
text = text.split("<image>")
|
112 |
+
for i, t in enumerate(text):
|
113 |
+
if i < num_images:
|
114 |
+
text[i] = t + "<image>"
|
115 |
+
text = "".join(text)
|
116 |
+
logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
|
117 |
+
# raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
|
118 |
+
texts = [text]
|
119 |
+
elif isinstance(text, list):
|
120 |
+
if not isinstance(text[0], str):
|
121 |
+
raise ValueError("Invalid input text. Each element of text must be a string.")
|
122 |
+
for i, t in enumerate(text):
|
123 |
+
num_image_tokens = t.count("<image>")
|
124 |
+
num_images = len(images[i])
|
125 |
+
if num_image_tokens < num_images:
|
126 |
+
# prepend empty image tokens to text
|
127 |
+
if "USER:" in t:
|
128 |
+
t = t.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
|
129 |
+
elif "Human:" in t:
|
130 |
+
t = t.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
|
131 |
+
elif "HUMAN:" in t:
|
132 |
+
t = t.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
|
133 |
+
else:
|
134 |
+
t = "<image>" * (num_images - num_image_tokens) + t
|
135 |
+
# logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
|
136 |
+
elif num_image_tokens > num_images:
|
137 |
+
t = t.split("<image>")
|
138 |
+
for j, s in enumerate(t):
|
139 |
+
if j < num_images:
|
140 |
+
t[j] = s + "<image>"
|
141 |
+
t = "".join(t)
|
142 |
+
logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
|
143 |
+
# raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
|
144 |
+
text[i] = t
|
145 |
+
texts = text
|
146 |
+
else:
|
147 |
+
raise ValueError("Invalid input text. text must be a string or a list of strings.")
|
148 |
+
assert all([t.count("<image>") == len(images_per_text) for t, images_per_text in zip(texts, images)]), "Number of <image> tokens in text does not match number of images."
|
149 |
+
# add image denotation in text before each <image> as "(image {i}: <image>)"
|
150 |
+
for i, t in enumerate(texts):
|
151 |
+
for j in range(len(images[i])):
|
152 |
+
t = t.replace("<image>", f"(image {j+1}: <Image><IMAGE></Image>)", 1)
|
153 |
+
t = t.replace("<IMAGE>", "<image>")
|
154 |
+
texts[i] = t
|
155 |
+
|
156 |
+
# flatten images
|
157 |
+
images = [image for images_per_text in images for image in images_per_text]
|
158 |
+
else:
|
159 |
+
if isinstance(text, str):
|
160 |
+
texts = [text]
|
161 |
+
elif isinstance(text, list):
|
162 |
+
if not isinstance(text[0], str):
|
163 |
+
raise ValueError("Invalid input text. Each element of text must be a string.")
|
164 |
+
texts = text
|
165 |
+
else:
|
166 |
+
raise ValueError("Invalid input text. text must be a string or a list of strings.")
|
167 |
+
|
168 |
+
return texts, images
|
169 |
+
|
170 |
+
def __call__(
|
171 |
+
self,
|
172 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
173 |
+
images: ImageInput = None,
|
174 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
175 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
176 |
+
max_length=None,
|
177 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
178 |
+
add_image_ids: bool = True,
|
179 |
+
) -> BatchFeature:
|
180 |
+
"""
|
181 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
182 |
+
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
183 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
184 |
+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
185 |
+
of the above two methods for more information.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
189 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
190 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
191 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
192 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
193 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
194 |
+
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
195 |
+
number of channels, H and W are image height and width.
|
196 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
197 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
198 |
+
index) among:
|
199 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
200 |
+
sequence if provided).
|
201 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
202 |
+
acceptable input length for the model if that argument is not provided.
|
203 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
204 |
+
lengths).
|
205 |
+
max_length (`int`, *optional*):
|
206 |
+
Maximum length of the returned list and optionally padding length (see above).
|
207 |
+
truncation (`bool`, *optional*):
|
208 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
209 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
210 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
211 |
+
|
212 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
213 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
214 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
215 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
219 |
+
|
220 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
221 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
222 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
223 |
+
`None`).
|
224 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
225 |
+
"""
|
226 |
+
if add_image_ids:
|
227 |
+
text, images = self.preprocess_interleaved_images_and_text(text, images)
|
228 |
+
if images is not None:
|
229 |
+
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] # [batch_size, num_channels, height, width], e.g. [1, 3, 336, 336]
|
230 |
+
else:
|
231 |
+
pixel_values = None
|
232 |
+
text_inputs = self.tokenizer(
|
233 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
234 |
+
)
|
235 |
+
# text_inputs:
|
236 |
+
# 1. input_ids: [batch_size, sequence_length], e.g. [1, 6]
|
237 |
+
# 2. attention_mask: [batch_size, sequence_length], e.g. [1, 6]
|
238 |
+
|
239 |
+
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
|
240 |
+
|
241 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
242 |
+
def batch_decode(self, *args, **kwargs):
|
243 |
+
"""
|
244 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
245 |
+
refer to the docstring of this method for more information.
|
246 |
+
"""
|
247 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
248 |
+
|
249 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
250 |
+
def decode(self, *args, **kwargs):
|
251 |
+
"""
|
252 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
253 |
+
the docstring of this method for more information.
|
254 |
+
"""
|
255 |
+
return self.tokenizer.decode(*args, **kwargs)
|
256 |
+
|
257 |
+
@property
|
258 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
259 |
+
def model_input_names(self):
|
260 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
261 |
+
image_processor_input_names = self.image_processor.model_input_names
|
262 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
263 |
+
|
264 |
+
def _right_pad_inputs_with_attention_mask(self, model_inputs: List[Dict]):
|
265 |
+
results = {}
|
266 |
+
assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
|
267 |
+
for k in model_inputs[0].keys():
|
268 |
+
if k == "pixel_values":
|
269 |
+
results[k] = [inputs[k] if inputs[k] is not None else None for inputs in model_inputs]
|
270 |
+
else:
|
271 |
+
results[k] = torch.cat([inputs[k] for inputs in model_inputs], dim=0)
|
272 |
+
return results
|
273 |
+
|
274 |
+
@classmethod
|
275 |
+
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
276 |
+
args = []
|
277 |
+
|
278 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
279 |
+
force_download = kwargs.pop("force_download", False)
|
280 |
+
resume_download = kwargs.pop("resume_download", False)
|
281 |
+
proxies = kwargs.pop("proxies", None)
|
282 |
+
token = kwargs.pop("token", None)
|
283 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
284 |
+
revision = kwargs.pop("revision", None)
|
285 |
+
subfolder = kwargs.pop("subfolder", "")
|
286 |
+
|
287 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
288 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
289 |
+
|
290 |
+
user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
|
291 |
+
if from_pipeline is not None:
|
292 |
+
user_agent["using_pipeline"] = from_pipeline
|
293 |
+
|
294 |
+
if is_offline_mode() and not local_files_only:
|
295 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
296 |
+
local_files_only = True
|
297 |
+
|
298 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
299 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
300 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
301 |
+
processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
|
302 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
303 |
+
resolved_processor_file = pretrained_model_name_or_path
|
304 |
+
is_local = True
|
305 |
+
elif is_remote_url(pretrained_model_name_or_path):
|
306 |
+
processor_file = pretrained_model_name_or_path
|
307 |
+
resolved_processor_file = download_url(pretrained_model_name_or_path)
|
308 |
+
else:
|
309 |
+
processor_file = IMAGE_PROCESSOR_NAME
|
310 |
+
try:
|
311 |
+
# Load from local folder or from cache or download from model Hub and cache
|
312 |
+
resolved_processor_file = cached_file(
|
313 |
+
pretrained_model_name_or_path,
|
314 |
+
processor_file,
|
315 |
+
cache_dir=cache_dir,
|
316 |
+
force_download=force_download,
|
317 |
+
proxies=proxies,
|
318 |
+
resume_download=resume_download,
|
319 |
+
local_files_only=local_files_only,
|
320 |
+
token=token,
|
321 |
+
user_agent=user_agent,
|
322 |
+
revision=revision,
|
323 |
+
subfolder=subfolder,
|
324 |
+
_raise_exceptions_for_missing_entries=True,
|
325 |
+
)
|
326 |
+
except EnvironmentError:
|
327 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
328 |
+
# the original exception.
|
329 |
+
raise
|
330 |
+
except Exception:
|
331 |
+
# For any other exception, we throw a generic error.
|
332 |
+
raise EnvironmentError(
|
333 |
+
f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
334 |
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
335 |
+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
336 |
+
f" directory containing a {IMAGE_PROCESSOR_NAME} file"
|
337 |
+
)
|
338 |
+
|
339 |
+
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
|
340 |
+
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
|
341 |
+
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
342 |
+
# However, for models added in the future, we won't get the expected error if this file is missing.
|
343 |
+
if resolved_processor_file is None:
|
344 |
+
image_processor_dict = {}
|
345 |
+
|
346 |
+
try:
|
347 |
+
# Load processor dict
|
348 |
+
with open(resolved_processor_file, "r", encoding="utf-8") as reader:
|
349 |
+
text = reader.read()
|
350 |
+
image_processor_dict = json.loads(text)
|
351 |
+
|
352 |
+
except json.JSONDecodeError:
|
353 |
+
raise EnvironmentError(
|
354 |
+
f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
|
355 |
+
)
|
356 |
+
|
357 |
+
for attribute_name in cls.attributes:
|
358 |
+
class_name = getattr(cls, f"{attribute_name}_class")
|
359 |
+
if isinstance(class_name, tuple):
|
360 |
+
if attribute_name == "tokenizer":
|
361 |
+
classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
|
362 |
+
use_fast = kwargs.get("use_fast", True)
|
363 |
+
if use_fast and classes[1] is not None:
|
364 |
+
attribute_class = classes[1]
|
365 |
+
else:
|
366 |
+
attribute_class = classes[0]
|
367 |
+
elif attribute_name == "image_processor":
|
368 |
+
image_processor_type = image_processor_dict.get("image_processor_type", None)
|
369 |
+
if image_processor_type is not None:
|
370 |
+
assert image_processor_type in class_name, f"Invalid image processor type: {image_processor_type}"
|
371 |
+
attribute_class = getattr(transformers_module, image_processor_type)
|
372 |
+
else:
|
373 |
+
attribute_class = getattr(transformers_module, class_name[0])
|
374 |
+
else:
|
375 |
+
raise ValueError(f"Invalid attribute name: {attribute_name}")
|
376 |
+
else:
|
377 |
+
attribute_class = getattr(transformers_module, class_name)
|
378 |
+
|
379 |
+
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
380 |
+
return args
|
381 |
+
|
models/mllava/utils.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import torch
|
3 |
+
from .modeling_llava import LlavaForConditionalGeneration
|
4 |
+
from .processing_llava import MLlavaProcessor
|
5 |
+
# from ..conversation import conv_mllava_v1_mmtag as default_conv
|
6 |
+
from ..conversation import conv_mllava_v1 as default_conv, conv_templates
|
7 |
+
|
8 |
+
from typing import List, Tuple, Union, Tuple
|
9 |
+
|
10 |
+
def chat_mllava(
|
11 |
+
text:str,
|
12 |
+
images: List[Union[PIL.Image.Image, str]],
|
13 |
+
model:LlavaForConditionalGeneration,
|
14 |
+
processor:MLlavaProcessor,
|
15 |
+
max_input_length:int=None,
|
16 |
+
history:List[dict]=None,
|
17 |
+
**kwargs) -> Tuple[str, List[dict]]:
|
18 |
+
"""
|
19 |
+
Chat with the Mllava model
|
20 |
+
Args:
|
21 |
+
text: str, the text to be sent to the model, where <image> will be the placeholder for the image
|
22 |
+
images: List[PIL.Image.Image], the images to be sent to the model, or None
|
23 |
+
model: LlavaForConditionalGeneration, the model to be used
|
24 |
+
processor: MLlavaProcessor, the processor to be used
|
25 |
+
max_input_length: int, the maximum input length
|
26 |
+
history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
|
27 |
+
kwargs: dict, the generation kwargs
|
28 |
+
Returns:
|
29 |
+
Tuple[str, List[dict]], the generated text and the history of the conversation
|
30 |
+
|
31 |
+
|
32 |
+
"""
|
33 |
+
if "llama-3" in model.language_model.name_or_path.lower():
|
34 |
+
conv = conv_templates['llama_3']
|
35 |
+
terminators = [
|
36 |
+
processor.tokenizer.eos_token_id,
|
37 |
+
processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
38 |
+
]
|
39 |
+
else:
|
40 |
+
conv = default_conv
|
41 |
+
terminators = None
|
42 |
+
kwargs["eos_token_id"] = terminators
|
43 |
+
conv = conv.copy()
|
44 |
+
conv.messages = []
|
45 |
+
if history is not None:
|
46 |
+
for message in history:
|
47 |
+
assert message["role"] in conv.roles
|
48 |
+
conv.append_message(message["role"], message["text"])
|
49 |
+
if text:
|
50 |
+
assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
|
51 |
+
conv.append_message(conv.roles[0], text)
|
52 |
+
conv.append_message(conv.roles[1], "")
|
53 |
+
history.append({"role": conv.roles[0], "text": text})
|
54 |
+
history.append({"role": conv.roles[1], "text": ""})
|
55 |
+
else:
|
56 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
57 |
+
assert conv.messages[-1][1] == "", "No user message should be provided"
|
58 |
+
else:
|
59 |
+
assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
|
60 |
+
conv.append_message(conv.roles[0], "")
|
61 |
+
history.append({"role": conv.roles[0], "text": ""})
|
62 |
+
else:
|
63 |
+
history = []
|
64 |
+
history.append({"role": conv.roles[0], "text": text})
|
65 |
+
history.append({"role": conv.roles[1], "text": ""})
|
66 |
+
conv.append_message(conv.roles[0], text)
|
67 |
+
conv.append_message(conv.roles[1], "")
|
68 |
+
assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
|
69 |
+
assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
|
70 |
+
|
71 |
+
prompt = conv.get_prompt()
|
72 |
+
if images:
|
73 |
+
for i in range(len(images)):
|
74 |
+
if isinstance(images[i], str):
|
75 |
+
images[i] = PIL.Image.open(images[i]).convert("RGB")
|
76 |
+
|
77 |
+
inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
|
78 |
+
for k, v in inputs.items():
|
79 |
+
if v is not None:
|
80 |
+
if isinstance(v, torch.Tensor):
|
81 |
+
inputs[k] = v.to(model.device)
|
82 |
+
elif isinstance(v, list):
|
83 |
+
inputs[k] = [x.to(model.device) for x in v]
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Invalid input type: {type(v)}")
|
86 |
+
|
87 |
+
|
88 |
+
output_ids = model.generate(**inputs, **kwargs)
|
89 |
+
output_ids = output_ids[0]
|
90 |
+
|
91 |
+
# remove the input tokens
|
92 |
+
generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
|
93 |
+
generated_text = processor.decode(generated_ids, skip_special_tokens=True)
|
94 |
+
|
95 |
+
history[-1]["text"] = generated_text
|
96 |
+
|
97 |
+
return generated_text, history
|
98 |
+
|
99 |
+
|
100 |
+
def chat_mllava_stream(
|
101 |
+
text:str,
|
102 |
+
images: List[Union[PIL.Image.Image, str]],
|
103 |
+
model:LlavaForConditionalGeneration,
|
104 |
+
processor:MLlavaProcessor,
|
105 |
+
max_input_length:int=None,
|
106 |
+
history:List[dict]=None,
|
107 |
+
**kwargs) -> Tuple[str, List[dict]]:
|
108 |
+
"""
|
109 |
+
Chat with the Mllava model
|
110 |
+
Args:
|
111 |
+
text: str, the text to be sent to the model, where <image> will be the placeholder for the image
|
112 |
+
images: List[PIL.Image.Image], the images to be sent to the model, or None
|
113 |
+
model: LlavaForConditionalGeneration, the model to be used
|
114 |
+
processor: MLlavaProcessor, the processor to be used
|
115 |
+
max_input_length: int, the maximum input length
|
116 |
+
history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
|
117 |
+
kwargs: dict, the generation kwargs
|
118 |
+
Returns:
|
119 |
+
Tuple[str, List[dict]], the generated text and the history of the conversation
|
120 |
+
|
121 |
+
|
122 |
+
"""
|
123 |
+
if "llama-3" in model.language_model.name_or_path.lower():
|
124 |
+
conv = conv_templates['llama_3']
|
125 |
+
terminators = [
|
126 |
+
processor.tokenizer.eos_token_id,
|
127 |
+
processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
128 |
+
]
|
129 |
+
else:
|
130 |
+
conv = default_conv
|
131 |
+
terminators = None
|
132 |
+
kwargs["eos_token_id"] = terminators
|
133 |
+
conv = conv.copy()
|
134 |
+
conv.messages = []
|
135 |
+
if history is not None:
|
136 |
+
for message in history:
|
137 |
+
assert message["role"] in conv.roles
|
138 |
+
conv.append_message(message["role"], message["text"])
|
139 |
+
if text:
|
140 |
+
assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
|
141 |
+
conv.append_message(conv.roles[0], text)
|
142 |
+
conv.append_message(conv.roles[1], "")
|
143 |
+
history.append({"role": conv.roles[0], "text": text})
|
144 |
+
history.append({"role": conv.roles[1], "text": ""})
|
145 |
+
else:
|
146 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
147 |
+
assert conv.messages[-1][1] == "", "No user message should be provided"
|
148 |
+
else:
|
149 |
+
assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
|
150 |
+
conv.append_message(conv.roles[0], "")
|
151 |
+
history.append({"role": conv.roles[0], "text": ""})
|
152 |
+
else:
|
153 |
+
history = []
|
154 |
+
history.append({"role": conv.roles[0], "text": text})
|
155 |
+
history.append({"role": conv.roles[1], "text": ""})
|
156 |
+
conv.append_message(conv.roles[0], text)
|
157 |
+
conv.append_message(conv.roles[1], "")
|
158 |
+
assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
|
159 |
+
assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
|
160 |
+
|
161 |
+
prompt = conv.get_prompt()
|
162 |
+
if images:
|
163 |
+
for i in range(len(images)):
|
164 |
+
if isinstance(images[i], str):
|
165 |
+
images[i] = PIL.Image.open(images[i])
|
166 |
+
images[i] = images[i].convert("RGB")
|
167 |
+
|
168 |
+
inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
|
169 |
+
print(processor.tokenizer.decode(inputs["input_ids"][0]))
|
170 |
+
for k, v in inputs.items():
|
171 |
+
if v is not None:
|
172 |
+
if isinstance(v, torch.Tensor):
|
173 |
+
inputs[k] = v.to(model.device)
|
174 |
+
elif isinstance(v, list):
|
175 |
+
inputs[k] = [x.to(model.device) for x in v]
|
176 |
+
else:
|
177 |
+
raise ValueError(f"Invalid input type: {type(v)}")
|
178 |
+
|
179 |
+
from transformers import TextIteratorStreamer
|
180 |
+
from threading import Thread
|
181 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
182 |
+
kwargs["streamer"] = streamer
|
183 |
+
inputs.update(kwargs)
|
184 |
+
thread = Thread(target=model.generate, kwargs=inputs)
|
185 |
+
thread.start()
|
186 |
+
for _output in streamer:
|
187 |
+
history[-1]["text"] += _output
|
188 |
+
yield history[-1]["text"], history
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
gradio
|
2 |
-
Pillow
|
3 |
torch
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
torch
|
2 |
+
transformers>=4.41.0
|
3 |
+
Pillow
|
4 |
+
gradio
|
5 |
+
spaces
|
6 |
+
multiprocess
|