UniTok / model /language_model /mini_gemini_llama.py
machuofan
init
7385f22
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
# Copyright 2024 Yanwei Li
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from typing import List, Optional, Tuple, Union
from transformers.utils import logging
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
from model.arhead import AR_head
from model.liquid import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
logger = logging.get_logger(__name__)
class MiniGeminiConfig(LlamaConfig):
model_type = "mini_gemini"
class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel):
config_class = MiniGeminiConfig
def __init__(self, config: LlamaConfig):
super(MiniGeminiLlamaModel, self).__init__(config)
class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM):
config_class = MiniGeminiConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = MiniGeminiLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.ar_head = AR_head(self.config, codebook_size=32768, num_codebooks=8)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
data_types: torch.LongTensor = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
additional_image_indexs = None
if inputs_embeds is None and past_key_values is None: # no in inference
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
data_types,
additional_image_labels,
additional_image_indexs
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
images_aux,
data_types
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
text_loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
text_loss = loss_fct(shift_logits, shift_labels)
num_text_tokens = (shift_labels != -100).sum().item()
if additional_image_indexs is None:
return CausalLMOutputWithPast(
loss=text_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
to_image_mask = data_types == 1 # where to get t2i loss in each batch [True, False, False, True....]
if len(additional_image_indexs) > 0 and len(to_image_mask) == len(hidden_states): # image generation loss
to_image_states = hidden_states[to_image_mask]
# assert len(to_image_states) == len(additional_image_indexs)
if len(to_image_states) != len(additional_image_indexs):
print('to_image_mask', to_image_mask)
print('additional_image_indexs', additional_image_indexs)
shift_image_states = torch.stack([state[start_id - 1:end_id - 1] for (start_id, end_id), state in
zip(additional_image_indexs, to_image_states)]) # Shift so that tokens < n predict n [bz, seq_len, hidden_dim]
base_tokens = shift_image_states
K = self.ar_head.num_codebooks
B, L, C = base_tokens.shape
base_tokens = base_tokens.reshape(B * L, 1, C)
targets = torch.cat(additional_image_labels, dim=0) # [B, K, L]
image_code_labels = targets
targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1]
index_embeddings = []
for i in range(K - 1):
index_embed = self.ar_head.codebooks[i](targets[:, i])
index_embeddings.append(index_embed)
index_embeddings = torch.stack(index_embeddings, dim=1)
# import pdb;pdb.set_trace()
h = torch.cat((base_tokens, index_embeddings), dim=1) # [B*L, K, C]
multicode_embedding = self.ar_head(
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=h,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
cache_position=None,
)
image_logits = self.ar_head.linear_head(multicode_embedding)
image_logits = image_logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) # [B, K, L, sub_vocab_size]
loss_fct = CrossEntropyLoss()
image_logits = image_logits.reshape(-1, self.ar_head.sub_vocab_size)
image_labels = image_code_labels.view(-1)
image_labels = image_labels.to(image_logits.device)
image_softmax_normalizer = image_logits.max(-1).values ** 2
image_z_loss = 0.00005 * image_softmax_normalizer.mean()
image_loss = loss_fct(image_logits, image_labels) + image_z_loss
num_image_tokens = image_labels.shape[0]
else:
if len(hidden_states) != len(to_image_mask):
print('to_image_mask', to_image_mask)
print('hidden_states', hidden_states.shape)
print('inputs_embeds', inputs_embeds.shape)
print('additional_image_indexs', additional_image_indexs)
fake_ids = torch.ones(1, self.model.multi_embedder.num_codebooks - 1).to(inputs_embeds).long()
index_embeddings = []
for i in range(self.model.multi_embedder.num_codebooks - 1):
index_embed = self.ar_head.codebooks[i](fake_ids[:, i])
index_embeddings.append(index_embed)
index_embeddings = torch.stack(index_embeddings, dim=1)
multicode_embedding = self.ar_head(
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=index_embeddings,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
cache_position=None,
)
image_logits = self.ar_head.linear_head(multicode_embedding)
num_image_tokens = 0
image_loss = (image_logits * 0).sum() # + (base_tokens*0).sum()
pass
loss = image_loss * (num_image_tokens / (num_image_tokens + num_text_tokens)) + \
text_loss * (num_text_tokens / (num_image_tokens + num_text_tokens))
# t2i_ratio = to_image_mask.sum() / len(to_image_mask)
# loss = image_loss * t2i_ratio + text_loss * (1 - t2i_ratio)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate_mllm(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
# import pdb;pdb.set_trace()
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
images_aux
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
# import pdb;pdb.set_trace()
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
images_aux
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def test_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
input_multi_ids: torch.LongTensor = None,
data_types: torch.LongTensor = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# import pdb;pdb.set_trace()
if input_multi_ids is not None:
input_multi_ids = input_multi_ids.unsqueeze(-1) # [B,K,1]
input_ids = None # [B,1]
inputs_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
def T2I_forward_nocache(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
input_multi_ids: torch.LongTensor = None,
data_types: torch.LongTensor = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# import pdb;pdb.set_trace()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_multi_ids is not None:
inputs_text_embeds = self.get_model().embed_tokens(input_ids)
input_ids = None # [B,1]
inputs_image_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
inputs_image_mask = torch.empty(inputs_image_embeds.shape[0], inputs_image_embeds.shape[1]).fill_(1).to(
attention_mask)
inputs_embeds = torch.cat([inputs_text_embeds, inputs_image_embeds], dim=1)
attention_mask = torch.cat([attention_mask, inputs_image_mask], dim=1)
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).repeat(
inputs_embeds.shape[0], 1)
else:
inputs_embeds = self.get_model().embed_tokens(input_ids)
input_ids = None
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
def T2I_forward_withcache(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
input_multi_ids: torch.LongTensor = None,
data_types: torch.LongTensor = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# import pdb;pdb.set_trace()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_multi_ids is not None:
inputs_image_embeds = self.model.multi_embedder(input_multi_ids[:, :, -1:]) # [B,1,C]
inputs_embeds = inputs_image_embeds
input_ids = None # [B,1]
else:
inputs_embeds = self.get_model().embed_tokens(input_ids)
input_ids = None
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
images_aux = kwargs.pop("images_aux", None)
_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
_inputs['images'] = images
if images_aux is not None:
_inputs['images_aux'] = images_aux
return _inputs
AutoConfig.register("mini_gemini", MiniGeminiConfig)
AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM)