Spaces:
Runtime error
Runtime error
Update minigpt4/models/mini_gpt4.py
Browse files- minigpt4/models/mini_gpt4.py +264 -264
minigpt4/models/mini_gpt4.py
CHANGED
@@ -1,264 +1,264 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2023, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
import logging
|
8 |
-
import random
|
9 |
-
import os
|
10 |
-
import torch
|
11 |
-
from torch.cuda.amp import autocast as autocast
|
12 |
-
import torch.nn as nn
|
13 |
-
|
14 |
-
from minigpt4.common.registry import registry
|
15 |
-
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
16 |
-
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
17 |
-
from transformers import LlamaTokenizer
|
18 |
-
|
19 |
-
|
20 |
-
@registry.register_model("mini_gpt4")
|
21 |
-
class MiniGPT4(Blip2Base):
|
22 |
-
"""
|
23 |
-
BLIP2 GPT-LLAMA model.
|
24 |
-
"""
|
25 |
-
|
26 |
-
PRETRAINED_MODEL_CONFIG_DICT = {
|
27 |
-
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
28 |
-
}
|
29 |
-
|
30 |
-
def __init__(
|
31 |
-
self,
|
32 |
-
vit_model="eva_clip_g",
|
33 |
-
q_former_model="blip2_pretrained_flant5xxl.pth",
|
34 |
-
img_size=224,
|
35 |
-
drop_path_rate=0,
|
36 |
-
use_grad_checkpoint=False,
|
37 |
-
vit_precision="fp16",
|
38 |
-
freeze_vit=True,
|
39 |
-
freeze_qformer=True,
|
40 |
-
num_query_token=32,
|
41 |
-
llama_model="",
|
42 |
-
llama_cache_dir='',
|
43 |
-
prompt_path="",
|
44 |
-
prompt_template="",
|
45 |
-
max_txt_len=32,
|
46 |
-
end_sym='\n',
|
47 |
-
):
|
48 |
-
super().__init__()
|
49 |
-
|
50 |
-
self.tokenizer = self.init_tokenizer()
|
51 |
-
|
52 |
-
print('Loading VIT')
|
53 |
-
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
54 |
-
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
55 |
-
)
|
56 |
-
if freeze_vit:
|
57 |
-
for name, param in self.visual_encoder.named_parameters():
|
58 |
-
param.requires_grad = False
|
59 |
-
self.visual_encoder = self.visual_encoder.eval()
|
60 |
-
self.visual_encoder.train = disabled_train
|
61 |
-
for name, param in self.ln_vision.named_parameters():
|
62 |
-
param.requires_grad = False
|
63 |
-
self.ln_vision = self.ln_vision.eval()
|
64 |
-
self.ln_vision.train = disabled_train
|
65 |
-
logging.info("freeze vision encoder")
|
66 |
-
print('Loading VIT Done')
|
67 |
-
|
68 |
-
print('Loading Q-Former')
|
69 |
-
self.Qformer, self.query_tokens = self.init_Qformer(
|
70 |
-
num_query_token, self.visual_encoder.num_features
|
71 |
-
)
|
72 |
-
self.Qformer.cls = None
|
73 |
-
self.Qformer.bert.embeddings.word_embeddings = None
|
74 |
-
self.Qformer.bert.embeddings.position_embeddings = None
|
75 |
-
for layer in self.Qformer.bert.encoder.layer:
|
76 |
-
layer.output = None
|
77 |
-
layer.intermediate = None
|
78 |
-
self.load_from_pretrained(url_or_filename=q_former_model)
|
79 |
-
|
80 |
-
if freeze_qformer:
|
81 |
-
for name, param in self.Qformer.named_parameters():
|
82 |
-
param.requires_grad = False
|
83 |
-
self.Qformer = self.Qformer.eval()
|
84 |
-
self.Qformer.train = disabled_train
|
85 |
-
self.query_tokens.requires_grad = False
|
86 |
-
logging.info("freeze Qformer")
|
87 |
-
print('Loading Q-Former Done')
|
88 |
-
|
89 |
-
print('Loading LLAMA')
|
90 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True)
|
91 |
-
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
92 |
-
|
93 |
-
|
94 |
-
if llama_cache_dir:
|
95 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(
|
96 |
-
'
|
97 |
-
)
|
98 |
-
else:
|
99 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(
|
100 |
-
'',
|
101 |
-
)
|
102 |
-
for name, param in self.llama_model.named_parameters():
|
103 |
-
param.requires_grad = False
|
104 |
-
print('Loading LLAMA Done')
|
105 |
-
|
106 |
-
self.llama_proj = nn.Linear(
|
107 |
-
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
108 |
-
)
|
109 |
-
self.max_txt_len = max_txt_len
|
110 |
-
self.end_sym = end_sym
|
111 |
-
|
112 |
-
if prompt_path:
|
113 |
-
with open(prompt_path, 'r') as f:
|
114 |
-
raw_prompts = f.read().splitlines()
|
115 |
-
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
116 |
-
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
117 |
-
print('Load {} training prompts'.format(len(self.prompt_list)))
|
118 |
-
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
119 |
-
else:
|
120 |
-
self.prompt_list = []
|
121 |
-
|
122 |
-
def vit_to_cpu(self):
|
123 |
-
self.ln_vision.to("cpu")
|
124 |
-
self.ln_vision.float()
|
125 |
-
self.visual_encoder.to("cpu")
|
126 |
-
self.visual_encoder.float()
|
127 |
-
|
128 |
-
def encode_img(self, image):
|
129 |
-
device = image.device
|
130 |
-
self.vit_to_cpu()
|
131 |
-
image = image.to("cpu")
|
132 |
-
with self.maybe_autocast():
|
133 |
-
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
134 |
-
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
135 |
-
|
136 |
-
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
137 |
-
query_output = self.Qformer.bert(
|
138 |
-
query_embeds=query_tokens,
|
139 |
-
encoder_hidden_states=image_embeds,
|
140 |
-
encoder_attention_mask=image_atts,
|
141 |
-
return_dict=True,
|
142 |
-
)
|
143 |
-
|
144 |
-
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
145 |
-
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
146 |
-
return inputs_llama, atts_llama
|
147 |
-
|
148 |
-
def prompt_wrap(self, img_embeds, atts_img, prompt):
|
149 |
-
if prompt:
|
150 |
-
batch_size = img_embeds.shape[0]
|
151 |
-
p_before, p_after = prompt.split('<ImageHere>')
|
152 |
-
p_before_tokens = self.llama_tokenizer(
|
153 |
-
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
154 |
-
p_after_tokens = self.llama_tokenizer(
|
155 |
-
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
156 |
-
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
157 |
-
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
158 |
-
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
159 |
-
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
160 |
-
return wrapped_img_embeds, wrapped_atts_img
|
161 |
-
else:
|
162 |
-
return img_embeds, atts_img
|
163 |
-
|
164 |
-
def forward(self, samples):
|
165 |
-
image = samples["image"]
|
166 |
-
img_embeds, atts_img = self.encode_img(image)
|
167 |
-
if hasattr(samples, 'question_split'): # VQA dataset
|
168 |
-
print('VQA Batch')
|
169 |
-
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
170 |
-
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
171 |
-
elif self.prompt_list:
|
172 |
-
prompt = random.choice(self.prompt_list)
|
173 |
-
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
174 |
-
|
175 |
-
self.llama_tokenizer.padding_side = "right"
|
176 |
-
|
177 |
-
text = [t + self.end_sym for t in samples["text_input"]]
|
178 |
-
|
179 |
-
to_regress_tokens = self.llama_tokenizer(
|
180 |
-
text,
|
181 |
-
return_tensors="pt",
|
182 |
-
padding="longest",
|
183 |
-
truncation=True,
|
184 |
-
max_length=self.max_txt_len,
|
185 |
-
add_special_tokens=False
|
186 |
-
).to(image.device)
|
187 |
-
|
188 |
-
targets = to_regress_tokens.input_ids.masked_fill(
|
189 |
-
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
190 |
-
)
|
191 |
-
|
192 |
-
empty_targets = (
|
193 |
-
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
194 |
-
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
195 |
-
)
|
196 |
-
targets = torch.cat([empty_targets, targets], dim=1)
|
197 |
-
|
198 |
-
batch_size = img_embeds.shape[0]
|
199 |
-
bos = torch.ones([batch_size, 1],
|
200 |
-
dtype=to_regress_tokens.input_ids.dtype,
|
201 |
-
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
202 |
-
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
203 |
-
atts_bos = atts_img[:, :1]
|
204 |
-
|
205 |
-
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
206 |
-
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
207 |
-
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
208 |
-
|
209 |
-
with self.maybe_autocast():
|
210 |
-
outputs = self.llama_model(
|
211 |
-
inputs_embeds=inputs_embeds,
|
212 |
-
attention_mask=attention_mask,
|
213 |
-
return_dict=True,
|
214 |
-
labels=targets,
|
215 |
-
)
|
216 |
-
loss = outputs.loss
|
217 |
-
|
218 |
-
return {"loss": loss}
|
219 |
-
|
220 |
-
@classmethod
|
221 |
-
def from_config(cls, cfg):
|
222 |
-
vit_model = cfg.get("vit_model", "eva_clip_g")
|
223 |
-
q_former_model = cfg.get("q_former_model", "blip2_pretrained_flant5xxl.pth")
|
224 |
-
img_size = cfg.get("image_size")
|
225 |
-
num_query_token = cfg.get("num_query_token")
|
226 |
-
llama_model = cfg.get("llama_model")
|
227 |
-
|
228 |
-
drop_path_rate = cfg.get("drop_path_rate", 0)
|
229 |
-
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
230 |
-
vit_precision = cfg.get("vit_precision", "fp16")
|
231 |
-
freeze_vit = cfg.get("freeze_vit", True)
|
232 |
-
freeze_qformer = cfg.get("freeze_qformer", True)
|
233 |
-
llama_cache_dir = cfg.get("llama_cache_dir", "")
|
234 |
-
|
235 |
-
prompt_path = cfg.get("prompt_path", "")
|
236 |
-
prompt_template = cfg.get("prompt_template", "")
|
237 |
-
max_txt_len = cfg.get("max_txt_len", 32)
|
238 |
-
end_sym = cfg.get("end_sym", '\n')
|
239 |
-
|
240 |
-
model = cls(
|
241 |
-
vit_model=vit_model,
|
242 |
-
q_former_model=q_former_model,
|
243 |
-
img_size=img_size,
|
244 |
-
drop_path_rate=drop_path_rate,
|
245 |
-
use_grad_checkpoint=use_grad_checkpoint,
|
246 |
-
vit_precision=vit_precision,
|
247 |
-
freeze_vit=freeze_vit,
|
248 |
-
freeze_qformer=freeze_qformer,
|
249 |
-
llama_cache_dir=llama_cache_dir,
|
250 |
-
num_query_token=num_query_token,
|
251 |
-
llama_model=llama_model,
|
252 |
-
prompt_path=prompt_path,
|
253 |
-
prompt_template=prompt_template,
|
254 |
-
max_txt_len=max_txt_len,
|
255 |
-
end_sym=end_sym
|
256 |
-
)
|
257 |
-
|
258 |
-
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
259 |
-
if ckpt_path:
|
260 |
-
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
261 |
-
ckpt = torch.load(ckpt_path, map_location="cpu")
|
262 |
-
msg = model.load_state_dict(ckpt['model'], strict=False)
|
263 |
-
|
264 |
-
return model
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import random
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
from torch.cuda.amp import autocast as autocast
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from minigpt4.common.registry import registry
|
15 |
+
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
16 |
+
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
17 |
+
from transformers import LlamaTokenizer
|
18 |
+
|
19 |
+
|
20 |
+
@registry.register_model("mini_gpt4")
|
21 |
+
class MiniGPT4(Blip2Base):
|
22 |
+
"""
|
23 |
+
BLIP2 GPT-LLAMA model.
|
24 |
+
"""
|
25 |
+
|
26 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
27 |
+
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
28 |
+
}
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
vit_model="eva_clip_g",
|
33 |
+
q_former_model="blip2_pretrained_flant5xxl.pth",
|
34 |
+
img_size=224,
|
35 |
+
drop_path_rate=0,
|
36 |
+
use_grad_checkpoint=False,
|
37 |
+
vit_precision="fp16",
|
38 |
+
freeze_vit=True,
|
39 |
+
freeze_qformer=True,
|
40 |
+
num_query_token=32,
|
41 |
+
llama_model="",
|
42 |
+
llama_cache_dir='',
|
43 |
+
prompt_path="",
|
44 |
+
prompt_template="",
|
45 |
+
max_txt_len=32,
|
46 |
+
end_sym='\n',
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.tokenizer = self.init_tokenizer()
|
51 |
+
|
52 |
+
print('Loading VIT')
|
53 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
54 |
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
55 |
+
)
|
56 |
+
if freeze_vit:
|
57 |
+
for name, param in self.visual_encoder.named_parameters():
|
58 |
+
param.requires_grad = False
|
59 |
+
self.visual_encoder = self.visual_encoder.eval()
|
60 |
+
self.visual_encoder.train = disabled_train
|
61 |
+
for name, param in self.ln_vision.named_parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
self.ln_vision = self.ln_vision.eval()
|
64 |
+
self.ln_vision.train = disabled_train
|
65 |
+
logging.info("freeze vision encoder")
|
66 |
+
print('Loading VIT Done')
|
67 |
+
|
68 |
+
print('Loading Q-Former')
|
69 |
+
self.Qformer, self.query_tokens = self.init_Qformer(
|
70 |
+
num_query_token, self.visual_encoder.num_features
|
71 |
+
)
|
72 |
+
self.Qformer.cls = None
|
73 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
74 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
75 |
+
for layer in self.Qformer.bert.encoder.layer:
|
76 |
+
layer.output = None
|
77 |
+
layer.intermediate = None
|
78 |
+
self.load_from_pretrained(url_or_filename=q_former_model)
|
79 |
+
|
80 |
+
if freeze_qformer:
|
81 |
+
for name, param in self.Qformer.named_parameters():
|
82 |
+
param.requires_grad = False
|
83 |
+
self.Qformer = self.Qformer.eval()
|
84 |
+
self.Qformer.train = disabled_train
|
85 |
+
self.query_tokens.requires_grad = False
|
86 |
+
logging.info("freeze Qformer")
|
87 |
+
print('Loading Q-Former Done')
|
88 |
+
|
89 |
+
print('Loading LLAMA')
|
90 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True)
|
91 |
+
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
92 |
+
|
93 |
+
|
94 |
+
if llama_cache_dir:
|
95 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
96 |
+
'Vision-CAIR/vicuna-7b', load_in_4bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
100 |
+
'Vision-CAIR/vicuna-7b', load_in_4bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
|
101 |
+
)
|
102 |
+
for name, param in self.llama_model.named_parameters():
|
103 |
+
param.requires_grad = False
|
104 |
+
print('Loading LLAMA Done')
|
105 |
+
|
106 |
+
self.llama_proj = nn.Linear(
|
107 |
+
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
108 |
+
)
|
109 |
+
self.max_txt_len = max_txt_len
|
110 |
+
self.end_sym = end_sym
|
111 |
+
|
112 |
+
if prompt_path:
|
113 |
+
with open(prompt_path, 'r') as f:
|
114 |
+
raw_prompts = f.read().splitlines()
|
115 |
+
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
116 |
+
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
117 |
+
print('Load {} training prompts'.format(len(self.prompt_list)))
|
118 |
+
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
119 |
+
else:
|
120 |
+
self.prompt_list = []
|
121 |
+
|
122 |
+
def vit_to_cpu(self):
|
123 |
+
self.ln_vision.to("cpu")
|
124 |
+
self.ln_vision.float()
|
125 |
+
self.visual_encoder.to("cpu")
|
126 |
+
self.visual_encoder.float()
|
127 |
+
|
128 |
+
def encode_img(self, image):
|
129 |
+
device = image.device
|
130 |
+
self.vit_to_cpu()
|
131 |
+
image = image.to("cpu")
|
132 |
+
with self.maybe_autocast():
|
133 |
+
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
134 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
135 |
+
|
136 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
137 |
+
query_output = self.Qformer.bert(
|
138 |
+
query_embeds=query_tokens,
|
139 |
+
encoder_hidden_states=image_embeds,
|
140 |
+
encoder_attention_mask=image_atts,
|
141 |
+
return_dict=True,
|
142 |
+
)
|
143 |
+
|
144 |
+
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
145 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
146 |
+
return inputs_llama, atts_llama
|
147 |
+
|
148 |
+
def prompt_wrap(self, img_embeds, atts_img, prompt):
|
149 |
+
if prompt:
|
150 |
+
batch_size = img_embeds.shape[0]
|
151 |
+
p_before, p_after = prompt.split('<ImageHere>')
|
152 |
+
p_before_tokens = self.llama_tokenizer(
|
153 |
+
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
154 |
+
p_after_tokens = self.llama_tokenizer(
|
155 |
+
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
156 |
+
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
157 |
+
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
158 |
+
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
159 |
+
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
160 |
+
return wrapped_img_embeds, wrapped_atts_img
|
161 |
+
else:
|
162 |
+
return img_embeds, atts_img
|
163 |
+
|
164 |
+
def forward(self, samples):
|
165 |
+
image = samples["image"]
|
166 |
+
img_embeds, atts_img = self.encode_img(image)
|
167 |
+
if hasattr(samples, 'question_split'): # VQA dataset
|
168 |
+
print('VQA Batch')
|
169 |
+
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
170 |
+
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
171 |
+
elif self.prompt_list:
|
172 |
+
prompt = random.choice(self.prompt_list)
|
173 |
+
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
174 |
+
|
175 |
+
self.llama_tokenizer.padding_side = "right"
|
176 |
+
|
177 |
+
text = [t + self.end_sym for t in samples["text_input"]]
|
178 |
+
|
179 |
+
to_regress_tokens = self.llama_tokenizer(
|
180 |
+
text,
|
181 |
+
return_tensors="pt",
|
182 |
+
padding="longest",
|
183 |
+
truncation=True,
|
184 |
+
max_length=self.max_txt_len,
|
185 |
+
add_special_tokens=False
|
186 |
+
).to(image.device)
|
187 |
+
|
188 |
+
targets = to_regress_tokens.input_ids.masked_fill(
|
189 |
+
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
190 |
+
)
|
191 |
+
|
192 |
+
empty_targets = (
|
193 |
+
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
194 |
+
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
195 |
+
)
|
196 |
+
targets = torch.cat([empty_targets, targets], dim=1)
|
197 |
+
|
198 |
+
batch_size = img_embeds.shape[0]
|
199 |
+
bos = torch.ones([batch_size, 1],
|
200 |
+
dtype=to_regress_tokens.input_ids.dtype,
|
201 |
+
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
202 |
+
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
203 |
+
atts_bos = atts_img[:, :1]
|
204 |
+
|
205 |
+
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
206 |
+
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
207 |
+
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
208 |
+
|
209 |
+
with self.maybe_autocast():
|
210 |
+
outputs = self.llama_model(
|
211 |
+
inputs_embeds=inputs_embeds,
|
212 |
+
attention_mask=attention_mask,
|
213 |
+
return_dict=True,
|
214 |
+
labels=targets,
|
215 |
+
)
|
216 |
+
loss = outputs.loss
|
217 |
+
|
218 |
+
return {"loss": loss}
|
219 |
+
|
220 |
+
@classmethod
|
221 |
+
def from_config(cls, cfg):
|
222 |
+
vit_model = cfg.get("vit_model", "eva_clip_g")
|
223 |
+
q_former_model = cfg.get("q_former_model", "blip2_pretrained_flant5xxl.pth")
|
224 |
+
img_size = cfg.get("image_size")
|
225 |
+
num_query_token = cfg.get("num_query_token")
|
226 |
+
llama_model = cfg.get("llama_model")
|
227 |
+
|
228 |
+
drop_path_rate = cfg.get("drop_path_rate", 0)
|
229 |
+
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
230 |
+
vit_precision = cfg.get("vit_precision", "fp16")
|
231 |
+
freeze_vit = cfg.get("freeze_vit", True)
|
232 |
+
freeze_qformer = cfg.get("freeze_qformer", True)
|
233 |
+
llama_cache_dir = cfg.get("llama_cache_dir", "")
|
234 |
+
|
235 |
+
prompt_path = cfg.get("prompt_path", "")
|
236 |
+
prompt_template = cfg.get("prompt_template", "")
|
237 |
+
max_txt_len = cfg.get("max_txt_len", 32)
|
238 |
+
end_sym = cfg.get("end_sym", '\n')
|
239 |
+
|
240 |
+
model = cls(
|
241 |
+
vit_model=vit_model,
|
242 |
+
q_former_model=q_former_model,
|
243 |
+
img_size=img_size,
|
244 |
+
drop_path_rate=drop_path_rate,
|
245 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
246 |
+
vit_precision=vit_precision,
|
247 |
+
freeze_vit=freeze_vit,
|
248 |
+
freeze_qformer=freeze_qformer,
|
249 |
+
llama_cache_dir=llama_cache_dir,
|
250 |
+
num_query_token=num_query_token,
|
251 |
+
llama_model=llama_model,
|
252 |
+
prompt_path=prompt_path,
|
253 |
+
prompt_template=prompt_template,
|
254 |
+
max_txt_len=max_txt_len,
|
255 |
+
end_sym=end_sym
|
256 |
+
)
|
257 |
+
|
258 |
+
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
259 |
+
if ckpt_path:
|
260 |
+
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
261 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
262 |
+
msg = model.load_state_dict(ckpt['model'], strict=False)
|
263 |
+
|
264 |
+
return model
|