zifei9 commited on
Commit
6237f55
·
verified ·
1 Parent(s): 8cc99f4

Update modeling_gpt2.py

Browse files

updating based on transformers==4.52.4

Files changed (1) hide show
  1. modeling_gpt2.py +1664 -23
modeling_gpt2.py CHANGED
@@ -15,20 +15,50 @@
15
  # limitations under the License.
16
  """PyTorch OpenAI GPT-2 model."""
17
 
 
 
 
18
  from dataclasses import dataclass
19
  from typing import Callable, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
 
23
 
 
24
  from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
25
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
26
- from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from transformers.utils import (
 
 
 
28
  logging,
29
  )
30
  from transformers.utils.deprecation import deprecate_kwarg
31
- from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, eager_attention_forward, GPT2Block, GPT2MLP, GPT2SequenceSummary,GPT2PreTrainedModel,GPT2DoubleHeadsModelOuptut,GPT2DoubleHeadsModel, GPT2Model,GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification,GPT2ForTokenClassification,GPT2ForQuestionAnswering
 
 
 
 
 
 
32
 
33
  logger = logging.get_logger(__name__)
34
 
@@ -40,9 +70,9 @@ class GPT2Attention(nn.Module):
40
  max_positions = config.max_position_embeddings
41
  self.register_buffer(
42
  "bias",
43
- torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
44
- 1, 1, max_positions, max_positions
45
- ),
46
  persistent=False,
47
  )
48
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
@@ -81,25 +111,39 @@ class GPT2Attention(nn.Module):
81
  def prune_heads(self, heads):
82
  if len(heads) == 0:
83
  return
84
- heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
85
- index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
 
 
 
 
86
 
87
  # Prune conv1d layers
88
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
89
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
90
 
91
  # Update hyper params
92
- self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
 
 
93
  self.num_heads = self.num_heads - len(heads)
94
  self.pruned_heads = self.pruned_heads.union(heads)
95
 
96
- def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
 
 
97
  # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
98
  bsz, num_heads, q_seq_len, dk = query.size()
99
  _, _, k_seq_len, _ = key.size()
100
 
101
  # Preallocate attn_weights for `baddbmm`
102
- attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
 
 
 
 
 
 
103
 
104
  # Compute Scale Factor
105
  scale_factor = 1.0
@@ -111,18 +155,26 @@ class GPT2Attention(nn.Module):
111
 
112
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
113
  with torch.amp.autocast(query.device.type, enabled=False):
114
- q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
115
- attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
 
 
 
 
116
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
117
 
118
  if not self.is_cross_attention:
119
  # if only "normal" attention layer implements causal mask
120
  query_length, key_length = query.size(-2), key.size(-2)
121
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
 
 
122
  mask_value = torch.finfo(attn_weights.dtype).min
123
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
124
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
125
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
 
 
126
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
127
 
128
  if attention_mask is not None:
@@ -133,7 +185,9 @@ class GPT2Attention(nn.Module):
133
 
134
  # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
135
  if attn_weights.dtype != torch.float32:
136
- raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
 
 
137
  attn_weights = attn_weights.type(value.dtype)
138
  attn_weights = self.attn_dropout(attn_weights)
139
 
@@ -146,7 +200,12 @@ class GPT2Attention(nn.Module):
146
 
147
  return attn_output, attn_weights
148
 
149
- @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
 
 
 
 
 
150
  def forward(
151
  self,
152
  hidden_states: Optional[Tuple[torch.FloatTensor]],
@@ -168,13 +227,17 @@ class GPT2Attention(nn.Module):
168
  )
169
 
170
  query_states = self.q_attn(hidden_states)
171
- key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
 
 
172
  attention_mask = encoder_attention_mask
173
  else:
174
- query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
 
 
175
 
176
  shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
177
- shape_kv = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
178
 
179
  query_states = query_states.view(shape_q).transpose(1, 2)
180
  key_states = key_states.view(shape_kv).transpose(1, 2)
@@ -191,12 +254,18 @@ class GPT2Attention(nn.Module):
191
  key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
192
  )
193
 
194
- is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
 
 
 
 
195
 
196
  using_eager = self.config._attn_implementation == "eager"
197
  attention_interface: Callable = eager_attention_forward
198
  if self.config._attn_implementation != "eager":
199
- if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
 
 
200
  using_eager = True
201
  logger.warning_once(
202
  "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
@@ -206,7 +275,9 @@ class GPT2Attention(nn.Module):
206
  # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
207
  # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
208
  # not necessarily to eager (if mentioned options are provided).
209
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
210
 
211
  if using_eager and self.reorder_and_upcast_attn:
212
  attn_output, attn_weights = self._upcast_and_reordered_attn(
@@ -231,6 +302,1576 @@ class GPT2Attention(nn.Module):
231
 
232
  return attn_output, attn_weights
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  __all__ = [
235
  "GPT2DoubleHeadsModel",
236
  "GPT2ForQuestionAnswering",
 
15
  # limitations under the License.
16
  """PyTorch OpenAI GPT-2 model."""
17
 
18
+ import math
19
+ import os
20
+ import warnings
21
  from dataclasses import dataclass
22
  from typing import Callable, Optional, Tuple, Union
23
 
24
  import torch
25
  from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
+ from transformers.activations import ACT2FN, get_activation
29
  from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import (
32
+ AttentionMaskConverter,
33
+ _prepare_4d_attention_mask_for_sdpa,
34
+ )
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutputWithPastAndCrossAttentions,
37
+ CausalLMOutputWithCrossAttentions,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutputWithPast,
40
+ TokenClassifierOutput,
41
+ )
42
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from transformers.pytorch_utils import (
44
+ Conv1D,
45
+ find_pruneable_heads_and_indices,
46
+ prune_conv1d_layer,
47
+ )
48
  from transformers.utils import (
49
+ ModelOutput,
50
+ add_start_docstrings,
51
+ auto_docstring,
52
  logging,
53
  )
54
  from transformers.utils.deprecation import deprecate_kwarg
55
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
56
+ from .configuration_gpt2 import GPT2Config
57
+ from transformers.models.gpt2.modeling_gpt2 import (
58
+ load_tf_weights_in_gpt2,
59
+ eager_attention_forward,
60
+ )
61
+
62
 
63
  logger = logging.get_logger(__name__)
64
 
 
70
  max_positions = config.max_position_embeddings
71
  self.register_buffer(
72
  "bias",
73
+ torch.tril(
74
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
75
+ ).view(1, 1, max_positions, max_positions),
76
  persistent=False,
77
  )
78
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
 
111
  def prune_heads(self, heads):
112
  if len(heads) == 0:
113
  return
114
+ heads, index = find_pruneable_heads_and_indices(
115
+ heads, self.num_heads, self.head_dim, self.pruned_heads
116
+ )
117
+ index_attn = torch.cat(
118
+ [index, index + self.split_size, index + (2 * self.split_size)]
119
+ )
120
 
121
  # Prune conv1d layers
122
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
123
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
124
 
125
  # Update hyper params
126
+ self.split_size = (self.split_size // self.num_heads) * (
127
+ self.num_heads - len(heads)
128
+ )
129
  self.num_heads = self.num_heads - len(heads)
130
  self.pruned_heads = self.pruned_heads.union(heads)
131
 
132
+ def _upcast_and_reordered_attn(
133
+ self, query, key, value, attention_mask=None, head_mask=None
134
+ ):
135
  # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
136
  bsz, num_heads, q_seq_len, dk = query.size()
137
  _, _, k_seq_len, _ = key.size()
138
 
139
  # Preallocate attn_weights for `baddbmm`
140
+ attn_weights = torch.empty(
141
+ bsz * num_heads,
142
+ q_seq_len,
143
+ k_seq_len,
144
+ dtype=torch.float32,
145
+ device=query.device,
146
+ )
147
 
148
  # Compute Scale Factor
149
  scale_factor = 1.0
 
155
 
156
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
157
  with torch.amp.autocast(query.device.type, enabled=False):
158
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
159
+ -1, dk, k_seq_len
160
+ )
161
+ attn_weights = torch.baddbmm(
162
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
163
+ )
164
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
165
 
166
  if not self.is_cross_attention:
167
  # if only "normal" attention layer implements causal mask
168
  query_length, key_length = query.size(-2), key.size(-2)
169
+ causal_mask = self.bias[
170
+ :, :, key_length - query_length : key_length, :key_length
171
+ ]
172
  mask_value = torch.finfo(attn_weights.dtype).min
173
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
174
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
175
+ mask_value = torch.tensor(
176
+ mask_value, dtype=attn_weights.dtype, device=attn_weights.device
177
+ )
178
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
179
 
180
  if attention_mask is not None:
 
185
 
186
  # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
187
  if attn_weights.dtype != torch.float32:
188
+ raise RuntimeError(
189
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
190
+ )
191
  attn_weights = attn_weights.type(value.dtype)
192
  attn_weights = self.attn_dropout(attn_weights)
193
 
 
200
 
201
  return attn_output, attn_weights
202
 
203
+ @deprecate_kwarg(
204
+ "layer_past",
205
+ new_name="past_key_value",
206
+ version="4.53.0",
207
+ raise_if_both_names=True,
208
+ )
209
  def forward(
210
  self,
211
  hidden_states: Optional[Tuple[torch.FloatTensor]],
 
227
  )
228
 
229
  query_states = self.q_attn(hidden_states)
230
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(
231
+ self.split_size, dim=2
232
+ )
233
  attention_mask = encoder_attention_mask
234
  else:
235
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(
236
+ self.split_size, dim=2
237
+ )
238
 
239
  shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
240
+ shape_kv = (key_states.shape[0], key_states.shape[1],-1, self.head_dim)
241
 
242
  query_states = query_states.view(shape_q).transpose(1, 2)
243
  key_states = key_states.view(shape_kv).transpose(1, 2)
 
254
  key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
255
  )
