pipyp commited on
Commit
246d445
·
1 Parent(s): 1d18d73
Files changed (1) hide show
  1. mini_gpt4.py +263 -0
mini_gpt4.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import random
9
+ import os
10
+ import torch
11
+ from torch.cuda.amp import autocast as autocast
12
+ import torch.nn as nn
13
+
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
16
+ from minigpt4.models.modeling_llama import LlamaForCausalLM
17
+ from transformers import LlamaTokenizer
18
+
19
+
20
+ @registry.register_model("mini_gpt4")
21
+ class MiniGPT4(Blip2Base):
22
+ """
23
+ BLIP2 GPT-LLAMA model.
24
+ """
25
+
26
+ PRETRAINED_MODEL_CONFIG_DICT = {
27
+ "pretrain_vicuna": "configs/models/minigpt4.yaml",
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ vit_model="eva_clip_g",
33
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
34
+ img_size=224,
35
+ drop_path_rate=0,
36
+ use_grad_checkpoint=False,
37
+ vit_precision="fp16",
38
+ freeze_vit=True,
39
+ freeze_qformer=True,
40
+ num_query_token=32,
41
+ llama_model="",
42
+ llama_cache_dir='',
43
+ prompt_path="",
44
+ prompt_template="",
45
+ max_txt_len=32,
46
+ end_sym='\n',
47
+ ):
48
+ super().__init__()
49
+
50
+ self.tokenizer = self.init_tokenizer()
51
+
52
+ print('Loading VIT')
53
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55
+ )
56
+ if freeze_vit:
57
+ for name, param in self.visual_encoder.named_parameters():
58
+ param.requires_grad = False
59
+ self.visual_encoder = self.visual_encoder.eval()
60
+ self.visual_encoder.train = disabled_train
61
+ for name, param in self.ln_vision.named_parameters():
62
+ param.requires_grad = False
63
+ self.ln_vision = self.ln_vision.eval()
64
+ self.ln_vision.train = disabled_train
65
+ logging.info("freeze vision encoder")
66
+ print('Loading VIT Done')
67
+
68
+ print('Loading Q-Former')
69
+ self.Qformer, self.query_tokens = self.init_Qformer(
70
+ num_query_token, self.visual_encoder.num_features
71
+ )
72
+ self.Qformer.cls = None
73
+ self.Qformer.bert.embeddings.word_embeddings = None
74
+ self.Qformer.bert.embeddings.position_embeddings = None
75
+ for layer in self.Qformer.bert.encoder.layer:
76
+ layer.output = None
77
+ layer.intermediate = None
78
+ self.load_from_pretrained(url_or_filename=q_former_model)
79
+
80
+ if freeze_qformer:
81
+ for name, param in self.Qformer.named_parameters():
82
+ param.requires_grad = False
83
+ self.Qformer = self.Qformer.eval()
84
+ self.Qformer.train = disabled_train
85
+ self.query_tokens.requires_grad = False
86
+ logging.info("freeze Qformer")
87
+ print('Loading Q-Former Done')
88
+
89
+ print('Loading LLAMA')
90
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained('camenduru/MiniGPT4', use_fast=False)
91
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92
+
93
+ if llama_cache_dir:
94
+ self.llama_model = LlamaForCausalLM.from_pretrained(
95
+ 'camenduru/MiniGPT4', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto"
96
+ )
97
+ else:
98
+ self.llama_model = LlamaForCausalLM.from_pretrained(
99
+ 'camenduru/MiniGPT4', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto"
100
+ )
101
+ for name, param in self.llama_model.named_parameters():
102
+ param.requires_grad = False
103
+ print('Loading LLAMA Done')
104
+
105
+ self.llama_proj = nn.Linear(
106
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
107
+ )
108
+ self.max_txt_len = max_txt_len
109
+ self.end_sym = end_sym
110
+
111
+ if prompt_path:
112
+ with open(prompt_path, 'r') as f:
113
+ raw_prompts = f.read().splitlines()
114
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
115
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
116
+ print('Load {} training prompts'.format(len(self.prompt_list)))
117
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
118
+ else:
119
+ self.prompt_list = []
120
+
121
+ def vit_to_cpu(self):
122
+ self.ln_vision.to("cpu")
123
+ self.ln_vision.float()
124
+ self.visual_encoder.to("cpu")
125
+ self.visual_encoder.float()
126
+
127
+ def encode_img(self, image):
128
+ device = image.device
129
+ self.vit_to_cpu()
130
+ image = image.to("cpu")
131
+ with self.maybe_autocast():
132
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
133
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
134
+
135
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
136
+ query_output = self.Qformer.bert(
137
+ query_embeds=query_tokens,
138
+ encoder_hidden_states=image_embeds,
139
+ encoder_attention_mask=image_atts,
140
+ return_dict=True,
141
+ )
142
+
143
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
144
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
145
+ return inputs_llama, atts_llama
146
+
147
+ def prompt_wrap(self, img_embeds, atts_img, prompt):
148
+ if prompt:
149
+ batch_size = img_embeds.shape[0]
150
+ p_before, p_after = prompt.split('<ImageHere>')
151
+ p_before_tokens = self.llama_tokenizer(
152
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
153
+ p_after_tokens = self.llama_tokenizer(
154
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
155
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
156
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
157
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
158
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
159
+ return wrapped_img_embeds, wrapped_atts_img
160
+ else:
161
+ return img_embeds, atts_img
162
+
163
+ def forward(self, samples):
164
+ image = samples["image"]
165
+ img_embeds, atts_img = self.encode_img(image)
166
+ if hasattr(samples, 'question_split'): # VQA dataset
167
+ print('VQA Batch')
168
+ vqa_prompt = '###Human: <Img><ImageHere></Img> '
169
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
170
+ elif self.prompt_list:
171
+ prompt = random.choice(self.prompt_list)
172
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
173
+
174
+ self.llama_tokenizer.padding_side = "right"
175
+
176
+ text = [t + self.end_sym for t in samples["text_input"]]
177
+
178
+ to_regress_tokens = self.llama_tokenizer(
179
+ text,
180
+ return_tensors="pt",
181
+ padding="longest",
182
+ truncation=True,
183
+ max_length=self.max_txt_len,
184
+ add_special_tokens=False
185
+ ).to(image.device)
186
+
187
+ targets = to_regress_tokens.input_ids.masked_fill(
188
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
189
+ )
190
+
191
+ empty_targets = (
192
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
193
+ dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
194
+ )
195
+ targets = torch.cat([empty_targets, targets], dim=1)
196
+
197
+ batch_size = img_embeds.shape[0]
198
+ bos = torch.ones([batch_size, 1],
199
+ dtype=to_regress_tokens.input_ids.dtype,
200
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
201
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
202
+ atts_bos = atts_img[:, :1]
203
+
204
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
205
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
206
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
207
+
208
+ with self.maybe_autocast():
209
+ outputs = self.llama_model(
210
+ inputs_embeds=inputs_embeds,
211
+ attention_mask=attention_mask,
212
+ return_dict=True,
213
+ labels=targets,
214
+ )
215
+ loss = outputs.loss
216
+
217
+ return {"loss": loss}
218
+
219
+ @classmethod
220
+ def from_config(cls, cfg):
221
+ vit_model = cfg.get("vit_model", "eva_clip_g")
222
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
223
+ img_size = cfg.get("image_size")
224
+ num_query_token = cfg.get("num_query_token")
225
+ llama_model = cfg.get("llama_model")
226
+
227
+ drop_path_rate = cfg.get("drop_path_rate", 0)
228
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
229
+ vit_precision = cfg.get("vit_precision", "fp16")
230
+ freeze_vit = cfg.get("freeze_vit", True)
231
+ freeze_qformer = cfg.get("freeze_qformer", True)
232
+ llama_cache_dir = cfg.get("llama_cache_dir", "")
233
+
234
+ prompt_path = cfg.get("prompt_path", "")
235
+ prompt_template = cfg.get("prompt_template", "")
236
+ max_txt_len = cfg.get("max_txt_len", 32)
237
+ end_sym = cfg.get("end_sym", '\n')
238
+
239
+ model = cls(
240
+ vit_model=vit_model,
241
+ q_former_model=q_former_model,
242
+ img_size=img_size,
243
+ drop_path_rate=drop_path_rate,
244
+ use_grad_checkpoint=use_grad_checkpoint,
245
+ vit_precision=vit_precision,
246
+ freeze_vit=freeze_vit,
247
+ freeze_qformer=freeze_qformer,
248
+ llama_cache_dir=llama_cache_dir,
249
+ num_query_token=num_query_token,
250
+ llama_model=llama_model,
251
+ prompt_path=prompt_path,
252
+ prompt_template=prompt_template,
253
+ max_txt_len=max_txt_len,
254
+ end_sym=end_sym
255
+ )
256
+
257
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
258
+ if ckpt_path:
259
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
260
+ ckpt = torch.load(ckpt_path, map_location="cpu")
261
+ msg = model.load_state_dict(ckpt['model'], strict=False)
262
+
263
+ return model