saicharan1234 commited on
Commit
142f387
·
verified ·
1 Parent(s): f1dd03d

Update minigpt4/models/mini_gpt4.py

Browse files
Files changed (1) hide show
  1. minigpt4/models/mini_gpt4.py +264 -264
minigpt4/models/mini_gpt4.py CHANGED
@@ -1,264 +1,264 @@
1
- """
2
- Copyright (c) 2023, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
- import logging
8
- import random
9
- import os
10
- import torch
11
- from torch.cuda.amp import autocast as autocast
12
- import torch.nn as nn
13
-
14
- from minigpt4.common.registry import registry
15
- from minigpt4.models.blip2 import Blip2Base, disabled_train
16
- from minigpt4.models.modeling_llama import LlamaForCausalLM
17
- from transformers import LlamaTokenizer
18
-
19
-
20
- @registry.register_model("mini_gpt4")
21
- class MiniGPT4(Blip2Base):
22
- """
23
- BLIP2 GPT-LLAMA model.
24
- """
25
-
26
- PRETRAINED_MODEL_CONFIG_DICT = {
27
- "pretrain_vicuna": "configs/models/minigpt4.yaml",
28
- }
29
-
30
- def __init__(
31
- self,
32
- vit_model="eva_clip_g",
33
- q_former_model="blip2_pretrained_flant5xxl.pth",
34
- img_size=224,
35
- drop_path_rate=0,
36
- use_grad_checkpoint=False,
37
- vit_precision="fp16",
38
- freeze_vit=True,
39
- freeze_qformer=True,
40
- num_query_token=32,
41
- llama_model="",
42
- llama_cache_dir='',
43
- prompt_path="",
44
- prompt_template="",
45
- max_txt_len=32,
46
- end_sym='\n',
47
- ):
48
- super().__init__()
49
-
50
- self.tokenizer = self.init_tokenizer()
51
-
52
- print('Loading VIT')
53
- self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54
- vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55
- )
56
- if freeze_vit:
57
- for name, param in self.visual_encoder.named_parameters():
58
- param.requires_grad = False
59
- self.visual_encoder = self.visual_encoder.eval()
60
- self.visual_encoder.train = disabled_train
61
- for name, param in self.ln_vision.named_parameters():
62
- param.requires_grad = False
63
- self.ln_vision = self.ln_vision.eval()
64
- self.ln_vision.train = disabled_train
65
- logging.info("freeze vision encoder")
66
- print('Loading VIT Done')
67
-
68
- print('Loading Q-Former')
69
- self.Qformer, self.query_tokens = self.init_Qformer(
70
- num_query_token, self.visual_encoder.num_features
71
- )
72
- self.Qformer.cls = None
73
- self.Qformer.bert.embeddings.word_embeddings = None
74
- self.Qformer.bert.embeddings.position_embeddings = None
75
- for layer in self.Qformer.bert.encoder.layer:
76
- layer.output = None
77
- layer.intermediate = None
78
- self.load_from_pretrained(url_or_filename=q_former_model)
79
-
80
- if freeze_qformer:
81
- for name, param in self.Qformer.named_parameters():
82
- param.requires_grad = False
83
- self.Qformer = self.Qformer.eval()
84
- self.Qformer.train = disabled_train
85
- self.query_tokens.requires_grad = False
86
- logging.info("freeze Qformer")
87
- print('Loading Q-Former Done')
88
-
89
- print('Loading LLAMA')
90
- self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True)
91
- self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92
-
93
-
94
- if llama_cache_dir:
95
- self.llama_model = LlamaForCausalLM.from_pretrained(
96
- 'meta-llama/Llama-2-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
97
- )
98
- else:
99
- self.llama_model = LlamaForCausalLM.from_pretrained(
100
- '', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
101
- )
102
- for name, param in self.llama_model.named_parameters():
103
- param.requires_grad = False
104
- print('Loading LLAMA Done')
105
-
106
- self.llama_proj = nn.Linear(
107
- self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
108
- )
109
- self.max_txt_len = max_txt_len
110
- self.end_sym = end_sym
111
-
112
- if prompt_path:
113
- with open(prompt_path, 'r') as f:
114
- raw_prompts = f.read().splitlines()
115
- filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
116
- self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
117
- print('Load {} training prompts'.format(len(self.prompt_list)))
118
- print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
119
- else:
120
- self.prompt_list = []
121
-
122
- def vit_to_cpu(self):
123
- self.ln_vision.to("cpu")
124
- self.ln_vision.float()
125
- self.visual_encoder.to("cpu")
126
- self.visual_encoder.float()
127
-
128
- def encode_img(self, image):
129
- device = image.device
130
- self.vit_to_cpu()
131
- image = image.to("cpu")
132
- with self.maybe_autocast():
133
- image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
134
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
135
-
136
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
137
- query_output = self.Qformer.bert(
138
- query_embeds=query_tokens,
139
- encoder_hidden_states=image_embeds,
140
- encoder_attention_mask=image_atts,
141
- return_dict=True,
142
- )
143
-
144
- inputs_llama = self.llama_proj(query_output.last_hidden_state)
145
- atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
146
- return inputs_llama, atts_llama
147
-
148
- def prompt_wrap(self, img_embeds, atts_img, prompt):
149
- if prompt:
150
- batch_size = img_embeds.shape[0]
151
- p_before, p_after = prompt.split('<ImageHere>')
152
- p_before_tokens = self.llama_tokenizer(
153
- p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
154
- p_after_tokens = self.llama_tokenizer(
155
- p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
156
- p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
157
- p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
158
- wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
159
- wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
160
- return wrapped_img_embeds, wrapped_atts_img
161
- else:
162
- return img_embeds, atts_img
163
-
164
- def forward(self, samples):
165
- image = samples["image"]
166
- img_embeds, atts_img = self.encode_img(image)
167
- if hasattr(samples, 'question_split'): # VQA dataset
168
- print('VQA Batch')
169
- vqa_prompt = '###Human: <Img><ImageHere></Img> '
170
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
171
- elif self.prompt_list:
172
- prompt = random.choice(self.prompt_list)
173
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
174
-
175
- self.llama_tokenizer.padding_side = "right"
176
-
177
- text = [t + self.end_sym for t in samples["text_input"]]
178
-
179
- to_regress_tokens = self.llama_tokenizer(
180
- text,
181
- return_tensors="pt",
182
- padding="longest",
183
- truncation=True,
184
- max_length=self.max_txt_len,
185
- add_special_tokens=False
186
- ).to(image.device)
187
-
188
- targets = to_regress_tokens.input_ids.masked_fill(
189
- to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
190
- )
191
-
192
- empty_targets = (
193
- torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
194
- dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
195
- )
196
- targets = torch.cat([empty_targets, targets], dim=1)
197
-
198
- batch_size = img_embeds.shape[0]
199
- bos = torch.ones([batch_size, 1],
200
- dtype=to_regress_tokens.input_ids.dtype,
201
- device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
202
- bos_embeds = self.llama_model.model.embed_tokens(bos)
203
- atts_bos = atts_img[:, :1]
204
-
205
- to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
206
- inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
207
- attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
208
-
209
- with self.maybe_autocast():
210
- outputs = self.llama_model(
211
- inputs_embeds=inputs_embeds,
212
- attention_mask=attention_mask,
213
- return_dict=True,
214
- labels=targets,
215
- )
216
- loss = outputs.loss
217
-
218
- return {"loss": loss}
219
-
220
- @classmethod
221
- def from_config(cls, cfg):
222
- vit_model = cfg.get("vit_model", "eva_clip_g")
223
- q_former_model = cfg.get("q_former_model", "blip2_pretrained_flant5xxl.pth")
224
- img_size = cfg.get("image_size")
225
- num_query_token = cfg.get("num_query_token")
226
- llama_model = cfg.get("llama_model")
227
-
228
- drop_path_rate = cfg.get("drop_path_rate", 0)
229
- use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
230
- vit_precision = cfg.get("vit_precision", "fp16")
231
- freeze_vit = cfg.get("freeze_vit", True)
232
- freeze_qformer = cfg.get("freeze_qformer", True)
233
- llama_cache_dir = cfg.get("llama_cache_dir", "")
234
-
235
- prompt_path = cfg.get("prompt_path", "")
236
- prompt_template = cfg.get("prompt_template", "")
237
- max_txt_len = cfg.get("max_txt_len", 32)
238
- end_sym = cfg.get("end_sym", '\n')
239
-
240
- model = cls(
241
- vit_model=vit_model,
242
- q_former_model=q_former_model,
243
- img_size=img_size,
244
- drop_path_rate=drop_path_rate,
245
- use_grad_checkpoint=use_grad_checkpoint,
246
- vit_precision=vit_precision,
247
- freeze_vit=freeze_vit,
248
- freeze_qformer=freeze_qformer,
249
- llama_cache_dir=llama_cache_dir,
250
- num_query_token=num_query_token,
251
- llama_model=llama_model,
252
- prompt_path=prompt_path,
253
- prompt_template=prompt_template,
254
- max_txt_len=max_txt_len,
255
- end_sym=end_sym
256
- )
257
-
258
- ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
259
- if ckpt_path:
260
- print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
261
- ckpt = torch.load(ckpt_path, map_location="cpu")
262
- msg = model.load_state_dict(ckpt['model'], strict=False)
263
-
264
- return model
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import random
9
+ import os
10
+ import torch
11
+ from torch.cuda.amp import autocast as autocast
12
+ import torch.nn as nn
13
+
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
16
+ from minigpt4.models.modeling_llama import LlamaForCausalLM
17
+ from transformers import LlamaTokenizer
18
+
19
+
20
+ @registry.register_model("mini_gpt4")
21
+ class MiniGPT4(Blip2Base):
22
+ """
23
+ BLIP2 GPT-LLAMA model.
24
+ """
25
+
26
+ PRETRAINED_MODEL_CONFIG_DICT = {
27
+ "pretrain_vicuna": "configs/models/minigpt4.yaml",
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ vit_model="eva_clip_g",
33
+ q_former_model="blip2_pretrained_flant5xxl.pth",
34
+ img_size=224,
35
+ drop_path_rate=0,
36
+ use_grad_checkpoint=False,
37
+ vit_precision="fp16",
38
+ freeze_vit=True,
39
+ freeze_qformer=True,
40
+ num_query_token=32,
41
+ llama_model="",
42
+ llama_cache_dir='',
43
+ prompt_path="",
44
+ prompt_template="",
45
+ max_txt_len=32,
46
+ end_sym='\n',
47
+ ):
48
+ super().__init__()
49
+
50
+ self.tokenizer = self.init_tokenizer()
51
+
52
+ print('Loading VIT')
53
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
54
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
55
+ )
56
+ if freeze_vit:
57
+ for name, param in self.visual_encoder.named_parameters():
58
+ param.requires_grad = False
59
+ self.visual_encoder = self.visual_encoder.eval()
60
+ self.visual_encoder.train = disabled_train
61
+ for name, param in self.ln_vision.named_parameters():
62
+ param.requires_grad = False
63
+ self.ln_vision = self.ln_vision.eval()
64
+ self.ln_vision.train = disabled_train
65
+ logging.info("freeze vision encoder")
66
+ print('Loading VIT Done')
67
+
68
+ print('Loading Q-Former')
69
+ self.Qformer, self.query_tokens = self.init_Qformer(
70
+ num_query_token, self.visual_encoder.num_features
71
+ )
72
+ self.Qformer.cls = None
73
+ self.Qformer.bert.embeddings.word_embeddings = None
74
+ self.Qformer.bert.embeddings.position_embeddings = None
75
+ for layer in self.Qformer.bert.encoder.layer:
76
+ layer.output = None
77
+ layer.intermediate = None
78
+ self.load_from_pretrained(url_or_filename=q_former_model)
79
+
80
+ if freeze_qformer:
81
+ for name, param in self.Qformer.named_parameters():
82
+ param.requires_grad = False
83
+ self.Qformer = self.Qformer.eval()
84
+ self.Qformer.train = disabled_train
85
+ self.query_tokens.requires_grad = False
86
+ logging.info("freeze Qformer")
87
+ print('Loading Q-Former Done')
88
+
89
+ print('Loading LLAMA')
90
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True)
91
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
92
+
93
+
94
+ if llama_cache_dir:
95
+ self.llama_model = LlamaForCausalLM.from_pretrained(
96
+ 'Vision-CAIR/vicuna-7b', load_in_4bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
97
+ )
98
+ else:
99
+ self.llama_model = LlamaForCausalLM.from_pretrained(
100
+ 'Vision-CAIR/vicuna-7b', load_in_4bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
101
+ )
102
+ for name, param in self.llama_model.named_parameters():
103
+ param.requires_grad = False
104
+ print('Loading LLAMA Done')
105
+
106
+ self.llama_proj = nn.Linear(
107
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
108
+ )
109
+ self.max_txt_len = max_txt_len
110
+ self.end_sym = end_sym
111
+
112
+ if prompt_path:
113
+ with open(prompt_path, 'r') as f:
114
+ raw_prompts = f.read().splitlines()
115
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
116
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
117
+ print('Load {} training prompts'.format(len(self.prompt_list)))
118
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
119
+ else:
120
+ self.prompt_list = []
121
+
122
+ def vit_to_cpu(self):
123
+ self.ln_vision.to("cpu")
124
+ self.ln_vision.float()
125
+ self.visual_encoder.to("cpu")
126
+ self.visual_encoder.float()
127
+
128
+ def encode_img(self, image):
129
+ device = image.device
130
+ self.vit_to_cpu()
131
+ image = image.to("cpu")
132
+ with self.maybe_autocast():
133
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
134
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
135
+
136
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
137
+ query_output = self.Qformer.bert(
138
+ query_embeds=query_tokens,
139
+ encoder_hidden_states=image_embeds,
140
+ encoder_attention_mask=image_atts,
141
+ return_dict=True,
142
+ )
143
+
144
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
145
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
146
+ return inputs_llama, atts_llama
147
+
148
+ def prompt_wrap(self, img_embeds, atts_img, prompt):
149
+ if prompt:
150
+ batch_size = img_embeds.shape[0]
151
+ p_before, p_after = prompt.split('<ImageHere>')
152
+ p_before_tokens = self.llama_tokenizer(
153
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
154
+ p_after_tokens = self.llama_tokenizer(
155
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
156
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
157
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
158
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
159
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
160
+ return wrapped_img_embeds, wrapped_atts_img
161
+ else:
162
+ return img_embeds, atts_img
163
+
164
+ def forward(self, samples):
165
+ image = samples["image"]
166
+ img_embeds, atts_img = self.encode_img(image)
167
+ if hasattr(samples, 'question_split'): # VQA dataset
168
+ print('VQA Batch')
169
+ vqa_prompt = '###Human: <Img><ImageHere></Img> '
170
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
171
+ elif self.prompt_list:
172
+ prompt = random.choice(self.prompt_list)
173
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
174
+
175
+ self.llama_tokenizer.padding_side = "right"
176
+
177
+ text = [t + self.end_sym for t in samples["text_input"]]
178
+
179
+ to_regress_tokens = self.llama_tokenizer(
180
+ text,
181
+ return_tensors="pt",
182
+ padding="longest",
183
+ truncation=True,
184
+ max_length=self.max_txt_len,
185
+ add_special_tokens=False
186
+ ).to(image.device)
187
+
188
+ targets = to_regress_tokens.input_ids.masked_fill(
189
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
190
+ )
191
+
192
+ empty_targets = (
193
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
194
+ dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
195
+ )
196
+ targets = torch.cat([empty_targets, targets], dim=1)
197
+
198
+ batch_size = img_embeds.shape[0]
199
+ bos = torch.ones([batch_size, 1],
200
+ dtype=to_regress_tokens.input_ids.dtype,
201
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
202
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
203
+ atts_bos = atts_img[:, :1]
204
+
205
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
206
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
207
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
208
+
209
+ with self.maybe_autocast():
210
+ outputs = self.llama_model(
211
+ inputs_embeds=inputs_embeds,
212
+ attention_mask=attention_mask,
213
+ return_dict=True,
214
+ labels=targets,
215
+ )
216
+ loss = outputs.loss
217
+
218
+ return {"loss": loss}
219
+
220
+ @classmethod
221
+ def from_config(cls, cfg):
222
+ vit_model = cfg.get("vit_model", "eva_clip_g")
223
+ q_former_model = cfg.get("q_former_model", "blip2_pretrained_flant5xxl.pth")
224
+ img_size = cfg.get("image_size")
225
+ num_query_token = cfg.get("num_query_token")
226
+ llama_model = cfg.get("llama_model")
227
+
228
+ drop_path_rate = cfg.get("drop_path_rate", 0)
229
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
230
+ vit_precision = cfg.get("vit_precision", "fp16")
231
+ freeze_vit = cfg.get("freeze_vit", True)
232
+ freeze_qformer = cfg.get("freeze_qformer", True)
233
+ llama_cache_dir = cfg.get("llama_cache_dir", "")
234
+
235
+ prompt_path = cfg.get("prompt_path", "")
236
+ prompt_template = cfg.get("prompt_template", "")
237
+ max_txt_len = cfg.get("max_txt_len", 32)
238
+ end_sym = cfg.get("end_sym", '\n')
239
+
240
+ model = cls(
241
+ vit_model=vit_model,
242
+ q_former_model=q_former_model,
243
+ img_size=img_size,
244
+ drop_path_rate=drop_path_rate,
245
+ use_grad_checkpoint=use_grad_checkpoint,
246
+ vit_precision=vit_precision,
247
+ freeze_vit=freeze_vit,
248
+ freeze_qformer=freeze_qformer,
249
+ llama_cache_dir=llama_cache_dir,
250
+ num_query_token=num_query_token,
251
+ llama_model=llama_model,
252
+ prompt_path=prompt_path,
253
+ prompt_template=prompt_template,
254
+ max_txt_len=max_txt_len,
255
+ end_sym=end_sym
256
+ )
257
+
258
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
259
+ if ckpt_path:
260
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
261
+ ckpt = torch.load(ckpt_path, map_location="cpu")
262
+ msg = model.load_state_dict(ckpt['model'], strict=False)
263
+
264
+ return model