# Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) # Copyright 2024 Yanwei Li # ------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from typing import List, Optional, Tuple, Union from transformers.utils import logging from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM from model.arhead import AR_head from model.liquid import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM logger = logging.get_logger(__name__) class MiniGeminiConfig(LlamaConfig): model_type = "mini_gemini" class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel): config_class = MiniGeminiConfig def __init__(self, config: LlamaConfig): super(MiniGeminiLlamaModel, self).__init__(config) class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM): config_class = MiniGeminiConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = MiniGeminiLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.ar_head = AR_head(self.config, codebook_size=32768, num_codebooks=8) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, data_types: torch.LongTensor = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_aux: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict additional_image_indexs = None if inputs_embeds is None and past_key_values is None: # no in inference ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, data_types, additional_image_labels, additional_image_indexs ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux, data_types ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() text_loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) text_loss = loss_fct(shift_logits, shift_labels) num_text_tokens = (shift_labels != -100).sum().item() if additional_image_indexs is None: return CausalLMOutputWithPast( loss=text_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) to_image_mask = data_types == 1 # where to get t2i loss in each batch [True, False, False, True....] if len(additional_image_indexs) > 0 and len(to_image_mask) == len(hidden_states): # image generation loss to_image_states = hidden_states[to_image_mask] # assert len(to_image_states) == len(additional_image_indexs) if len(to_image_states) != len(additional_image_indexs): print('to_image_mask', to_image_mask) print('additional_image_indexs', additional_image_indexs) shift_image_states = torch.stack([state[start_id - 1:end_id - 1] for (start_id, end_id), state in zip(additional_image_indexs, to_image_states)]) # Shift so that tokens < n predict n [bz, seq_len, hidden_dim] base_tokens = shift_image_states K = self.ar_head.num_codebooks B, L, C = base_tokens.shape base_tokens = base_tokens.reshape(B * L, 1, C) targets = torch.cat(additional_image_labels, dim=0) # [B, K, L] image_code_labels = targets targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1] index_embeddings = [] for i in range(K - 1): index_embed = self.ar_head.codebooks[i](targets[:, i]) index_embeddings.append(index_embed) index_embeddings = torch.stack(index_embeddings, dim=1) # import pdb;pdb.set_trace() h = torch.cat((base_tokens, index_embeddings), dim=1) # [B*L, K, C] multicode_embedding = self.ar_head( input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=h, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=False, cache_position=None, ) image_logits = self.ar_head.linear_head(multicode_embedding) image_logits = image_logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) # [B, K, L, sub_vocab_size] loss_fct = CrossEntropyLoss() image_logits = image_logits.reshape(-1, self.ar_head.sub_vocab_size) image_labels = image_code_labels.view(-1) image_labels = image_labels.to(image_logits.device) image_softmax_normalizer = image_logits.max(-1).values ** 2 image_z_loss = 0.00005 * image_softmax_normalizer.mean() image_loss = loss_fct(image_logits, image_labels) + image_z_loss num_image_tokens = image_labels.shape[0] else: if len(hidden_states) != len(to_image_mask): print('to_image_mask', to_image_mask) print('hidden_states', hidden_states.shape) print('inputs_embeds', inputs_embeds.shape) print('additional_image_indexs', additional_image_indexs) fake_ids = torch.ones(1, self.model.multi_embedder.num_codebooks - 1).to(inputs_embeds).long() index_embeddings = [] for i in range(self.model.multi_embedder.num_codebooks - 1): index_embed = self.ar_head.codebooks[i](fake_ids[:, i]) index_embeddings.append(index_embed) index_embeddings = torch.stack(index_embeddings, dim=1) multicode_embedding = self.ar_head( input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=index_embeddings, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=False, cache_position=None, ) image_logits = self.ar_head.linear_head(multicode_embedding) num_image_tokens = 0 image_loss = (image_logits * 0).sum() # + (base_tokens*0).sum() pass loss = image_loss * (num_image_tokens / (num_image_tokens + num_text_tokens)) + \ text_loss * (num_text_tokens / (num_image_tokens + num_text_tokens)) # t2i_ratio = to_image_mask.sum() / len(to_image_mask) # loss = image_loss * t2i_ratio + text_loss * (1 - t2i_ratio) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def generate_mllm( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, images_aux: Optional[torch.FloatTensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") # import pdb;pdb.set_trace() if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_for_multimodal( inputs, position_ids, attention_mask, None, None, images, images_aux ) else: inputs_embeds = self.get_model().embed_tokens(inputs) # import pdb;pdb.set_trace() return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, images_aux: Optional[torch.FloatTensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_for_multimodal( inputs, position_ids, attention_mask, None, None, images, images_aux ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) def test_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, input_multi_ids: torch.LongTensor = None, data_types: torch.LongTensor = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_aux: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: # import pdb;pdb.set_trace() if input_multi_ids is not None: input_multi_ids = input_multi_ids.unsqueeze(-1) # [B,K,1] input_ids = None # [B,1] inputs_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C] outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs def T2I_forward_nocache( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, input_multi_ids: torch.LongTensor = None, data_types: torch.LongTensor = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_aux: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: # import pdb;pdb.set_trace() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_multi_ids is not None: inputs_text_embeds = self.get_model().embed_tokens(input_ids) input_ids = None # [B,1] inputs_image_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C] inputs_image_mask = torch.empty(inputs_image_embeds.shape[0], inputs_image_embeds.shape[1]).fill_(1).to( attention_mask) inputs_embeds = torch.cat([inputs_text_embeds, inputs_image_embeds], dim=1) attention_mask = torch.cat([attention_mask, inputs_image_mask], dim=1) position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).repeat( inputs_embeds.shape[0], 1) else: inputs_embeds = self.get_model().embed_tokens(input_ids) input_ids = None outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs def T2I_forward_withcache( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, input_multi_ids: torch.LongTensor = None, data_types: torch.LongTensor = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_aux: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: # import pdb;pdb.set_trace() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_multi_ids is not None: inputs_image_embeds = self.model.multi_embedder(input_multi_ids[:, :, -1:]) # [B,1,C] inputs_embeds = inputs_image_embeds input_ids = None # [B,1] else: inputs_embeds = self.get_model().embed_tokens(input_ids) input_ids = None outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) images_aux = kwargs.pop("images_aux", None) _inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: _inputs['images'] = images if images_aux is not None: _inputs['images_aux'] = images_aux return _inputs AutoConfig.register("mini_gemini", MiniGeminiConfig) AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM)