|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch OpenAI GPT-2 model.""" |
|
|
|
from dataclasses import dataclass |
|
from typing import Callable, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer |
|
from transformers.utils import ( |
|
logging, |
|
) |
|
from transformers.utils.deprecation import deprecate_kwarg |
|
from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, eager_attention_forward, GPT2Block, GPT2MLP, GPT2SequenceSummary,GPT2PreTrainedModel,GPT2DoubleHeadsModelOuptut,GPT2DoubleHeadsModel, GPT2Model,GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification,GPT2ForTokenClassification,GPT2ForQuestionAnswering |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class GPT2Attention(nn.Module): |
|
def __init__(self, config, is_cross_attention=False, layer_idx=None): |
|
super().__init__() |
|
self.config = config |
|
max_positions = config.max_position_embeddings |
|
self.register_buffer( |
|
"bias", |
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( |
|
1, 1, max_positions, max_positions |
|
), |
|
persistent=False, |
|
) |
|
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) |
|
|
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
self.split_size = self.embed_dim |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
f" {self.num_heads})." |
|
) |
|
|
|
self.scale_attn_weights = config.scale_attn_weights |
|
self.is_cross_attention = is_cross_attention |
|
|
|
|
|
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx |
|
self.layer_idx = layer_idx |
|
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn |
|
|
|
if self.is_cross_attention: |
|
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) |
|
self.q_attn = Conv1D(self.embed_dim, self.embed_dim) |
|
else: |
|
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) |
|
self.c_proj = Conv1D(self.embed_dim, self.embed_dim) |
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop) |
|
self.resid_dropout = nn.Dropout(config.resid_pdrop) |
|
self.is_causal = True |
|
|
|
self.pruned_heads = set() |
|
|
|
def prune_heads(self, heads): |
|
if len(heads) == 0: |
|
return |
|
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) |
|
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) |
|
|
|
|
|
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) |
|
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) |
|
|
|
|
|
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) |
|
self.num_heads = self.num_heads - len(heads) |
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): |
|
|
|
bsz, num_heads, q_seq_len, dk = query.size() |
|
_, _, k_seq_len, _ = key.size() |
|
|
|
|
|
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) |
|
|
|
|
|
scale_factor = 1.0 |
|
if self.scale_attn_weights: |
|
scale_factor /= float(value.size(-1)) ** 0.5 |
|
|
|
if self.scale_attn_by_inverse_layer_idx: |
|
scale_factor /= float(self.layer_idx + 1) |
|
|
|
|
|
with torch.amp.autocast(query.device.type, enabled=False): |
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) |
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) |
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) |
|
|
|
if not self.is_cross_attention: |
|
|
|
query_length, key_length = query.size(-2), key.size(-2) |
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] |
|
mask_value = torch.finfo(attn_weights.dtype).min |
|
|
|
|
|
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) |
|
attn_weights = torch.where(causal_mask, attn_weights, mask_value) |
|
|
|
if attention_mask is not None: |
|
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
|
|
if attn_weights.dtype != torch.float32: |
|
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") |
|
attn_weights = attn_weights.type(value.dtype) |
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
if head_mask is not None: |
|
attn_weights = attn_weights * head_mask |
|
|
|
attn_output = torch.matmul(attn_weights, value) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
return attn_output, attn_weights |
|
|
|
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) |
|
def forward( |
|
self, |
|
hidden_states: Optional[Tuple[torch.FloatTensor]], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
|
is_cross_attention = encoder_hidden_states is not None |
|
if is_cross_attention: |
|
if not hasattr(self, "q_attn"): |
|
raise ValueError( |
|
"If class is used as cross attention, the weights `q_attn` have to be defined. " |
|
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." |
|
) |
|
|
|
query_states = self.q_attn(hidden_states) |
|
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) |
|
attention_mask = encoder_attention_mask |
|
else: |
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) |
|
|
|
shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim) |
|
shape_kv = (query_states.shape[0],query_states.shape[1], -1, self.head_dim) |
|
|
|
query_states = query_states.view(shape_q).transpose(1, 2) |
|
key_states = key_states.view(shape_kv).transpose(1, 2) |
|
value_states = value_states.view(shape_kv).transpose(1, 2) |
|
|
|
if past_key_value is not None: |
|
if isinstance(past_key_value, EncoderDecoderCache): |
|
if is_cross_attention: |
|
past_key_value = past_key_value.cross_attention_cache |
|
else: |
|
past_key_value = past_key_value.self_attention_cache |
|
cache_kwargs = {"cache_position": cache_position} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs |
|
) |
|
|
|
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention |
|
|
|
using_eager = self.config._attn_implementation == "eager" |
|
attention_interface: Callable = eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): |
|
using_eager = True |
|
logger.warning_once( |
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
) |
|
else: |
|
|
|
|
|
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
if using_eager and self.reorder_and_upcast_attn: |
|
attn_output, attn_weights = self._upcast_and_reordered_attn( |
|
query_states, key_states, value_states, attention_mask, head_mask |
|
) |
|
else: |
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
head_mask=head_mask, |
|
dropout=self.attn_dropout.p if self.training else 0.0, |
|
is_causal=is_causal, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(attn_output.shape[0],attn_output.shape[1], -1).contiguous() |
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
__all__ = [ |
|
"GPT2DoubleHeadsModel", |
|
"GPT2ForQuestionAnswering", |
|
"GPT2ForSequenceClassification", |
|
"GPT2ForTokenClassification", |
|
"GPT2LMHeadModel", |
|
"GPT2Model", |
|
"GPT2PreTrainedModel", |
|
"load_tf_weights_in_gpt2", |
|
] |
|
|