# coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers import PretrainedConfig from transformers.models.llama.modeling_llama import ( LlamaFlashAttention2, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, apply_rotary_pos_emb, repeat_kv, ) from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import logging logger = logging.get_logger(__name__) class GemmaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] hidden_size (`int`, *optional*, defaults to 3072): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 24576): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 28): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 16): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The legacy activation function. It is overwritten by the `hidden_activation`. hidden_activation (`str` or `function`, *optional*): The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. eos_token_id (`int`, *optional*, defaults to 1): End of stream token id. bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import GemmaModel, GemmaConfig >>> # Initializing a Gemma gemma-7b style configuration >>> configuration = GemmaConfig() >>> # Initializing a model from the gemma-7b style configuration >>> model = GemmaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "gemma" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act="gelu_pytorch_tanh", hidden_activation=None, max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) class GemmaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) if config.hidden_activation is None: logger.warning_once( "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" "`config.hidden_activation` if you want to override this behaviour.\n" "See https://github.com/huggingface/transformers/pull/29402 for more details." ) config.hidden_activation = "gelu_pytorch_tanh" hidden_activation = config.hidden_activation self.act_fn = ACT2FN[hidden_activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class GemmaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.scaling = 1 / math.sqrt(config.head_dim) if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self.rotary_emb = GemmaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? class GemmaFlashAttention2(LlamaFlashAttention2): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) output_attentions = False bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (GemmaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class GemmaModel(LlamaModel): def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False # noqa: F841 if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # embed positions hidden_states = inputs_embeds # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer return super().forward( causal_mask, position_ids, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, input_ids=None, inputs_embeds=hidden_states, ) # Example where we ony modify the docstring and call super class GemmaForCausalLM(LlamaForCausalLM): def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class GemmaForSequenceClassification(LlamaForSequenceClassification): pass class GemmaForTokenClassification(LlamaForTokenClassification): pass