# coding=utf-8 # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch OpenAI GPT-2 model.""" from dataclasses import dataclass from typing import Callable, Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from transformers.utils import ( logging, ) from transformers.utils.deprecation import deprecate_kwarg from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, eager_attention_forward, GPT2Block, GPT2MLP, GPT2SequenceSummary,GPT2PreTrainedModel,GPT2DoubleHeadsModelOuptut,GPT2DoubleHeadsModel, GPT2Model,GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification,GPT2ForTokenClassification,GPT2ForQuestionAnswering logger = logging.get_logger(__name__) class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.config = config max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), persistent=False, ) self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention # Layer-wise attention scaling, reordering, and upcasting self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx self.layer_idx = layer_idx self.reorder_and_upcast_attn = config.reorder_and_upcast_attn if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) self.c_proj = Conv1D(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.is_causal = True self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) # Prune conv1d layers self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) # Update hyper params self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() _, _, k_seq_len, _ = key.size() # Preallocate attn_weights for `baddbmm` attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) # Compute Scale Factor scale_factor = 1.0 if self.scale_attn_weights: scale_factor /= float(value.size(-1)) ** 0.5 if self.scale_attn_by_inverse_layer_idx: scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) with torch.amp.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise if attn_weights.dtype != torch.float32: raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) query_states = self.q_attn(hidden_states) key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim) shape_kv = (query_states.shape[0],query_states.shape[1], -1, self.head_dim) query_states = query_states.view(shape_q).transpose(1, 2) key_states = key_states.view(shape_kv).transpose(1, 2) value_states = value_states.view(shape_kv).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): if is_cross_attention: past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs ) is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): using_eager = True logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: # Attention functions are consistent with previous equivalent attention classes, however they do not support some options # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but # not necessarily to eager (if mentioned options are provided). attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if using_eager and self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( query_states, key_states, value_states, attention_mask, head_mask ) else: attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, head_mask=head_mask, dropout=self.attn_dropout.p if self.training else 0.0, is_causal=is_causal, **kwargs, ) attn_output = attn_output.reshape(attn_output.shape[0],attn_output.shape[1], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output, attn_weights __all__ = [ "GPT2DoubleHeadsModel", "GPT2ForQuestionAnswering", "GPT2ForSequenceClassification", "GPT2ForTokenClassification", "GPT2LMHeadModel", "GPT2Model", "GPT2PreTrainedModel", "load_tf_weights_in_gpt2", ]