Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023 Haotian Liu | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA) | |
# Copyright 2024 Yanwei Li | |
# ------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
from typing import List, Optional, Tuple, Union | |
from transformers.utils import logging | |
from transformers.generation.utils import GenerateOutput | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM | |
from model.arhead import AR_head | |
from model.liquid import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM | |
logger = logging.get_logger(__name__) | |
class MiniGeminiConfig(LlamaConfig): | |
model_type = "mini_gemini" | |
class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel): | |
config_class = MiniGeminiConfig | |
def __init__(self, config: LlamaConfig): | |
super(MiniGeminiLlamaModel, self).__init__(config) | |
class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM): | |
config_class = MiniGeminiConfig | |
def __init__(self, config): | |
super(LlamaForCausalLM, self).__init__(config) | |
self.model = MiniGeminiLlamaModel(config) | |
self.pretraining_tp = config.pretraining_tp | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.ar_head = AR_head(self.config, codebook_size=32768, num_codebooks=8) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_model(self): | |
return self.model | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
data_types: torch.LongTensor = None, | |
use_cache: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
additional_image_indexs = None | |
if inputs_embeds is None and past_key_values is None: # no in inference | |
( | |
input_ids, | |
position_ids, | |
attention_mask, | |
past_key_values, | |
inputs_embeds, | |
labels, | |
data_types, | |
additional_image_labels, | |
additional_image_indexs | |
) = self.prepare_inputs_labels_for_multimodal( | |
input_ids, | |
position_ids, | |
attention_mask, | |
past_key_values, | |
labels, | |
images, | |
images_aux, | |
data_types | |
) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
if self.pretraining_tp > 1: | |
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) | |
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] | |
logits = torch.cat(logits, dim=-1) | |
else: | |
logits = self.lm_head(hidden_states) | |
logits = logits.float() | |
text_loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
text_loss = loss_fct(shift_logits, shift_labels) | |
num_text_tokens = (shift_labels != -100).sum().item() | |
if additional_image_indexs is None: | |
return CausalLMOutputWithPast( | |
loss=text_loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
to_image_mask = data_types == 1 # where to get t2i loss in each batch [True, False, False, True....] | |
if len(additional_image_indexs) > 0 and len(to_image_mask) == len(hidden_states): # image generation loss | |
to_image_states = hidden_states[to_image_mask] | |
# assert len(to_image_states) == len(additional_image_indexs) | |
if len(to_image_states) != len(additional_image_indexs): | |
print('to_image_mask', to_image_mask) | |
print('additional_image_indexs', additional_image_indexs) | |
shift_image_states = torch.stack([state[start_id - 1:end_id - 1] for (start_id, end_id), state in | |
zip(additional_image_indexs, to_image_states)]) # Shift so that tokens < n predict n [bz, seq_len, hidden_dim] | |
base_tokens = shift_image_states | |
K = self.ar_head.num_codebooks | |
B, L, C = base_tokens.shape | |
base_tokens = base_tokens.reshape(B * L, 1, C) | |
targets = torch.cat(additional_image_labels, dim=0) # [B, K, L] | |
image_code_labels = targets | |
targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1] | |
index_embeddings = [] | |
for i in range(K - 1): | |
index_embed = self.ar_head.codebooks[i](targets[:, i]) | |
index_embeddings.append(index_embed) | |
index_embeddings = torch.stack(index_embeddings, dim=1) | |
# import pdb;pdb.set_trace() | |
h = torch.cat((base_tokens, index_embeddings), dim=1) # [B*L, K, C] | |
multicode_embedding = self.ar_head( | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=h, | |
use_cache=False, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=False, | |
cache_position=None, | |
) | |
image_logits = self.ar_head.linear_head(multicode_embedding) | |
image_logits = image_logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) # [B, K, L, sub_vocab_size] | |
loss_fct = CrossEntropyLoss() | |
image_logits = image_logits.reshape(-1, self.ar_head.sub_vocab_size) | |
image_labels = image_code_labels.view(-1) | |
image_labels = image_labels.to(image_logits.device) | |
image_softmax_normalizer = image_logits.max(-1).values ** 2 | |
image_z_loss = 0.00005 * image_softmax_normalizer.mean() | |
image_loss = loss_fct(image_logits, image_labels) + image_z_loss | |
num_image_tokens = image_labels.shape[0] | |
else: | |
if len(hidden_states) != len(to_image_mask): | |
print('to_image_mask', to_image_mask) | |
print('hidden_states', hidden_states.shape) | |
print('inputs_embeds', inputs_embeds.shape) | |
print('additional_image_indexs', additional_image_indexs) | |
fake_ids = torch.ones(1, self.model.multi_embedder.num_codebooks - 1).to(inputs_embeds).long() | |
index_embeddings = [] | |
for i in range(self.model.multi_embedder.num_codebooks - 1): | |
index_embed = self.ar_head.codebooks[i](fake_ids[:, i]) | |
index_embeddings.append(index_embed) | |
index_embeddings = torch.stack(index_embeddings, dim=1) | |
multicode_embedding = self.ar_head( | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=index_embeddings, | |
use_cache=False, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=False, | |
cache_position=None, | |
) | |
image_logits = self.ar_head.linear_head(multicode_embedding) | |
num_image_tokens = 0 | |
image_loss = (image_logits * 0).sum() # + (base_tokens*0).sum() | |
pass | |
loss = image_loss * (num_image_tokens / (num_image_tokens + num_text_tokens)) + \ | |
text_loss * (num_text_tokens / (num_image_tokens + num_text_tokens)) | |
# t2i_ratio = to_image_mask.sum() / len(to_image_mask) | |
# loss = image_loss * t2i_ratio + text_loss * (1 - t2i_ratio) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def generate_mllm( | |
self, | |
inputs: Optional[torch.Tensor] = None, | |
images: Optional[torch.Tensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
**kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
position_ids = kwargs.pop("position_ids", None) | |
attention_mask = kwargs.pop("attention_mask", None) | |
if "inputs_embeds" in kwargs: | |
raise NotImplementedError("`inputs_embeds` is not supported") | |
# import pdb;pdb.set_trace() | |
if images is not None: | |
( | |
inputs, | |
position_ids, | |
attention_mask, | |
_, | |
inputs_embeds, | |
_ | |
) = self.prepare_inputs_for_multimodal( | |
inputs, | |
position_ids, | |
attention_mask, | |
None, | |
None, | |
images, | |
images_aux | |
) | |
else: | |
inputs_embeds = self.get_model().embed_tokens(inputs) | |
# import pdb;pdb.set_trace() | |
return super().generate( | |
position_ids=position_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
**kwargs | |
) | |
def generate( | |
self, | |
inputs: Optional[torch.Tensor] = None, | |
images: Optional[torch.Tensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
**kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
position_ids = kwargs.pop("position_ids", None) | |
attention_mask = kwargs.pop("attention_mask", None) | |
if "inputs_embeds" in kwargs: | |
raise NotImplementedError("`inputs_embeds` is not supported") | |
if images is not None: | |
( | |
inputs, | |
position_ids, | |
attention_mask, | |
_, | |
inputs_embeds, | |
_ | |
) = self.prepare_inputs_for_multimodal( | |
inputs, | |
position_ids, | |
attention_mask, | |
None, | |
None, | |
images, | |
images_aux | |
) | |
else: | |
inputs_embeds = self.get_model().embed_tokens(inputs) | |
return super().generate( | |
position_ids=position_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
**kwargs | |
) | |
def test_forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
input_multi_ids: torch.LongTensor = None, | |
data_types: torch.LongTensor = None, | |
use_cache: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
# import pdb;pdb.set_trace() | |
if input_multi_ids is not None: | |
input_multi_ids = input_multi_ids.unsqueeze(-1) # [B,K,1] | |
input_ids = None # [B,1] | |
inputs_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C] | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
return outputs | |
def T2I_forward_nocache( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
input_multi_ids: torch.LongTensor = None, | |
data_types: torch.LongTensor = None, | |
use_cache: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
# import pdb;pdb.set_trace() | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_multi_ids is not None: | |
inputs_text_embeds = self.get_model().embed_tokens(input_ids) | |
input_ids = None # [B,1] | |
inputs_image_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C] | |
inputs_image_mask = torch.empty(inputs_image_embeds.shape[0], inputs_image_embeds.shape[1]).fill_(1).to( | |
attention_mask) | |
inputs_embeds = torch.cat([inputs_text_embeds, inputs_image_embeds], dim=1) | |
attention_mask = torch.cat([attention_mask, inputs_image_mask], dim=1) | |
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).repeat( | |
inputs_embeds.shape[0], 1) | |
else: | |
inputs_embeds = self.get_model().embed_tokens(input_ids) | |
input_ids = None | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
return outputs | |
def T2I_forward_withcache( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
input_multi_ids: torch.LongTensor = None, | |
data_types: torch.LongTensor = None, | |
use_cache: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
images_aux: Optional[torch.FloatTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
# import pdb;pdb.set_trace() | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_multi_ids is not None: | |
inputs_image_embeds = self.model.multi_embedder(input_multi_ids[:, :, -1:]) # [B,1,C] | |
inputs_embeds = inputs_image_embeds | |
input_ids = None # [B,1] | |
else: | |
inputs_embeds = self.get_model().embed_tokens(input_ids) | |
input_ids = None | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
return outputs | |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): | |
images = kwargs.pop("images", None) | |
images_aux = kwargs.pop("images_aux", None) | |
_inputs = super().prepare_inputs_for_generation( | |
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs | |
) | |
if images is not None: | |
_inputs['images'] = images | |
if images_aux is not None: | |
_inputs['images_aux'] = images_aux | |
return _inputs | |
AutoConfig.register("mini_gemini", MiniGeminiConfig) | |
AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM) |