Spaces:
Runtime error
Runtime error
HaoZhang534
commited on
Commit
β’
a65550c
1
Parent(s):
d260573
first
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- app.py +288 -0
- llava/__init__.py +1 -0
- llava/__pycache__/__init__.cpython-310.pyc +0 -0
- llava/__pycache__/constants.cpython-310.pyc +0 -0
- llava/__pycache__/conversation.cpython-310.pyc +0 -0
- llava/__pycache__/mm_utils.cpython-310.pyc +0 -0
- llava/__pycache__/utils.cpython-310.pyc +0 -0
- llava/constants.py +12 -0
- llava/conversation.py +554 -0
- llava/eval/evaluate_interleave.py +339 -0
- llava/eval/model_vqa.py +240 -0
- llava/mm_utils.py +381 -0
- llava/model/__init__.py +20 -0
- llava/model/__pycache__/__init__.cpython-310.pyc +0 -0
- llava/model/__pycache__/builder.cpython-310.pyc +0 -0
- llava/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
- llava/model/apply_delta.py +47 -0
- llava/model/builder.py +250 -0
- llava/model/consolidate.py +30 -0
- llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc +0 -0
- llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
- llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
- llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc +0 -0
- llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc +0 -0
- llava/model/language_model/llava_gemma.py +122 -0
- llava/model/language_model/llava_llama.py +131 -0
- llava/model/language_model/llava_mistral.py +127 -0
- llava/model/language_model/llava_mixtral.py +122 -0
- llava/model/language_model/llava_mpt.py +105 -0
- llava/model/language_model/llava_qwen.py +128 -0
- llava/model/language_model/llava_qwen_moe.py +128 -0
- llava/model/llava_arch.py +389 -0
- llava/model/make_delta.py +52 -0
- llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
- llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
- llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
- llava/model/multimodal_encoder/builder.py +14 -0
- llava/model/multimodal_encoder/clip_encoder.py +114 -0
- llava/model/multimodal_encoder/siglip_encoder.py +620 -0
- llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
- llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
- llava/model/multimodal_projector/builder.py +65 -0
- llava/model/multimodal_projector/pooler_projector.py +33 -0
- llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
- llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc +0 -0
- llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
- llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc +0 -0
- llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc +0 -0
- llava/model/multimodal_resampler/builder.py +34 -0
- llava/model/multimodal_resampler/masked_drop.py +80 -0
app.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# from .demo_modelpart import InferenceDemo
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
# import time
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
# import copy
|
10 |
+
import torch
|
11 |
+
# import random
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from llava import conversation as conversation_lib
|
15 |
+
from llava.constants import DEFAULT_IMAGE_TOKEN
|
16 |
+
|
17 |
+
|
18 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
19 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
20 |
+
from llava.model.builder import load_pretrained_model
|
21 |
+
from llava.utils import disable_torch_init
|
22 |
+
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
import requests
|
27 |
+
from PIL import Image
|
28 |
+
from io import BytesIO
|
29 |
+
from transformers import TextStreamer
|
30 |
+
|
31 |
+
class InferenceDemo(object):
|
32 |
+
def __init__(self,args,model_path,tokenizer, model, image_processor, context_len) -> None:
|
33 |
+
disable_torch_init()
|
34 |
+
|
35 |
+
|
36 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = tokenizer, model, image_processor, context_len
|
37 |
+
|
38 |
+
if "llama-2" in model_name.lower():
|
39 |
+
conv_mode = "llava_llama_2"
|
40 |
+
elif "v1" in model_name.lower():
|
41 |
+
conv_mode = "llava_v1"
|
42 |
+
elif "mpt" in model_name.lower():
|
43 |
+
conv_mode = "mpt"
|
44 |
+
elif 'qwen' in model_name.lower():
|
45 |
+
conv_mode = "qwen_1_5"
|
46 |
+
else:
|
47 |
+
conv_mode = "llava_v0"
|
48 |
+
|
49 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
50 |
+
print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
|
51 |
+
else:
|
52 |
+
args.conv_mode = conv_mode
|
53 |
+
self.conv_mode=conv_mode
|
54 |
+
self.conversation = conv_templates[args.conv_mode].copy()
|
55 |
+
self.num_frames = args.num_frames
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def is_valid_video_filename(name):
|
60 |
+
video_extensions = ['avi', 'mp4', 'mov', 'mkv', 'flv', 'wmv', 'mjpeg']
|
61 |
+
|
62 |
+
ext = name.split('.')[-1].lower()
|
63 |
+
|
64 |
+
if ext in video_extensions:
|
65 |
+
return True
|
66 |
+
else:
|
67 |
+
return False
|
68 |
+
|
69 |
+
def sample_frames(video_file, num_frames) :
|
70 |
+
video = cv2.VideoCapture(video_file)
|
71 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
72 |
+
interval = total_frames // num_frames
|
73 |
+
frames = []
|
74 |
+
for i in range(total_frames):
|
75 |
+
ret, frame = video.read()
|
76 |
+
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
77 |
+
if not ret:
|
78 |
+
continue
|
79 |
+
if i % interval == 0:
|
80 |
+
frames.append(pil_img)
|
81 |
+
video.release()
|
82 |
+
return frames
|
83 |
+
|
84 |
+
def load_image(image_file):
|
85 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
86 |
+
response = requests.get(image_file)
|
87 |
+
if response.status_code == 200:
|
88 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
89 |
+
else:
|
90 |
+
print('failed to load the image')
|
91 |
+
else:
|
92 |
+
print('Load image from local file')
|
93 |
+
print(image_file)
|
94 |
+
image = Image.open(image_file).convert("RGB")
|
95 |
+
|
96 |
+
return image
|
97 |
+
|
98 |
+
|
99 |
+
def clear_history(history):
|
100 |
+
|
101 |
+
our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
|
102 |
+
|
103 |
+
return None
|
104 |
+
def clear_response(history):
|
105 |
+
for index_conv in range(1, len(history)):
|
106 |
+
# loop until get a text response from our model.
|
107 |
+
conv = history[-index_conv]
|
108 |
+
if not (conv[0] is None):
|
109 |
+
break
|
110 |
+
question = history[-index_conv][0]
|
111 |
+
history = history[:-index_conv]
|
112 |
+
return history, question
|
113 |
+
|
114 |
+
def print_like_dislike(x: gr.LikeData):
|
115 |
+
print(x.index, x.value, x.liked)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def add_message(history, message):
|
120 |
+
# history=[]
|
121 |
+
global our_chatbot
|
122 |
+
if len(history)==0:
|
123 |
+
our_chatbot = InferenceDemo(args,model_path,tokenizer, model, image_processor, context_len)
|
124 |
+
|
125 |
+
for x in message["files"]:
|
126 |
+
history.append(((x,), None))
|
127 |
+
if message["text"] is not None:
|
128 |
+
history.append((message["text"], None))
|
129 |
+
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
130 |
+
|
131 |
+
def bot(history):
|
132 |
+
text=history[-1][0]
|
133 |
+
images_this_term=[]
|
134 |
+
text_this_term=''
|
135 |
+
# import pdb;pdb.set_trace()
|
136 |
+
num_new_images = 0
|
137 |
+
for i,message in enumerate(history[:-1]):
|
138 |
+
if type(message[0]) is tuple:
|
139 |
+
images_this_term.append(message[0][0])
|
140 |
+
if is_valid_video_filename(message[0][0]):
|
141 |
+
num_new_images+=our_chatbot.num_frames
|
142 |
+
else:
|
143 |
+
num_new_images+=1
|
144 |
+
else:
|
145 |
+
num_new_images=0
|
146 |
+
|
147 |
+
# for message in history[-i-1:]:
|
148 |
+
# images_this_term.append(message[0][0])
|
149 |
+
|
150 |
+
assert len(images_this_term)>0, "must have an image"
|
151 |
+
# image_files = (args.image_file).split(',')
|
152 |
+
# image = [load_image(f) for f in images_this_term if f]
|
153 |
+
image_list=[]
|
154 |
+
for f in images_this_term:
|
155 |
+
if is_valid_video_filename(f):
|
156 |
+
image_list+=sample_frames(f, our_chatbot.num_frames)
|
157 |
+
else:
|
158 |
+
image_list.append(load_image(f))
|
159 |
+
image_tensor = [our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0].half().to(our_chatbot.model.device) for f in image_list]
|
160 |
+
|
161 |
+
image_tensor = torch.stack(image_tensor)
|
162 |
+
image_token = DEFAULT_IMAGE_TOKEN*num_new_images
|
163 |
+
# if our_chatbot.model.config.mm_use_im_start_end:
|
164 |
+
# inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
|
165 |
+
# else:
|
166 |
+
inp=text
|
167 |
+
inp = image_token+ "\n" + inp
|
168 |
+
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
|
169 |
+
# image = None
|
170 |
+
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
|
171 |
+
prompt = our_chatbot.conversation.get_prompt()
|
172 |
+
|
173 |
+
input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
|
174 |
+
stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
|
175 |
+
keywords = [stop_str]
|
176 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, our_chatbot.tokenizer, input_ids)
|
177 |
+
streamer = TextStreamer(our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
178 |
+
# import pdb;pdb.set_trace()
|
179 |
+
with torch.inference_mode():
|
180 |
+
output_ids = our_chatbot.model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=False, stopping_criteria=[stopping_criteria])
|
181 |
+
|
182 |
+
outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
|
183 |
+
if outputs.endswith(stop_str):
|
184 |
+
outputs = outputs[:-len(stop_str)]
|
185 |
+
our_chatbot.conversation.messages[-1][-1] = outputs
|
186 |
+
|
187 |
+
history[-1]=[text,outputs]
|
188 |
+
|
189 |
+
return history
|
190 |
+
txt = gr.Textbox(
|
191 |
+
scale=4,
|
192 |
+
show_label=False,
|
193 |
+
placeholder="Enter text and press enter.",
|
194 |
+
container=False,
|
195 |
+
)
|
196 |
+
with gr.Blocks() as demo:
|
197 |
+
# Informations
|
198 |
+
title_markdown = ("""
|
199 |
+
# LLaVA-NeXT Interleave
|
200 |
+
[[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
|
201 |
+
""")
|
202 |
+
tos_markdown = ("""
|
203 |
+
### TODO!. Terms of use
|
204 |
+
By using this service, users are required to agree to the following terms:
|
205 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
206 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
207 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
208 |
+
""")
|
209 |
+
learn_more_markdown = ("""
|
210 |
+
### TODO!. License
|
211 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
212 |
+
""")
|
213 |
+
models = [
|
214 |
+
"LLaVA-Interleave-7B",
|
215 |
+
]
|
216 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
217 |
+
gr.Markdown(title_markdown)
|
218 |
+
|
219 |
+
chatbot = gr.Chatbot(
|
220 |
+
[],
|
221 |
+
elem_id="chatbot",
|
222 |
+
bubble_full_width=False
|
223 |
+
)
|
224 |
+
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image","video"], placeholder="Enter message or upload file...", show_label=False)
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
with gr.Row():
|
229 |
+
upvote_btn = gr.Button(value="π Upvote", interactive=True)
|
230 |
+
downvote_btn = gr.Button(value="π Downvote", interactive=True)
|
231 |
+
flag_btn = gr.Button(value="β οΈ Flag", interactive=True)
|
232 |
+
#stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=True)
|
233 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=True)
|
234 |
+
clear_btn = gr.Button(value="ποΈ Clear history", interactive=True)
|
235 |
+
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
236 |
+
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
|
237 |
+
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
|
238 |
+
|
239 |
+
chatbot.like(print_like_dislike, None, None)
|
240 |
+
clear_btn.click(fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all")
|
241 |
+
with gr.Column():
|
242 |
+
gr.Examples(examples=[
|
243 |
+
[{"files": [f"{cur_dir}/examples/code1.jpeg",f"{cur_dir}/examples/code2.jpeg"], "text": "Please pay attention to the movement of the object from the first image to the second image, then write a HTML code to show this movement."}],
|
244 |
+
[{"files": [f"{cur_dir}/examples/shub.jpg",f"{cur_dir}/examples/shuc.jpg",f"{cur_dir}/examples/shud.jpg"], "text": "what is fun about the images?"}],
|
245 |
+
[{"files": [f"{cur_dir}/examples/iphone-15-price-1024x576.jpg",f"{cur_dir}/examples/dynamic-island-1024x576.jpg",f"{cur_dir}/examples/iphone-15-colors-1024x576.jpg",f"{cur_dir}/examples/Iphone-15-Usb-c-charger-1024x576.jpg",f"{cur_dir}/examples/A-17-processors-1024x576.jpg"], "text": "The images are the PPT of iPhone 15 review. can you summarize the main information?"}],
|
246 |
+
[{"files": [f"{cur_dir}/examples/fangao3.jpeg",f"{cur_dir}/examples/fangao2.jpeg",f"{cur_dir}/examples/fangao1.jpeg"], "text": "Do you kown who draw these paintings?"}],
|
247 |
+
[{"files": [f"{cur_dir}/examples/oprah-winfrey-resume.png",f"{cur_dir}/examples/steve-jobs-resume.jpg"], "text": "Hi, there are two candidates, can you provide a brief description for each of them for me?"}],
|
248 |
+
[{"files": [f"{cur_dir}/examples/original_bench.jpeg",f"{cur_dir}/examples/changed_bench.jpeg"], "text": "How to edit image1 to make it look like image2?"}],
|
249 |
+
[{"files": [f"{cur_dir}/examples/twitter2.jpeg",f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
|
250 |
+
# [{"files": [f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
|
251 |
+
# [{"files": [f"playground/demo/examples/lion1_.mp4",f"playground/demo/examples/lion2_.mp4"], "text": "The input contains two videos, the first half is the first video and the second half is the second video. What is the difference between the two videos?"}],
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
], inputs=[chat_input], label="Compare images: ")
|
257 |
+
|
258 |
+
demo.queue()
|
259 |
+
if __name__ == "__main__":
|
260 |
+
import argparse
|
261 |
+
argparser = argparse.ArgumentParser()
|
262 |
+
argparser.add_argument("--server_name", default="0.0.0.0", type=str)
|
263 |
+
argparser.add_argument("--port", default="6123", type=str)
|
264 |
+
argparser.add_argument("--model_path", default="", type=str)
|
265 |
+
# argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
266 |
+
argparser.add_argument("--model-base", type=str, default=None)
|
267 |
+
argparser.add_argument("--num-gpus", type=int, default=1)
|
268 |
+
argparser.add_argument("--conv-mode", type=str, default=None)
|
269 |
+
argparser.add_argument("--temperature", type=float, default=0.2)
|
270 |
+
argparser.add_argument("--max-new-tokens", type=int, default=512)
|
271 |
+
argparser.add_argument("--num_frames", type=int, default=16)
|
272 |
+
argparser.add_argument("--load-8bit", action="store_true")
|
273 |
+
argparser.add_argument("--load-4bit", action="store_true")
|
274 |
+
argparser.add_argument("--debug", action="store_true")
|
275 |
+
|
276 |
+
args = argparser.parse_args()
|
277 |
+
model_path = args.model_path
|
278 |
+
filt_invalid="cut"
|
279 |
+
model_name = get_model_name_from_path(args.model_path)
|
280 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
|
281 |
+
our_chatbot = None
|
282 |
+
# import pdb;pdb.set_trace()
|
283 |
+
try:
|
284 |
+
demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
|
285 |
+
except Exception as e:
|
286 |
+
args.port=int(args.port)+1
|
287 |
+
print(f"Port {args.port} is occupied, try port {args.port}")
|
288 |
+
demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
|
llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LlavaLlamaForCausalLM
|
llava/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (212 Bytes). View file
|
|
llava/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (474 Bytes). View file
|
|
llava/__pycache__/conversation.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
llava/__pycache__/mm_utils.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
llava/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.3 kB). View file
|
|
llava/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
llava/conversation.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Any, Dict, Union, Tuple
|
4 |
+
import re
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
class SeparatorStyle(Enum):
|
12 |
+
"""Different separator style."""
|
13 |
+
|
14 |
+
SINGLE = auto()
|
15 |
+
TWO = auto()
|
16 |
+
MPT = auto()
|
17 |
+
PLAIN = auto()
|
18 |
+
CHATML = auto()
|
19 |
+
LLAMA_2 = auto()
|
20 |
+
LLAMA_3 = auto()
|
21 |
+
QWEN = auto()
|
22 |
+
GEMMA = auto()
|
23 |
+
|
24 |
+
|
25 |
+
@dataclasses.dataclass
|
26 |
+
class Conversation:
|
27 |
+
"""A class that keeps all conversation history."""
|
28 |
+
|
29 |
+
system: str
|
30 |
+
roles: List[str]
|
31 |
+
messages: List[List[str]]
|
32 |
+
offset: int
|
33 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
34 |
+
sep: str = "###"
|
35 |
+
sep2: str = None
|
36 |
+
version: str = "Unknown"
|
37 |
+
|
38 |
+
tokenizer_id: str = ""
|
39 |
+
tokenizer: Any = None
|
40 |
+
# Stop criteria (the default one is EOS token)
|
41 |
+
stop_str: Union[str, List[str]] = None
|
42 |
+
# Stops generation if meeting any token in this list
|
43 |
+
stop_token_ids: List[int] = None
|
44 |
+
|
45 |
+
skip_next: bool = False
|
46 |
+
|
47 |
+
def get_prompt(self):
|
48 |
+
messages = self.messages
|
49 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
50 |
+
messages = self.messages.copy()
|
51 |
+
init_role, init_msg = messages[0].copy()
|
52 |
+
init_msg = init_msg[0]
|
53 |
+
if "mmtag" in self.version:
|
54 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
55 |
+
messages[0] = (init_role, init_msg)
|
56 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
57 |
+
messages.insert(1, (self.roles[1], "Received."))
|
58 |
+
elif not init_msg.startswith("<image>"):
|
59 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
60 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
61 |
+
else:
|
62 |
+
messages[0] = (init_role, init_msg)
|
63 |
+
|
64 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
65 |
+
ret = self.system + self.sep
|
66 |
+
for role, message in messages:
|
67 |
+
if message:
|
68 |
+
if type(message) is tuple:
|
69 |
+
message, _, _ = message
|
70 |
+
ret += role + ": " + message + self.sep
|
71 |
+
else:
|
72 |
+
ret += role + ":"
|
73 |
+
|
74 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
75 |
+
seps = [self.sep, self.sep2]
|
76 |
+
ret = self.system + seps[0]
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if message:
|
79 |
+
if type(message) is tuple:
|
80 |
+
message, _, _ = message
|
81 |
+
ret += role + ": " + message + seps[i % 2]
|
82 |
+
else:
|
83 |
+
ret += role + ":"
|
84 |
+
|
85 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
86 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
87 |
+
for role, message in messages:
|
88 |
+
if message:
|
89 |
+
if type(message) is tuple:
|
90 |
+
message, images = message
|
91 |
+
message = "<image>" * len(images) + message
|
92 |
+
ret += role + "\n" + message + self.sep + "\n"
|
93 |
+
else:
|
94 |
+
ret += role + "\n"
|
95 |
+
return ret
|
96 |
+
|
97 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
98 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
99 |
+
for role, message in messages:
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple:
|
102 |
+
message, images = message
|
103 |
+
message = "<image>" * len(images) + message
|
104 |
+
chat_template_messages.append({"role": role, "content": message})
|
105 |
+
|
106 |
+
# print(chat_template_messages)
|
107 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
|
108 |
+
# ret = "" if self.system == "" else self.system + self.sep + "\n"
|
109 |
+
# for role, message in messages:
|
110 |
+
# if message:
|
111 |
+
# if type(message) is tuple:
|
112 |
+
# message, images = message
|
113 |
+
# message = "<image>" * len(images) + message
|
114 |
+
# ret += role + "\n" + message + self.sep + "\n"
|
115 |
+
# else:
|
116 |
+
# ret += role + "\n"
|
117 |
+
# return ret
|
118 |
+
|
119 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
120 |
+
ret = self.system + self.sep
|
121 |
+
for role, message in messages:
|
122 |
+
if message:
|
123 |
+
if type(message) is tuple:
|
124 |
+
message, _, _ = message
|
125 |
+
ret += role + message + self.sep
|
126 |
+
else:
|
127 |
+
ret += role
|
128 |
+
|
129 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
130 |
+
ret = ""
|
131 |
+
for i, (role, message) in enumerate(messages):
|
132 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
133 |
+
if message:
|
134 |
+
if type(message) is tuple:
|
135 |
+
message, _, _ = message
|
136 |
+
ret += role + message + self.sep
|
137 |
+
else:
|
138 |
+
ret += role
|
139 |
+
|
140 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
141 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
142 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
143 |
+
ret = ""
|
144 |
+
|
145 |
+
for i, (role, message) in enumerate(messages):
|
146 |
+
if i == 0:
|
147 |
+
assert message, "first message should not be none"
|
148 |
+
assert role == self.roles[0], "first message should come from user"
|
149 |
+
if message:
|
150 |
+
if type(message) is tuple:
|
151 |
+
message, _, _ = message
|
152 |
+
if i == 0:
|
153 |
+
message = wrap_sys(self.system) + message
|
154 |
+
if i % 2 == 0:
|
155 |
+
message = wrap_inst(message)
|
156 |
+
ret += self.sep + message
|
157 |
+
else:
|
158 |
+
ret += " " + message + " " + self.sep2
|
159 |
+
else:
|
160 |
+
ret += ""
|
161 |
+
ret = ret.lstrip(self.sep)
|
162 |
+
|
163 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
164 |
+
seps = [self.sep, self.sep2]
|
165 |
+
ret = self.system
|
166 |
+
for i, (role, message) in enumerate(messages):
|
167 |
+
if message:
|
168 |
+
if type(message) is tuple:
|
169 |
+
message, _, _ = message
|
170 |
+
ret += message + seps[i % 2]
|
171 |
+
else:
|
172 |
+
ret += ""
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
175 |
+
|
176 |
+
return ret
|
177 |
+
|
178 |
+
def append_message(self, role, message):
|
179 |
+
self.messages.append([role, message])
|
180 |
+
|
181 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
|
182 |
+
if image_process_mode == "Pad":
|
183 |
+
|
184 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
185 |
+
width, height = pil_img.size
|
186 |
+
if width == height:
|
187 |
+
return pil_img
|
188 |
+
elif width > height:
|
189 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
190 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
191 |
+
return result
|
192 |
+
else:
|
193 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
194 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
195 |
+
return result
|
196 |
+
|
197 |
+
image = expand2square(image)
|
198 |
+
elif image_process_mode in ["Default", "Crop"]:
|
199 |
+
pass
|
200 |
+
elif image_process_mode == "Resize":
|
201 |
+
image = image.resize((336, 336))
|
202 |
+
else:
|
203 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
204 |
+
|
205 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
206 |
+
aspect_ratio = max_hw / min_hw
|
207 |
+
max_len, min_len = 672, 448
|
208 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
209 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
210 |
+
W, H = image.size
|
211 |
+
if H > W:
|
212 |
+
H, W = longest_edge, shortest_edge
|
213 |
+
else:
|
214 |
+
H, W = shortest_edge, longest_edge
|
215 |
+
image = image.resize((W, H))
|
216 |
+
if return_pil:
|
217 |
+
return image
|
218 |
+
else:
|
219 |
+
buffered = BytesIO()
|
220 |
+
image.save(buffered, format=image_format)
|
221 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
222 |
+
return img_b64_str
|
223 |
+
|
224 |
+
def get_images(self, return_pil=False):
|
225 |
+
images = []
|
226 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
227 |
+
if i % 2 == 0:
|
228 |
+
if type(msg) is tuple:
|
229 |
+
msg, image, image_process_mode = msg
|
230 |
+
if type(image) != list:
|
231 |
+
image = [image]
|
232 |
+
for img in image:
|
233 |
+
img = self.process_image(img, image_process_mode, return_pil=return_pil)
|
234 |
+
images.append(img)
|
235 |
+
return images
|
236 |
+
|
237 |
+
def to_gradio_chatbot(self):
|
238 |
+
ret = []
|
239 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
240 |
+
if i % 2 == 0:
|
241 |
+
if type(msg) is tuple:
|
242 |
+
msg, image, image_process_mode = msg
|
243 |
+
if type(image) != list:
|
244 |
+
image = [image]
|
245 |
+
if len(image) == 1:
|
246 |
+
msg = "<image>\n" + msg.replace("<image>", "").strip()
|
247 |
+
else:
|
248 |
+
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
|
249 |
+
for img in image:
|
250 |
+
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
|
251 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}"/>'
|
252 |
+
msg = msg.replace("<image>", img_str, 1).strip()
|
253 |
+
if len(msg) > 0:
|
254 |
+
ret.append([msg, None])
|
255 |
+
else:
|
256 |
+
ret.append([msg, None])
|
257 |
+
else:
|
258 |
+
ret[-1][-1] = msg
|
259 |
+
return ret
|
260 |
+
|
261 |
+
def copy(self):
|
262 |
+
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
|
263 |
+
|
264 |
+
def dict(self):
|
265 |
+
if len(self.get_images()) > 0:
|
266 |
+
return {
|
267 |
+
"system": self.system,
|
268 |
+
"roles": self.roles,
|
269 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
270 |
+
"offset": self.offset,
|
271 |
+
"sep": self.sep,
|
272 |
+
"sep2": self.sep2,
|
273 |
+
}
|
274 |
+
return {
|
275 |
+
"system": self.system,
|
276 |
+
"roles": self.roles,
|
277 |
+
"messages": self.messages,
|
278 |
+
"offset": self.offset,
|
279 |
+
"sep": self.sep,
|
280 |
+
"sep2": self.sep2,
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
conv_vicuna_v0 = Conversation(
|
285 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
286 |
+
roles=("Human", "Assistant"),
|
287 |
+
messages=[
|
288 |
+
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
|
289 |
+
[
|
290 |
+
"Assistant",
|
291 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
292 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
293 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
294 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
295 |
+
"renewable and non-renewable energy sources:\n"
|
296 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
297 |
+
"energy sources are finite and will eventually run out.\n"
|
298 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
299 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
300 |
+
"and other negative effects.\n"
|
301 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
302 |
+
"have lower operational costs than non-renewable sources.\n"
|
303 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
304 |
+
"locations than non-renewable sources.\n"
|
305 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
306 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
307 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
308 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
309 |
+
],
|
310 |
+
],
|
311 |
+
offset=2,
|
312 |
+
sep_style=SeparatorStyle.SINGLE,
|
313 |
+
sep="###",
|
314 |
+
)
|
315 |
+
|
316 |
+
conv_vicuna_v1 = Conversation(
|
317 |
+
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
318 |
+
roles=("USER", "ASSISTANT"),
|
319 |
+
version="v1",
|
320 |
+
messages=[],
|
321 |
+
offset=0,
|
322 |
+
sep_style=SeparatorStyle.TWO,
|
323 |
+
sep=" ",
|
324 |
+
sep2="</s>",
|
325 |
+
)
|
326 |
+
|
327 |
+
conv_llama_2 = Conversation(
|
328 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
329 |
+
|
330 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
331 |
+
roles=("USER", "ASSISTANT"),
|
332 |
+
version="llama_v2",
|
333 |
+
messages=[],
|
334 |
+
offset=0,
|
335 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
336 |
+
sep="<s>",
|
337 |
+
sep2="</s>",
|
338 |
+
)
|
339 |
+
|
340 |
+
conv_llava_llama_2 = Conversation(
|
341 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
342 |
+
roles=("USER", "ASSISTANT"),
|
343 |
+
version="llama_v2",
|
344 |
+
messages=[],
|
345 |
+
offset=0,
|
346 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
347 |
+
sep="<s>",
|
348 |
+
sep2="</s>",
|
349 |
+
)
|
350 |
+
|
351 |
+
try:
|
352 |
+
llama3_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
353 |
+
except Exception as e:
|
354 |
+
print("Error loading llama3 tokenizer")
|
355 |
+
print(e)
|
356 |
+
|
357 |
+
# conv_llava_llama_3 = Conversation(
|
358 |
+
# system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
359 |
+
# roles=("<|start_header_id|>user", "<|start_header_id|>assistant"),
|
360 |
+
# version="llama_v3",
|
361 |
+
# messages=[],
|
362 |
+
# offset=0,
|
363 |
+
# sep_style=SeparatorStyle.LLAMA_3,
|
364 |
+
# tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
365 |
+
# tokenizer=llama3_tokenizer,
|
366 |
+
# stop_token_ids=[128009],
|
367 |
+
# )
|
368 |
+
|
369 |
+
conv_mistral_instruct = Conversation(
|
370 |
+
system="",
|
371 |
+
roles=("USER", "ASSISTANT"),
|
372 |
+
version="llama_v2",
|
373 |
+
messages=[],
|
374 |
+
offset=0,
|
375 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
376 |
+
sep="",
|
377 |
+
sep2="</s>",
|
378 |
+
)
|
379 |
+
|
380 |
+
conv_llava_llama_2_simple = Conversation(
|
381 |
+
system="Answer the questions about the visual content that the user provides.",
|
382 |
+
roles=("USER", "ASSISTANT"),
|
383 |
+
version="llama_v2",
|
384 |
+
messages=[],
|
385 |
+
offset=0,
|
386 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
387 |
+
sep="<s>",
|
388 |
+
sep2="</s>",
|
389 |
+
)
|
390 |
+
|
391 |
+
conv_llava_llama_2_mmtag = Conversation(
|
392 |
+
system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
|
393 |
+
roles=("USER", "ASSISTANT"),
|
394 |
+
version="llama_v2_mmtag",
|
395 |
+
messages=[],
|
396 |
+
offset=0,
|
397 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
398 |
+
sep="<s>",
|
399 |
+
sep2="</s>",
|
400 |
+
)
|
401 |
+
|
402 |
+
conv_mpt = Conversation(
|
403 |
+
system="""<|im_start|>system
|
404 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
405 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
406 |
+
version="mpt",
|
407 |
+
messages=[],
|
408 |
+
offset=0,
|
409 |
+
sep_style=SeparatorStyle.MPT,
|
410 |
+
sep="<|im_end|>",
|
411 |
+
)
|
412 |
+
|
413 |
+
conv_qwen = Conversation(
|
414 |
+
system="""<|im_start|>system
|
415 |
+
You are a helpful assistant.""",
|
416 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
417 |
+
version="qwen",
|
418 |
+
messages=[],
|
419 |
+
offset=0,
|
420 |
+
sep_style=SeparatorStyle.CHATML,
|
421 |
+
sep="<|im_end|>",
|
422 |
+
)
|
423 |
+
|
424 |
+
conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
|
425 |
+
|
426 |
+
conv_llava_plain = Conversation(
|
427 |
+
system="",
|
428 |
+
roles=("", ""),
|
429 |
+
messages=[],
|
430 |
+
offset=0,
|
431 |
+
sep_style=SeparatorStyle.PLAIN,
|
432 |
+
sep="\n",
|
433 |
+
)
|
434 |
+
|
435 |
+
conv_llava_v0 = Conversation(
|
436 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
437 |
+
roles=("Human", "Assistant"),
|
438 |
+
messages=[],
|
439 |
+
offset=0,
|
440 |
+
sep_style=SeparatorStyle.SINGLE,
|
441 |
+
sep="###",
|
442 |
+
)
|
443 |
+
|
444 |
+
conv_llava_v0_mmtag = Conversation(
|
445 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
446 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
447 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
448 |
+
roles=("Human", "Assistant"),
|
449 |
+
messages=[],
|
450 |
+
offset=0,
|
451 |
+
sep_style=SeparatorStyle.SINGLE,
|
452 |
+
sep="###",
|
453 |
+
version="v0_mmtag",
|
454 |
+
)
|
455 |
+
|
456 |
+
conv_llava_v1 = Conversation(
|
457 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
458 |
+
roles=("USER", "ASSISTANT"),
|
459 |
+
version="v1",
|
460 |
+
messages=[],
|
461 |
+
offset=0,
|
462 |
+
sep_style=SeparatorStyle.TWO,
|
463 |
+
sep=" ",
|
464 |
+
sep2="</s>",
|
465 |
+
)
|
466 |
+
|
467 |
+
conv_llava_v1_mmtag = Conversation(
|
468 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
469 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
470 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
471 |
+
roles=("USER", "ASSISTANT"),
|
472 |
+
messages=[],
|
473 |
+
offset=0,
|
474 |
+
sep_style=SeparatorStyle.TWO,
|
475 |
+
sep=" ",
|
476 |
+
sep2="</s>",
|
477 |
+
version="v1_mmtag",
|
478 |
+
)
|
479 |
+
|
480 |
+
conv_mistral_orca = Conversation(
|
481 |
+
system="""<|im_start|>system
|
482 |
+
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
|
483 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
484 |
+
version="mpt",
|
485 |
+
messages=[],
|
486 |
+
offset=0,
|
487 |
+
sep_style=SeparatorStyle.MPT,
|
488 |
+
sep="<|im_end|>",
|
489 |
+
)
|
490 |
+
|
491 |
+
conv_mistral_zephyr = Conversation(
|
492 |
+
system="""<|system|>
|
493 |
+
You are a helpful AI assistant.""",
|
494 |
+
roles=("<|user|>\n", "<|assistant|>\n"),
|
495 |
+
version="mpt",
|
496 |
+
messages=[],
|
497 |
+
offset=0,
|
498 |
+
sep_style=SeparatorStyle.MPT,
|
499 |
+
sep="</s>",
|
500 |
+
)
|
501 |
+
|
502 |
+
conv_mistral_direct = Conversation(
|
503 |
+
system="""<|im_start|>system
|
504 |
+
Answer the questions.""",
|
505 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
506 |
+
version="mpt",
|
507 |
+
messages=[],
|
508 |
+
offset=0,
|
509 |
+
sep_style=SeparatorStyle.MPT,
|
510 |
+
sep="<|im_end|>",
|
511 |
+
)
|
512 |
+
|
513 |
+
conv_chatml_direct = Conversation(
|
514 |
+
system="""<|im_start|>system
|
515 |
+
Answer the questions.""",
|
516 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
517 |
+
version="mpt",
|
518 |
+
messages=[],
|
519 |
+
offset=0,
|
520 |
+
sep_style=SeparatorStyle.MPT,
|
521 |
+
sep="<|im_end|>",
|
522 |
+
)
|
523 |
+
|
524 |
+
default_conversation = conv_vicuna_v0
|
525 |
+
conv_templates = {
|
526 |
+
"default": conv_vicuna_v0,
|
527 |
+
"v0": conv_vicuna_v0,
|
528 |
+
"v1": conv_vicuna_v1,
|
529 |
+
"vicuna_v1": conv_vicuna_v1,
|
530 |
+
"llama_2": conv_llama_2,
|
531 |
+
"mistral_instruct": conv_mistral_instruct,
|
532 |
+
"mistral_orca": conv_mistral_orca,
|
533 |
+
"mistral_zephyr": conv_mistral_zephyr,
|
534 |
+
"mistral_direct": conv_mistral_direct,
|
535 |
+
"plain": conv_llava_plain,
|
536 |
+
"v0_plain": conv_llava_plain,
|
537 |
+
"chatml_direct": conv_chatml_direct,
|
538 |
+
"llava_v0": conv_llava_v0,
|
539 |
+
"llava_v0_mmtag": conv_llava_v0_mmtag,
|
540 |
+
"llava_v1": conv_llava_v1,
|
541 |
+
"llava_v1_mmtag": conv_llava_v1_mmtag,
|
542 |
+
"llava_llama_2": conv_llava_llama_2,
|
543 |
+
# "llava_llama_3": conv_llava_llama_3,
|
544 |
+
"llava_llama_2_simple": conv_llava_llama_2_simple,
|
545 |
+
"llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
|
546 |
+
"llava_mistral_instruct": conv_mistral_instruct,
|
547 |
+
"mpt": conv_mpt,
|
548 |
+
"qwen_1_5": conv_qwen,
|
549 |
+
"gemma_instruct": conv_gemma_instruct,
|
550 |
+
}
|
551 |
+
|
552 |
+
|
553 |
+
if __name__ == "__main__":
|
554 |
+
print(default_conversation.get_prompt())
|
llava/eval/evaluate_interleave.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from rouge import Rouge
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
|
11 |
+
spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
|
12 |
+
image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
|
13 |
+
visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
|
14 |
+
visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
|
15 |
+
text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
|
16 |
+
multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
|
17 |
+
|
18 |
+
puzzle = ["RAVEN"]
|
19 |
+
nlrv2 = ["NLVR2_Mantis"]
|
20 |
+
qbench = ["QBench"]
|
21 |
+
|
22 |
+
class Eval:
|
23 |
+
def __init__(self):
|
24 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
25 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
26 |
+
self.punct = [
|
27 |
+
";",
|
28 |
+
r"/",
|
29 |
+
"[",
|
30 |
+
"]",
|
31 |
+
'"',
|
32 |
+
"{",
|
33 |
+
"}",
|
34 |
+
"(",
|
35 |
+
")",
|
36 |
+
"=",
|
37 |
+
"+",
|
38 |
+
"\\",
|
39 |
+
"_",
|
40 |
+
"-",
|
41 |
+
">",
|
42 |
+
"<",
|
43 |
+
"@",
|
44 |
+
"`",
|
45 |
+
",",
|
46 |
+
"?",
|
47 |
+
"!",
|
48 |
+
]
|
49 |
+
|
50 |
+
def processPunctuation(self, inText):
|
51 |
+
outText = inText
|
52 |
+
for p in self.punct:
|
53 |
+
if (p + " " in inText or " " + p in inText) or (
|
54 |
+
re.search(self.commaStrip, inText) != None
|
55 |
+
):
|
56 |
+
outText = outText.replace(p, "")
|
57 |
+
else:
|
58 |
+
outText = outText.replace(p, " ")
|
59 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
60 |
+
return outText
|
61 |
+
|
62 |
+
def process(self, answer):
|
63 |
+
answer = answer.replace("\n", " ")
|
64 |
+
answer = answer.replace("\t", " ")
|
65 |
+
answer = answer.strip()
|
66 |
+
answer = self.processPunctuation(answer)
|
67 |
+
answer = answer.strip('\'')
|
68 |
+
answer = answer.strip('\"')
|
69 |
+
answer = answer.strip(')')
|
70 |
+
answer = answer.strip('(')
|
71 |
+
answer = answer.strip().lower()
|
72 |
+
return answer
|
73 |
+
|
74 |
+
def evaluate_rouge(self,preds):
|
75 |
+
rouge = Rouge()
|
76 |
+
acc = {'f': []}
|
77 |
+
eval_list = []
|
78 |
+
for i, res in enumerate(preds):
|
79 |
+
sample_id = res['sample_id']
|
80 |
+
# print(sample_id)
|
81 |
+
gt_ans = self.process(res["gt_response"])
|
82 |
+
pred_ans = self.process(res["pred_response"])
|
83 |
+
# assert gt_ans != ''
|
84 |
+
|
85 |
+
if gt_ans == '':
|
86 |
+
continue
|
87 |
+
|
88 |
+
if pred_ans == '':
|
89 |
+
s = 0
|
90 |
+
else:
|
91 |
+
if len(pred_ans) > 512:
|
92 |
+
pred_ans = pred_ans[0: 512]
|
93 |
+
s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
|
94 |
+
acc['f'].append(s)
|
95 |
+
eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
|
96 |
+
results = {'Rouge-L f': np.mean(acc['f'])}
|
97 |
+
return results,eval_list
|
98 |
+
|
99 |
+
|
100 |
+
def judge_multi_choice(self,sample):
|
101 |
+
sample_id = sample['sample_id']
|
102 |
+
gt_ans = sample["gt_response"]
|
103 |
+
pred_ans = sample["pred_response"]
|
104 |
+
|
105 |
+
if ":" in pred_ans:
|
106 |
+
a_list = pred_ans.split(":")
|
107 |
+
a_list = [a.strip() for a in a_list ]
|
108 |
+
for a in a_list:
|
109 |
+
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
110 |
+
pred_ans = a
|
111 |
+
|
112 |
+
if pred_ans == gt_ans:
|
113 |
+
return 1
|
114 |
+
else:
|
115 |
+
return 0
|
116 |
+
|
117 |
+
def process_sample(self,sample):
|
118 |
+
sample["gt_response"] = self.process(sample["gt_response"])
|
119 |
+
sample["pred_response"] = self.process(sample["pred_response"])
|
120 |
+
|
121 |
+
def evaluate_multichoice(self, preditions):
|
122 |
+
correct = 0
|
123 |
+
eval_list = []
|
124 |
+
for i, sample in enumerate(preditions):
|
125 |
+
self.process_sample(sample)
|
126 |
+
score = self.judge_multi_choice(sample)
|
127 |
+
sample_id = sample['sample_id']
|
128 |
+
sample['result'] = score
|
129 |
+
eval_list.append({'id':str(sample_id),'score':str(score)})
|
130 |
+
correct+=score
|
131 |
+
return {'Accuracy':correct/len(preditions)},eval_list
|
132 |
+
|
133 |
+
def evaluate_multi_choice_image(self,preditions):
|
134 |
+
correct = 0
|
135 |
+
eval_list = []
|
136 |
+
for i,sample in enumerate(preditions):
|
137 |
+
gt_ans = self.process(sample["gt_response"])
|
138 |
+
pred_ans = self.process(sample["pred_response"])
|
139 |
+
sample_id = sample['sample_id']
|
140 |
+
|
141 |
+
if ":" in pred_ans:
|
142 |
+
a_list = pred_ans.split(":")
|
143 |
+
a_list = [a.strip() for a in a_list ]
|
144 |
+
for a in a_list:
|
145 |
+
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
146 |
+
pred_ans = a
|
147 |
+
|
148 |
+
if gt_ans == pred_ans:
|
149 |
+
score = 1
|
150 |
+
else:
|
151 |
+
score = 0
|
152 |
+
sample_id = sample['sample_id']
|
153 |
+
sample['result'] = score
|
154 |
+
eval_list.append({'id':str(sample_id),'score':str(score)})
|
155 |
+
correct+=score
|
156 |
+
return {'Accuracy':correct/len(preditions)},eval_list
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
parser = argparse.ArgumentParser()
|
161 |
+
parser.add_argument('--result-dir', type=str, required=True)
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
|
165 |
+
result_file = os.path.join(args.result_dir, "result.jsonl")
|
166 |
+
|
167 |
+
if not os.path.exists(result_file):
|
168 |
+
print('No prediction file found')
|
169 |
+
exit(0)
|
170 |
+
with open(result_file, 'r') as f:
|
171 |
+
preds_all = [json.loads(line) for line in f]
|
172 |
+
|
173 |
+
preds_all_dict = dict()
|
174 |
+
for pred in preds_all:
|
175 |
+
if pred["dataset"] not in preds_all_dict:
|
176 |
+
preds_all_dict[pred["dataset"]] = list()
|
177 |
+
preds_all_dict[pred["dataset"]].append(pred)
|
178 |
+
|
179 |
+
image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
|
180 |
+
E = Eval()
|
181 |
+
|
182 |
+
eval_result_list = dict()
|
183 |
+
eval_result_list_detail = dict()
|
184 |
+
|
185 |
+
for dataset in preds_all_dict:
|
186 |
+
|
187 |
+
preds = preds_all_dict[dataset]
|
188 |
+
question_type = preds[0]["question_type"]
|
189 |
+
|
190 |
+
if question_type == 'open-ended':
|
191 |
+
eval_result, eval_list = E.evaluate_rouge(preds)
|
192 |
+
|
193 |
+
elif question_type == 'multi-choice' or dataset == 'nlrv2':
|
194 |
+
if dataset in image_choice_dataset_list:
|
195 |
+
eval_result, eval_list = E.evaluate_multi_choice_image(preds)
|
196 |
+
else:
|
197 |
+
eval_result, eval_list = E.evaluate_multichoice(preds)
|
198 |
+
|
199 |
+
else:
|
200 |
+
eval_result = 'Dataset not supported'
|
201 |
+
print('Dataset not supported')
|
202 |
+
exit(0)
|
203 |
+
|
204 |
+
print(dataset, end = ': ')
|
205 |
+
print(eval_result)
|
206 |
+
|
207 |
+
eval_result_list[dataset] = eval_result
|
208 |
+
eval_result_list_detail[dataset] = eval_list
|
209 |
+
|
210 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
211 |
+
with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
|
212 |
+
json.dump(eval_result_list, f, indent=4)
|
213 |
+
|
214 |
+
with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
|
215 |
+
json.dump(eval_result_list_detail, f, indent=4)
|
216 |
+
|
217 |
+
|
218 |
+
eval_cat_list = dict()
|
219 |
+
print()
|
220 |
+
|
221 |
+
# spot_the_diff
|
222 |
+
score = 0
|
223 |
+
count = 0
|
224 |
+
for dataset in eval_result_list:
|
225 |
+
if dataset in spot_the_diff:
|
226 |
+
count += 1
|
227 |
+
score += list(eval_result_list[dataset].values())[0]
|
228 |
+
if count > 0:
|
229 |
+
score /= count
|
230 |
+
eval_cat_list["spot_the_diff"] = score
|
231 |
+
print("spot_the_diff", end = ': ')
|
232 |
+
print('{:.2f}'.format(100 * score))
|
233 |
+
|
234 |
+
# image_edit_instruct
|
235 |
+
score = 0
|
236 |
+
count = 0
|
237 |
+
for dataset in eval_result_list:
|
238 |
+
if dataset in image_edit_instruct:
|
239 |
+
count += 1
|
240 |
+
score += list(eval_result_list[dataset].values())[0]
|
241 |
+
if count > 0:
|
242 |
+
score /= count
|
243 |
+
eval_cat_list["image_edit_instruct"] = score
|
244 |
+
print("image_edit_instruct", end = ': ')
|
245 |
+
print('{:.2f}'.format(100 * score))
|
246 |
+
|
247 |
+
# visual_story_telling
|
248 |
+
score = 0
|
249 |
+
count = 0
|
250 |
+
for dataset in eval_result_list:
|
251 |
+
if dataset in visual_story_telling:
|
252 |
+
count += 1
|
253 |
+
score += list(eval_result_list[dataset].values())[0]
|
254 |
+
if count > 0:
|
255 |
+
score /= count
|
256 |
+
eval_cat_list["visual_story_telling"] = score
|
257 |
+
print("visual_story_telling", end = ': ')
|
258 |
+
print('{:.2f}'.format(100 * score))
|
259 |
+
|
260 |
+
# visual_cloze
|
261 |
+
score = 0
|
262 |
+
count = 0
|
263 |
+
for dataset in eval_result_list:
|
264 |
+
if dataset in visual_cloze:
|
265 |
+
count += 1
|
266 |
+
score += list(eval_result_list[dataset].values())[0]
|
267 |
+
if count > 0:
|
268 |
+
score /= count
|
269 |
+
eval_cat_list["visual_cloze"] = score
|
270 |
+
print("visual_cloze", end = ': ')
|
271 |
+
print('{:.2f}'.format(100 * score))
|
272 |
+
|
273 |
+
# text_rich_vqa
|
274 |
+
score = 0
|
275 |
+
count = 0
|
276 |
+
for dataset in eval_result_list:
|
277 |
+
if dataset in text_rich_vqa:
|
278 |
+
count += 1
|
279 |
+
score += list(eval_result_list[dataset].values())[0]
|
280 |
+
if count > 0:
|
281 |
+
score /= count
|
282 |
+
eval_cat_list["text_rich_vqa"] = score
|
283 |
+
print("text_rich_vqa", end = ': ')
|
284 |
+
print('{:.2f}'.format(100 * score))
|
285 |
+
|
286 |
+
# multi_image_vqa
|
287 |
+
score = 0
|
288 |
+
count = 0
|
289 |
+
for dataset in eval_result_list:
|
290 |
+
if dataset in multi_image_vqa:
|
291 |
+
count += 1
|
292 |
+
score += list(eval_result_list[dataset].values())[0]
|
293 |
+
if count > 0:
|
294 |
+
score /= count
|
295 |
+
eval_cat_list["multi_image_vqa"] = score
|
296 |
+
print("multi_image_vqa", end = ': ')
|
297 |
+
print('{:.2f}'.format(100 * score))
|
298 |
+
|
299 |
+
# puzzle
|
300 |
+
score = 0
|
301 |
+
count = 0
|
302 |
+
for dataset in eval_result_list:
|
303 |
+
if dataset in puzzle:
|
304 |
+
count += 1
|
305 |
+
score += list(eval_result_list[dataset].values())[0]
|
306 |
+
if count > 0:
|
307 |
+
score /= count
|
308 |
+
eval_cat_list["puzzle"] = score
|
309 |
+
print("puzzle", end = ': ')
|
310 |
+
print('{:.2f}'.format(100 * score))
|
311 |
+
|
312 |
+
# nlrv2
|
313 |
+
score = 0
|
314 |
+
count = 0
|
315 |
+
for dataset in eval_result_list:
|
316 |
+
if dataset in nlrv2:
|
317 |
+
count += 1
|
318 |
+
score += list(eval_result_list[dataset].values())[0]
|
319 |
+
if count > 0:
|
320 |
+
score /= count
|
321 |
+
eval_cat_list["nlrv2"] = score
|
322 |
+
print("nlrv2", end = ': ')
|
323 |
+
print('{:.2f}'.format(100 * score))
|
324 |
+
|
325 |
+
# qbench
|
326 |
+
score = 0
|
327 |
+
count = 0
|
328 |
+
for dataset in eval_result_list:
|
329 |
+
if dataset in qbench:
|
330 |
+
count += 1
|
331 |
+
score += list(eval_result_list[dataset].values())[0]
|
332 |
+
if count > 0:
|
333 |
+
score /= count
|
334 |
+
eval_cat_list["qbench"] = score
|
335 |
+
print("qbench", end = ': ')
|
336 |
+
print('{:.2f}'.format(100 * score))
|
337 |
+
|
338 |
+
with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
|
339 |
+
json.dump(eval_cat_list, f, indent=4)
|
llava/eval/model_vqa.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
10 |
+
from llava.model.builder import load_pretrained_model
|
11 |
+
from llava.utils import disable_torch_init
|
12 |
+
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
|
15 |
+
from typing import Dict, Optional, Sequence, List
|
16 |
+
import transformers
|
17 |
+
import re
|
18 |
+
|
19 |
+
from PIL import Image
|
20 |
+
import math
|
21 |
+
|
22 |
+
|
23 |
+
def split_list(lst, n):
|
24 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
25 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
26 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
27 |
+
|
28 |
+
|
29 |
+
def get_chunk(lst, n, k):
|
30 |
+
chunks = split_list(lst, n)
|
31 |
+
return chunks[k]
|
32 |
+
|
33 |
+
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
|
34 |
+
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
|
35 |
+
|
36 |
+
im_start, im_end = tokenizer.additional_special_tokens_ids
|
37 |
+
nl_tokens = tokenizer("\n").input_ids
|
38 |
+
_system = tokenizer("system").input_ids + nl_tokens
|
39 |
+
_user = tokenizer("user").input_ids + nl_tokens
|
40 |
+
_assistant = tokenizer("assistant").input_ids + nl_tokens
|
41 |
+
|
42 |
+
# Apply prompt templates
|
43 |
+
input_ids, targets = [], []
|
44 |
+
|
45 |
+
source = sources
|
46 |
+
if roles[source[0]["from"]] != roles["human"]:
|
47 |
+
source = source[1:]
|
48 |
+
|
49 |
+
input_id, target = [], []
|
50 |
+
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
|
51 |
+
input_id += system
|
52 |
+
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
|
53 |
+
assert len(input_id) == len(target)
|
54 |
+
for j, sentence in enumerate(source):
|
55 |
+
role = roles[sentence["from"]]
|
56 |
+
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
|
57 |
+
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
|
58 |
+
texts = sentence["value"].split('<image>')
|
59 |
+
_input_id = tokenizer(role).input_ids + nl_tokens
|
60 |
+
for i,text in enumerate(texts):
|
61 |
+
_input_id += tokenizer(text).input_ids
|
62 |
+
if i<len(texts)-1:
|
63 |
+
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
|
64 |
+
_input_id += [im_end] + nl_tokens
|
65 |
+
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
|
66 |
+
else:
|
67 |
+
if sentence["value"] is None:
|
68 |
+
_input_id = tokenizer(role).input_ids + nl_tokens
|
69 |
+
else:
|
70 |
+
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
|
71 |
+
input_id += _input_id
|
72 |
+
if role == "<|im_start|>user":
|
73 |
+
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
|
74 |
+
elif role == "<|im_start|>assistant":
|
75 |
+
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
|
76 |
+
else:
|
77 |
+
raise NotImplementedError
|
78 |
+
target += _target
|
79 |
+
|
80 |
+
input_ids.append(input_id)
|
81 |
+
targets.append(target)
|
82 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
83 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
84 |
+
return input_ids
|
85 |
+
|
86 |
+
def eval_model(args):
|
87 |
+
|
88 |
+
# Model
|
89 |
+
disable_torch_init()
|
90 |
+
model_path = os.path.expanduser(args.model_path)
|
91 |
+
model_name = get_model_name_from_path(model_path)
|
92 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
93 |
+
|
94 |
+
# Data
|
95 |
+
with open(os.path.expanduser(args.question_file)) as f:
|
96 |
+
questions = json.load(f)
|
97 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
98 |
+
answers_file = os.path.expanduser(args.answers_file)
|
99 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
100 |
+
ans_file = open(answers_file, "w")
|
101 |
+
|
102 |
+
for line in tqdm(questions):
|
103 |
+
idx = line["sample_id"]
|
104 |
+
question_type = line["metadata"]["question_type"]
|
105 |
+
dataset_name = line["metadata"]["dataset"]
|
106 |
+
gt = line["conversations"][1]["value"]
|
107 |
+
|
108 |
+
image_files = line["image"]
|
109 |
+
qs = line["conversations"][0]["value"]
|
110 |
+
cur_prompt = args.extra_prompt + qs
|
111 |
+
|
112 |
+
args.conv_mode = "qwen_1_5"
|
113 |
+
|
114 |
+
conv = conv_templates[args.conv_mode].copy()
|
115 |
+
conv.append_message(conv.roles[0], qs)
|
116 |
+
conv.append_message(conv.roles[1], None)
|
117 |
+
prompt = conv.get_prompt()
|
118 |
+
|
119 |
+
input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
|
120 |
+
img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
|
121 |
+
|
122 |
+
image_tensors = []
|
123 |
+
for image_file in image_files:
|
124 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
125 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
|
126 |
+
image_tensors.append(image_tensor.half().cuda())
|
127 |
+
# image_tensors = torch.cat(image_tensors, dim=0)
|
128 |
+
|
129 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
130 |
+
keywords = [stop_str]
|
131 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
132 |
+
|
133 |
+
with torch.inference_mode():
|
134 |
+
output_ids = model.generate(
|
135 |
+
input_ids,
|
136 |
+
images=image_tensors,
|
137 |
+
do_sample=True if args.temperature > 0 else False,
|
138 |
+
temperature=args.temperature,
|
139 |
+
top_p=args.top_p,
|
140 |
+
num_beams=args.num_beams,
|
141 |
+
# no_repeat_ngram_size=3,
|
142 |
+
max_new_tokens=1024,
|
143 |
+
use_cache=True)
|
144 |
+
|
145 |
+
|
146 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
147 |
+
outputs = outputs.strip()
|
148 |
+
if outputs.endswith(stop_str):
|
149 |
+
outputs = outputs[:-len(stop_str)]
|
150 |
+
outputs = outputs.strip()
|
151 |
+
|
152 |
+
ans_id = shortuuid.uuid()
|
153 |
+
ans_file.write(json.dumps({
|
154 |
+
"dataset": dataset_name,
|
155 |
+
"sample_id": idx,
|
156 |
+
"prompt": cur_prompt,
|
157 |
+
"pred_response": outputs,
|
158 |
+
"gt_response": gt,
|
159 |
+
"shortuuid": ans_id,
|
160 |
+
"model_id": model_name,
|
161 |
+
"question_type": question_type,
|
162 |
+
}) + "\n")
|
163 |
+
ans_file.flush()
|
164 |
+
|
165 |
+
if len(line["conversations"]) > 2:
|
166 |
+
|
167 |
+
for i in range(2, len(line["conversations"]), 2):
|
168 |
+
input_ids = torch.cat((input_ids, output_ids), dim=1)
|
169 |
+
|
170 |
+
gt = line["conversations"][i + 1]["value"]
|
171 |
+
qs = line["conversations"][i]["value"]
|
172 |
+
cur_prompt = args.extra_prompt + qs
|
173 |
+
|
174 |
+
args.conv_mode = "qwen_1_5"
|
175 |
+
|
176 |
+
conv = conv_templates[args.conv_mode].copy()
|
177 |
+
conv.append_message(conv.roles[0], qs)
|
178 |
+
conv.append_message(conv.roles[1], None)
|
179 |
+
prompt = conv.get_prompt()
|
180 |
+
|
181 |
+
input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
|
182 |
+
input_ids = torch.cat((input_ids, input_ids_new), dim=1)
|
183 |
+
img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
|
184 |
+
|
185 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
186 |
+
keywords = [stop_str]
|
187 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
188 |
+
|
189 |
+
with torch.inference_mode():
|
190 |
+
output_ids = model.generate(
|
191 |
+
input_ids,
|
192 |
+
images=image_tensors,
|
193 |
+
do_sample=True if args.temperature > 0 else False,
|
194 |
+
temperature=args.temperature,
|
195 |
+
top_p=args.top_p,
|
196 |
+
num_beams=args.num_beams,
|
197 |
+
# no_repeat_ngram_size=3,
|
198 |
+
max_new_tokens=1024,
|
199 |
+
use_cache=True)
|
200 |
+
|
201 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
202 |
+
outputs = outputs.strip()
|
203 |
+
if outputs.endswith(stop_str):
|
204 |
+
outputs = outputs[:-len(stop_str)]
|
205 |
+
outputs = outputs.strip()
|
206 |
+
|
207 |
+
ans_id = shortuuid.uuid()
|
208 |
+
ans_file.write(json.dumps({
|
209 |
+
"dataset": dataset_name,
|
210 |
+
"sample_id": idx,
|
211 |
+
"prompt": cur_prompt,
|
212 |
+
"pred_response": outputs,
|
213 |
+
"gt_response": gt,
|
214 |
+
"shortuuid": ans_id,
|
215 |
+
"model_id": model_name,
|
216 |
+
"question_type": question_type,
|
217 |
+
}) + "\n")
|
218 |
+
ans_file.flush()
|
219 |
+
|
220 |
+
|
221 |
+
ans_file.close()
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
parser = argparse.ArgumentParser()
|
225 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
226 |
+
parser.add_argument("--model-base", type=str, default=None)
|
227 |
+
parser.add_argument("--image-folder", type=str, default="")
|
228 |
+
parser.add_argument("--extra-prompt", type=str, default="")
|
229 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
230 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
231 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
232 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
233 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
234 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
235 |
+
parser.add_argument("--top_p", type=float, default=None)
|
236 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
237 |
+
parser.add_argument("--test_size", type=int, default=10000000)
|
238 |
+
args = parser.parse_args()
|
239 |
+
|
240 |
+
eval_model(args)
|
llava/mm_utils.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import math
|
5 |
+
import ast
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
10 |
+
|
11 |
+
|
12 |
+
def resize_and_center_crop(image, shortest_edge_length):
|
13 |
+
# Calculate new dimensions and resize
|
14 |
+
aspect_ratio = float(image.width) / float(image.height)
|
15 |
+
if aspect_ratio > 1:
|
16 |
+
new_width = int(shortest_edge_length * aspect_ratio)
|
17 |
+
new_height = shortest_edge_length
|
18 |
+
else:
|
19 |
+
new_width = shortest_edge_length
|
20 |
+
new_height = int(shortest_edge_length / aspect_ratio)
|
21 |
+
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
22 |
+
|
23 |
+
# Calculate the position and perform the center crop
|
24 |
+
left = (new_width - shortest_edge_length) / 2
|
25 |
+
top = (new_height - shortest_edge_length) / 2
|
26 |
+
right = (new_width + shortest_edge_length) / 2
|
27 |
+
bottom = (new_height + shortest_edge_length) / 2
|
28 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
29 |
+
|
30 |
+
return cropped_image
|
31 |
+
|
32 |
+
|
33 |
+
def auto_pad_images(image, grid_params):
|
34 |
+
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
|
35 |
+
assert len(grid_params) > 0, "Grid parameters should not be empty"
|
36 |
+
|
37 |
+
# Step 1: Calculate and find the closest aspect ratio
|
38 |
+
input_width, input_height = image.size
|
39 |
+
input_aspect_ratio = input_width / input_height
|
40 |
+
candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
|
41 |
+
closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
|
42 |
+
|
43 |
+
candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
|
44 |
+
|
45 |
+
target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
|
46 |
+
|
47 |
+
resize_width, resize_height = target_resolution
|
48 |
+
if input_width > input_height:
|
49 |
+
resize_height = int(resize_width / input_aspect_ratio)
|
50 |
+
else:
|
51 |
+
resize_width = int(resize_height * input_aspect_ratio)
|
52 |
+
resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
|
53 |
+
|
54 |
+
# Step 5: Pad the resized image if necessary to match the target resolution
|
55 |
+
pad_width = target_resolution[0] - resize_width
|
56 |
+
pad_height = target_resolution[1] - resize_height
|
57 |
+
padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
|
58 |
+
padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
|
59 |
+
|
60 |
+
return padded_image
|
61 |
+
|
62 |
+
|
63 |
+
def extract_patches(image, patch_size, overlap_ratio):
|
64 |
+
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
|
65 |
+
assert patch_size > 0, "Patch size should be greater than 0"
|
66 |
+
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
|
67 |
+
|
68 |
+
W, H = image.size
|
69 |
+
patches = []
|
70 |
+
|
71 |
+
stride = int(patch_size * (1 - overlap_ratio))
|
72 |
+
|
73 |
+
num_patches_y = (H - patch_size) // stride + 1
|
74 |
+
num_patches_x = (W - patch_size) // stride + 1
|
75 |
+
|
76 |
+
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
|
77 |
+
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
|
78 |
+
|
79 |
+
for y in range(y_start, y_start + num_patches_y * stride, stride):
|
80 |
+
for x in range(x_start, x_start + num_patches_x * stride, stride):
|
81 |
+
patch = image.crop((x, y, x + patch_size, y + patch_size))
|
82 |
+
patches.append(patch)
|
83 |
+
|
84 |
+
return patches
|
85 |
+
|
86 |
+
|
87 |
+
def process_highres_image_crop_split(image, data_args, processor=None):
|
88 |
+
crop_resolution = data_args.image_crop_resolution
|
89 |
+
split_resolution = data_args.image_split_resolution
|
90 |
+
if processor is None:
|
91 |
+
processor = data_args.image_processor
|
92 |
+
image_crop = resize_and_center_crop(image, crop_resolution)
|
93 |
+
image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
|
94 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
95 |
+
return torch.stack(image_patches, dim=0)
|
96 |
+
|
97 |
+
|
98 |
+
def process_highres_image(image, processor, grid_pinpoints):
|
99 |
+
grid_params = [int(x) for x in grid_pinpoints.split(",")]
|
100 |
+
width_height = max(image.size)
|
101 |
+
fit_grid_params = [x for x in grid_params if x >= width_height]
|
102 |
+
if len(fit_grid_params) == 0:
|
103 |
+
select_size = max(grid_params)
|
104 |
+
else:
|
105 |
+
select_size = min(fit_grid_params)
|
106 |
+
# FIXME: always select the 448
|
107 |
+
select_size = max(grid_params)
|
108 |
+
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
109 |
+
|
110 |
+
# FIXME: this seems to be a bug that it always resizes instead of padding
|
111 |
+
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
|
112 |
+
image_padded = image_padded.resize((select_size, select_size))
|
113 |
+
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
|
114 |
+
image_patches = [image_original_resize] + image_patches
|
115 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
116 |
+
return torch.stack(image_patches, dim=0)
|
117 |
+
|
118 |
+
|
119 |
+
def select_best_resolution(original_size, possible_resolutions):
|
120 |
+
"""
|
121 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
125 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
tuple: The best fit resolution in the format (width, height).
|
129 |
+
"""
|
130 |
+
original_width, original_height = original_size
|
131 |
+
best_fit = None
|
132 |
+
max_effective_resolution = 0
|
133 |
+
min_wasted_resolution = float("inf")
|
134 |
+
|
135 |
+
for width, height in possible_resolutions:
|
136 |
+
# Calculate the downscaled size to keep the aspect ratio
|
137 |
+
scale = min(width / original_width, height / original_height)
|
138 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
139 |
+
|
140 |
+
# Calculate effective and wasted resolutions
|
141 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
142 |
+
wasted_resolution = (width * height) - effective_resolution
|
143 |
+
|
144 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
145 |
+
max_effective_resolution = effective_resolution
|
146 |
+
min_wasted_resolution = wasted_resolution
|
147 |
+
best_fit = (width, height)
|
148 |
+
|
149 |
+
return best_fit
|
150 |
+
|
151 |
+
|
152 |
+
def resize_and_pad_image(image, target_resolution):
|
153 |
+
"""
|
154 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
image (PIL.Image.Image): The input image.
|
158 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
PIL.Image.Image: The resized and padded image.
|
162 |
+
"""
|
163 |
+
original_width, original_height = image.size
|
164 |
+
target_width, target_height = target_resolution
|
165 |
+
|
166 |
+
# Determine which dimension (width or height) to fill
|
167 |
+
scale_w = target_width / original_width
|
168 |
+
scale_h = target_height / original_height
|
169 |
+
|
170 |
+
if scale_w < scale_h:
|
171 |
+
# Width will be filled completely
|
172 |
+
new_width = target_width
|
173 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
174 |
+
else:
|
175 |
+
# Height will be filled completely
|
176 |
+
new_height = target_height
|
177 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
178 |
+
|
179 |
+
# Resize the image
|
180 |
+
resized_image = image.resize((new_width, new_height))
|
181 |
+
|
182 |
+
# Create a new image with the target size and paste the resized image onto it
|
183 |
+
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
184 |
+
paste_x = (target_width - new_width) // 2
|
185 |
+
paste_y = (target_height - new_height) // 2
|
186 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
187 |
+
|
188 |
+
return new_image
|
189 |
+
|
190 |
+
|
191 |
+
def divide_to_patches(image, patch_size):
|
192 |
+
"""
|
193 |
+
Divides an image into patches of a specified size.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image (PIL.Image.Image): The input image.
|
197 |
+
patch_size (int): The size of each patch.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
201 |
+
"""
|
202 |
+
patches = []
|
203 |
+
width, height = image.size
|
204 |
+
for i in range(0, height, patch_size):
|
205 |
+
for j in range(0, width, patch_size):
|
206 |
+
box = (j, i, j + patch_size, i + patch_size)
|
207 |
+
patch = image.crop(box)
|
208 |
+
patches.append(patch)
|
209 |
+
|
210 |
+
return patches
|
211 |
+
|
212 |
+
|
213 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
214 |
+
"""
|
215 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
219 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
220 |
+
patch_size (int): The size of each image patch.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
224 |
+
"""
|
225 |
+
if isinstance(grid_pinpoints, str):
|
226 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
227 |
+
grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
|
228 |
+
grid_pinpoints = [[int(x) * patch_size for x in item.split(",")] for item in grid_pinpoints]
|
229 |
+
|
230 |
+
if type(grid_pinpoints) is list:
|
231 |
+
possible_resolutions = grid_pinpoints
|
232 |
+
else:
|
233 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
234 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
235 |
+
return width // patch_size, height // patch_size
|
236 |
+
|
237 |
+
|
238 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
239 |
+
"""
|
240 |
+
Process an image with variable resolutions.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
image (PIL.Image.Image): The input image to be processed.
|
244 |
+
processor: The image processor object.
|
245 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
torch.Tensor: A tensor containing the processed image patches.
|
249 |
+
"""
|
250 |
+
# Convert grid_pinpoints from string to list
|
251 |
+
if isinstance(grid_pinpoints, str):
|
252 |
+
vis_encoder_size = processor.size[0]
|
253 |
+
assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]"
|
254 |
+
grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
|
255 |
+
grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints]
|
256 |
+
|
257 |
+
if type(grid_pinpoints) is list:
|
258 |
+
possible_resolutions = grid_pinpoints
|
259 |
+
else:
|
260 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
261 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
262 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
263 |
+
|
264 |
+
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
265 |
+
|
266 |
+
# FIXME: this seems to be a bug that it resizes instead of pad.
|
267 |
+
# but to keep it consistent with previous, i will keep it as it is
|
268 |
+
# TODO: uncomment below to ablate with the padding
|
269 |
+
if isinstance(processor.size, dict):
|
270 |
+
shortest_edge = processor.size["shortest_edge"]
|
271 |
+
else:
|
272 |
+
shortest_edge = min(processor.size)
|
273 |
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
274 |
+
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
275 |
+
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
276 |
+
|
277 |
+
image_patches = [image_original_resize] + patches
|
278 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
279 |
+
return torch.stack(image_patches, dim=0)
|
280 |
+
|
281 |
+
|
282 |
+
def load_image_from_base64(image):
|
283 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
284 |
+
|
285 |
+
|
286 |
+
def expand2square(pil_img, background_color):
|
287 |
+
width, height = pil_img.size
|
288 |
+
if width == height:
|
289 |
+
return pil_img
|
290 |
+
elif width > height:
|
291 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
292 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
293 |
+
return result
|
294 |
+
else:
|
295 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
296 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
297 |
+
return result
|
298 |
+
|
299 |
+
|
300 |
+
def process_images(images, image_processor, model_cfg):
|
301 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
302 |
+
new_images = []
|
303 |
+
if image_aspect_ratio == "highres":
|
304 |
+
for image in images:
|
305 |
+
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
306 |
+
new_images.append(image)
|
307 |
+
elif image_aspect_ratio == "anyres":
|
308 |
+
for image in images:
|
309 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
310 |
+
new_images.append(image)
|
311 |
+
elif image_aspect_ratio == "crop_split":
|
312 |
+
for image in images:
|
313 |
+
image = process_highres_image_crop_split(image, model_cfg, image_processor)
|
314 |
+
new_images.append(image)
|
315 |
+
elif image_aspect_ratio == "pad":
|
316 |
+
for image in images:
|
317 |
+
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
|
318 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
319 |
+
new_images.append(image)
|
320 |
+
else:
|
321 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
322 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
323 |
+
new_images = torch.stack(new_images, dim=0)
|
324 |
+
return new_images
|
325 |
+
|
326 |
+
|
327 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
328 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
329 |
+
|
330 |
+
def insert_separator(X, sep):
|
331 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
332 |
+
|
333 |
+
input_ids = []
|
334 |
+
offset = 0
|
335 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
336 |
+
offset = 1
|
337 |
+
input_ids.append(prompt_chunks[0][0])
|
338 |
+
|
339 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
340 |
+
input_ids.extend(x[offset:])
|
341 |
+
|
342 |
+
if return_tensors is not None:
|
343 |
+
if return_tensors == "pt":
|
344 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
345 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
346 |
+
return input_ids
|
347 |
+
|
348 |
+
|
349 |
+
def get_model_name_from_path(model_path):
|
350 |
+
model_path = model_path.strip("/")
|
351 |
+
model_paths = model_path.split("/")
|
352 |
+
if model_paths[-1].startswith("checkpoint-"):
|
353 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
354 |
+
else:
|
355 |
+
return model_paths[-1]
|
356 |
+
|
357 |
+
|
358 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
359 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
360 |
+
self.keywords = keywords
|
361 |
+
self.keyword_ids = []
|
362 |
+
for keyword in keywords:
|
363 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
364 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
365 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
366 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
367 |
+
self.tokenizer = tokenizer
|
368 |
+
self.start_len = input_ids.shape[1]
|
369 |
+
|
370 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
371 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
372 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
373 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
374 |
+
for keyword_id in self.keyword_ids:
|
375 |
+
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
|
376 |
+
return True
|
377 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
378 |
+
for keyword in self.keywords:
|
379 |
+
if keyword in outputs:
|
380 |
+
return True
|
381 |
+
return False
|
llava/model/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
AVAILABLE_MODELS = {
|
4 |
+
"llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
|
5 |
+
"llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig",
|
6 |
+
"llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
|
7 |
+
# "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
|
8 |
+
"llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
|
9 |
+
"llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
|
10 |
+
# Add other models as needed
|
11 |
+
}
|
12 |
+
|
13 |
+
for model_name, model_classes in AVAILABLE_MODELS.items():
|
14 |
+
try:
|
15 |
+
exec(f"from .language_model.{model_name} import {model_classes}")
|
16 |
+
except ImportError:
|
17 |
+
# import traceback
|
18 |
+
# traceback.print_exc()
|
19 |
+
print(f"Failed to import {model_name} from llava.language_model.{model_name}")
|
20 |
+
pass
|
llava/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (740 Bytes). View file
|
|
llava/model/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (6.9 kB). View file
|
|
llava/model/__pycache__/llava_arch.cpython-310.pyc
ADDED
Binary file (11.7 kB). View file
|
|
llava/model/apply_delta.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from llava import LlavaLlamaForCausalLM
|
12 |
+
|
13 |
+
|
14 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
15 |
+
print("Loading base model")
|
16 |
+
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
|
33 |
+
|
34 |
+
print("Saving target model")
|
35 |
+
delta.save_pretrained(target_model_path)
|
36 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
42 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
44 |
+
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
llava/model/builder.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from llava.model import *
|
23 |
+
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
from llava.utils import rank0_print
|
25 |
+
|
26 |
+
|
27 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, **kwargs):
|
28 |
+
kwargs = {"device_map": device_map}
|
29 |
+
|
30 |
+
if load_8bit:
|
31 |
+
kwargs["load_in_8bit"] = True
|
32 |
+
elif load_4bit:
|
33 |
+
kwargs["load_in_4bit"] = True
|
34 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
35 |
+
else:
|
36 |
+
kwargs["torch_dtype"] = torch.float16
|
37 |
+
|
38 |
+
if customized_config is not None:
|
39 |
+
kwargs["config"] = customized_config
|
40 |
+
|
41 |
+
if "llava" in model_name.lower():
|
42 |
+
# Load LLaVA model
|
43 |
+
if "lora" in model_name.lower() and model_base is None:
|
44 |
+
warnings.warn(
|
45 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
|
46 |
+
)
|
47 |
+
if "lora" in model_name.lower() and model_base is not None:
|
48 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
50 |
+
rank0_print("Loading LLaVA from base model...")
|
51 |
+
if "mixtral" in model_name.lower():
|
52 |
+
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
|
53 |
+
|
54 |
+
lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
56 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
57 |
+
elif "mistral" in model_name.lower():
|
58 |
+
from llava.model.language_model.llava_mistral import LlavaMistralConfig
|
59 |
+
|
60 |
+
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
62 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
63 |
+
elif "gemma" in model_name.lower():
|
64 |
+
from llava.model.language_model.llava_gemma import LlavaGemmaConfig
|
65 |
+
|
66 |
+
lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
68 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
69 |
+
else:
|
70 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
71 |
+
|
72 |
+
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
74 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
75 |
+
|
76 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
77 |
+
if model.lm_head.weight.shape[0] != token_num:
|
78 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
79 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
80 |
+
|
81 |
+
rank0_print("Loading additional LLaVA weights...")
|
82 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
83 |
+
non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
|
84 |
+
else:
|
85 |
+
# this is probably from HF Hub
|
86 |
+
from huggingface_hub import hf_hub_download
|
87 |
+
|
88 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
89 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
|
90 |
+
return torch.load(cache_file, map_location="cpu")
|
91 |
+
|
92 |
+
non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
|
93 |
+
non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
|
94 |
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
95 |
+
non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
|
96 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
97 |
+
|
98 |
+
from peft import PeftModel
|
99 |
+
|
100 |
+
rank0_print("Loading LoRA weights...")
|
101 |
+
model = PeftModel.from_pretrained(model, model_path)
|
102 |
+
rank0_print("Merging LoRA weights...")
|
103 |
+
model = model.merge_and_unload()
|
104 |
+
rank0_print("Model is loaded...")
|
105 |
+
elif model_base is not None:
|
106 |
+
# this may be mm projector only
|
107 |
+
rank0_print(f"Loading LLaVA from base model {model_base}...")
|
108 |
+
if "mixtral" in model_name.lower():
|
109 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
110 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
111 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
112 |
+
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
|
113 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
114 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
115 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
116 |
+
elif "gemma" in model_name.lower():
|
117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
118 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
119 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
120 |
+
elif (
|
121 |
+
"wizardlm-2" in model_name.lower()
|
122 |
+
and "vicuna" in model_name.lower()
|
123 |
+
or "llama" in model_name.lower()
|
124 |
+
or "yi" in model_name.lower()
|
125 |
+
or "nous-hermes" in model_name.lower()
|
126 |
+
or "llava-v1.6-34b" in model_name.lower()
|
127 |
+
or "llava-v1.5" in model_name.lower()
|
128 |
+
):
|
129 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
130 |
+
|
131 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
132 |
+
if customized_config is None:
|
133 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
134 |
+
if "v1.5" in model_name.lower():
|
135 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
136 |
+
else:
|
137 |
+
llava_cfg = customized_config
|
138 |
+
|
139 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
140 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
141 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Model {model_name} not supported")
|
144 |
+
|
145 |
+
mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
|
146 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
147 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
148 |
+
else:
|
149 |
+
rank0_print(f"Loaded LLaVA model: {model_path}")
|
150 |
+
if "mixtral" in model_name.lower():
|
151 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
152 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
153 |
+
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
|
154 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
155 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
156 |
+
elif (
|
157 |
+
"wizardlm-2" in model_name.lower()
|
158 |
+
and "vicuna" in model_name.lower()
|
159 |
+
or "llama" in model_name.lower()
|
160 |
+
or "yi" in model_name.lower()
|
161 |
+
or "nous-hermes" in model_name.lower()
|
162 |
+
or "llava-v1.6-34b" in model_name.lower()
|
163 |
+
or "llava-v1.5" in model_name.lower()
|
164 |
+
):
|
165 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
166 |
+
|
167 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
168 |
+
if customized_config is None:
|
169 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
170 |
+
if "v1.5" in model_name.lower():
|
171 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
172 |
+
else:
|
173 |
+
llava_cfg = customized_config
|
174 |
+
|
175 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
176 |
+
elif "qwen" in model_name.lower():
|
177 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
178 |
+
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
179 |
+
elif "gemma" in model_name.lower():
|
180 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
181 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
182 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
183 |
+
else:
|
184 |
+
rank0_print("\n\n\nWarning : No matching llava architecture, auto load llava_llama. If it is not intended, specify it in model_name\n\n\n")
|
185 |
+
try:
|
186 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
187 |
+
|
188 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
189 |
+
if customized_config is None:
|
190 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
191 |
+
if "v1.5" in model_path.lower():
|
192 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
193 |
+
else:
|
194 |
+
llava_cfg = customized_config
|
195 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
196 |
+
except:
|
197 |
+
raise ValueError(f"Model {model_name} not supported")
|
198 |
+
|
199 |
+
else:
|
200 |
+
# Load language model
|
201 |
+
if model_base is not None:
|
202 |
+
# PEFT model
|
203 |
+
from peft import PeftModel
|
204 |
+
|
205 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
206 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
207 |
+
print(f"Loading LoRA weights from {model_path}")
|
208 |
+
model = PeftModel.from_pretrained(model, model_path)
|
209 |
+
print(f"Merging weights")
|
210 |
+
model = model.merge_and_unload()
|
211 |
+
print("Convert to FP16...")
|
212 |
+
model.to(torch.float16)
|
213 |
+
else:
|
214 |
+
use_fast = False
|
215 |
+
if "mpt" in model_name.lower().replace("prompt", ""):
|
216 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
217 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
218 |
+
else:
|
219 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
220 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
221 |
+
|
222 |
+
rank0_print(f"Model Class: {model.__class__.__name__}")
|
223 |
+
image_processor = None
|
224 |
+
|
225 |
+
if "llava" in model_name.lower():
|
226 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
227 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
228 |
+
if mm_use_im_patch_token:
|
229 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
230 |
+
if mm_use_im_start_end:
|
231 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
232 |
+
model.resize_token_embeddings(len(tokenizer))
|
233 |
+
|
234 |
+
vision_tower = model.get_vision_tower()
|
235 |
+
if not vision_tower.is_loaded:
|
236 |
+
vision_tower.load_model(device_map=device_map)
|
237 |
+
if device_map != "auto":
|
238 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
239 |
+
image_processor = vision_tower.image_processor
|
240 |
+
|
241 |
+
if hasattr(model.config, "max_sequence_length"):
|
242 |
+
context_len = model.config.max_sequence_length
|
243 |
+
elif hasattr(model.config, "max_position_embeddings"):
|
244 |
+
context_len = model.config.max_position_embeddings
|
245 |
+
elif hasattr(model.config, "tokenizer_model_max_length"):
|
246 |
+
context_len = model.config.tokenizer_model_max_length
|
247 |
+
else:
|
248 |
+
context_len = 2048
|
249 |
+
|
250 |
+
return tokenizer, model, image_processor, context_len
|
llava/model/consolidate.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from llava.model import *
|
11 |
+
from llava.model.utils import auto_upgrade
|
12 |
+
|
13 |
+
|
14 |
+
def consolidate_ckpt(src_path, dst_path):
|
15 |
+
print("Loading model")
|
16 |
+
auto_upgrade(src_path)
|
17 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
18 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
19 |
+
src_model.save_pretrained(dst_path)
|
20 |
+
src_tokenizer.save_pretrained(dst_path)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--src", type=str, required=True)
|
26 |
+
parser.add_argument("--dst", type=str, required=True)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
consolidate_ckpt(args.src, args.dst)
|
llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc
ADDED
Binary file (3.79 kB). View file
|
|
llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc
ADDED
Binary file (3.98 kB). View file
|
|
llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc
ADDED
Binary file (3.86 kB). View file
|
|
llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc
ADDED
Binary file (3.9 kB). View file
|
|
llava/model/language_model/llava_gemma.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaGemmaConfig(GemmaConfig):
|
31 |
+
model_type = "llava_gemma"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
|
35 |
+
config_class = LlavaGemmaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: GemmaConfig):
|
38 |
+
super(LlavaGemmaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaGemmaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(GemmaForCausalLM, self).__init__(config)
|
46 |
+
self.model = LlavaGemmaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
position_ids: Optional[torch.LongTensor] = None,
|
61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
use_cache: Optional[bool] = None,
|
65 |
+
output_attentions: Optional[bool] = None,
|
66 |
+
output_hidden_states: Optional[bool] = None,
|
67 |
+
images: Optional[torch.FloatTensor] = None,
|
68 |
+
image_sizes: Optional[List[List[int]]] = None,
|
69 |
+
return_dict: Optional[bool] = None,
|
70 |
+
cache_position: Optional[torch.LongTensor] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
|
73 |
+
if inputs_embeds is None:
|
74 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
75 |
+
|
76 |
+
return super().forward(
|
77 |
+
input_ids=input_ids,
|
78 |
+
attention_mask=attention_mask,
|
79 |
+
position_ids=position_ids,
|
80 |
+
past_key_values=past_key_values,
|
81 |
+
inputs_embeds=inputs_embeds,
|
82 |
+
labels=labels,
|
83 |
+
use_cache=use_cache,
|
84 |
+
output_attentions=output_attentions,
|
85 |
+
output_hidden_states=output_hidden_states,
|
86 |
+
return_dict=return_dict,
|
87 |
+
cache_position=cache_position,
|
88 |
+
)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate(
|
92 |
+
self,
|
93 |
+
inputs: Optional[torch.Tensor] = None,
|
94 |
+
images: Optional[torch.Tensor] = None,
|
95 |
+
image_sizes: Optional[torch.Tensor] = None,
|
96 |
+
**kwargs,
|
97 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
98 |
+
position_ids = kwargs.pop("position_ids", None)
|
99 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
100 |
+
if "inputs_embeds" in kwargs:
|
101 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
102 |
+
|
103 |
+
if images is not None:
|
104 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
105 |
+
else:
|
106 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
107 |
+
|
108 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
109 |
+
|
110 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
111 |
+
images = kwargs.pop("images", None)
|
112 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
113 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
114 |
+
if images is not None:
|
115 |
+
inputs["images"] = images
|
116 |
+
if image_sizes is not None:
|
117 |
+
inputs["image_sizes"] = image_sizes
|
118 |
+
return inputs
|
119 |
+
|
120 |
+
|
121 |
+
AutoConfig.register("llava_gemma", LlavaGemmaConfig)
|
122 |
+
AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
|
llava/model/language_model/llava_llama.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
|
22 |
+
|
23 |
+
# , LlamaModel, LlamaForCausalLM, GenerationConfig
|
24 |
+
# from .modeling_llama import LlamaModel, LlamaForCausalLM
|
25 |
+
from transformers import LlamaModel, LlamaForCausalLM
|
26 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
27 |
+
from transformers.generation.utils import GenerateOutput
|
28 |
+
|
29 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
30 |
+
|
31 |
+
|
32 |
+
class LlavaConfig(LlamaConfig):
|
33 |
+
model_type = "llava_llama"
|
34 |
+
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
|
35 |
+
max_new_tokens: int = 1024
|
36 |
+
do_sample: bool = False
|
37 |
+
top_p: Optional[float] = None
|
38 |
+
rope_scaling: Optional[dict] = {}
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
42 |
+
config_class = LlavaConfig
|
43 |
+
|
44 |
+
def __init__(self, config: LlamaConfig):
|
45 |
+
super(LlavaLlamaModel, self).__init__(config)
|
46 |
+
|
47 |
+
|
48 |
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
49 |
+
config_class = LlavaConfig
|
50 |
+
|
51 |
+
def __init__(self, config):
|
52 |
+
LlamaForCausalLM.__init__(self, config)
|
53 |
+
|
54 |
+
# configure default generation settings
|
55 |
+
config.model_type = "llava_llama"
|
56 |
+
config.rope_scaling = None
|
57 |
+
|
58 |
+
self.model = LlavaLlamaModel(config)
|
59 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
60 |
+
# Initialize weights and apply final processing
|
61 |
+
self.post_init()
|
62 |
+
|
63 |
+
def get_model(self):
|
64 |
+
return self.model
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
input_ids: torch.LongTensor = None,
|
69 |
+
attention_mask: Optional[torch.Tensor] = None,
|
70 |
+
position_ids: Optional[torch.LongTensor] = None,
|
71 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
72 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
73 |
+
labels: Optional[torch.LongTensor] = None,
|
74 |
+
use_cache: Optional[bool] = None,
|
75 |
+
output_attentions: Optional[bool] = None,
|
76 |
+
output_hidden_states: Optional[bool] = None,
|
77 |
+
images: Optional[torch.FloatTensor] = None,
|
78 |
+
image_sizes: Optional[List[List[int]]] = None,
|
79 |
+
return_dict: Optional[bool] = None,
|
80 |
+
cache_position=None,
|
81 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
82 |
+
|
83 |
+
if inputs_embeds is None:
|
84 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
85 |
+
|
86 |
+
return super().forward(
|
87 |
+
input_ids=input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
position_ids=position_ids,
|
90 |
+
past_key_values=past_key_values,
|
91 |
+
inputs_embeds=inputs_embeds,
|
92 |
+
labels=labels,
|
93 |
+
use_cache=use_cache,
|
94 |
+
output_attentions=output_attentions,
|
95 |
+
output_hidden_states=output_hidden_states,
|
96 |
+
return_dict=return_dict,
|
97 |
+
)
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def generate(
|
101 |
+
self,
|
102 |
+
inputs: Optional[torch.Tensor] = None,
|
103 |
+
images: Optional[torch.Tensor] = None,
|
104 |
+
image_sizes: Optional[torch.Tensor] = None,
|
105 |
+
**kwargs,
|
106 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
107 |
+
position_ids = kwargs.pop("position_ids", None)
|
108 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
109 |
+
if "inputs_embeds" in kwargs:
|
110 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
111 |
+
|
112 |
+
if images is not None:
|
113 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
114 |
+
else:
|
115 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
116 |
+
|
117 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
118 |
+
|
119 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
120 |
+
images = kwargs.pop("images", None)
|
121 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
122 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
123 |
+
if images is not None:
|
124 |
+
inputs["images"] = images
|
125 |
+
if image_sizes is not None:
|
126 |
+
inputs["image_sizes"] = image_sizes
|
127 |
+
return inputs
|
128 |
+
|
129 |
+
|
130 |
+
AutoConfig.register("llava_llama", LlavaConfig)
|
131 |
+
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
llava/model/language_model/llava_mistral.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaMistralConfig(MistralConfig):
|
31 |
+
model_type = "llava_mistral"
|
32 |
+
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
|
33 |
+
max_new_tokens: int = 1024
|
34 |
+
do_sample: bool = False
|
35 |
+
top_p: Optional[float] = None
|
36 |
+
|
37 |
+
|
38 |
+
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
39 |
+
config_class = LlavaMistralConfig
|
40 |
+
|
41 |
+
def __init__(self, config: MistralConfig):
|
42 |
+
super(LlavaMistralModel, self).__init__(config)
|
43 |
+
|
44 |
+
|
45 |
+
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
46 |
+
config_class = LlavaMistralConfig
|
47 |
+
|
48 |
+
def __init__(self, config):
|
49 |
+
super(MistralForCausalLM, self).__init__(config)
|
50 |
+
|
51 |
+
config.model_type = "llava_mistral"
|
52 |
+
config.rope_scaling = None
|
53 |
+
|
54 |
+
self.model = LlavaMistralModel(config)
|
55 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
56 |
+
# Initialize weights and apply final processing
|
57 |
+
self.post_init()
|
58 |
+
|
59 |
+
def get_model(self):
|
60 |
+
return self.model
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
input_ids: torch.LongTensor = None,
|
65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
66 |
+
position_ids: Optional[torch.LongTensor] = None,
|
67 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
68 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
69 |
+
labels: Optional[torch.LongTensor] = None,
|
70 |
+
use_cache: Optional[bool] = None,
|
71 |
+
output_attentions: Optional[bool] = None,
|
72 |
+
output_hidden_states: Optional[bool] = None,
|
73 |
+
images: Optional[torch.FloatTensor] = None,
|
74 |
+
image_sizes: Optional[List[List[int]]] = None,
|
75 |
+
return_dict: Optional[bool] = None,
|
76 |
+
cache_position=None,
|
77 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
78 |
+
|
79 |
+
if inputs_embeds is None:
|
80 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
81 |
+
|
82 |
+
return super().forward(
|
83 |
+
input_ids=input_ids,
|
84 |
+
attention_mask=attention_mask,
|
85 |
+
position_ids=position_ids,
|
86 |
+
past_key_values=past_key_values,
|
87 |
+
inputs_embeds=inputs_embeds,
|
88 |
+
labels=labels,
|
89 |
+
use_cache=use_cache,
|
90 |
+
output_attentions=output_attentions,
|
91 |
+
output_hidden_states=output_hidden_states,
|
92 |
+
return_dict=return_dict,
|
93 |
+
)
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def generate(
|
97 |
+
self,
|
98 |
+
inputs: Optional[torch.Tensor] = None,
|
99 |
+
images: Optional[torch.Tensor] = None,
|
100 |
+
image_sizes: Optional[torch.Tensor] = None,
|
101 |
+
**kwargs,
|
102 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
103 |
+
position_ids = kwargs.pop("position_ids", None)
|
104 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
105 |
+
if "inputs_embeds" in kwargs:
|
106 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
107 |
+
|
108 |
+
if images is not None:
|
109 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
110 |
+
else:
|
111 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
112 |
+
|
113 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
114 |
+
|
115 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
116 |
+
images = kwargs.pop("images", None)
|
117 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
118 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
119 |
+
if images is not None:
|
120 |
+
inputs["images"] = images
|
121 |
+
if image_sizes is not None:
|
122 |
+
inputs["image_sizes"] = image_sizes
|
123 |
+
return inputs
|
124 |
+
|
125 |
+
|
126 |
+
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
127 |
+
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
llava/model/language_model/llava_mixtral.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaMixtralConfig(MixtralConfig):
|
31 |
+
model_type = "llava_mixtral"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
|
35 |
+
config_class = LlavaMixtralConfig
|
36 |
+
|
37 |
+
def __init__(self, config: MixtralConfig):
|
38 |
+
super(LlavaMixtralModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaMixtralConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(MixtralForCausalLM, self).__init__(config)
|
46 |
+
|
47 |
+
config.model_type = "llava_mixtral"
|
48 |
+
config.rope_scaling = None
|
49 |
+
self.model = LlavaMixtralModel(config)
|
50 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.post_init()
|
53 |
+
|
54 |
+
def get_model(self):
|
55 |
+
return self.model
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
position_ids: Optional[torch.LongTensor] = None,
|
62 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
63 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
64 |
+
labels: Optional[torch.LongTensor] = None,
|
65 |
+
use_cache: Optional[bool] = None,
|
66 |
+
output_attentions: Optional[bool] = None,
|
67 |
+
output_hidden_states: Optional[bool] = None,
|
68 |
+
images: Optional[torch.FloatTensor] = None,
|
69 |
+
image_sizes: Optional[List[List[int]]] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
cache_position=None,
|
72 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
73 |
+
|
74 |
+
if inputs_embeds is None:
|
75 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
76 |
+
|
77 |
+
return super().forward(
|
78 |
+
input_ids=input_ids,
|
79 |
+
attention_mask=attention_mask,
|
80 |
+
position_ids=position_ids,
|
81 |
+
past_key_values=past_key_values,
|
82 |
+
inputs_embeds=inputs_embeds,
|
83 |
+
labels=labels,
|
84 |
+
use_cache=use_cache,
|
85 |
+
output_attentions=output_attentions,
|
86 |
+
output_hidden_states=output_hidden_states,
|
87 |
+
return_dict=return_dict,
|
88 |
+
)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate(
|
92 |
+
self,
|
93 |
+
inputs: Optional[torch.Tensor] = None,
|
94 |
+
images: Optional[torch.Tensor] = None,
|
95 |
+
image_sizes: Optional[torch.Tensor] = None,
|
96 |
+
**kwargs,
|
97 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
98 |
+
position_ids = kwargs.pop("position_ids", None)
|
99 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
100 |
+
if "inputs_embeds" in kwargs:
|
101 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
102 |
+
|
103 |
+
if images is not None:
|
104 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
105 |
+
else:
|
106 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
107 |
+
|
108 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
109 |
+
|
110 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
111 |
+
images = kwargs.pop("images", None)
|
112 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
113 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
114 |
+
if images is not None:
|
115 |
+
inputs["images"] = images
|
116 |
+
if image_sizes is not None:
|
117 |
+
inputs["image_sizes"] = image_sizes
|
118 |
+
return inputs
|
119 |
+
|
120 |
+
|
121 |
+
AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
|
122 |
+
AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
|
llava/model/language_model/llava_mpt.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
|
21 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
22 |
+
|
23 |
+
|
24 |
+
class LlavaMptConfig(MptConfig):
|
25 |
+
model_type = "llava_mpt"
|
26 |
+
|
27 |
+
|
28 |
+
class LlavaMptModel(LlavaMetaModel, MptModel):
|
29 |
+
config_class = LlavaMptConfig
|
30 |
+
|
31 |
+
def __init__(self, config: MptConfig):
|
32 |
+
config.hidden_size = config.d_model
|
33 |
+
super(LlavaMptModel, self).__init__(config)
|
34 |
+
|
35 |
+
def embed_tokens(self, x):
|
36 |
+
return self.wte(x)
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
|
40 |
+
config_class = LlavaMptConfig
|
41 |
+
supports_gradient_checkpointing = True
|
42 |
+
|
43 |
+
def __init__(self, config):
|
44 |
+
super(MptForCausalLM, self).__init__(config)
|
45 |
+
|
46 |
+
config.model_type = "llava_mpt"
|
47 |
+
config.rope_scaling = None
|
48 |
+
self.generation_config = GenerationConfig(
|
49 |
+
temperature=0.0,
|
50 |
+
max_new_tokens=1024,
|
51 |
+
do_sample=False,
|
52 |
+
top_p=None,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.transformer = LlavaMptModel(config)
|
56 |
+
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
|
58 |
+
# Initialize weights and apply final processing
|
59 |
+
self.post_init()
|
60 |
+
|
61 |
+
def get_model(self):
|
62 |
+
return self.transformer
|
63 |
+
|
64 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
65 |
+
if isinstance(module, LlavaMptModel):
|
66 |
+
module.gradient_checkpointing = value
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self,
|
70 |
+
input_ids: Optional[torch.LongTensor] = None,
|
71 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
74 |
+
labels: Optional[torch.Tensor] = None,
|
75 |
+
use_cache: Optional[bool] = None,
|
76 |
+
output_attentions: Optional[bool] = None,
|
77 |
+
output_hidden_states: Optional[bool] = None,
|
78 |
+
return_dict: Optional[bool] = None,
|
79 |
+
cache_position=None,
|
80 |
+
images=None,
|
81 |
+
):
|
82 |
+
|
83 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
84 |
+
|
85 |
+
return super().forward(
|
86 |
+
input_ids,
|
87 |
+
past_key_values=past_key_values,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
inputs_embeds=inputs_embeds,
|
90 |
+
labels=labels,
|
91 |
+
use_cache=use_cache,
|
92 |
+
output_attentions=output_attentions,
|
93 |
+
output_hidden_states=output_hidden_states,
|
94 |
+
return_dict=return_dict,
|
95 |
+
)
|
96 |
+
|
97 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
98 |
+
images = kwargs.pop("images", None)
|
99 |
+
_inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
100 |
+
_inputs["images"] = images
|
101 |
+
return _inputs
|
102 |
+
|
103 |
+
|
104 |
+
AutoConfig.register("llava_mpt", LlavaMptConfig)
|
105 |
+
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
|
llava/model/language_model/llava_qwen.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hao Zhang
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import CrossEntropyLoss
|
20 |
+
|
21 |
+
import transformers
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
28 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
29 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
30 |
+
|
31 |
+
# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
|
32 |
+
# from .qwen.configuration_qwen import QWenConfig
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaQwenConfig(Qwen2Config):
|
36 |
+
model_type = "llava_qwen"
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
|
40 |
+
config_class = LlavaQwenConfig
|
41 |
+
|
42 |
+
def __init__(self, config: Qwen2Config):
|
43 |
+
super(LlavaQwenModel, self).__init__(config)
|
44 |
+
|
45 |
+
|
46 |
+
class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
|
47 |
+
config_class = LlavaQwenConfig
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
51 |
+
Qwen2ForCausalLM.__init__(self, config)
|
52 |
+
config.model_type = "llava_qwen"
|
53 |
+
config.rope_scaling = None
|
54 |
+
|
55 |
+
self.model = LlavaQwenModel(config)
|
56 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
59 |
+
|
60 |
+
def get_model(self):
|
61 |
+
return self.model
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
input_ids: torch.LongTensor = None,
|
66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
position_ids: Optional[torch.LongTensor] = None,
|
68 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
+
labels: Optional[torch.LongTensor] = None,
|
71 |
+
use_cache: Optional[bool] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
images: Optional[torch.FloatTensor] = None,
|
75 |
+
image_sizes: Optional[List[List[int]]] = None,
|
76 |
+
return_dict: Optional[bool] = None,
|
77 |
+
cache_position=None,
|
78 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
79 |
+
|
80 |
+
if inputs_embeds is None:
|
81 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
82 |
+
|
83 |
+
return super().forward(
|
84 |
+
input_ids=input_ids,
|
85 |
+
attention_mask=attention_mask,
|
86 |
+
position_ids=position_ids,
|
87 |
+
past_key_values=past_key_values,
|
88 |
+
inputs_embeds=inputs_embeds,
|
89 |
+
labels=labels,
|
90 |
+
use_cache=use_cache,
|
91 |
+
output_attentions=output_attentions,
|
92 |
+
output_hidden_states=output_hidden_states,
|
93 |
+
return_dict=return_dict,
|
94 |
+
)
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def generate(
|
98 |
+
self,
|
99 |
+
inputs: Optional[torch.Tensor] = None,
|
100 |
+
images: Optional[torch.Tensor] = None,
|
101 |
+
image_sizes: Optional[torch.Tensor] = None,
|
102 |
+
**kwargs,
|
103 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
104 |
+
position_ids = kwargs.pop("position_ids", None)
|
105 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
106 |
+
if "inputs_embeds" in kwargs:
|
107 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
108 |
+
|
109 |
+
if images is not None:
|
110 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
111 |
+
else:
|
112 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
113 |
+
|
114 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
115 |
+
|
116 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
117 |
+
images = kwargs.pop("images", None)
|
118 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
119 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
120 |
+
if images is not None:
|
121 |
+
inputs["images"] = images
|
122 |
+
if image_sizes is not None:
|
123 |
+
inputs["image_sizes"] = image_sizes
|
124 |
+
return inputs
|
125 |
+
|
126 |
+
|
127 |
+
AutoConfig.register("llava_qwen", LlavaQwenConfig)
|
128 |
+
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
|
llava/model/language_model/llava_qwen_moe.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hao Zhang
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import CrossEntropyLoss
|
20 |
+
|
21 |
+
import transformers
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
28 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
29 |
+
from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
|
30 |
+
|
31 |
+
# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
|
32 |
+
# from .qwen.configuration_qwen import QWenConfig
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaQwenMoeConfig(Qwen2MoeConfig):
|
36 |
+
model_type = "llava_qwen_moe"
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
|
40 |
+
config_class = LlavaQwenMoeConfig
|
41 |
+
|
42 |
+
def __init__(self, config: Qwen2MoeConfig):
|
43 |
+
super(LlavaQwenMoeModel, self).__init__(config)
|
44 |
+
|
45 |
+
|
46 |
+
class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
|
47 |
+
config_class = LlavaQwenMoeConfig
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
# super(Qwen2MoeForCausalLM, self).__init__(config)
|
51 |
+
Qwen2MoeForCausalLM.__init__(self, config)
|
52 |
+
config.model_type = "llava_qwen_moe"
|
53 |
+
config.rope_scaling = None
|
54 |
+
|
55 |
+
self.model = LlavaQwenMoeModel(config)
|
56 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
59 |
+
|
60 |
+
def get_model(self):
|
61 |
+
return self.model
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
input_ids: torch.LongTensor = None,
|
66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
position_ids: Optional[torch.LongTensor] = None,
|
68 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
+
labels: Optional[torch.LongTensor] = None,
|
71 |
+
use_cache: Optional[bool] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
images: Optional[torch.FloatTensor] = None,
|
75 |
+
image_sizes: Optional[List[List[int]]] = None,
|
76 |
+
return_dict: Optional[bool] = None,
|
77 |
+
cache_position=None,
|
78 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
79 |
+
|
80 |
+
if inputs_embeds is None:
|
81 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
82 |
+
|
83 |
+
return super().forward(
|
84 |
+
input_ids=input_ids,
|
85 |
+
attention_mask=attention_mask,
|
86 |
+
position_ids=position_ids,
|
87 |
+
past_key_values=past_key_values,
|
88 |
+
inputs_embeds=inputs_embeds,
|
89 |
+
labels=labels,
|
90 |
+
use_cache=use_cache,
|
91 |
+
output_attentions=output_attentions,
|
92 |
+
output_hidden_states=output_hidden_states,
|
93 |
+
return_dict=return_dict,
|
94 |
+
)
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def generate(
|
98 |
+
self,
|
99 |
+
inputs: Optional[torch.Tensor] = None,
|
100 |
+
images: Optional[torch.Tensor] = None,
|
101 |
+
image_sizes: Optional[torch.Tensor] = None,
|
102 |
+
**kwargs,
|
103 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
104 |
+
position_ids = kwargs.pop("position_ids", None)
|
105 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
106 |
+
if "inputs_embeds" in kwargs:
|
107 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
108 |
+
|
109 |
+
if images is not None:
|
110 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
111 |
+
else:
|
112 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
113 |
+
|
114 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
115 |
+
|
116 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
117 |
+
images = kwargs.pop("images", None)
|
118 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
119 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
120 |
+
if images is not None:
|
121 |
+
inputs["images"] = images
|
122 |
+
if image_sizes is not None:
|
123 |
+
inputs["image_sizes"] = image_sizes
|
124 |
+
return inputs
|
125 |
+
|
126 |
+
|
127 |
+
AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
|
128 |
+
AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
|
llava/model/llava_arch.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .multimodal_encoder.builder import build_vision_tower
|
22 |
+
from .multimodal_resampler.builder import build_vision_resampler
|
23 |
+
from .multimodal_projector.builder import build_vision_projector
|
24 |
+
|
25 |
+
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
26 |
+
|
27 |
+
from llava.mm_utils import get_anyres_image_grid_shape
|
28 |
+
from llava.utils import rank0_print
|
29 |
+
|
30 |
+
|
31 |
+
class LlavaMetaModel:
|
32 |
+
|
33 |
+
def __init__(self, config):
|
34 |
+
super(LlavaMetaModel, self).__init__(config)
|
35 |
+
|
36 |
+
if hasattr(config, "mm_vision_tower"):
|
37 |
+
delay_load = getattr(config, "delay_load", False)
|
38 |
+
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
39 |
+
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
40 |
+
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
41 |
+
|
42 |
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
43 |
+
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
44 |
+
|
45 |
+
def get_vision_tower(self):
|
46 |
+
vision_tower = getattr(self, "vision_tower", None)
|
47 |
+
if type(vision_tower) is list:
|
48 |
+
vision_tower = vision_tower[0]
|
49 |
+
return vision_tower
|
50 |
+
|
51 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
52 |
+
vision_tower = model_args.vision_tower
|
53 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
54 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
55 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
56 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
57 |
+
|
58 |
+
self.config.mm_vision_tower = vision_tower
|
59 |
+
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
60 |
+
|
61 |
+
if self.get_vision_tower() is None:
|
62 |
+
vision_tower = build_vision_tower(model_args)
|
63 |
+
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
64 |
+
for k, v in vision_resampler.config.items():
|
65 |
+
setattr(self.config, k, v)
|
66 |
+
|
67 |
+
if fsdp is not None and len(fsdp) > 0:
|
68 |
+
self.vision_tower = [vision_tower]
|
69 |
+
self.vision_resampler = [vision_resampler]
|
70 |
+
else:
|
71 |
+
self.vision_tower = vision_tower
|
72 |
+
self.vision_resampler = vision_resampler
|
73 |
+
else:
|
74 |
+
if fsdp is not None and len(fsdp) > 0:
|
75 |
+
vision_resampler = self.vision_resampler[0]
|
76 |
+
vision_tower = self.vision_tower[0]
|
77 |
+
else:
|
78 |
+
vision_resampler = self.vision_resampler
|
79 |
+
vision_tower = self.vision_tower
|
80 |
+
vision_tower.load_model()
|
81 |
+
|
82 |
+
# In case it is frozen by LoRA
|
83 |
+
for p in self.vision_resampler.parameters():
|
84 |
+
p.requires_grad = True
|
85 |
+
|
86 |
+
self.config.use_mm_proj = True
|
87 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
88 |
+
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
89 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
90 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
91 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
92 |
+
|
93 |
+
if getattr(self, "mm_projector", None) is None:
|
94 |
+
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
95 |
+
|
96 |
+
if "unpad" in mm_patch_merge_type:
|
97 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
98 |
+
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
99 |
+
else:
|
100 |
+
# In case it is frozen by LoRA
|
101 |
+
for p in self.mm_projector.parameters():
|
102 |
+
p.requires_grad = True
|
103 |
+
|
104 |
+
if pretrain_mm_mlp_adapter is not None:
|
105 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
106 |
+
|
107 |
+
def get_w(weights, keyword):
|
108 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
109 |
+
|
110 |
+
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
111 |
+
rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
112 |
+
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
113 |
+
rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
114 |
+
|
115 |
+
|
116 |
+
def unpad_image(tensor, original_size):
|
117 |
+
"""
|
118 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
122 |
+
original_size (tuple): The original size of the image (height, width).
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
torch.Tensor: The unpadded image tensor.
|
126 |
+
"""
|
127 |
+
original_width, original_height = original_size
|
128 |
+
current_height, current_width = tensor.shape[1:]
|
129 |
+
|
130 |
+
# Compute aspect ratios
|
131 |
+
original_aspect_ratio = original_width / original_height
|
132 |
+
current_aspect_ratio = current_width / current_height
|
133 |
+
|
134 |
+
# Determine padding size and direction
|
135 |
+
if original_aspect_ratio > current_aspect_ratio:
|
136 |
+
# Padding was added to the height
|
137 |
+
scale_factor = current_width / original_width
|
138 |
+
new_height = int(original_height * scale_factor)
|
139 |
+
padding = (current_height - new_height) // 2
|
140 |
+
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
141 |
+
else:
|
142 |
+
# Padding was added to the width
|
143 |
+
scale_factor = current_height / original_height
|
144 |
+
new_width = int(original_width * scale_factor)
|
145 |
+
padding = (current_width - new_width) // 2
|
146 |
+
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
147 |
+
|
148 |
+
return unpadded_tensor
|
149 |
+
|
150 |
+
|
151 |
+
class LlavaMetaForCausalLM(ABC):
|
152 |
+
|
153 |
+
@abstractmethod
|
154 |
+
def get_model(self):
|
155 |
+
pass
|
156 |
+
|
157 |
+
def get_vision_tower(self):
|
158 |
+
return self.get_model().get_vision_tower()
|
159 |
+
|
160 |
+
def encode_images(self, images):
|
161 |
+
image_features = self.get_model().get_vision_tower()(images)
|
162 |
+
image_features = self.get_model().vision_resampler(image_features, images=images)
|
163 |
+
image_features = self.get_model().mm_projector(image_features)
|
164 |
+
return image_features
|
165 |
+
|
166 |
+
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None):
|
167 |
+
vision_tower = self.get_vision_tower()
|
168 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
169 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
170 |
+
|
171 |
+
if type(images) is list or images.ndim == 5:
|
172 |
+
if type(images) is list:
|
173 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
174 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
175 |
+
image_features = self.encode_images(concat_images)
|
176 |
+
split_sizes = [image.shape[0] for image in images]
|
177 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
178 |
+
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
179 |
+
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
180 |
+
if mm_patch_merge_type == "flat":
|
181 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
182 |
+
elif mm_patch_merge_type.startswith("spatial"):
|
183 |
+
new_image_features = []
|
184 |
+
for image_idx, image_feature in enumerate(image_features):
|
185 |
+
# FIXME: now assume the image is square, and split to 2x2 patches
|
186 |
+
# num_patches = h * w, where h = w = sqrt(num_patches)
|
187 |
+
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
|
188 |
+
# we want to first unflatten it to (2, 2, h, w, hidden_size)
|
189 |
+
|
190 |
+
if image_feature.shape[0] > 1:
|
191 |
+
base_image_feature = image_feature[0]
|
192 |
+
image_feature = image_feature[1:]
|
193 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
194 |
+
assert height * width == base_image_feature.shape[0]
|
195 |
+
if image_aspect_ratio == "anyres":
|
196 |
+
if hasattr(self.get_vision_tower(), "image_size"):
|
197 |
+
vision_tower_image_size = self.get_vision_tower().image_size
|
198 |
+
else:
|
199 |
+
raise ValueError("vision_tower_image_size is not found in the vision tower.")
|
200 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
|
201 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
202 |
+
else:
|
203 |
+
image_feature = image_feature.view(2, 2, height, width, -1)
|
204 |
+
if "maxpool2x2" in mm_patch_merge_type:
|
205 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
206 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
207 |
+
image_feature = nn.functional.max_pool2d(image_feature, 2)
|
208 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
209 |
+
elif "unpad" in mm_patch_merge_type:
|
210 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
211 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
212 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
213 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
214 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
215 |
+
else:
|
216 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
217 |
+
image_feature = image_feature.flatten(0, 3)
|
218 |
+
if "nobase" in mm_patch_merge_type:
|
219 |
+
pass
|
220 |
+
else:
|
221 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
222 |
+
else:
|
223 |
+
image_feature = image_feature[0]
|
224 |
+
if "unpad" in mm_patch_merge_type:
|
225 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
|
226 |
+
new_image_features.append(image_feature)
|
227 |
+
image_features = new_image_features
|
228 |
+
else:
|
229 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
230 |
+
else:
|
231 |
+
image_features = self.encode_images(images)
|
232 |
+
|
233 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
234 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
235 |
+
raise NotImplementedError
|
236 |
+
|
237 |
+
# Let's just add dummy tensors if they do not exist,
|
238 |
+
# it is a headache to deal with None all the time.
|
239 |
+
# But it is not ideal, and if you have a better idea,
|
240 |
+
# please open an issue / submit a PR, thanks.
|
241 |
+
_labels = labels
|
242 |
+
_position_ids = position_ids
|
243 |
+
_attention_mask = attention_mask
|
244 |
+
if attention_mask is None:
|
245 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
246 |
+
else:
|
247 |
+
attention_mask = attention_mask.bool()
|
248 |
+
if position_ids is None:
|
249 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
250 |
+
if labels is None:
|
251 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
252 |
+
|
253 |
+
# remove the padding using attention_mask -- FIXME
|
254 |
+
_input_ids = input_ids
|
255 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
256 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
257 |
+
|
258 |
+
new_input_embeds = []
|
259 |
+
new_labels = []
|
260 |
+
cur_image_idx = 0
|
261 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
262 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
263 |
+
if num_images == 0:
|
264 |
+
cur_image_features = image_features[cur_image_idx]
|
265 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
266 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
267 |
+
new_input_embeds.append(cur_input_embeds)
|
268 |
+
new_labels.append(labels[batch_idx])
|
269 |
+
cur_image_idx += 1
|
270 |
+
continue
|
271 |
+
|
272 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
273 |
+
cur_input_ids_noim = []
|
274 |
+
cur_labels = labels[batch_idx]
|
275 |
+
cur_labels_noim = []
|
276 |
+
for i in range(len(image_token_indices) - 1):
|
277 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
278 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
279 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
280 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
281 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
282 |
+
cur_new_input_embeds = []
|
283 |
+
cur_new_labels = []
|
284 |
+
|
285 |
+
for i in range(num_images + 1):
|
286 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
287 |
+
cur_new_labels.append(cur_labels_noim[i])
|
288 |
+
if i < num_images:
|
289 |
+
cur_image_features = image_features[cur_image_idx]
|
290 |
+
cur_image_idx += 1
|
291 |
+
cur_new_input_embeds.append(cur_image_features)
|
292 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
293 |
+
|
294 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
295 |
+
|
296 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
297 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
298 |
+
|
299 |
+
new_input_embeds.append(cur_new_input_embeds)
|
300 |
+
new_labels.append(cur_new_labels)
|
301 |
+
|
302 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
303 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
304 |
+
if tokenizer_model_max_length is not None:
|
305 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
306 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
307 |
+
|
308 |
+
# Combine them
|
309 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
310 |
+
batch_size = len(new_input_embeds)
|
311 |
+
|
312 |
+
new_input_embeds_padded = []
|
313 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
314 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
315 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
316 |
+
|
317 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
318 |
+
cur_len = cur_new_embed.shape[0]
|
319 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
320 |
+
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
321 |
+
if cur_len > 0:
|
322 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
323 |
+
attention_mask[i, -cur_len:] = True
|
324 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
325 |
+
else:
|
326 |
+
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
327 |
+
if cur_len > 0:
|
328 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
329 |
+
attention_mask[i, :cur_len] = True
|
330 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
331 |
+
|
332 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
333 |
+
|
334 |
+
if _labels is None:
|
335 |
+
new_labels = None
|
336 |
+
else:
|
337 |
+
new_labels = new_labels_padded
|
338 |
+
|
339 |
+
if _attention_mask is None:
|
340 |
+
attention_mask = None
|
341 |
+
else:
|
342 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
343 |
+
|
344 |
+
if _position_ids is None:
|
345 |
+
position_ids = None
|
346 |
+
|
347 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
348 |
+
|
349 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
350 |
+
if model_args.mm_use_im_patch_token:
|
351 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
352 |
+
self.resize_token_embeddings(len(tokenizer))
|
353 |
+
|
354 |
+
if model_args.mm_use_im_start_end:
|
355 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
356 |
+
self.resize_token_embeddings(len(tokenizer))
|
357 |
+
|
358 |
+
if num_new_tokens > 0:
|
359 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
360 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
361 |
+
|
362 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
363 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
364 |
+
|
365 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
366 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
367 |
+
|
368 |
+
if model_args.tune_mm_mlp_adapter:
|
369 |
+
for p in self.get_input_embeddings().parameters():
|
370 |
+
p.requires_grad = True
|
371 |
+
for p in self.get_output_embeddings().parameters():
|
372 |
+
p.requires_grad = False
|
373 |
+
|
374 |
+
if model_args.pretrain_mm_mlp_adapter:
|
375 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
376 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
377 |
+
assert num_new_tokens == 2
|
378 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
379 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
380 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
381 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
382 |
+
else:
|
383 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
384 |
+
elif model_args.mm_use_im_patch_token:
|
385 |
+
if model_args.tune_mm_mlp_adapter:
|
386 |
+
for p in self.get_input_embeddings().parameters():
|
387 |
+
p.requires_grad = False
|
388 |
+
for p in self.get_output_embeddings().parameters():
|
389 |
+
p.requires_grad = False
|
llava/model/make_delta.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from llava.model.utils import auto_upgrade
|
12 |
+
|
13 |
+
|
14 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
15 |
+
print("Loading base model")
|
16 |
+
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading target model")
|
19 |
+
auto_upgrade(target_model_path)
|
20 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
21 |
+
|
22 |
+
print("Calculating delta")
|
23 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data -= base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
|
33 |
+
|
34 |
+
print("Saving delta")
|
35 |
+
if hub_repo_id:
|
36 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
37 |
+
else:
|
38 |
+
kwargs = {}
|
39 |
+
target.save_pretrained(delta_path, **kwargs)
|
40 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
41 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
47 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
48 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
49 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (772 Bytes). View file
|
|
llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|
llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc
ADDED
Binary file (21.8 kB). View file
|
|
llava/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
from .siglip_encoder import SigLipVisionTower
|
4 |
+
|
5 |
+
|
6 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
7 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
8 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
9 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
10 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
11 |
+
elif "siglip" in vision_tower:
|
12 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
13 |
+
|
14 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
llava/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
20 |
+
# TODO: better detector is needed.
|
21 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
22 |
+
self.load_model()
|
23 |
+
else:
|
24 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
25 |
+
|
26 |
+
def load_model(self, device_map=None):
|
27 |
+
if self.is_loaded:
|
28 |
+
print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
|
29 |
+
return
|
30 |
+
|
31 |
+
# import pdb; pdb.set_trace()
|
32 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
33 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
34 |
+
self.vision_tower.requires_grad_(False)
|
35 |
+
|
36 |
+
self.is_loaded = True
|
37 |
+
|
38 |
+
def feature_select(self, image_forward_outs):
|
39 |
+
select_feature_type = self.select_feature
|
40 |
+
|
41 |
+
if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
|
42 |
+
select_every_k_layer = len(image_forward_outs.hidden_states) // 4
|
43 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
|
44 |
+
select_feature_type = select_feature_type.replace("slicefour_", "")
|
45 |
+
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
|
46 |
+
select_layers = [-2, -5, -8, -11, 6]
|
47 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
|
48 |
+
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
|
49 |
+
else:
|
50 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
51 |
+
|
52 |
+
if select_feature_type == "patch":
|
53 |
+
image_features = image_features[:, 1:]
|
54 |
+
elif select_feature_type == "cls_patch":
|
55 |
+
image_features = image_features
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unexpected select feature: {select_feature_type}")
|
58 |
+
return image_features
|
59 |
+
|
60 |
+
def forward(self, images):
|
61 |
+
if type(images) is list:
|
62 |
+
image_features = []
|
63 |
+
for image in images:
|
64 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
65 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
66 |
+
image_features.append(image_feature)
|
67 |
+
else:
|
68 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
69 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
70 |
+
|
71 |
+
return image_features
|
72 |
+
|
73 |
+
@property
|
74 |
+
def dummy_feature(self):
|
75 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def dtype(self):
|
79 |
+
return self.vision_tower.dtype
|
80 |
+
|
81 |
+
@property
|
82 |
+
def device(self):
|
83 |
+
return self.vision_tower.device
|
84 |
+
|
85 |
+
@property
|
86 |
+
def config(self):
|
87 |
+
if self.is_loaded:
|
88 |
+
return self.vision_tower.config
|
89 |
+
else:
|
90 |
+
return self.cfg_only
|
91 |
+
|
92 |
+
@property
|
93 |
+
def hidden_size(self):
|
94 |
+
_hidden_size = self.config.hidden_size
|
95 |
+
if "slicefour" in self.select_feature:
|
96 |
+
_hidden_size *= 4
|
97 |
+
if "slice_m25811_f6" in self.select_feature:
|
98 |
+
_hidden_size *= 5
|
99 |
+
return _hidden_size
|
100 |
+
|
101 |
+
@property
|
102 |
+
def num_patches_per_side(self):
|
103 |
+
return self.config.image_size // self.config.patch_size
|
104 |
+
|
105 |
+
@property
|
106 |
+
def num_patches(self):
|
107 |
+
_num_patches = (self.config.image_size // self.config.patch_size) ** 2
|
108 |
+
if "cls_patch" in self.select_feature:
|
109 |
+
_num_patches += 1
|
110 |
+
return _num_patches
|
111 |
+
|
112 |
+
@property
|
113 |
+
def image_size(self):
|
114 |
+
return self.config.image_size
|
llava/model/multimodal_encoder/siglip_encoder.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Optional, Tuple, Union, Dict
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from functools import partial, reduce
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from torch import nn
|
12 |
+
import os
|
13 |
+
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
14 |
+
from transformers.image_transforms import (
|
15 |
+
convert_to_rgb,
|
16 |
+
normalize,
|
17 |
+
rescale,
|
18 |
+
resize,
|
19 |
+
to_channel_dimension_format,
|
20 |
+
)
|
21 |
+
from transformers.image_utils import (
|
22 |
+
ChannelDimension,
|
23 |
+
PILImageResampling,
|
24 |
+
to_numpy_array,
|
25 |
+
)
|
26 |
+
from transformers.activations import ACT2FN
|
27 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
28 |
+
from transformers.modeling_utils import PreTrainedModel
|
29 |
+
from transformers import PretrainedConfig
|
30 |
+
from transformers.utils import ModelOutput
|
31 |
+
from llava.utils import rank0_print
|
32 |
+
|
33 |
+
|
34 |
+
class SigLipImageProcessor:
|
35 |
+
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
|
36 |
+
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
|
37 |
+
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
38 |
+
|
39 |
+
self.image_mean = image_mean
|
40 |
+
self.image_std = image_std
|
41 |
+
self.size = size
|
42 |
+
self.resample = resample
|
43 |
+
self.rescale_factor = rescale_factor
|
44 |
+
self.data_format = data_format
|
45 |
+
self.crop_size = crop_size
|
46 |
+
|
47 |
+
def preprocess(self, images, return_tensors):
|
48 |
+
if isinstance(images, Image.Image):
|
49 |
+
images = [images]
|
50 |
+
else:
|
51 |
+
assert isinstance(images, list)
|
52 |
+
|
53 |
+
transforms = [
|
54 |
+
convert_to_rgb,
|
55 |
+
to_numpy_array,
|
56 |
+
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
|
57 |
+
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
|
58 |
+
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
|
59 |
+
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
|
60 |
+
]
|
61 |
+
|
62 |
+
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
|
63 |
+
data = {"pixel_values": images}
|
64 |
+
|
65 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
66 |
+
|
67 |
+
|
68 |
+
class SigLipVisionConfig(PretrainedConfig):
|
69 |
+
model_type = "siglip_vision_model"
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
hidden_size=1152,
|
74 |
+
image_mean=(0.5, 0.5, 0.5),
|
75 |
+
intermediate_size=4304,
|
76 |
+
num_hidden_layers=27,
|
77 |
+
num_attention_heads=16,
|
78 |
+
num_channels=3,
|
79 |
+
image_size=384,
|
80 |
+
patch_size=14,
|
81 |
+
hidden_act="gelu_pytorch_tanh",
|
82 |
+
layer_norm_eps=1e-6,
|
83 |
+
attention_dropout=0.0,
|
84 |
+
**kwargs,
|
85 |
+
):
|
86 |
+
super().__init__(**kwargs)
|
87 |
+
|
88 |
+
self.hidden_size = hidden_size
|
89 |
+
self.intermediate_size = intermediate_size
|
90 |
+
self.num_hidden_layers = num_hidden_layers
|
91 |
+
self.num_attention_heads = num_attention_heads
|
92 |
+
self.num_channels = num_channels
|
93 |
+
self.patch_size = patch_size
|
94 |
+
self.image_size = image_size
|
95 |
+
self.attention_dropout = attention_dropout
|
96 |
+
self.layer_norm_eps = layer_norm_eps
|
97 |
+
self.hidden_act = hidden_act
|
98 |
+
self.image_mean = image_mean
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
102 |
+
cls._set_token_in_kwargs(kwargs)
|
103 |
+
|
104 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
105 |
+
|
106 |
+
# get the vision config dict if we are loading from SigLipConfig
|
107 |
+
if config_dict.get("model_type") == "siglip":
|
108 |
+
config_dict = config_dict["vision_config"]
|
109 |
+
|
110 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
111 |
+
print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
|
112 |
+
|
113 |
+
return cls.from_dict(config_dict, **kwargs)
|
114 |
+
|
115 |
+
|
116 |
+
@dataclass
|
117 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
|
118 |
+
class SigLipVisionModelOutput(ModelOutput):
|
119 |
+
"""
|
120 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
124 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
125 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
126 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
127 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
128 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
129 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
130 |
+
|
131 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
132 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
133 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
134 |
+
sequence_length)`.
|
135 |
+
|
136 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
137 |
+
heads.
|
138 |
+
"""
|
139 |
+
|
140 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
141 |
+
last_hidden_state: torch.FloatTensor = None
|
142 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
143 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
144 |
+
|
145 |
+
|
146 |
+
class SigLipVisionEmbeddings(nn.Module):
|
147 |
+
def __init__(self, config: SigLipVisionConfig):
|
148 |
+
super().__init__()
|
149 |
+
self.config = config
|
150 |
+
self.embed_dim = config.hidden_size
|
151 |
+
self.image_size = config.image_size
|
152 |
+
self.patch_size = config.patch_size
|
153 |
+
|
154 |
+
self.patch_embedding = nn.Conv2d(
|
155 |
+
in_channels=config.num_channels,
|
156 |
+
out_channels=self.embed_dim,
|
157 |
+
kernel_size=self.patch_size,
|
158 |
+
stride=self.patch_size,
|
159 |
+
padding="valid",
|
160 |
+
)
|
161 |
+
|
162 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
163 |
+
self.num_positions = self.num_patches
|
164 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
165 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
166 |
+
|
167 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
168 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
169 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
170 |
+
|
171 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
172 |
+
return embeddings
|
173 |
+
|
174 |
+
|
175 |
+
class SigLipAttention(nn.Module):
|
176 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
177 |
+
|
178 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
179 |
+
def __init__(self, config):
|
180 |
+
super().__init__()
|
181 |
+
self.config = config
|
182 |
+
self.embed_dim = config.hidden_size
|
183 |
+
self.num_heads = config.num_attention_heads
|
184 |
+
self.head_dim = self.embed_dim // self.num_heads
|
185 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
186 |
+
raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
|
187 |
+
self.scale = self.head_dim**-0.5
|
188 |
+
self.dropout = config.attention_dropout
|
189 |
+
|
190 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
191 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
192 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
193 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
hidden_states: torch.Tensor,
|
198 |
+
attention_mask: Optional[torch.Tensor] = None,
|
199 |
+
output_attentions: Optional[bool] = False,
|
200 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
201 |
+
"""Input shape: Batch x Time x Channel"""
|
202 |
+
|
203 |
+
batch_size, q_len, _ = hidden_states.size()
|
204 |
+
|
205 |
+
query_states = self.q_proj(hidden_states)
|
206 |
+
key_states = self.k_proj(hidden_states)
|
207 |
+
value_states = self.v_proj(hidden_states)
|
208 |
+
|
209 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
210 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
211 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
212 |
+
|
213 |
+
k_v_seq_len = key_states.shape[-2]
|
214 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
215 |
+
|
216 |
+
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
217 |
+
raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
|
218 |
+
|
219 |
+
if attention_mask is not None:
|
220 |
+
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
221 |
+
raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
|
222 |
+
attn_weights = attn_weights + attention_mask
|
223 |
+
|
224 |
+
# upcast attention to fp32
|
225 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
226 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
227 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
228 |
+
|
229 |
+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
230 |
+
raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
|
231 |
+
|
232 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
233 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
234 |
+
|
235 |
+
attn_output = self.out_proj(attn_output)
|
236 |
+
|
237 |
+
return attn_output, attn_weights
|
238 |
+
|
239 |
+
|
240 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
|
241 |
+
class SigLipMLP(nn.Module):
|
242 |
+
def __init__(self, config):
|
243 |
+
super().__init__()
|
244 |
+
self.config = config
|
245 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
246 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
247 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
248 |
+
|
249 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
250 |
+
hidden_states = self.fc1(hidden_states)
|
251 |
+
hidden_states = self.activation_fn(hidden_states)
|
252 |
+
hidden_states = self.fc2(hidden_states)
|
253 |
+
return hidden_states
|
254 |
+
|
255 |
+
|
256 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
|
257 |
+
class SigLipEncoderLayer(nn.Module):
|
258 |
+
def __init__(self, config: SigLipVisionConfig):
|
259 |
+
super().__init__()
|
260 |
+
self.embed_dim = config.hidden_size
|
261 |
+
self.self_attn = SigLipAttention(config)
|
262 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
263 |
+
self.mlp = SigLipMLP(config)
|
264 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
265 |
+
|
266 |
+
# Ignore copy
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states: torch.Tensor,
|
270 |
+
attention_mask: torch.Tensor,
|
271 |
+
output_attentions: Optional[bool] = False,
|
272 |
+
) -> Tuple[torch.FloatTensor]:
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
hidden_states (`torch.FloatTensor`):
|
276 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
277 |
+
attention_mask (`torch.FloatTensor`):
|
278 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
279 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
280 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
281 |
+
returned tensors for more detail.
|
282 |
+
"""
|
283 |
+
residual = hidden_states
|
284 |
+
|
285 |
+
hidden_states = self.layer_norm1(hidden_states)
|
286 |
+
hidden_states, attn_weights = self.self_attn(
|
287 |
+
hidden_states=hidden_states,
|
288 |
+
attention_mask=attention_mask,
|
289 |
+
output_attentions=output_attentions,
|
290 |
+
)
|
291 |
+
hidden_states = residual + hidden_states
|
292 |
+
|
293 |
+
residual = hidden_states
|
294 |
+
hidden_states = self.layer_norm2(hidden_states)
|
295 |
+
hidden_states = self.mlp(hidden_states)
|
296 |
+
hidden_states = residual + hidden_states
|
297 |
+
|
298 |
+
outputs = (hidden_states,)
|
299 |
+
|
300 |
+
if output_attentions:
|
301 |
+
outputs += (attn_weights,)
|
302 |
+
|
303 |
+
return outputs
|
304 |
+
|
305 |
+
|
306 |
+
class SigLipPreTrainedModel(PreTrainedModel):
|
307 |
+
"""
|
308 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
309 |
+
models.
|
310 |
+
"""
|
311 |
+
|
312 |
+
config_class = SigLipVisionConfig
|
313 |
+
base_model_prefix = "siglip"
|
314 |
+
supports_gradient_checkpointing = True
|
315 |
+
|
316 |
+
def _init_weights(self, module):
|
317 |
+
"""Initialize the weights"""
|
318 |
+
pass
|
319 |
+
|
320 |
+
|
321 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
|
322 |
+
class SigLipEncoder(nn.Module):
|
323 |
+
"""
|
324 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
325 |
+
[`SigLipEncoderLayer`].
|
326 |
+
|
327 |
+
Args:
|
328 |
+
config: SigLipVisionConfig
|
329 |
+
"""
|
330 |
+
|
331 |
+
def __init__(self, config: SigLipVisionConfig):
|
332 |
+
super().__init__()
|
333 |
+
self.config = config
|
334 |
+
self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
335 |
+
self.gradient_checkpointing = False
|
336 |
+
|
337 |
+
# Ignore copy
|
338 |
+
def forward(
|
339 |
+
self,
|
340 |
+
inputs_embeds,
|
341 |
+
attention_mask: Optional[torch.Tensor] = None,
|
342 |
+
output_attentions: Optional[bool] = None,
|
343 |
+
output_hidden_states: Optional[bool] = None,
|
344 |
+
return_dict: Optional[bool] = None,
|
345 |
+
) -> Union[Tuple, BaseModelOutput]:
|
346 |
+
r"""
|
347 |
+
Args:
|
348 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
349 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
350 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
351 |
+
than the model's internal embedding lookup matrix.
|
352 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
353 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
354 |
+
|
355 |
+
- 1 for tokens that are **not masked**,
|
356 |
+
- 0 for tokens that are **masked**.
|
357 |
+
|
358 |
+
[What are attention masks?](../glossary#attention-mask)
|
359 |
+
output_attentions (`bool`, *optional*):
|
360 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
361 |
+
returned tensors for more detail.
|
362 |
+
output_hidden_states (`bool`, *optional*):
|
363 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
364 |
+
for more detail.
|
365 |
+
return_dict (`bool`, *optional*):
|
366 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
367 |
+
"""
|
368 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
369 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
370 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
371 |
+
|
372 |
+
encoder_states = () if output_hidden_states else None
|
373 |
+
all_attentions = () if output_attentions else None
|
374 |
+
|
375 |
+
hidden_states = inputs_embeds
|
376 |
+
for encoder_layer in self.layers:
|
377 |
+
if output_hidden_states:
|
378 |
+
encoder_states = encoder_states + (hidden_states,)
|
379 |
+
if self.gradient_checkpointing and self.training:
|
380 |
+
layer_outputs = self._gradient_checkpointing_func(
|
381 |
+
encoder_layer.__call__,
|
382 |
+
hidden_states,
|
383 |
+
attention_mask,
|
384 |
+
output_attentions,
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
layer_outputs = encoder_layer(
|
388 |
+
hidden_states,
|
389 |
+
attention_mask,
|
390 |
+
output_attentions=output_attentions,
|
391 |
+
)
|
392 |
+
|
393 |
+
hidden_states = layer_outputs[0]
|
394 |
+
|
395 |
+
if output_attentions:
|
396 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
397 |
+
|
398 |
+
if output_hidden_states:
|
399 |
+
encoder_states = encoder_states + (hidden_states,)
|
400 |
+
|
401 |
+
if not return_dict:
|
402 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
403 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
|
404 |
+
|
405 |
+
|
406 |
+
class SigLipVisionTransformer(nn.Module):
|
407 |
+
def __init__(self, config: SigLipVisionConfig):
|
408 |
+
super().__init__()
|
409 |
+
self.config = config
|
410 |
+
embed_dim = config.hidden_size
|
411 |
+
|
412 |
+
self.embeddings = SigLipVisionEmbeddings(config)
|
413 |
+
self.encoder = SigLipEncoder(config)
|
414 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
415 |
+
self.head = SigLipMultiheadAttentionPoolingHead(config)
|
416 |
+
|
417 |
+
def forward(
|
418 |
+
self,
|
419 |
+
pixel_values,
|
420 |
+
output_attentions: Optional[bool] = None,
|
421 |
+
output_hidden_states: Optional[bool] = None,
|
422 |
+
return_dict: Optional[bool] = None,
|
423 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
424 |
+
r"""
|
425 |
+
Returns:
|
426 |
+
|
427 |
+
"""
|
428 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
429 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
430 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
431 |
+
|
432 |
+
hidden_states = self.embeddings(pixel_values)
|
433 |
+
|
434 |
+
encoder_outputs = self.encoder(
|
435 |
+
inputs_embeds=hidden_states,
|
436 |
+
output_attentions=output_attentions,
|
437 |
+
output_hidden_states=output_hidden_states,
|
438 |
+
return_dict=return_dict,
|
439 |
+
)
|
440 |
+
|
441 |
+
last_hidden_state = encoder_outputs[0]
|
442 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
443 |
+
|
444 |
+
pooled_output = self.head(last_hidden_state)
|
445 |
+
|
446 |
+
if not return_dict:
|
447 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
448 |
+
|
449 |
+
return BaseModelOutputWithPooling(
|
450 |
+
last_hidden_state=last_hidden_state,
|
451 |
+
pooler_output=pooled_output,
|
452 |
+
hidden_states=encoder_outputs.hidden_states,
|
453 |
+
attentions=encoder_outputs.attentions,
|
454 |
+
)
|
455 |
+
|
456 |
+
|
457 |
+
class SigLipMultiheadAttentionPoolingHead(nn.Module):
|
458 |
+
"""Multihead Attention Pooling."""
|
459 |
+
|
460 |
+
def __init__(self, config: SigLipVisionConfig):
|
461 |
+
super().__init__()
|
462 |
+
|
463 |
+
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
464 |
+
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
465 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
466 |
+
self.mlp = SigLipMLP(config)
|
467 |
+
|
468 |
+
def forward(self, hidden_state):
|
469 |
+
batch_size = hidden_state.shape[0]
|
470 |
+
probe = self.probe.repeat(batch_size, 1, 1)
|
471 |
+
|
472 |
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
473 |
+
|
474 |
+
residual = hidden_state
|
475 |
+
hidden_state = self.layernorm(hidden_state)
|
476 |
+
hidden_state = residual + self.mlp(hidden_state)
|
477 |
+
|
478 |
+
return hidden_state[:, 0]
|
479 |
+
|
480 |
+
|
481 |
+
class SigLipVisionModel(SigLipPreTrainedModel):
|
482 |
+
config_class = SigLipVisionConfig
|
483 |
+
main_input_name = "pixel_values"
|
484 |
+
_no_split_modules = ["SigLipEncoderLayer"]
|
485 |
+
|
486 |
+
def __init__(self, config: SigLipVisionConfig):
|
487 |
+
super().__init__(config)
|
488 |
+
|
489 |
+
self.vision_model = SigLipVisionTransformer(config)
|
490 |
+
|
491 |
+
# Initialize weights and apply final processing
|
492 |
+
self.post_init()
|
493 |
+
|
494 |
+
def get_input_embeddings(self) -> nn.Module:
|
495 |
+
return self.vision_model.embeddings.patch_embedding
|
496 |
+
|
497 |
+
def forward(
|
498 |
+
self,
|
499 |
+
pixel_values,
|
500 |
+
output_attentions: Optional[bool] = None,
|
501 |
+
output_hidden_states: Optional[bool] = None,
|
502 |
+
return_dict: Optional[bool] = None,
|
503 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
504 |
+
r"""
|
505 |
+
Returns:
|
506 |
+
|
507 |
+
Examples:
|
508 |
+
|
509 |
+
```python
|
510 |
+
>>> from PIL import Image
|
511 |
+
>>> import requests
|
512 |
+
>>> from transformers import AutoProcessor, SigLipVisionModel
|
513 |
+
|
514 |
+
>>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
|
515 |
+
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
516 |
+
|
517 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
518 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
519 |
+
|
520 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
521 |
+
|
522 |
+
>>> outputs = model(**inputs)
|
523 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
524 |
+
>>> pooled_output = outputs.pooler_output # pooled features
|
525 |
+
```"""
|
526 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
527 |
+
|
528 |
+
return self.vision_model(
|
529 |
+
pixel_values=pixel_values.to(self.device),
|
530 |
+
output_attentions=output_attentions,
|
531 |
+
output_hidden_states=output_hidden_states,
|
532 |
+
return_dict=return_dict,
|
533 |
+
)
|
534 |
+
|
535 |
+
|
536 |
+
class SigLipVisionTower(nn.Module):
|
537 |
+
def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
|
538 |
+
super().__init__()
|
539 |
+
|
540 |
+
self.is_loaded = False
|
541 |
+
|
542 |
+
self.config = SigLipVisionConfig()
|
543 |
+
|
544 |
+
self.vision_tower_name = vision_tower
|
545 |
+
|
546 |
+
self.image_processor = SigLipImageProcessor()
|
547 |
+
|
548 |
+
if not delay_load:
|
549 |
+
rank0_print(f"Loading vision tower: {vision_tower}")
|
550 |
+
self.load_model()
|
551 |
+
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
|
552 |
+
# TODO: better detector is needed.
|
553 |
+
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
554 |
+
self.load_model()
|
555 |
+
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
|
556 |
+
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
557 |
+
self.load_model()
|
558 |
+
else:
|
559 |
+
self.cfg_only = self.config
|
560 |
+
|
561 |
+
def load_model(self, device_map=None):
|
562 |
+
if self.is_loaded:
|
563 |
+
return
|
564 |
+
|
565 |
+
self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
566 |
+
|
567 |
+
del self.vision_tower.vision_model.encoder.layers[-1:]
|
568 |
+
self.vision_tower.vision_model.head = nn.Identity()
|
569 |
+
self.vision_tower.requires_grad_(False)
|
570 |
+
self.vision_tower.eval()
|
571 |
+
|
572 |
+
self.is_loaded = True
|
573 |
+
|
574 |
+
@torch.no_grad()
|
575 |
+
def forward(self, images):
|
576 |
+
if type(images) is list:
|
577 |
+
image_features = []
|
578 |
+
for image in images:
|
579 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
580 |
+
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
581 |
+
assert image_features.shape[-2] == 729
|
582 |
+
image_features.append(image_feature)
|
583 |
+
else:
|
584 |
+
images=images.to(device=self.device, dtype=self.dtype)
|
585 |
+
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
586 |
+
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
587 |
+
assert image_features.shape[-2] == 729
|
588 |
+
|
589 |
+
return image_features
|
590 |
+
|
591 |
+
@property
|
592 |
+
def dummy_feature(self):
|
593 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
594 |
+
|
595 |
+
@property
|
596 |
+
def dtype(self):
|
597 |
+
for p in self.vision_tower.parameters():
|
598 |
+
return p.dtype
|
599 |
+
|
600 |
+
@property
|
601 |
+
def device(self):
|
602 |
+
for p in self.vision_tower.parameters():
|
603 |
+
return p.device
|
604 |
+
|
605 |
+
@property
|
606 |
+
def hidden_size(self):
|
607 |
+
return self.config.hidden_size
|
608 |
+
|
609 |
+
@property
|
610 |
+
def num_patches(self):
|
611 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
612 |
+
|
613 |
+
@property
|
614 |
+
def num_patches_per_side(self):
|
615 |
+
return self.config.image_size // self.config.patch_size
|
616 |
+
# return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
|
617 |
+
|
618 |
+
@property
|
619 |
+
def image_size(self):
|
620 |
+
return self.config.image_size
|
llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (2.39 kB). View file
|
|
llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc
ADDED
Binary file (1.46 kB). View file
|
|
llava/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
from .pooler_projector import PoolerProjector
|
6 |
+
|
7 |
+
|
8 |
+
class IdentityMap(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
def forward(self, x, *args, **kwargs):
|
13 |
+
return x
|
14 |
+
|
15 |
+
@property
|
16 |
+
def config(self):
|
17 |
+
return {"mm_projector_type": "identity"}
|
18 |
+
|
19 |
+
|
20 |
+
class SimpleResBlock(nn.Module):
|
21 |
+
def __init__(self, channels):
|
22 |
+
super().__init__()
|
23 |
+
self.pre_norm = nn.LayerNorm(channels)
|
24 |
+
|
25 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.pre_norm(x)
|
29 |
+
return x + self.proj(x)
|
30 |
+
|
31 |
+
|
32 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
33 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
34 |
+
|
35 |
+
if projector_type == "linear":
|
36 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
37 |
+
|
38 |
+
if projector_type == "pooler":
|
39 |
+
return PoolerProjector(config, kwargs["vision_cfg"])
|
40 |
+
|
41 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
42 |
+
if mlp_gelu_match:
|
43 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
44 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
45 |
+
for _ in range(1, mlp_depth):
|
46 |
+
modules.append(nn.GELU())
|
47 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
48 |
+
return nn.Sequential(*modules)
|
49 |
+
|
50 |
+
mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
|
51 |
+
if mlp_gelu_resnet_match:
|
52 |
+
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
53 |
+
res_depth = int(mlp_gelu_resnet_match.group(2))
|
54 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
55 |
+
for _ in range(1, mlp_depth):
|
56 |
+
modules.append(nn.GELU())
|
57 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
58 |
+
for _ in range(res_depth):
|
59 |
+
modules.append(SimpleResBlock(config.hidden_size))
|
60 |
+
return nn.Sequential(*modules)
|
61 |
+
|
62 |
+
if projector_type == "identity":
|
63 |
+
return IdentityMap()
|
64 |
+
|
65 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
llava/model/multimodal_projector/pooler_projector.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
from transformers.models.clip.modeling_clip import CLIPVisionModel
|
7 |
+
|
8 |
+
|
9 |
+
class PoolerProjector(nn.Module):
|
10 |
+
def __init__(self, config, vision_cfg):
|
11 |
+
super().__init__()
|
12 |
+
self._config = config
|
13 |
+
self.hw = vision_cfg.image_size // vision_cfg.patch_size
|
14 |
+
|
15 |
+
self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
|
16 |
+
|
17 |
+
self.proj = nn.Sequential(
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x, *args, **kwargs):
|
23 |
+
height = width = self.hw
|
24 |
+
assert height * width == x.shape[1]
|
25 |
+
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
|
26 |
+
x = self.conv_pool(x)
|
27 |
+
x = x.flatten(2).transpose(1, 2)
|
28 |
+
x = self.proj(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
@property
|
32 |
+
def config(self):
|
33 |
+
return {"mm_projector_type": "pooler"}
|
llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (1.44 kB). View file
|
|
llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc
ADDED
Binary file (2.46 kB). View file
|
|
llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc
ADDED
Binary file (4.85 kB). View file
|
|
llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc
ADDED
Binary file (32.7 kB). View file
|
|
llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
llava/model/multimodal_resampler/builder.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .masked_drop import MaskedDrop
|
4 |
+
from .spatial_pool import SpatialPool
|
5 |
+
from .perceiver import PerceiverResampler
|
6 |
+
from .qformer import Qformer
|
7 |
+
|
8 |
+
|
9 |
+
class IdentityMap(torch.nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def forward(self, x, *args, **kwargs):
|
14 |
+
return x
|
15 |
+
|
16 |
+
@property
|
17 |
+
def config(self):
|
18 |
+
return {"mm_resampler_type": None}
|
19 |
+
|
20 |
+
|
21 |
+
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
22 |
+
resampler_type = getattr(model_args, "mm_resampler_type", None)
|
23 |
+
if resampler_type == "masked_drop":
|
24 |
+
return MaskedDrop(model_args)
|
25 |
+
elif resampler_type == "spatial_pool":
|
26 |
+
return SpatialPool(model_args, **kwargs)
|
27 |
+
elif resampler_type == "perceiver":
|
28 |
+
return PerceiverResampler(model_args, **kwargs)
|
29 |
+
elif resampler_type == "qformer":
|
30 |
+
return Qformer(model_args, **kwargs)
|
31 |
+
elif resampler_type is None:
|
32 |
+
return IdentityMap()
|
33 |
+
|
34 |
+
raise ValueError(f"Unknown resampler type: {resampler_type}")
|
llava/model/multimodal_resampler/masked_drop.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import random
|
5 |
+
|
6 |
+
|
7 |
+
class MaskedDrop(nn.Module):
|
8 |
+
def __init__(self, model_args):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.mode = model_args.mm_mask_drop_mode
|
12 |
+
self.skip_percentage = model_args.mm_mask_drop_skip_percentage
|
13 |
+
self.ratio = model_args.mm_mask_drop_ratio
|
14 |
+
self.ratio_upper = model_args.mm_mask_drop_ratio_upper
|
15 |
+
self.ratio_lower = model_args.mm_mask_drop_ratio_lower
|
16 |
+
|
17 |
+
def forward(self, image_features, *args, **kwargs):
|
18 |
+
|
19 |
+
if not self.training:
|
20 |
+
return image_features
|
21 |
+
|
22 |
+
if self.skip_percentage > random.random():
|
23 |
+
return image_features
|
24 |
+
|
25 |
+
masked_features = []
|
26 |
+
|
27 |
+
for image_feature in image_features:
|
28 |
+
num_tokens = image_feature.shape[0]
|
29 |
+
if self.mode == "fixed":
|
30 |
+
num_keep = int(num_tokens * self.ratio)
|
31 |
+
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
|
32 |
+
elif self.mode == "range":
|
33 |
+
num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
|
34 |
+
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
|
35 |
+
elif self.mode == "cls_only":
|
36 |
+
masked_features.append(image_feature[0:1])
|
37 |
+
else:
|
38 |
+
raise ValueError(f"Unexpected masked drop mode: {self.mode}")
|
39 |
+
|
40 |
+
if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
|
41 |
+
masked_features = torch.stack(masked_features, dim=0)
|
42 |
+
|
43 |
+
return masked_features
|
44 |
+
|
45 |
+
@property
|
46 |
+
def config(self):
|
47 |
+
return {
|
48 |
+
"mm_resampler_type": "masked_drop",
|
49 |
+
"mm_mask_drop_mode": self.mode,
|
50 |
+
"mm_mask_drop_skip_percentage": self.skip_percentage,
|
51 |
+
"mm_mask_drop_ratio": self.ratio,
|
52 |
+
"mm_mask_drop_ratio_upper": self.ratio_upper,
|
53 |
+
"mm_mask_drop_ratio_lower": self.ratio_lower,
|
54 |
+
}
|
55 |
+
|
56 |
+
def random_masking(self, x, len_keep):
|
57 |
+
"""
|
58 |
+
Perform per-sample random masking by per-sample shuffling.
|
59 |
+
Per-sample shuffling is done by argsort random noise.
|
60 |
+
x: [N, L, D], sequence
|
61 |
+
"""
|
62 |
+
N, L, D = x.shape # batch, length, dim
|
63 |
+
|
64 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
65 |
+
|
66 |
+
# sort noise for each sample
|
67 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
68 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
69 |
+
|
70 |
+
# keep the first subset
|
71 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
72 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
73 |
+
|
74 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
75 |
+
mask = torch.ones([N, L], device=x.device)
|
76 |
+
mask[:, :len_keep] = 0
|
77 |
+
# unshuffle to get the binary mask
|
78 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
79 |
+
|
80 |
+
return x_masked, mask, ids_restore
|