katuni4ka commited on
Commit
5c955cb
·
verified ·
1 Parent(s): f0b37fb

Upload 18 files

Browse files
Files changed (1) hide show
  1. modeling_vlm.py +83 -1
modeling_vlm.py CHANGED
@@ -27,12 +27,14 @@ from transformers import (
27
  PreTrainedModel,
28
  GenerationMixin
29
  )
 
30
  from transformers.configuration_utils import PretrainedConfig
31
 
32
  from .clip_encoder import CLIPVisionTower
33
  from .siglip_vit import create_siglip_vit
34
  from .projector import MlpProjector
35
  from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
 
36
 
37
 
38
  class vision_head(torch.nn.Module):
@@ -61,7 +63,7 @@ def model_name_to_cls(cls_name):
61
  cls = CLIPVisionTower
62
 
63
  elif "VQ" in cls_name:
64
- from janus.models.vq_model import VQ_models
65
 
66
  cls = VQ_models[cls_name]
67
  elif "vision_head" in cls_name:
@@ -193,7 +195,87 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
193
  inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
194
  return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
 
197
 
198
 
199
  AutoConfig.register("vision", VisionConfig)
 
27
  PreTrainedModel,
28
  GenerationMixin
29
  )
30
+ import numpy as np
31
  from transformers.configuration_utils import PretrainedConfig
32
 
33
  from .clip_encoder import CLIPVisionTower
34
  from .siglip_vit import create_siglip_vit
35
  from .projector import MlpProjector
36
  from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
37
+ from .vq_model import VQ_models
38
 
39
 
40
  class vision_head(torch.nn.Module):
 
63
  cls = CLIPVisionTower
64
 
65
  elif "VQ" in cls_name:
66
+ from .vq_model import VQ_models
67
 
68
  cls = VQ_models[cls_name]
69
  elif "vision_head" in cls_name:
 
195
  inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
196
  return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
197
 
198
+ @torch.no_grad()
199
+ def generate_image(
200
+ self,
201
+ processor,
202
+ prompt: str,
203
+ temperature: float = 1,
204
+ parallel_size: int = 16,
205
+ cfg_weight: float = 5,
206
+ image_token_num_per_image: int = 576,
207
+ img_size: int = 384,
208
+ patch_size: int = 16,
209
+ generator=None
210
+ ):
211
+ from PIL import Image
212
+
213
+ conversation = [
214
+ {
215
+ "role": "User",
216
+ "content": prompt,
217
+ },
218
+ {"role": "Assistant", "content": ""},
219
+ ]
220
+
221
+ sft_format = processor.apply_sft_template_for_multi_turn_prompts(
222
+ conversations=conversation,
223
+ sft_format=processor.sft_format,
224
+ system_prompt="",
225
+ )
226
+ prompt = sft_format + processor.image_start_tag
227
+ input_ids = processor.tokenizer.encode(prompt)
228
+ input_ids = torch.LongTensor(input_ids)
229
+
230
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int)
231
+ for i in range(parallel_size * 2):
232
+ tokens[i, :] = input_ids
233
+ if i % 2 != 0:
234
+ tokens[i, 1:-1] = processor.pad_id
235
+
236
+ inputs_embeds = self.language_model.get_input_embeddings()(tokens)
237
+
238
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int)
239
+ past_key_values = None
240
+
241
+ for i in range(image_token_num_per_image):
242
+ outputs = self.language_model.model.forward(
243
+ input_ids=None,
244
+ inputs_embeds=inputs_embeds,
245
+ use_cache=True,
246
+ past_key_values=past_key_values,
247
+ )
248
+ hidden_states = outputs.last_hidden_state
249
+ past_key_values = outputs.past_key_values
250
+ logits = self.gen_head(hidden_states[:, -1, :])
251
+ logit_cond = logits[0::2, :]
252
+ logit_uncond = logits[1::2, :]
253
+
254
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
255
+ probs = torch.softmax(logits / temperature, dim=-1)
256
+
257
+ next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator)
258
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
259
+
260
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
261
+ img_embeds = self.prepare_gen_img_embeds(next_token)
262
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
263
+ dec = self.gen_vision_model.decode_code(
264
+ generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size]
265
+ )
266
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
267
+
268
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
269
+
270
+ visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
271
+ visual_img[:, :, :] = dec
272
+
273
+ images = []
274
+
275
+ for i in range(parallel_size):
276
+ images.append(Image.fromarray(visual_img[i]))
277
 
278
+ return images
279
 
280
 
281
  AutoConfig.register("vision", VisionConfig)