Singularity666 commited on
Commit
17a5ce9
·
verified ·
1 Parent(s): eae59cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -86
app.py CHANGED
@@ -1,16 +1,10 @@
1
  from typing import List, Optional, Tuple, Union
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torch.nn import CrossEntropyLoss
7
-
8
- from transformers import AutoConfig, AutoModelForCausalLM, \
9
- LlamaConfig, LlamaModel, LlamaForCausalLM, \
10
- CLIPVisionModel, CLIPImageProcessor
11
-
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
-
14
  import os, diffusers
15
 
16
  DEFAULT_IMAGE_TOKEN = "<image>"
@@ -26,10 +20,8 @@ class LlavaLlamaModel(LlamaModel):
26
 
27
  def __init__(self, config: LlamaConfig):
28
  super(LlavaLlamaModel, self).__init__(config)
29
-
30
  if hasattr(config, "mm_vision_tower"):
31
  self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
32
-
33
  if hasattr(config, "use_mm_proj"):
34
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
35
 
@@ -39,42 +31,29 @@ class LlavaLlamaModel(LlamaModel):
39
  vision_tower = vision_tower[0]
40
  return vision_tower
41
 
42
- def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
43
- pretrain_mm_mlp_adapter=None, fsdp=None):
44
  self.config.mm_vision_tower = vision_tower
45
-
46
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
47
-
48
  if not hasattr(self, 'vision_tower'):
49
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
50
  else:
51
  vision_tower = self.vision_tower[0]
52
  vision_tower.requires_grad_(False)
53
-
54
  if fsdp is not None and len(fsdp) > 0:
55
  self.vision_tower = [vision_tower]
56
  else:
57
  self.vision_tower = vision_tower
58
-
59
  vision_config = vision_tower.config
60
  num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
61
-
62
  self.config.use_mm_proj = True
63
  self.config.mm_hidden_size = vision_config.hidden_size
64
  self.config.mm_vision_select_layer = mm_vision_select_layer
65
-
66
  if not hasattr(self, 'mm_projector'):
67
  self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
68
-
69
  if pretrain_mm_mlp_adapter is not None:
70
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
71
  self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
72
-
73
- return dict(
74
- image_processor=image_processor,
75
- image_token_len=num_patches,
76
- vision_config=vision_config
77
- )
78
 
79
  def forward(
80
  self,
@@ -88,12 +67,9 @@ class LlavaLlamaModel(LlamaModel):
88
  images: Optional[torch.FloatTensor] = None,
89
  return_dict: Optional[bool] = None,
90
  ) -> Union[Tuple, BaseModelOutputWithPast]:
91
-
92
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
93
-
94
  if inputs_embeds is None:
95
  inputs_embeds = self.embed_tokens(input_ids)
96
-
97
  vision_tower = self.get_vision_tower()
98
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
99
  with torch.no_grad():
@@ -116,7 +92,6 @@ class LlavaLlamaModel(LlamaModel):
116
  image_features = self.mm_projector(image_features)
117
  dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
118
  dummy_image_features = self.mm_projector(dummy_image_features)
119
-
120
  new_input_embeds = []
121
  cur_image_idx = 0
122
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
@@ -158,7 +133,6 @@ class LlavaLlamaModel(LlamaModel):
158
  new_input_embeds.append(cur_new_input_embeds)
159
  cur_image_idx += 1
