gpt2 / modeling_gpt2.py
zifei9's picture
Update modeling_gpt2.py
8cc99f4 verified
raw
history blame
11.6 kB
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""
from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union
import torch
from torch import nn
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from transformers.utils import (
logging,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, eager_attention_forward, GPT2Block, GPT2MLP, GPT2SequenceSummary,GPT2PreTrainedModel,GPT2DoubleHeadsModelOuptut,GPT2DoubleHeadsModel, GPT2Model,GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification,GPT2ForTokenClassification,GPT2ForQuestionAnswering
logger = logging.get_logger(__name__)
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.is_causal = True
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
# Compute Scale Factor
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
if self.scale_attn_by_inverse_layer_idx:
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
if attn_weights.dtype != torch.float32:
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query_states = self.q_attn(hidden_states)
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
shape_kv = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
if is_cross_attention:
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
)
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
using_eager = self.config._attn_implementation == "eager"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
using_eager = True
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
# not necessarily to eager (if mentioned options are provided).
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if using_eager and self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query_states, key_states, value_states, attention_mask, head_mask
)
else:
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
head_mask=head_mask,
dropout=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
**kwargs,
)
attn_output = attn_output.reshape(attn_output.shape[0],attn_output.shape[1], -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output, attn_weights
__all__ = [
"GPT2DoubleHeadsModel",
"GPT2ForQuestionAnswering",
"GPT2ForSequenceClassification",
"GPT2ForTokenClassification",
"GPT2LMHeadModel",
"GPT2Model",
"GPT2PreTrainedModel",
"load_tf_weights_in_gpt2",
]