multimodalart's picture
Upload 69 files
d0cbcd5 verified
raw
history blame
16.8 kB
from typing import List, Optional, Tuple, Union
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
import pdb
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import numpy_to_pil
import numpy as np
from diffusers.models import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
class blip3oConfig(LlamaConfig):
model_type = "blip3o_llama"
class blip3oLlamaModel(blip3oMetaModel, LlamaModel):
config_class = blip3oConfig
def __init__(self, config: LlamaConfig):
super(blip3oLlamaModel, self).__init__(config)
class blip3oLlamaForCausalLM(LlamaForCausalLM, blip3oMetaForCausalLM):
config_class = blip3oConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = blip3oLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.dist = None
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
ids: Optional[list] = None,
i_s_pos: Optional[list] = None,
image_type: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
gen_image: Optional[torch.FloatTensor] = None,
und_image: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
latents
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
gen_image,
und_image,
i_s_pos,
image_sizes
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
total_loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# compute image loss
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
img_loss_funct = torch.nn.MSELoss()
# img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
img_hidden_states = []
for b in range(hidden_states.shape[0]):
img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
img_hidden_states = torch.stack(img_hidden_states,dim=0)
img_hidden_states = self.get_model().down_projector(img_hidden_states)
# img_loss = 0.0
if latents is None:
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
else:
bsz = latents.shape[0]
# device = latents.device
dtype = latents.dtype
noise = torch.randn_like(latents, device=latents.device)
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
noise_pred = self.get_model().dit(
x=noisy_latents,
timestep=timesteps,
z_latents=self.mask_drop(img_hidden_states),
)
target = noise - latents
img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
print(f"img loss {img_loss}, text loss {loss}")
total_loss = img_loss
return CausalLMOutputWithPast(
loss=total_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
img_indicator,
_
) = self.prepare_inputs_labels_for_understanding(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
@torch.no_grad()
def generate_image(
self,
text: List[str],
tokenizer: AutoTokenizer,
image: Optional[torch.Tensor] = None,
max_var: Optional[float] = None,
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
vision_tower = self.get_vision_tower()
mm_projector = self.get_mm_projector()
N_QUERY = self.get_n_query()
if image is not None:
# image: [Batch, 3, 448, 448]
prompt_image_embeds = vision_tower(batch_images)
num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
all_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
prompt_image_embeds = mm_projector(prompt_image_embeds)
inputs = tokenizer(text, padding="longest", return_tensors="pt")
device = self.get_model().device
attention_mask = inputs.attention_mask.to(device)
input_ids = inputs.input_ids.to(device) # B x N
input_ids = torch.cat([input_ids, torch.tensor([[198]]).to(device)], dim=1)
# breakpoint()
text_embeds = self.get_model().embed_tokens(input_ids)
latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
outputs = self.model(
inputs_embeds=text_embeds,
# img_indicator=img_indicator,
# concept_indicator=concept_indicator if self.use_concept_token else None,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
img_hidden_states = self.get_model().down_projector(hidden_states)
output_img = self.sample_images(img_hidden_states, scheduler)
output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
return output_img
def sample_images(
self,
img_hidden_states,
scheduler,
guidance_scale: float = 3.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
return_tensor=False,
**kwargs,
):
device = img_hidden_states.device
dtype = img_hidden_states.dtype
img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
batch_size = img_hidden_states.shape[0]
latent_size = self.get_model().dit.config.input_size
latent_channels = self.get_model().dit.config.in_channels
latents = randn_tensor(
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
generator=generator,
device=device,
dtype=dtype,
)
# set step values
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
# Repeat z_latents and conditions for each image per prompt
img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
for t in scheduler.timesteps:
latent_model_input = latents.repeat(2, 1, 1, 1)
if hasattr(scheduler, "scale_model_input"):
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict noise model_output
noise_pred = self.get_model().dit(
x=latent_model_input,
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
z_latents=img_hidden_states_input,
)
# perform guidance
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# compute previous image: x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# samples = self.decode_latents(latents, return_tensor=return_tensor)
return latents
def decode_latents(self, latents, normalize=True, return_tensor=False):
if isinstance(self.get_model().vae, AutoencoderKL):
latents = latents / self.get_model().vae.config.scaling_factor
if self.get_model().vae.config.shift_factor is not None:
latents = latents + self.get_model().vae.config.shift_factor
latents = latents.to(dtype=torch.float32)
samples = self.get_model().vae.decode(latents).sample
else:
samples = self.get_model().vae.decode(latents)
if normalize:
samples = (samples / 2 + 0.5).clamp(0, 1)
else:
samples = samples.clamp(-1, 1)
if return_tensor:
return samples
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
samples = numpy_to_pil(samples)
return samples
def prepare_and_encode_inputs(
self,
inputs: List[str | Image.Image],
tokenizer: AutoTokenizer,
do_classifier_free_guidance: bool = False,
):
# pdb.set_trace()
device = self.get_model().device
dtype = self.get_model().dtype
has_image, has_text = False, False
text_prompt, image_prompt = "", []
img_processor = self.get_vision_tower().image_processor
negative_prompt = {}
for x in inputs:
if isinstance(x, str):
has_text = True
text_prompt += x
else:
has_image = True
text_prompt += DEFAULT_IMAGE_TOKEN
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
# pdb.set_trace()
if len(image_prompt) == 0:
image_prompt = None
else:
image_prompt = torch.cat(image_prompt)
image_prompt = image_prompt.type(dtype).to(device)
if has_image and not has_text:
prompt = self.encode_images(image_prompt)
# pdb.set_trace()
if do_classifier_free_guidance:
key = "[NULL_IMAGE]"
if key not in negative_prompt:
negative_image = torch.zeros_like(image_prompt)
negative_prompt[key] = self.encode_images(negative_image)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
else:
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
if do_classifier_free_guidance:
key = ""
if key not in negative_prompt:
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query()
num_img, _, c = prompt.shape
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
stride = int(gen_pooling.split('_')[1])
sqrt_n = int(n_query**0.5)
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
return prompt
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("blip3o_llama", blip3oConfig)
AutoModelForCausalLM.register(blip3oConfig, blip3oLlamaForCausalLM)