Upload 18 files
Browse files- modeling_vlm.py +83 -1
modeling_vlm.py
CHANGED
@@ -27,12 +27,14 @@ from transformers import (
|
|
27 |
PreTrainedModel,
|
28 |
GenerationMixin
|
29 |
)
|
|
|
30 |
from transformers.configuration_utils import PretrainedConfig
|
31 |
|
32 |
from .clip_encoder import CLIPVisionTower
|
33 |
from .siglip_vit import create_siglip_vit
|
34 |
from .projector import MlpProjector
|
35 |
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
|
|
|
36 |
|
37 |
|
38 |
class vision_head(torch.nn.Module):
|
@@ -61,7 +63,7 @@ def model_name_to_cls(cls_name):
|
|
61 |
cls = CLIPVisionTower
|
62 |
|
63 |
elif "VQ" in cls_name:
|
64 |
-
from
|
65 |
|
66 |
cls = VQ_models[cls_name]
|
67 |
elif "vision_head" in cls_name:
|
@@ -193,7 +195,87 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
193 |
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
|
194 |
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
|
|
197 |
|
198 |
|
199 |
AutoConfig.register("vision", VisionConfig)
|
|
|
27 |
PreTrainedModel,
|
28 |
GenerationMixin
|
29 |
)
|
30 |
+
import numpy as np
|
31 |
from transformers.configuration_utils import PretrainedConfig
|
32 |
|
33 |
from .clip_encoder import CLIPVisionTower
|
34 |
from .siglip_vit import create_siglip_vit
|
35 |
from .projector import MlpProjector
|
36 |
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
|
37 |
+
from .vq_model import VQ_models
|
38 |
|
39 |
|
40 |
class vision_head(torch.nn.Module):
|
|
|
63 |
cls = CLIPVisionTower
|
64 |
|
65 |
elif "VQ" in cls_name:
|
66 |
+
from .vq_model import VQ_models
|
67 |
|
68 |
cls = VQ_models[cls_name]
|
69 |
elif "vision_head" in cls_name:
|
|
|
195 |
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
|
196 |
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
|
197 |
|
198 |
+
@torch.no_grad()
|
199 |
+
def generate_image(
|
200 |
+
self,
|
201 |
+
processor,
|
202 |
+
prompt: str,
|
203 |
+
temperature: float = 1,
|
204 |
+
parallel_size: int = 16,
|
205 |
+
cfg_weight: float = 5,
|
206 |
+
image_token_num_per_image: int = 576,
|
207 |
+
img_size: int = 384,
|
208 |
+
patch_size: int = 16,
|
209 |
+
generator=None
|
210 |
+
):
|
211 |
+
from PIL import Image
|
212 |
+
|
213 |
+
conversation = [
|
214 |
+
{
|
215 |
+
"role": "User",
|
216 |
+
"content": prompt,
|
217 |
+
},
|
218 |
+
{"role": "Assistant", "content": ""},
|
219 |
+
]
|
220 |
+
|
221 |
+
sft_format = processor.apply_sft_template_for_multi_turn_prompts(
|
222 |
+
conversations=conversation,
|
223 |
+
sft_format=processor.sft_format,
|
224 |
+
system_prompt="",
|
225 |
+
)
|
226 |
+
prompt = sft_format + processor.image_start_tag
|
227 |
+
input_ids = processor.tokenizer.encode(prompt)
|
228 |
+
input_ids = torch.LongTensor(input_ids)
|
229 |
+
|
230 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int)
|
231 |
+
for i in range(parallel_size * 2):
|
232 |
+
tokens[i, :] = input_ids
|
233 |
+
if i % 2 != 0:
|
234 |
+
tokens[i, 1:-1] = processor.pad_id
|
235 |
+
|
236 |
+
inputs_embeds = self.language_model.get_input_embeddings()(tokens)
|
237 |
+
|
238 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int)
|
239 |
+
past_key_values = None
|
240 |
+
|
241 |
+
for i in range(image_token_num_per_image):
|
242 |
+
outputs = self.language_model.model.forward(
|
243 |
+
input_ids=None,
|
244 |
+
inputs_embeds=inputs_embeds,
|
245 |
+
use_cache=True,
|
246 |
+
past_key_values=past_key_values,
|
247 |
+
)
|
248 |
+
hidden_states = outputs.last_hidden_state
|
249 |
+
past_key_values = outputs.past_key_values
|
250 |
+
logits = self.gen_head(hidden_states[:, -1, :])
|
251 |
+
logit_cond = logits[0::2, :]
|
252 |
+
logit_uncond = logits[1::2, :]
|
253 |
+
|
254 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
255 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
256 |
+
|
257 |
+
next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator)
|
258 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
259 |
+
|
260 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
261 |
+
img_embeds = self.prepare_gen_img_embeds(next_token)
|
262 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
263 |
+
dec = self.gen_vision_model.decode_code(
|
264 |
+
generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
265 |
+
)
|
266 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
267 |
+
|
268 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
269 |
+
|
270 |
+
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
271 |
+
visual_img[:, :, :] = dec
|
272 |
+
|
273 |
+
images = []
|
274 |
+
|
275 |
+
for i in range(parallel_size):
|
276 |
+
images.append(Image.fromarray(visual_img[i]))
|
277 |
|
278 |
+
return images
|
279 |
|
280 |
|
281 |
AutoConfig.register("vision", VisionConfig)
|