256
 
257
+ is_causal = (
258
+ attention_mask is None
259
+ and query_states.shape[-2] > 1
260
+ and not is_cross_attention
261
+ )
262
 
263
  using_eager = self.config._attn_implementation == "eager"
264
  attention_interface: Callable = eager_attention_forward
265
  if self.config._attn_implementation != "eager":
266
+ if self.config._attn_implementation == "sdpa" and (
267
+ output_attentions or head_mask is not None
268
+ ):
269
  using_eager = True
270
  logger.warning_once(
271
  "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
 
275
  # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
276
  # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
277
  # not necessarily to eager (if mentioned options are provided).
278
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
279
+ self.config._attn_implementation
280
+ ]
281
 
282
  if using_eager and self.reorder_and_upcast_attn:
283
  attn_output, attn_weights = self._upcast_and_reordered_attn(
 
302
 
303
  return attn_output, attn_weights
304
 
305
+
306
+ class GPT2MLP(nn.Module):
307
+ def __init__(self, intermediate_size, config):
308
+ super().__init__()
309
+ embed_dim = config.hidden_size
310
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
311
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
312
+ self.act = ACT2FN[config.activation_function]
313
+ self.dropout = nn.Dropout(config.resid_pdrop)
314
+
315
+ def forward(
316
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
317
+ ) -> torch.FloatTensor:
318
+ hidden_states = self.c_fc(hidden_states)
319
+ hidden_states = self.act(hidden_states)
320
+ hidden_states = self.c_proj(hidden_states)
321
+ hidden_states = self.dropout(hidden_states)
322
+ return hidden_states
323
+
324
+
325
+ class GPT2Block(nn.Module):
326
+ def __init__(self, config, layer_idx=None):
327
+ super().__init__()
328
+ hidden_size = config.hidden_size
329
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
330
+
331
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
332
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
333
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
334
+
335
+ if config.add_cross_attention:
336
+ self.crossattention = GPT2Attention(
337
+ config=config, is_cross_attention=True, layer_idx=layer_idx
338
+ )
339
+ self.ln_cross_attn = nn.LayerNorm(
340
+ hidden_size, eps=config.layer_norm_epsilon
341
+ )
342
+
343
+ self.mlp = GPT2MLP(inner_dim, config)
344
+
345
+ @deprecate_kwarg(
346
+ "layer_past",
347
+ new_name="past_key_value",
348
+ version="4.53.0",
349
+ raise_if_both_names=True,
350
+ )
351
+ def forward(
352
+ self,
353
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
354
+ past_key_value: Optional[Cache] = None,
355
+ cache_position: Optional[torch.LongTensor] = None,
356
+ attention_mask: Optional[torch.FloatTensor] = None,
357
+ head_mask: Optional[torch.FloatTensor] = None,
358
+ encoder_hidden_states: Optional[torch.Tensor] = None,
359
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
360
+ use_cache: Optional[bool] = False,
361
+ output_attentions: Optional[bool] = False,
362
+ **kwargs,
363
+ ) -> Union[
364
+ Tuple[torch.Tensor],
365
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
366
+ ]:
367
+ residual = hidden_states
368
+ hidden_states = self.ln_1(hidden_states)
369
+ attn_output, self_attn_weights = self.attn(
370
+ hidden_states,
371
+ past_key_value=past_key_value,
372
+ cache_position=cache_position,
373
+ attention_mask=attention_mask,
374
+ head_mask=head_mask,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ **kwargs,
378
+ )
379
+ # residual connection
380
+ hidden_states = attn_output + residual
381
+
382
+ if encoder_hidden_states is not None:
383
+ # add one self-attention block for cross-attention
384
+ if not hasattr(self, "crossattention"):
385
+ raise ValueError(
386
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
387
+ "cross-attention layers by setting `config.add_cross_attention=True`"
388
+ )
389
+ residual = hidden_states
390
+ hidden_states = self.ln_cross_attn(hidden_states)
391
+ cross_attn_output, cross_attn_weights = self.crossattention(
392
+ hidden_states,
393
+ past_key_value=past_key_value,
394
+ attention_mask=attention_mask,
395
+ head_mask=head_mask,
396
+ encoder_hidden_states=encoder_hidden_states,
397
+ encoder_attention_mask=encoder_attention_mask,
398
+ output_attentions=output_attentions,
399
+ )
400
+ # residual connection
401
+ hidden_states = residual + cross_attn_output
402
+
403
+ residual = hidden_states
404
+ hidden_states = self.ln_2(hidden_states)
405
+ feed_forward_hidden_states = self.mlp(hidden_states)
406
+ # residual connection
407
+ hidden_states = residual + feed_forward_hidden_states
408
+
409
+ outputs = (hidden_states,)
410
+ if output_attentions:
411
+ outputs += (self_attn_weights,)
412
+ if encoder_hidden_states is not None:
413
+ outputs += (cross_attn_weights,)
414
+
415
+ return outputs
416
+
417
+
418
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
419
+ class GPT2SequenceSummary(nn.Module):
420
+ r"""
421
+ Compute a single vector summary of a sequence hidden states.
422
+
423
+ Args:
424
+ config ([`GPT2Config`]):
425
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
426
+ config class of your model for the default values it uses):
427
+
428
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
429
+
430
+ - `"last"` -- Take the last token hidden state (like XLNet)
431
+ - `"first"` -- Take the first token hidden state (like Bert)
432
+ - `"mean"` -- Take the mean of all tokens hidden states
433
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
434
+ - `"attn"` -- Not implemented now, use multi-head attention
435
+
436
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
437
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
438
+ (otherwise to `config.hidden_size`).
439
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
440
+ another string or `None` will add no activation.
441
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
442
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
443
+ """
444
+
445
+ def __init__(self, config: GPT2Config):
446
+ super().__init__()
447
+
448
+ self.summary_type = getattr(config, "summary_type", "last")
449
+ if self.summary_type == "attn":
450
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
451
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
452
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
453
+ raise NotImplementedError
454
+
455
+ self.summary = nn.Identity()
456
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
457
+ if (
458
+ hasattr(config, "summary_proj_to_labels")
459
+ and config.summary_proj_to_labels
460
+ and config.num_labels > 0
461
+ ):
462
+ num_classes = config.num_labels
463
+ else:
464
+ num_classes = config.hidden_size
465
+ self.summary = nn.Linear(config.hidden_size, num_classes)
466
+
467
+ activation_string = getattr(config, "summary_activation", None)
468
+ self.activation: Callable = (
469
+ get_activation(activation_string) if activation_string else nn.Identity()
470
+ )
471
+
472
+ self.first_dropout = nn.Identity()
473
+ if (
474
+ hasattr(config, "summary_first_dropout")
475
+ and config.summary_first_dropout > 0
476
+ ):
477
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
478
+
479
+ self.last_dropout = nn.Identity()
480
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
481
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
482
+
483
+ def forward(
484
+ self,
485
+ hidden_states: torch.FloatTensor,
486
+ cls_index: Optional[torch.LongTensor] = None,
487
+ ) -> torch.FloatTensor:
488
+ """
489
+ Compute a single vector summary of a sequence hidden states.
490
+
491
+ Args:
492
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
493
+ The hidden states of the last layer.
494
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
495
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
496
+
497
+ Returns:
498
+ `torch.FloatTensor`: The summary of the sequence hidden states.
499
+ """
500
+ if self.summary_type == "last":
501
+ output = hidden_states[:, -1]
502
+ elif self.summary_type == "first":
503
+ output = hidden_states[:, 0]
504
+ elif self.summary_type == "mean":
505
+ output = hidden_states.mean(dim=1)
506
+ elif self.summary_type == "cls_index":
507
+ if cls_index is None:
508
+ cls_index = torch.full_like(
509
+ hidden_states[..., :1, :],
510
+ hidden_states.shape[-2] - 1,
511
+ dtype=torch.long,
512
+ )
513
+ else:
514
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
515
+ cls_index = cls_index.expand(
516
+ (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
517
+ )
518
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
519
+ output = hidden_states.gather(-2, cls_index).squeeze(
520
+ -2
521
+ ) # shape (bsz, XX, hidden_size)
522
+ elif self.summary_type == "attn":
523
+ raise NotImplementedError
524
+
525
+ output = self.first_dropout(output)
526
+ output = self.summary(output)
527
+ output = self.activation(output)
528
+ output = self.last_dropout(output)
529
+
530
+ return output
531
+
532
+
533
+ @auto_docstring
534
+ class GPT2PreTrainedModel(PreTrainedModel):
535
+ config_class = GPT2Config
536
+ load_tf_weights = load_tf_weights_in_gpt2
537
+ base_model_prefix = "transformer"
538
+ is_parallelizable = True
539
+ supports_gradient_checkpointing = True
540
+ _no_split_modules = ["GPT2Block"]
541
+ _skip_keys_device_placement = "past_key_values"
542
+ _supports_flash_attn_2 = True
543
+ _supports_sdpa = True
544
+ _supports_attention_backend = True
545
+ _supports_cache_class = True
546
+ _supports_static_cache = True
547
+
548
+ def __init__(self, *inputs, **kwargs):
549
+ super().__init__(*inputs, **kwargs)
550
+
551
+ def _init_weights(self, module):
552
+ """Initialize the weights."""
553
+ if isinstance(module, (nn.Linear, Conv1D)):
554
+ # Slightly different from the TF version which uses truncated_normal for initialization
555
+ # cf https://github.com/pytorch/pytorch/pull/5617
556
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557
+ if module.bias is not None:
558
+ module.bias.data.zero_()
559
+ elif isinstance(module, nn.Embedding):
560
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
561
+ if module.padding_idx is not None:
562
+ module.weight.data[module.padding_idx].zero_()
563
+ elif isinstance(module, nn.LayerNorm):
564
+ module.bias.data.zero_()
565
+ module.weight.data.fill_(1.0)
566
+
567
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
568
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
569
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
570
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
571
+ #
572
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
573
+ for name, p in module.named_parameters():
574
+ if name == "c_proj.weight":
575
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
576
+ p.data.normal_(
577
+ mean=0.0,
578
+ std=(
579
+ self.config.initializer_range
580
+ / math.sqrt(2 * self.config.n_layer)
581
+ ),
582
+ )
583
+
584
+
585
+ @dataclass
586
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
587
+ """
588
+ Base class for outputs of models predicting if two sentences are consecutive or not.
589
+
590
+ Args:
591
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
592
+ Language modeling loss.
593
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
594
+ Multiple choice classification loss.
595
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
596
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
597
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
598
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
599
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
600
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
601
+ sequence_length, embed_size_per_head)`).
602
+
603
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
604
+ `past_key_values` input) to speed up sequential decoding.
605
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
606
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
607
+ shape `(batch_size, sequence_length, hidden_size)`.
608
+
609
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
610
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
611
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
612
+ sequence_length)`.
613
+
614
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
615
+ self-attention heads.
616
+ """
617
+
618
+ loss: Optional[torch.FloatTensor] = None
619
+ mc_loss: Optional[torch.FloatTensor] = None
620
+ logits: Optional[torch.FloatTensor] = None
621
+ mc_logits: Optional[torch.FloatTensor] = None
622
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
623
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
624
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
625
+
626
+
627
+ PARALLELIZE_DOCSTRING = r"""
628
+ This is an experimental feature and is a subject to change at a moment's notice.
629
+
630
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
631
+ it will evenly distribute blocks across all devices.
632
+
633
+ Args:
634
+ device_map (`Dict[int, list]`, *optional*):
635
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
636
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
637
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
638
+ following number of attention modules:
639
+
640
+ - openai-community/gpt2: 12
641
+ - openai-community/gpt2-medium: 24
642
+ - openai-community/gpt2-large: 36
643
+ - openai-community/gpt2-xl: 48
644
+
645
+ Example:
646
+
647
+ ```python
648
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
649
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
650
+ device_map = {
651
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
652
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
653
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
654
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
655
+ }
656
+ model.parallelize(device_map)
657
+ ```
658
+ """
659
+ DEPARALLELIZE_DOCSTRING = r"""
660
+ Moves the model to cpu from a model parallel state.
661
+
662
+ Example:
663
+
664
+ ```python
665
+ # On a 4 GPU machine with openai-community/gpt2-large:
666
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
667
+ device_map = {
668
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
669
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
670
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
671
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
672
+ }
673
+ model.parallelize(device_map) # Splits the model across several devices
674
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
675
+ ```
676
+ """
677
+
678
+
679
+ @auto_docstring
680
+ class GPT2Model(GPT2PreTrainedModel):
681
+ _supports_param_buffer_assignment = False
682
+
683
+ def __init__(self, config):
684
+ super().__init__(config)
685
+
686
+ self.embed_dim = config.hidden_size
687
+
688
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
689
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
690
+
691
+ self.drop = nn.Dropout(config.embd_pdrop)
692
+ self.h = nn.ModuleList(
693
+ [GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
694
+ )
695
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
696
+
697
+ # Model parallel
698
+ self.model_parallel = False
699
+ self.device_map = None
700
+ self.gradient_checkpointing = False
701
+ self._attn_implementation = config._attn_implementation
702
+
703
+ # Initialize weights and apply final processing
704
+ self.post_init()
705
+
706
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
707
+ def parallelize(self, device_map=None):
708
+ # Check validity of device_map
709
+ warnings.warn(
710
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
711
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
712
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
713
+ " ...}",
714
+ FutureWarning,
715
+ )
716
+ self.device_map = (
717
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
718
+ if device_map is None
719
+ else device_map
720
+ )
721
+ assert_device_map(self.device_map, len(self.h))
722
+ self.model_parallel = True
723
+ self.first_device = (
724
+ "cpu"
725
+ if "cpu" in self.device_map.keys()
726
+ else "cuda:" + str(min(self.device_map.keys()))
727
+ )
728
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
729
+ self.wte = self.wte.to(self.first_device)
730
+ self.wpe = self.wpe.to(self.first_device)
731
+ # Load onto devices
732
+ for k, v in self.device_map.items():
733
+ for block in v:
734
+ cuda_device = "cuda:" + str(k)
735
+ self.h[block] = self.h[block].to(cuda_device)
736
+ # ln_f to last
737
+ self.ln_f = self.ln_f.to(self.last_device)
738
+
739
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
740
+ def deparallelize(self):
741
+ warnings.warn(
742
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
743
+ FutureWarning,
744
+ )
745
+ self.model_parallel = False
746
+ self.device_map = None
747
+ self.first_device = "cpu"
748
+ self.last_device = "cpu"
749
+ self.wte = self.wte.to("cpu")
750
+ self.wpe = self.wpe.to("cpu")
751
+ for index in range(len(self.h)):
752
+ self.h[index] = self.h[index].to("cpu")
753
+ self.ln_f = self.ln_f.to("cpu")
754
+ torch.cuda.empty_cache()
755
+
756
+ def get_input_embeddings(self):
757
+ return self.wte
758
+
759
+ def set_input_embeddings(self, new_embeddings):
760
+ self.wte = new_embeddings
761
+
762
+ def _prune_heads(self, heads_to_prune):
763
+ """
764
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
765
+ """
766
+ for layer, heads in heads_to_prune.items():
767
+ self.h[layer].attn.prune_heads(heads)
768
+
769
+ @auto_docstring
770
+ def forward(
771
+ self,
772
+ input_ids: Optional[torch.LongTensor] = None,
773
+ past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor]], Cache]] = None,
774
+ cache_position: Optional[torch.LongTensor] = None,
775
+ attention_mask: Optional[torch.FloatTensor] = None,
776
+ token_type_ids: Optional[torch.LongTensor] = None,
777
+ position_ids: Optional[torch.LongTensor] = None,
778
+ head_mask: Optional[torch.FloatTensor] = None,
779
+ inputs_embeds: Optional[torch.FloatTensor] = None,
780
+ encoder_hidden_states: Optional[torch.Tensor] = None,
781
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
782
+ use_cache: Optional[bool] = None,
783
+ output_attentions: Optional[bool] = None,
784
+ output_hidden_states: Optional[bool] = None,
785
+ return_dict: Optional[bool] = None,
786
+ **kwargs,
787
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
788
+ r"""
789
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
790
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
791
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
792
+ sequence tokens in the vocabulary.
793
+
794
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
795
+ `input_ids`.
796
+
797
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
798
+ [`PreTrainedTokenizer.__call__`] for details.
799
+
800
+ [What are input IDs?](../glossary#input-ids)
801
+ """
802
+ output_attentions = (
803
+ output_attentions
804
+ if output_attentions is not None
805
+ else self.config.output_attentions
806
+ )
807
+ output_hidden_states = (
808
+ output_hidden_states
809
+ if output_hidden_states is not None
810
+ else self.config.output_hidden_states
811
+ )
812
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
813
+ return_dict = (
814
+ return_dict if return_dict is not None else self.config.use_return_dict
815
+ )
816
+
817
+ if input_ids is not None and inputs_embeds is not None:
818
+ raise ValueError(
819
+ "You cannot specify both input_ids and inputs_embeds at the same time"
820
+ )
821
+ elif input_ids is not None:
822
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
823
+ input_shape = input_ids.size()
824
+ input_ids = input_ids.view(-1, input_shape[-1])
825
+ batch_size = input_ids.shape[0]
826
+ elif inputs_embeds is not None:
827
+ input_shape = inputs_embeds.size()[:-1]
828
+ batch_size = inputs_embeds.shape[0]
829
+ else:
830
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
831
+
832
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
833
+
834
+ if token_type_ids is not None:
835
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
836
+
837
+ if self.gradient_checkpointing and self.training:
838
+ if use_cache:
839
+ logger.warning_once(
840
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
841
+ )
842
+ use_cache = False
843
+
844
+ # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
845
+ return_legacy_cache = False
846
+ if use_cache:
847
+ if past_key_values is None:
848
+ return_legacy_cache = True
849
+ past_key_values = DynamicCache()
850
+ elif not isinstance(past_key_values, Cache):
851
+ return_legacy_cache = True
852
+ logger.warning_once(
853
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
854
+ "You should pass an instance of `Cache` instead, e.g. "
855
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
856
+ )
857
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
858
+
859
+ if self.config.add_cross_attention and not isinstance(
860
+ past_key_values, EncoderDecoderCache
861
+ ):
862
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
863
+
864
+ if inputs_embeds is None:
865
+ inputs_embeds = self.wte(input_ids)
866
+
867
+ if cache_position is None:
868
+ past_seen_tokens = (
869
+ past_key_values.get_seq_length() if past_key_values is not None else 0
870
+ )
871
+ cache_position = torch.arange(
872
+ past_seen_tokens,
873
+ past_seen_tokens + inputs_embeds.shape[1],
874
+ device=inputs_embeds.device,
875
+ )
876
+ if position_ids is None:
877
+ position_ids = cache_position.unsqueeze(0)
878
+
879
+ position_embeds = self.wpe(position_ids)
880
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
881
+
882
+ # Attention mask.
883
+ # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
884
+ if attention_mask is not None and attention_mask.ndim < 4:
885
+ attention_mask = attention_mask.view(batch_size, -1)
886
+ causal_mask = self._update_causal_mask(
887
+ attention_mask,
888
+ inputs_embeds,
889
+ cache_position,
890
+ past_key_values,
891
+ output_attentions,
892
+ )
893
+
894
+ # If a 2D or 3D attention mask is provided for the cross-attention
895
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
896
+ _use_sdpa = (
897
+ self._attn_implementation == "sdpa"
898
+ and output_attentions is False
899
+ and head_mask is None
900
+ )
901
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
902
+ encoder_batch_size, encoder_sequence_length, _ = (
903
+ encoder_hidden_states.size()
904
+ )
905
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
906
+ if encoder_attention_mask is None:
907
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
908
+ if _use_sdpa:
909
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
910
+ mask=encoder_attention_mask,
911
+ dtype=inputs_embeds.dtype,
912
+ tgt_len=input_shape[-1],
913
+ )
914
+ elif not self._attn_implementation == "flash_attention_2":
915
+ encoder_attention_mask = self.invert_attention_mask(
916
+ encoder_attention_mask
917
+ )
918
+ else:
919
+ encoder_attention_mask = None
920
+
921
+ # Prepare head mask if needed
922
+ # 1.0 in head_mask indicate we keep the head
923
+ # attention_probs has shape bsz x n_heads x N x N
924
+ # head_mask has shape n_layer x batch x n_heads x N x N
925
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
926
+
927
+ if token_type_ids is not None:
928
+ token_type_embeds = self.wte(token_type_ids)
929
+ hidden_states = hidden_states + token_type_embeds
930
+
931
+ hidden_states = self.drop(hidden_states)
932
+
933
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
934
+
935
+ all_self_attentions = () if output_attentions else None
936
+ all_cross_attentions = (
937
+ () if output_attentions and self.config.add_cross_attention else None
938
+ )
939
+ all_hidden_states = () if output_hidden_states else None
940
+ for i, block in enumerate(self.h):
941
+ # Model parallel
942
+ if self.model_parallel:
943
+ torch.cuda.set_device(hidden_states.device)
944
+ # Ensure that attention_mask is always on the same device as hidden_states
945
+ if attention_mask is not None:
946
+ attention_mask = attention_mask.to(hidden_states.device)
947
+ if isinstance(head_mask, torch.Tensor):
948
+ head_mask = head_mask.to(hidden_states.device)
949
+ if output_hidden_states:
950
+ all_hidden_states = all_hidden_states + (hidden_states,)
951
+
952
+ if self.gradient_checkpointing and self.training:
953
+ outputs = self._gradient_checkpointing_func(
954
+ block.__call__,
955
+ hidden_states,
956
+ past_key_values,
957
+ cache_position,
958
+ causal_mask,
959
+ head_mask[i],
960
+ encoder_hidden_states,
961
+ encoder_attention_mask,
962
+ use_cache,
963
+ output_attentions,
964
+ )
965
+ else:
966
+ outputs = block(
967
+ hidden_states,
968
+ past_key_value=past_key_values,
969
+ cache_position=cache_position,
970
+ attention_mask=causal_mask,
971
+ head_mask=head_mask[i],
972
+ encoder_hidden_states=encoder_hidden_states,
973
+ encoder_attention_mask=encoder_attention_mask,
974
+ use_cache=use_cache,
975
+ output_attentions=output_attentions,
976
+ **kwargs,
977
+ )
978
+
979
+ hidden_states = outputs[0]
980
+
981
+ if output_attentions:
982
+ all_self_attentions = all_self_attentions + (outputs[1],)
983
+ if self.config.add_cross_attention:
984
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
985
+
986
+ # Model Parallel: If it's the last layer for that device, put things on the next device
987
+ if self.model_parallel:
988
+ for k, v in self.device_map.items():
989
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
990
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
991
+
992
+ hidden_states = self.ln_f(hidden_states)
993
+
994
+ hidden_states = hidden_states.view(output_shape)
995
+ # Add last hidden state
996
+ if output_hidden_states:
997
+ all_hidden_states = all_hidden_states + (hidden_states,)
998
+
999
+ past_key_values = past_key_values if use_cache else None
1000
+ if return_legacy_cache:
1001
+ past_key_values = (
1002
+ past_key_values.self_attention_cache.to_legacy_cache()
1003
+ if self.config.add_cross_attention
1004
+ else past_key_values.to_legacy_cache()
1005
+ )
1006
+ if not return_dict:
1007
+ return tuple(
1008
+ v
1009
+ for v in [
1010
+ hidden_states,
1011
+ past_key_values,
1012
+ all_hidden_states,
1013
+ all_self_attentions,
1014
+ all_cross_attentions,
1015
+ ]
1016
+ if v is not None
1017
+ )
1018
+
1019
+ return BaseModelOutputWithPastAndCrossAttentions(
1020
+ last_hidden_state=hidden_states,
1021
+ past_key_values=past_key_values,
1022
+ hidden_states=all_hidden_states,
1023
+ attentions=all_self_attentions,
1024
+ cross_attentions=all_cross_attentions,
1025
+ )
1026
+
1027
+ def _update_causal_mask(
1028
+ self,
1029
+ attention_mask: torch.Tensor,
1030
+ input_tensor: torch.Tensor,
1031
+ cache_position: torch.Tensor,
1032
+ past_key_values: Cache,
1033
+ output_attentions: bool,
1034
+ ):
1035
+ if self.config._attn_implementation == "flash_attention_2":
1036
+ if attention_mask is not None and 0.0 in attention_mask:
1037
+ return attention_mask
1038
+ return None
1039
+
1040
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1041
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1042
+ # to infer the attention mask.
1043
+ past_seen_tokens = (
1044
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1045
+ )
1046
+ using_static_cache = isinstance(past_key_values, StaticCache)
1047
+
1048
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1049
+ if (
1050
+ self.config._attn_implementation == "sdpa"
1051
+ and not using_static_cache
1052
+ and not output_attentions
1053
+ ):
1054
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1055
+ attention_mask,
1056
+ inputs_embeds=input_tensor,
1057
+ past_key_values_length=past_seen_tokens,
1058
+ is_training=self.training,
1059
+ ):
1060
+ return None
1061
+
1062
+ dtype = input_tensor.dtype
1063
+ sequence_length = input_tensor.shape[1]
1064
+ if using_static_cache:
1065
+ target_length = past_key_values.get_max_cache_shape()
1066
+ else:
1067
+ target_length = (
1068
+ attention_mask.shape[-1]
1069
+ if isinstance(attention_mask, torch.Tensor)
1070
+ else past_seen_tokens + sequence_length + 1
1071
+ )
1072
+
1073
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1074
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1075
+ attention_mask,
1076
+ sequence_length=sequence_length,
1077
+ target_length=target_length,
1078
+ dtype=dtype,
1079
+ cache_position=cache_position,
1080
+ batch_size=input_tensor.shape[0],
1081
+ )
1082
+
1083
+ if (
1084
+ self.config._attn_implementation == "sdpa"
1085
+ and attention_mask is not None
1086
+ and attention_mask.device.type == "cuda"
1087
+ and not output_attentions
1088
+ ):
1089
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1090
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1091
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1092
+ min_dtype = torch.finfo(dtype).min
1093
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1094
+ causal_mask, min_dtype
1095
+ )
1096
+
1097
+ return causal_mask
1098
+
1099
+ @staticmethod
1100
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1101
+ attention_mask: torch.Tensor,
1102
+ sequence_length: int,
1103
+ target_length: int,
1104
+ dtype: torch.dtype,
1105
+ cache_position: torch.Tensor,
1106
+ batch_size: int,
1107
+ **kwargs,
1108
+ ):
1109
+ """
1110
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1111
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1112
+
1113
+ Args:
1114
+ attention_mask (`torch.Tensor`):
1115
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1116
+ `(batch_size, 1, query_length, key_value_length)`.
1117
+ sequence_length (`int`):
1118
+ The sequence length being processed.
1119
+ target_length (`int`):
1120
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1121
+ to account for the 0 padding, the part of the cache that is not filled yet.
1122
+ dtype (`torch.dtype`):
1123
+ The dtype to use for the 4D attention mask.
1124
+ cache_position (`torch.Tensor`):
1125
+ Indices depicting the position of the input sequence tokens in the sequence.
1126
+ batch_size (`torch.Tensor`):
1127
+ Batch size.
1128
+ """
1129
+ if attention_mask is not None and attention_mask.dim() == 4:
1130
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1131
+ causal_mask = attention_mask
1132
+ else:
1133
+ min_dtype = torch.finfo(dtype).min
1134
+ causal_mask = torch.full(
1135
+ (sequence_length, target_length),
1136
+ fill_value=min_dtype,
1137
+ dtype=dtype,
1138
+ device=cache_position.device,
1139
+ )
1140
+ if sequence_length != 1:
1141
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1142
+ causal_mask *= torch.arange(
1143
+ target_length, device=cache_position.device
1144
+ ) > cache_position.reshape(-1, 1)
1145
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1146
+ if attention_mask is not None:
1147
+ causal_mask = (
1148
+ causal_mask.clone()
1149
+ ) # copy to contiguous memory for in-place edit
1150
+ mask_length = attention_mask.shape[-1]
1151
+ padding_mask = (
1152
+ causal_mask[:, :, :, :mask_length]
1153
+ + attention_mask[:, None, None, :]
1154
+ )
1155
+ padding_mask = padding_mask == 0
1156
+ causal_mask[:, :, :, :mask_length] = causal_mask[
1157
+ :, :, :, :mask_length
1158
+ ].masked_fill(padding_mask, min_dtype)
1159
+
1160
+ return causal_mask
1161
+
1162
+
1163
+ @auto_docstring(
1164
+ custom_intro="""
1165
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1166
+ embeddings).
1167
+ """
1168
+ )
1169
+ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1170
+ _tied_weights_keys = ["lm_head.weight"]
1171
+
1172
+ def __init__(self, config):
1173
+ super().__init__(config)
1174
+ self.transformer = GPT2Model(config)
1175
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1176
+
1177
+ # Model parallel
1178
+ self.model_parallel = False
1179
+ self.device_map = None
1180
+
1181
+ # Initialize weights and apply final processing
1182
+ self.post_init()
1183
+
1184
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1185
+ def parallelize(self, device_map=None):
1186
+ warnings.warn(
1187
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1188
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1189
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1190
+ " 0, 'transformer.h.1': 1, ...}",
1191
+ FutureWarning,
1192
+ )
1193
+ self.device_map = (
1194
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1195
+ if device_map is None
1196
+ else device_map
1197
+ )
1198
+ assert_device_map(self.device_map, len(self.transformer.h))
1199
+ self.transformer.parallelize(self.device_map)
1200
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1201
+ self.model_parallel = True
1202
+
1203
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1204
+ def deparallelize(self):
1205
+ warnings.warn(
1206
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1207
+ FutureWarning,
1208
+ )
1209
+ self.transformer.deparallelize()
1210
+ self.transformer = self.transformer.to("cpu")
1211
+ self.lm_head = self.lm_head.to("cpu")
1212
+ self.model_parallel = False
1213
+ torch.cuda.empty_cache()
1214
+
1215
+ def get_output_embeddings(self):
1216
+ return self.lm_head
1217
+
1218
+ def set_output_embeddings(self, new_embeddings):
1219
+ self.lm_head = new_embeddings
1220
+
1221
+ @auto_docstring
1222
+ def forward(
1223
+ self,
1224
+ input_ids: Optional[torch.LongTensor] = None,
1225
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1226
+ cache_position: Optional[torch.LongTensor] = None,
1227
+ attention_mask: Optional[torch.FloatTensor] = None,
1228
+ token_type_ids: Optional[torch.LongTensor] = None,
1229
+ position_ids: Optional[torch.LongTensor] = None,
1230
+ head_mask: Optional[torch.FloatTensor] = None,
1231
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1232
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1233
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1234
+ labels: Optional[torch.LongTensor] = None,
1235
+ use_cache: Optional[bool] = None,
1236
+ output_attentions: Optional[bool] = None,
1237
+ output_hidden_states: Optional[bool] = None,
1238
+ return_dict: Optional[bool] = None,
1239
+ **kwargs,
1240
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1241
+ r"""
1242
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
1243
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
1244
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
1245
+ sequence tokens in the vocabulary.
1246
+
1247
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
1248
+ `input_ids`.
1249
+
1250
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1251
+ [`PreTrainedTokenizer.__call__`] for details.
1252
+
1253
+ [What are input IDs?](../glossary#input-ids)
1254
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
1255
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1256
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1257
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1258
+ """
1259
+ return_dict = (
1260
+ return_dict if return_dict is not None else self.config.use_return_dict
1261
+ )
1262
+
1263
+ transformer_outputs = self.transformer(
1264
+ input_ids,
1265
+ past_key_values=past_key_values,
1266
+ attention_mask=attention_mask,
1267
+ cache_position=cache_position,
1268
+ token_type_ids=token_type_ids,
1269
+ position_ids=position_ids,
1270
+ head_mask=head_mask,
1271
+ inputs_embeds=inputs_embeds,
1272
+ encoder_hidden_states=encoder_hidden_states,
1273
+ encoder_attention_mask=encoder_attention_mask,
1274
+ use_cache=use_cache,
1275
+ output_attentions=output_attentions,
1276
+ output_hidden_states=output_hidden_states,
1277
+ return_dict=return_dict,
1278
+ )
1279
+ hidden_states = transformer_outputs[0]
1280
+
1281
+ # Set device for model parallelism
1282
+ if self.model_parallel:
1283
+ torch.cuda.set_device(self.transformer.first_device)
1284
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1285
+
1286
+ lm_logits = self.lm_head(hidden_states)
1287
+
1288
+ loss = None
1289
+ if labels is not None:
1290
+ # Flatten the tokens
1291
+ loss = self.loss_function(
1292
+ lm_logits,
1293
+ labels,
1294
+ vocab_size=self.config.vocab_size,
1295
+ **kwargs,
1296
+ )
1297
+
1298
+ if not return_dict:
1299
+ output = (lm_logits,) + transformer_outputs[1:]
1300
+ return ((loss,) + output) if loss is not None else output
1301
+
1302
+ return CausalLMOutputWithCrossAttentions(
1303
+ loss=loss,
1304
+ logits=lm_logits,
1305
+ past_key_values=transformer_outputs.past_key_values,
1306
+ hidden_states=transformer_outputs.hidden_states,
1307
+ attentions=transformer_outputs.attentions,
1308
+ cross_attentions=transformer_outputs.cross_attentions,
1309
+ )
1310
+
1311
+
1312
+ @auto_docstring(
1313
+ custom_intro="""
1314
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1315
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1316
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1317
+ input sequence).
1318
+ """
1319
+ )
1320
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1321
+ _tied_weights_keys = ["lm_head.weight"]
1322
+
1323
+ def __init__(self, config):
1324
+ super().__init__(config)
1325
+ config.num_labels = 1
1326
+ self.transformer = GPT2Model(config)
1327
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1328
+ self.multiple_choice_head = GPT2SequenceSummary(config)
1329
+
1330
+ # Model parallel
1331
+ self.model_parallel = False
1332
+ self.device_map = None
1333
+
1334
+ # Initialize weights and apply final processing
1335
+ self.post_init()
1336
+
1337
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1338
+ def parallelize(self, device_map=None):
1339
+ warnings.warn(
1340
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1341
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1342
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1343
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1344
+ FutureWarning,
1345
+ )
1346
+ self.device_map = (
1347
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1348
+ if device_map is None
1349
+ else device_map
1350
+ )
1351
+ assert_device_map(self.device_map, len(self.transformer.h))
1352
+ self.transformer.parallelize(self.device_map)
1353
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1354
+ self.multiple_choice_head = self.multiple_choice_head.to(
1355
+ self.transformer.first_device
1356
+ )
1357
+ self.model_parallel = True
1358
+
1359
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1360
+ def deparallelize(self):
1361
+ warnings.warn(
1362
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1363
+ FutureWarning,
1364
+ )
1365
+ self.transformer.deparallelize()
1366
+ self.transformer = self.transformer.to("cpu")
1367
+ self.lm_head = self.lm_head.to("cpu")
1368
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1369
+ self.model_parallel = False
1370
+ torch.cuda.empty_cache()
1371
+
1372
+ def get_output_embeddings(self):
1373
+ return self.lm_head
1374
+
1375
+ def set_output_embeddings(self, new_embeddings):
1376
+ self.lm_head = new_embeddings
1377
+
1378
+ @auto_docstring
1379
+ def forward(
1380
+ self,
1381
+ input_ids: Optional[torch.LongTensor] = None,
1382
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1383
+ cache_position: Optional[torch.LongTensor] = None,
1384
+ attention_mask: Optional[torch.FloatTensor] = None,
1385
+ token_type_ids: Optional[torch.LongTensor] = None,
1386
+ position_ids: Optional[torch.LongTensor] = None,
1387
+ head_mask: Optional[torch.FloatTensor] = None,
1388
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1389
+ mc_token_ids: Optional[torch.LongTensor] = None,
1390
+ labels: Optional[torch.LongTensor] = None,
1391
+ mc_labels: Optional[torch.LongTensor] = None,
1392
+ use_cache: Optional[bool] = None,
1393
+ output_attentions: Optional[bool] = None,
1394
+ output_hidden_states: Optional[bool] = None,
1395
+ return_dict: Optional[bool] = None,
1396
+ **kwargs,
1397
+ ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1398
+ r"""
1399
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
1400
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
1401
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
1402
+ sequence tokens in the vocabulary.
1403
+
1404
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
1405
+ `input_ids`.
1406
+
1407
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1408
+ [`PreTrainedTokenizer.__call__`] for details.
1409
+
1410
+ [What are input IDs?](../glossary#input-ids)
1411
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1412
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1413
+ 1]`.
1414
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
1415
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1416
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1417
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1418
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1419
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1420
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1421
+
1422
+ Example:
1423
+
1424
+ ```python
1425
+ >>> import torch
1426
+ >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1427
+
1428
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1429
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1430
+
1431
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1432
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1433
+ >>> # Update the model embeddings with the new vocabulary size
1434
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1435
+
1436
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1437
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1438
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1439
+
1440
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1441
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1442
+
1443
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1444
+ >>> lm_logits = outputs.logits
1445
+ >>> mc_logits = outputs.mc_logits
1446
+ ```"""
1447
+ return_dict = (
1448
+ return_dict if return_dict is not None else self.config.use_return_dict
1449
+ )
1450
+
1451
+ transformer_outputs = self.transformer(
1452
+ input_ids,
1453
+ past_key_values=past_key_values,
1454
+ cache_position=cache_position,
1455
+ attention_mask=attention_mask,
1456
+ token_type_ids=token_type_ids,
1457
+ position_ids=position_ids,
1458
+ head_mask=head_mask,
1459
+ inputs_embeds=inputs_embeds,
1460
+ use_cache=use_cache,
1461
+ output_attentions=output_attentions,
1462
+ output_hidden_states=output_hidden_states,
1463
+ return_dict=return_dict,
1464
+ )
1465
+
1466
+ hidden_states = transformer_outputs[0]
1467
+
1468
+ # Set device for model parallelism
1469
+ if self.model_parallel:
1470
+ torch.cuda.set_device(self.transformer.first_device)
1471
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1472
+
1473
+ lm_logits = self.lm_head(hidden_states)
1474
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1475
+
1476
+ mc_loss = None
1477
+ if mc_labels is not None:
1478
+ loss_fct = CrossEntropyLoss()
1479
+ mc_loss = loss_fct(
1480
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1481
+ )
1482
+ lm_loss = None
1483
+ if labels is not None:
1484
+ labels = labels.to(lm_logits.device)
1485
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1486
+ shift_labels = labels[..., 1:].contiguous()
1487
+ loss_fct = CrossEntropyLoss()
1488
+ lm_loss = loss_fct(
1489
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1490
+ )
1491
+
1492
+ if not return_dict:
1493
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1494
+ if mc_loss is not None:
1495
+ output = (mc_loss,) + output
1496
+ return ((lm_loss,) + output) if lm_loss is not None else output
1497
+
1498
+ return GPT2DoubleHeadsModelOutput(
1499
+ loss=lm_loss,
1500
+ mc_loss=mc_loss,
1501
+ logits=lm_logits,
1502
+ mc_logits=mc_logits,
1503
+ past_key_values=transformer_outputs.past_key_values,
1504
+ hidden_states=transformer_outputs.hidden_states,
1505
+ attentions=transformer_outputs.attentions,
1506
+ )
1507
+
1508
+ @staticmethod
1509
+ def _reorder_cache(
1510
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1511
+ ) -> Tuple[Tuple[torch.Tensor]]:
1512
+ """
1513
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1514
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1515
+ beam_idx at every generation step.
1516
+ """
1517
+ return tuple(
1518
+ tuple(
1519
+ past_state.index_select(0, beam_idx.to(past_state.device))
1520
+ for past_state in layer_past
1521
+ )
1522
+ for layer_past in past_key_values
1523
+ )
1524
+
1525
+
1526
+ @auto_docstring(
1527
+ custom_intro="""
1528
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1529
+
1530
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1531
+ (e.g. GPT-1) do.
1532
+
1533
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1534
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1535
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1536
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1537
+ each row of the batch).
1538
+ """
1539
+ )
1540
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1541
+ def __init__(self, config):
1542
+ super().__init__(config)
1543
+ self.num_labels = config.num_labels
1544
+ self.transformer = GPT2Model(config)
1545
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1546
+
1547
+ # Model parallel
1548
+ self.model_parallel = False
1549
+ self.device_map = None
1550
+
1551
+ # Initialize weights and apply final processing
1552
+ self.post_init()
1553
+
1554
+ @auto_docstring
1555
+ def forward(
1556
+ self,
1557
+ input_ids: Optional[torch.LongTensor] = None,
1558
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1559
+ attention_mask: Optional[torch.FloatTensor] = None,
1560
+ token_type_ids: Optional[torch.LongTensor] = None,
1561
+ position_ids: Optional[torch.LongTensor] = None,
1562
+ head_mask: Optional[torch.FloatTensor] = None,
1563
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1564
+ labels: Optional[torch.LongTensor] = None,
1565
+ use_cache: Optional[bool] = None,
1566
+ output_attentions: Optional[bool] = None,
1567
+ output_hidden_states: Optional[bool] = None,
1568
+ return_dict: Optional[bool] = None,
1569
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1570
+ r"""
1571
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
1572
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
1573
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
1574
+ sequence tokens in the vocabulary.
1575
+
1576
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
1577
+ `input_ids`.
1578
+
1579
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1580
+ [`PreTrainedTokenizer.__call__`] for details.
1581
+
1582
+ [What are input IDs?](../glossary#input-ids)
1583
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1584
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1585
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1586
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1587
+ """
1588
+ return_dict = (
1589
+ return_dict if return_dict is not None else self.config.use_return_dict
1590
+ )
1591
+
1592
+ transformer_outputs = self.transformer(
1593
+ input_ids,
1594
+ past_key_values=past_key_values,
1595
+ attention_mask=attention_mask,
1596
+ token_type_ids=token_type_ids,
1597
+ position_ids=position_ids,
1598
+ head_mask=head_mask,
1599
+ inputs_embeds=inputs_embeds,
1600
+ use_cache=use_cache,
1601
+ output_attentions=output_attentions,
1602
+ output_hidden_states=output_hidden_states,
1603
+ return_dict=return_dict,
1604
+ )
1605
+ hidden_states = transformer_outputs[0]
1606
+ logits = self.score(hidden_states)
1607
+
1608
+ if input_ids is not None:
1609
+ batch_size, sequence_length = input_ids.shape[:2]
1610
+ else:
1611
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1612
+
1613
+ if self.config.pad_token_id is None and batch_size != 1:
1614
+ raise ValueError(
1615
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1616
+ )
1617
+ if self.config.pad_token_id is None:
1618
+ last_non_pad_token = -1
1619
+ elif input_ids is not None:
1620
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1621
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(
1622
+ logits.device, torch.int32
1623
+ )
1624
+ token_indices = torch.arange(
1625
+ input_ids.shape[-1], device=logits.device, dtype=torch.int32
1626
+ )
1627
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1628
+ else:
1629
+ last_non_pad_token = -1
1630
+ logger.warning_once(
1631
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1632
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1633
+ )
1634
+
1635
+ pooled_logits = logits[
1636
+ torch.arange(batch_size, device=logits.device), last_non_pad_token
1637
+ ]
1638
+
1639
+ loss = None
1640
+ if labels is not None:
1641
+ if self.config.problem_type is None:
1642
+ if self.num_labels == 1:
1643
+ self.config.problem_type = "regression"
1644
+ elif self.num_labels > 1 and (
1645
+ labels.dtype == torch.long or labels.dtype == torch.int
1646
+ ):
1647
+ self.config.problem_type = "single_label_classification"
1648
+ else:
1649
+ self.config.problem_type = "multi_label_classification"
1650
+
1651
+ if self.config.problem_type == "regression":
1652
+ loss_fct = MSELoss()
1653
+ if self.num_labels == 1:
1654
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1655
+ else:
1656
+ loss = loss_fct(pooled_logits, labels)
1657
+ elif self.config.problem_type == "single_label_classification":
1658
+ loss_fct = CrossEntropyLoss()
1659
+ loss = loss_fct(
1660
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1661
+ )
1662
+ elif self.config.problem_type == "multi_label_classification":
1663
+ loss_fct = BCEWithLogitsLoss()
1664
+ loss = loss_fct(pooled_logits, labels)
1665
+ if not return_dict:
1666
+ output = (pooled_logits,) + transformer_outputs[1:]
1667
+ return ((loss,) + output) if loss is not None else output
1668
+
1669
+ return SequenceClassifierOutputWithPast(
1670
+ loss=loss,
1671
+ logits=pooled_logits,
1672
+ past_key_values=transformer_outputs.past_key_values,
1673
+ hidden_states=transformer_outputs.hidden_states,
1674
+ attentions=transformer_outputs.attentions,
1675
+ )
1676
+
1677
+
1678
+ @auto_docstring
1679
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1680
+ def __init__(self, config):
1681
+ super().__init__(config)
1682
+ self.num_labels = config.num_labels
1683
+
1684
+ self.transformer = GPT2Model(config)
1685
+ if (
1686
+ hasattr(config, "classifier_dropout")
1687
+ and config.classifier_dropout is not None
1688
+ ):
1689
+ classifier_dropout = config.classifier_dropout
1690
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1691
+ classifier_dropout = config.hidden_dropout
1692
+ else:
1693
+ classifier_dropout = 0.1
1694
+ self.dropout = nn.Dropout(classifier_dropout)
1695
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1696
+
1697
+ # Model parallel
1698
+ self.model_parallel = False
1699
+ self.device_map = None
1700
+
1701
+ # Initialize weights and apply final processing
1702
+ self.post_init()
1703
+
1704
+ @auto_docstring
1705
+ def forward(
1706
+ self,
1707
+ input_ids: Optional[torch.LongTensor] = None,
1708
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1709
+ attention_mask: Optional[torch.FloatTensor] = None,
1710
+ token_type_ids: Optional[torch.LongTensor] = None,
1711
+ position_ids: Optional[torch.LongTensor] = None,
1712
+ head_mask: Optional[torch.FloatTensor] = None,
1713
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1714
+ labels: Optional[torch.LongTensor] = None,
1715
+ use_cache: Optional[bool] = None,
1716
+ output_attentions: Optional[bool] = None,
1717
+ output_hidden_states: Optional[bool] = None,
1718
+ return_dict: Optional[bool] = None,
1719
+ ) -> Union[Tuple, TokenClassifierOutput]:
1720
+ r"""
1721
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
1722
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
1723
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
1724
+ sequence tokens in the vocabulary.
1725
+
1726
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
1727
+ `input_ids`.
1728
+
1729
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1730
+ [`PreTrainedTokenizer.__call__`] for details.
1731
+
1732
+ [What are input IDs?](../glossary#input-ids)
1733
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1734
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1735
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1736
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1737
+ """
1738
+ return_dict = (
1739
+ return_dict if return_dict is not None else self.config.use_return_dict
1740
+ )
1741
+
1742
+ transformer_outputs = self.transformer(
1743
+ input_ids,
1744
+ past_key_values=past_key_values,
1745
+ attention_mask=attention_mask,
1746
+ token_type_ids=token_type_ids,
1747
+ position_ids=position_ids,
1748
+ head_mask=head_mask,
1749
+ inputs_embeds=inputs_embeds,
1750
+ use_cache=use_cache,
1751
+ output_attentions=output_attentions,
1752
+ output_hidden_states=output_hidden_states,
1753
+ return_dict=return_dict,
1754
+ )
1755
+
1756
+ hidden_states = transformer_outputs[0]
1757
+ hidden_states = self.dropout(hidden_states)
1758
+ logits = self.classifier(hidden_states)
1759
+
1760
+ loss = None
1761
+ if labels is not None:
1762
+ labels = labels.to(logits.device)
1763
+ loss_fct = CrossEntropyLoss()
1764
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1765
+
1766
+ if not return_dict:
1767
+ output = (logits,) + transformer_outputs[2:]
1768
+ return ((loss,) + output) if loss is not None else output
1769
+
1770
+ return TokenClassifierOutput(
1771
+ loss=loss,
1772
+ logits=logits,
1773
+ hidden_states=transformer_outputs.hidden_states,
1774
+ attentions=transformer_outputs.attentions,
1775
+ )
1776
+
1777
+
1778
+ @auto_docstring
1779
+ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1780
+ def __init__(self, config):
1781
+ super().__init__(config)
1782
+ self.num_labels = config.num_labels
1783
+ self.transformer = GPT2Model(config)
1784
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1785
+
1786
+ # Model parallel
1787
+ self.model_parallel = False
1788
+ self.device_map = None
1789
+
1790
+ # Initialize weights and apply final processing
1791
+ self.post_init()
1792
+
1793
+ @auto_docstring
1794
+ def forward(
1795
+ self,
1796
+ input_ids: Optional[torch.LongTensor] = None,
1797
+ attention_mask: Optional[torch.FloatTensor] = None,
1798
+ token_type_ids: Optional[torch.LongTensor] = None,
1799
+ position_ids: Optional[torch.LongTensor] = None,
1800
+ head_mask: Optional[torch.FloatTensor] = None,
1801
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1802
+ start_positions: Optional[torch.LongTensor] = None,
1803
+ end_positions: Optional[torch.LongTensor] = None,
1804
+ output_attentions: Optional[bool] = None,
1805
+ output_hidden_states: Optional[bool] = None,
1806
+ return_dict: Optional[bool] = None,
1807
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1808
+ r"""
1809
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
1810
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
1811
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
1812
+ sequence tokens in the vocabulary.
1813
+
1814
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
1815
+ `input_ids`.
1816
+
1817
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1818
+ [`PreTrainedTokenizer.__call__`] for details.
1819
+
1820
+ [What are input IDs?](../glossary#input-ids)
1821
+ """
1822
+ return_dict = (
1823
+ return_dict if return_dict is not None else self.config.use_return_dict
1824
+ )
1825
+
1826
+ outputs = self.transformer(
1827
+ input_ids,
1828
+ attention_mask=attention_mask,
1829
+ token_type_ids=token_type_ids,
1830
+ position_ids=position_ids,
1831
+ head_mask=head_mask,
1832
+ inputs_embeds=inputs_embeds,
1833
+ output_attentions=output_attentions,
1834
+ output_hidden_states=output_hidden_states,
1835
+ return_dict=return_dict,
1836
+ )
1837
+
1838
+ sequence_output = outputs[0]
1839
+
1840
+ logits = self.qa_outputs(sequence_output)
1841
+ start_logits, end_logits = logits.split(1, dim=-1)
1842
+ start_logits = start_logits.squeeze(-1).contiguous()
1843
+ end_logits = end_logits.squeeze(-1).contiguous()
1844
+
1845
+ total_loss = None
1846
+ if start_positions is not None and end_positions is not None:
1847
+ # If we are on multi-GPU, split add a dimension
1848
+ if len(start_positions.size()) > 1:
1849
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1850
+ if len(end_positions.size()) > 1:
1851
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1852
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1853
+ ignored_index = start_logits.size(1)
1854
+ start_positions = start_positions.clamp(0, ignored_index)
1855
+ end_positions = end_positions.clamp(0, ignored_index)
1856
+
1857
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1858
+ start_loss = loss_fct(start_logits, start_positions)
1859
+ end_loss = loss_fct(end_logits, end_positions)
1860
+ total_loss = (start_loss + end_loss) / 2
1861
+
1862
+ if not return_dict:
1863
+ output = (start_logits, end_logits) + outputs[2:]
1864
+ return ((total_loss,) + output) if total_loss is not None else output
1865
+
1866
+ return QuestionAnsweringModelOutput(
1867
+ loss=total_loss,
1868
+ start_logits=start_logits,
1869
+ end_logits=end_logits,
1870
+ hidden_states=outputs.hidden_states,
1871
+ attentions=outputs.attentions,
1872
+ )
1873
+
1874
+
1875
  __all__ = [
1876
  "GPT2DoubleHeadsModel",
1877
  "GPT2ForQuestionAnswering",