Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/models
/gemma
/diff_gemma.py
# coding=utf-8 | |
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. | |
# | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from transformers import PretrainedConfig | |
from transformers.models.llama.modeling_llama import ( | |
LlamaFlashAttention2, | |
LlamaForCausalLM, | |
LlamaForSequenceClassification, | |
LlamaForTokenClassification, | |
LlamaModel, | |
apply_rotary_pos_emb, | |
repeat_kv, | |
) | |
from ...activations import ACT2FN | |
from ...cache_utils import Cache, DynamicCache, StaticCache | |
from ...modeling_flash_attention_utils import _flash_attention_forward | |
from ...modeling_outputs import CausalLMOutputWithPast | |
from ...pytorch_utils import ALL_LAYERNORM_LAYERS | |
from ...utils import logging | |
logger = logging.get_logger(__name__) | |
class GemmaConfig(PretrainedConfig): | |
r""" | |
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma | |
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |
defaults will yield a similar configuration to that of the Gemma-7B. | |
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) | |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |
documentation from [`PretrainedConfig`] for more information. | |
Args: | |
vocab_size (`int`, *optional*, defaults to 256000): | |
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the | |
`inputs_ids` passed when calling [`GemmaModel`] | |
hidden_size (`int`, *optional*, defaults to 3072): | |
Dimension of the hidden representations. | |
intermediate_size (`int`, *optional*, defaults to 24576): | |
Dimension of the MLP representations. | |
num_hidden_layers (`int`, *optional*, defaults to 28): | |
Number of hidden layers in the Transformer decoder. | |
num_attention_heads (`int`, *optional*, defaults to 16): | |
Number of attention heads for each attention layer in the Transformer decoder. | |
num_key_value_heads (`int`, *optional*, defaults to 16): | |
This is the number of key_value heads that should be used to implement Grouped Query Attention. If | |
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if | |
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When | |
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed | |
by meanpooling all the original heads within that group. For more details checkout [this | |
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to | |
`num_attention_heads`. | |
head_dim (`int`, *optional*, defaults to 256): | |
The attention head dimension. | |
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): | |
The legacy activation function. It is overwritten by the `hidden_activation`. | |
hidden_activation (`str` or `function`, *optional*): | |
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` | |
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. | |
max_position_embeddings (`int`, *optional*, defaults to 8192): | |
The maximum sequence length that this model might ever be used with. | |
initializer_range (`float`, *optional*, defaults to 0.02): | |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |
rms_norm_eps (`float`, *optional*, defaults to 1e-06): | |
The epsilon used by the rms normalization layers. | |
use_cache (`bool`, *optional*, defaults to `True`): | |
Whether or not the model should return the last key/values attentions (not used by all models). Only | |
relevant if `config.is_decoder=True`. | |
pad_token_id (`int`, *optional*, defaults to 0): | |
Padding token id. | |
eos_token_id (`int`, *optional*, defaults to 1): | |
End of stream token id. | |
bos_token_id (`int`, *optional*, defaults to 2): | |
Beginning of stream token id. | |
tie_word_embeddings (`bool`, *optional*, defaults to `True`): | |
Whether to tie weight embeddings | |
rope_theta (`float`, *optional*, defaults to 10000.0): | |
The base period of the RoPE embeddings. | |
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): | |
Whether to use a bias in the query, key, value and output projection layers during self-attention. | |
attention_dropout (`float`, *optional*, defaults to 0.0): | |
The dropout ratio for the attention probabilities. | |
```python | |
>>> from transformers import GemmaModel, GemmaConfig | |
>>> # Initializing a Gemma gemma-7b style configuration | |
>>> configuration = GemmaConfig() | |
>>> # Initializing a model from the gemma-7b style configuration | |
>>> model = GemmaModel(configuration) | |
>>> # Accessing the model configuration | |
>>> configuration = model.config | |
```""" | |
model_type = "gemma" | |
keys_to_ignore_at_inference = ["past_key_values"] | |
def __init__( | |
self, | |
vocab_size=256000, | |
hidden_size=3072, | |
intermediate_size=24576, | |
num_hidden_layers=28, | |
num_attention_heads=16, | |
num_key_value_heads=16, | |
head_dim=256, | |
hidden_act="gelu_pytorch_tanh", | |
hidden_activation=None, | |
max_position_embeddings=8192, | |
initializer_range=0.02, | |
rms_norm_eps=1e-6, | |
use_cache=True, | |
pad_token_id=0, | |
eos_token_id=1, | |
bos_token_id=2, | |
tie_word_embeddings=True, | |
rope_theta=10000.0, | |
attention_bias=False, | |
attention_dropout=0.0, | |
**kwargs, | |
): | |
self.vocab_size = vocab_size | |
self.max_position_embeddings = max_position_embeddings | |
self.hidden_size = hidden_size | |
self.intermediate_size = intermediate_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.head_dim = head_dim | |
self.num_key_value_heads = num_key_value_heads | |
self.hidden_act = hidden_act | |
self.hidden_activation = hidden_activation | |
self.initializer_range = initializer_range | |
self.rms_norm_eps = rms_norm_eps | |
self.use_cache = use_cache | |
self.rope_theta = rope_theta | |
self.attention_bias = attention_bias | |
self.attention_dropout = attention_dropout | |
super().__init__( | |
pad_token_id=pad_token_id, | |
bos_token_id=bos_token_id, | |
eos_token_id=eos_token_id, | |
tie_word_embeddings=tie_word_embeddings, | |
**kwargs, | |
) | |
class GemmaRMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.zeros(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()) | |
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) | |
# See https://github.com/huggingface/transformers/pull/29402 | |
output = output * (1.0 + self.weight.float()) | |
return output.type_as(x) | |
def extra_repr(self): | |
return f"{tuple(self.weight.shape)}, eps={self.eps}" | |
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) | |
class GemmaRotaryEmbedding(nn.Module): | |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
super().__init__() | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) | |
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) | |
def forward(self, x, position_ids, seq_len=None): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
self.inv_freq.to(x.device) | |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
position_ids_expanded = position_ids[:, None, :].float() | |
# Force float32 since bfloat16 loses precision on long contexts | |
# See https://github.com/huggingface/transformers/pull/29285 | |
device_type = x.device.type | |
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
with torch.autocast(device_type=device_type, enabled=False): | |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
cos = emb.cos() | |
sin = emb.sin() | |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
class GemmaMLP(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.hidden_size = config.hidden_size | |
self.intermediate_size = config.intermediate_size | |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
if config.hidden_activation is None: | |
logger.warning_once( | |
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" | |
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" | |
"`config.hidden_activation` if you want to override this behaviour.\n" | |
"See https://github.com/huggingface/transformers/pull/29402 for more details." | |
) | |
config.hidden_activation = "gelu_pytorch_tanh" | |
hidden_activation = config.hidden_activation | |
self.act_fn = ACT2FN[hidden_activation] | |
def forward(self, x): | |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
class GemmaAttention(nn.Module): | |
"""Multi-headed attention from 'Attention Is All You Need' paper""" | |
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): | |
super().__init__() | |
self.config = config | |
self.layer_idx = layer_idx | |
if layer_idx is None: | |
logger.warning_once( | |
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | |
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | |
"when creating this class." | |
) | |
self.attention_dropout = config.attention_dropout | |
self.hidden_size = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.head_dim = config.head_dim | |
self.num_key_value_heads = config.num_key_value_heads | |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
self.max_position_embeddings = config.max_position_embeddings | |
self.rope_theta = config.rope_theta | |
self.is_causal = True | |
self.scaling = 1 / math.sqrt(config.head_dim) | |
if self.hidden_size % self.num_heads != 0: | |
raise ValueError( | |
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) | |
self.rotary_emb = GemmaRotaryEmbedding( | |
self.head_dim, | |
max_position_embeddings=self.max_position_embeddings, | |
base=self.rope_theta, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
if past_key_value is not None: | |
# sin and cos are specific to RoPE models; cache_position needed for the static cache | |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling | |
if attention_mask is not None: # no matter the length, we just slice it | |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
attn_weights = attn_weights + causal_mask | |
# upcast attention to fp32 | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | |
attn_output = torch.matmul(attn_weights, value_states) | |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | |
raise ValueError( | |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | |
f" {attn_output.size()}" | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.view(bsz, q_len, -1) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? | |
class GemmaFlashAttention2(LlamaFlashAttention2): | |
""" | |
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays | |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | |
flash attention and deal with padding tokens in case the input contains any of them. | |
""" | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
if isinstance(past_key_value, StaticCache): | |
raise ValueError( | |
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " | |
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" | |
) | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
# Flash attention requires the input to have the shape | |
# batch_size x seq_length x head_dim x hidden_dim | |
# therefore we just need to keep the original shape | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
if past_key_value is not None: | |
# sin and cos are specific to RoPE models; cache_position needed for the static cache | |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache | |
# to be able to avoid many of these transpose/reshape/view. | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | |
# therefore the input hidden states gets silently casted in float32. Hence, we need | |
# cast them back in the correct dtype just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms | |
# in fp32. (GemmaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, "_pre_quantization_dtype"): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
logger.warning_once( | |
f"The input hidden states seems to be silently casted in float32, this might be related to" | |
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | |
f" {target_dtype}." | |
) | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
attn_output = _flash_attention_forward( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
q_len, | |
dropout=dropout_rate, | |
sliding_window=getattr(self, "sliding_window", None), | |
is_causal=self.is_causal, | |
use_top_left_mask=self._flash_attn_uses_top_left_mask, | |
) | |
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
class GemmaModel(LlamaModel): | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if (input_ids is None) ^ (inputs_embeds is not None): | |
raise ValueError( | |
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" | |
) | |
if self.gradient_checkpointing and self.training and use_cache: | |
logger.warning_once( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | |
) | |
use_cache = False | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
return_legacy_cache = False # noqa: F841 | |
if ( | |
use_cache and not isinstance(past_key_values, Cache) and not self.training | |
): # kept for BC (non `Cache` `past_key_values` inputs) | |
return_legacy_cache = True # noqa: F841 | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
if cache_position is None: | |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
cache_position = torch.arange( | |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
) | |
if position_ids is None: | |
position_ids = cache_position.unsqueeze(0) | |
causal_mask = self._update_causal_mask( | |
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
) | |
# embed positions | |
hidden_states = inputs_embeds | |
# normalized | |
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 | |
# See https://github.com/huggingface/transformers/pull/29402 | |
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) | |
hidden_states = hidden_states * normalizer | |
return super().forward( | |
causal_mask, | |
position_ids, | |
past_key_values, | |
use_cache, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
cache_position, | |
input_ids=None, | |
inputs_embeds=hidden_states, | |
) | |
# Example where we ony modify the docstring and call super | |
class GemmaForCausalLM(LlamaForCausalLM): | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
r""" | |
Args: | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, GemmaForCausalLM | |
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") | |
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") | |
>>> prompt = "What is your favorite condiment?" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"What is your favorite condiment?" | |
```""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
) | |
hidden_states = outputs[0] | |
logits = self.lm_head(hidden_states) | |
logits = logits.float() | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class GemmaForSequenceClassification(LlamaForSequenceClassification): | |
pass | |
class GemmaForTokenClassification(LlamaForTokenClassification): | |
pass | |