160
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
161
-
162
  return super(LlavaLlamaModel, self).forward(
163
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
164
  inputs_embeds=inputs_embeds, use_cache=use_cache,
@@ -169,7 +143,6 @@ class LlavaLlamaModel(LlamaModel):
169
  class EditMapper(nn.Module):
170
  def __init__(self):
171
  super().__init__()
172
-
173
  self.llm2hid = nn.Linear(4096, 512)
174
  self.query = nn.Parameter(torch.randn(1, 77, 512))
175
  self.mapper = nn.Transformer(batch_first=True, norm_first=True,
@@ -178,10 +151,9 @@ class EditMapper(nn.Module):
178
  self.hid2feat = nn.Linear(512, 768)
179
 
180
  def forward(self, llm, emb):
181
- hid = self.llm2hid(llm + emb)
182
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
183
  feat = self.hid2feat(hid)
184
-
185
  return feat
186
 
187
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
@@ -190,11 +162,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
190
  def __init__(self, config):
191
  super(LlamaForCausalLM, self).__init__(config)
192
  self.model = LlavaLlamaModel(config)
193
-
194
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
195
-
196
  self.edit_head = EditMapper()
197
-
198
  self.scheduler, self.vae, self.unet = [
199
  diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
200
  diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
@@ -207,14 +176,17 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
207
  conv.weight.zero_()
208
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
209
  self.unet.conv_in = conv
210
-
211
  self.post_init()
212
 
213
  def get_model(self):
214
  return self.model
215
 
216
  def get_vision_tower(self):
217
- return self.get_model().get_vision_tower()
 
 
 
 
218
 
219
  def forward(
220
  self,
@@ -231,9 +203,10 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
231
  p2p_inp=None, p2p_ans=None
232
  ) -> Union[Tuple, CausalLMOutputWithPast]:
233
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
234
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
235
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236
-
237
  outputs = self.model(
238
  input_ids=input_ids,
239
  attention_mask=attention_mask,
@@ -245,10 +218,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
245
  return_dict=return_dict,
246
  images=images
247
  )
248
-
249
  hidden_states = outputs[0]
250
  logits = self.lm_head(hidden_states)
251
-
252
  loss = None
253
  if labels is not None:
254
  shift_logits = logits[..., :-1, :].contiguous()
@@ -258,65 +229,47 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
258
  shift_labels = shift_labels.view(-1)
259
  shift_labels = shift_labels.to(shift_logits.device)
260
  loss = loss_fct(shift_logits, shift_labels)
261
-
262
  if labels is not None:
263
  llm = []
264
  for i in range(labels.shape[0]):
265
- try:
266
- p = labels[i].data.cpu().tolist().index(32003) - 1
267
- except:
268
- p = len(labels[i]) - 9
269
- p = min(len(hidden_states[i]) - 9, p)
270
- llm.append(hidden_states[i][p:p + 8].unsqueeze(0))
271
  llm = torch.cat(llm, dim=0)
272
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
273
-
274
  B, DROP = labels.shape[0], 0.05
275
-
276
- hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
277
- self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
278
-
279
  with torch.no_grad():
280
- lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample() * self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
281
- lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
282
- torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
283
-
284
  noise = torch.randn_like(lat_ans)
285
  ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device).long()
286
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
287
-
288
  prob = torch.rand(B, device=lat_ans.device)
289
- mask = (prob < (DROP * 2)).reshape(B, 1, 1)
290
  hid_edit = torch.where(mask, hid_null, hid_edit)
291
- mask = (1.0 - ((prob >= DROP).to(lat_inp.dtype) * (prob < (DROP * 3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
292
  lat_inp *= mask
293
-
294
  # Progressive Feature Blending
295
  beta_1, beta_2 = 0.7, 0.3
296
- visual_features = lat_inp # Assuming lat_inp represents the visual features
297
  B_1 = beta_1 * hid_edit + (1 - beta_1) * visual_features
298
  B_2 = beta_2 * hid_edit + (1 - beta_2) * visual_features
299
-
300
  # Cross-Attention Masking
301
  attention_scores = torch.matmul(hid_edit, hid_edit.transpose(-1, -2))
302
  mask = torch.zeros_like(hid_edit)
303
- mask[:, 3:5] = 1.0 # Emphasize central elements (e.g., "hat", "blue")
304
  masked_attention_scores = attention_scores * mask
305
  hid_edit = torch.matmul(F.softmax(masked_attention_scores, dim=-1), hid_edit)
306
-
307
- # Use blended features in subsequent processing
308
  hid_edit = B_1 + B_2
309
-
310
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
311
-
312
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
313
- if int(os.environ.get('LOCAL_RANK', 0)) == 0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
314
  loss = loss_ce + loss_edit * 0.5
315
-
316
  if not return_dict:
317
  output = (logits,) + outputs[1:]
318
  return (loss,) + output if loss is not None else output
319
-
320
  return CausalLMOutputWithPast(
321
  loss=loss,
322
  logits=logits,
@@ -325,17 +278,13 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
325
  attentions=outputs.attentions,
326
  )
327
 
328
- def prepare_inputs_for_generation(
329
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
330
- ):
331
  if past_key_values:
332
  input_ids = input_ids[:, -1:]
333
-
334
  if inputs_embeds is not None and past_key_values is None:
335
  model_inputs = {"inputs_embeds": inputs_embeds}
336
  else:
337
  model_inputs = {"input_ids": input_ids}
338
-
339
  model_inputs.update(
340
  {
341
  "past_key_values": past_key_values,
@@ -346,35 +295,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
346
  )
347
  return model_inputs
348
 
349
- def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
350
- tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
351
  vision_config = self.get_vision_tower().config
352
  vision_config.use_im_start_end = mm_use_im_start_end
353
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
354
  self.resize_token_embeddings(len(tokenizer))
355
-
356
  if mm_use_im_start_end:
357
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
358
  self.resize_token_embeddings(len(tokenizer))
359
  vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
360
-
361
  if num_new_tokens > 0:
362
  input_embeddings = self.get_input_embeddings().weight.data
363
  output_embeddings = self.get_output_embeddings().weight.data
364
-
365
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
366
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
367
-
368
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
369
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
370
-
371
  if tune_mm_mlp_adapter:
372
  self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
373
  for p in self.get_input_embeddings().parameters():
374
  p.requires_grad = True
375
  for p in self.get_output_embeddings().parameters():
376
  p.requires_grad = False
377
-
378
  if pretrain_mm_mlp_adapter:
379
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
380
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
@@ -384,8 +326,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
384
  elif embed_tokens_weight.shape[0] == num_new_tokens:
385
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
386
  else:
387
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Number of new tokens: {num_new_tokens}.")
388
-
389
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
390
 
391
  AutoConfig.register("llava", LlavaConfig)
 
1
  from typing import List, Optional, Tuple, Union
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from torch.nn import CrossEntropyLoss
6
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM, CLIPVisionModel, CLIPImageProcessor
 
 
 
 
7
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
8
  import os, diffusers
9
 
10
  DEFAULT_IMAGE_TOKEN = "<image>"
 
20
 
21
  def __init__(self, config: LlamaConfig):
22
  super(LlavaLlamaModel, self).__init__(config)
 
23
  if hasattr(config, "mm_vision_tower"):
24
  self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
25
  if hasattr(config, "use_mm_proj"):
26
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
27
 
 
31
  vision_tower = vision_tower[0]
32
  return vision_tower
33
 
34
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, pretrain_mm_mlp_adapter=None, fsdp=None):
 
35
  self.config.mm_vision_tower = vision_tower
 
36
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
 
37
  if not hasattr(self, 'vision_tower'):
38
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
39
  else:
40
  vision_tower = self.vision_tower[0]
41
  vision_tower.requires_grad_(False)
 
42
  if fsdp is not None and len(fsdp) > 0:
43
  self.vision_tower = [vision_tower]
44
  else:
45
  self.vision_tower = vision_tower
 
46
  vision_config = vision_tower.config
47
  num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
 
48
  self.config.use_mm_proj = True
49
  self.config.mm_hidden_size = vision_config.hidden_size
50
  self.config.mm_vision_select_layer = mm_vision_select_layer
 
51
  if not hasattr(self, 'mm_projector'):
52
  self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
 
53
  if pretrain_mm_mlp_adapter is not None:
54
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
55
  self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
56
+ return dict(image_processor=image_processor, image_token_len=num_patches, vision_config=vision_config)
 
 
 
 
 
57
 
58
  def forward(
59
  self,
 
67
  images: Optional[torch.FloatTensor] = None,
68
  return_dict: Optional[bool] = None,
69
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
70
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
 
71
  if inputs_embeds is None:
72
  inputs_embeds = self.embed_tokens(input_ids)
 
73
  vision_tower = self.get_vision_tower()
74
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
75
  with torch.no_grad():
 
92
  image_features = self.mm_projector(image_features)
93
  dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
94
  dummy_image_features = self.mm_projector(dummy_image_features)
 
95
  new_input_embeds = []
96
  cur_image_idx = 0
97
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
 
133
  new_input_embeds.append(cur_new_input_embeds)
134
  cur_image_idx += 1
135
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
 
136
  return super(LlavaLlamaModel, self).forward(
137
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
138
  inputs_embeds=inputs_embeds, use_cache=use_cache,
 
143
  class EditMapper(nn.Module):
144
  def __init__(self):
145
  super().__init__()
 
146
  self.llm2hid = nn.Linear(4096, 512)
147
  self.query = nn.Parameter(torch.randn(1, 77, 512))
148
  self.mapper = nn.Transformer(batch_first=True, norm_first=True,
 
151
  self.hid2feat = nn.Linear(512, 768)
152
 
153
  def forward(self, llm, emb):
154
+ hid = self.llm2hid(llm+emb)
155
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
156
  feat = self.hid2feat(hid)
 
157
  return feat
158
 
159
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
 
162
  def __init__(self, config):
163
  super(LlamaForCausalLM, self).__init__(config)
164
  self.model = LlavaLlamaModel(config)
 
165
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
166
  self.edit_head = EditMapper()
 
167
  self.scheduler, self.vae, self.unet = [
168
  diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
169
  diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
 
176
  conv.weight.zero_()
177
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
178
  self.unet.conv_in = conv
 
179
  self.post_init()
180
 
181
  def get_model(self):
182
  return self.model
183
 
184
  def get_vision_tower(self):
185
+ model = self.get_model()
186
+ vision_tower = model.vision_tower
187
+ if type(vision_tower) is list:
188
+ vision_tower = vision_tower[0]
189
+ return vision_tower
190
 
191
  def forward(
192
  self,
 
203
  p2p_inp=None, p2p_ans=None
204
  ) -> Union[Tuple, CausalLMOutputWithPast]:
205
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
206
+ output_hidden_states = (
207
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
208
+ )
209
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
210
  outputs = self.model(
211
  input_ids=input_ids,
212
  attention_mask=attention_mask,
 
218
  return_dict=return_dict,
219
  images=images
220
  )
 
221
  hidden_states = outputs[0]
222
  logits = self.lm_head(hidden_states)
 
223
  loss = None
224
  if labels is not None:
225
  shift_logits = logits[..., :-1, :].contiguous()
 
229
  shift_labels = shift_labels.view(-1)
230
  shift_labels = shift_labels.to(shift_logits.device)
231
  loss = loss_fct(shift_logits, shift_labels)
 
232
  if labels is not None:
233
  llm = []
234
  for i in range(labels.shape[0]):
235
+ try: p = labels[i].data.cpu().tolist().index(32003)-1
236
+ except: p = len(labels[i])-9
237
+ p = min(len(hidden_states[i])-9, p)
238
+ llm.append(hidden_states[i][p:p+8].unsqueeze(0))
 
 
239
  llm = torch.cat(llm, dim=0)
240
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
 
241
  B, DROP = labels.shape[0], 0.05
242
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device), self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
 
 
 
243
  with torch.no_grad():
244
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
245
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device), torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
 
 
246
  noise = torch.randn_like(lat_ans)
247
  ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device).long()
248
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
 
249
  prob = torch.rand(B, device=lat_ans.device)
250
+ mask = (prob < (DROP*2)).reshape(B, 1, 1)
251
  hid_edit = torch.where(mask, hid_null, hid_edit)
252
+ mask = (1.0 - ((prob >= DROP).to(lat_inp.dtype) * (prob < (DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
253
  lat_inp *= mask
 
254
  # Progressive Feature Blending
255
  beta_1, beta_2 = 0.7, 0.3
256
+ visual_features = lat_inp
257
  B_1 = beta_1 * hid_edit + (1 - beta_1) * visual_features
258
  B_2 = beta_2 * hid_edit + (1 - beta_2) * visual_features
 
259
  # Cross-Attention Masking
260
  attention_scores = torch.matmul(hid_edit, hid_edit.transpose(-1, -2))
261
  mask = torch.zeros_like(hid_edit)
262
+ mask[:, 3:5] = 1.0
263
  masked_attention_scores = attention_scores * mask
264
  hid_edit = torch.matmul(F.softmax(masked_attention_scores, dim=-1), hid_edit)
 
 
265
  hid_edit = B_1 + B_2
 
266
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
 
267
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
268
+ if int(os.environ['LOCAL_RANK']) == 0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
269
  loss = loss_ce + loss_edit * 0.5
 
270
  if not return_dict:
271
  output = (logits,) + outputs[1:]
272
  return (loss,) + output if loss is not None else output
 
273
  return CausalLMOutputWithPast(
274
  loss=loss,
275
  logits=logits,
 
278
  attentions=outputs.attentions,
279
  )
280
 
281
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
 
 
282
  if past_key_values:
283
  input_ids = input_ids[:, -1:]
 
284
  if inputs_embeds is not None and past_key_values is None:
285
  model_inputs = {"inputs_embeds": inputs_embeds}
286
  else:
287
  model_inputs = {"input_ids": input_ids}
 
288
  model_inputs.update(
289
  {
290
  "past_key_values": past_key_values,
 
295
  )
296
  return model_inputs
297
 
298
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
 
299
  vision_config = self.get_vision_tower().config
300
  vision_config.use_im_start_end = mm_use_im_start_end
301
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
302
  self.resize_token_embeddings(len(tokenizer))
 
303
  if mm_use_im_start_end:
304
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
305
  self.resize_token_embeddings(len(tokenizer))
306
  vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
307
  if num_new_tokens > 0:
308
  input_embeddings = self.get_input_embeddings().weight.data
309
  output_embeddings = self.get_output_embeddings().weight.data
 
310
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
311
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
 
312
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
313
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
 
314
  if tune_mm_mlp_adapter:
315
  self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
316
  for p in self.get_input_embeddings().parameters():
317
  p.requires_grad = True
318
  for p in self.get_output_embeddings().parameters():
319
  p.requires_grad = False
 
320
  if pretrain_mm_mlp_adapter:
321
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
322
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
 
326
  elif embed_tokens_weight.shape[0] == num_new_tokens:
327
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
328
  else:
329
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
 
330
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
331
 
332
  AutoConfig.register("llava", LlavaConfig)