zifei9 commited on
Commit
8cc99f4
·
verified ·
1 Parent(s): 89a060e

Update modeling_gpt2.py

Browse files

Updating based on transformers==4.52.4

Files changed (1) hide show
  1. modeling_gpt2.py +81 -1892
modeling_gpt2.py CHANGED
@@ -15,119 +15,23 @@
15
  # limitations under the License.
16
  """PyTorch OpenAI GPT-2 model."""
17
 
18
- import math
19
- import os
20
- import warnings
21
  from dataclasses import dataclass
22
- from typing import Optional, Tuple, Union
23
 
24
  import torch
25
- import torch.utils.checkpoint
26
- from packaging import version
27
  from torch import nn
28
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
- from transformers.activations import ACT2FN
31
- from transformers.generation import GenerationMixin
32
- from transformers.modeling_attn_mask_utils import (
33
- _prepare_4d_attention_mask_for_sdpa,
34
- _prepare_4d_causal_attention_mask_for_sdpa,
35
- )
36
- from transformers.modeling_outputs import (
37
- BaseModelOutputWithPastAndCrossAttentions,
38
- CausalLMOutputWithCrossAttentions,
39
- QuestionAnsweringModelOutput,
40
- SequenceClassifierOutputWithPast,
41
- TokenClassifierOutput,
42
- )
43
- from transformers.modeling_utils import PreTrainedModel, SequenceSummary
44
- from transformers.pytorch_utils import (
45
- Conv1D,
46
- find_pruneable_heads_and_indices,
47
- prune_conv1d_layer,
48
- )
49
  from transformers.utils import (
50
- ModelOutput,
51
- add_code_sample_docstrings,
52
- add_start_docstrings,
53
- add_start_docstrings_to_model_forward,
54
- get_torch_version,
55
- is_flash_attn_2_available,
56
- is_flash_attn_greater_or_equal_2_10,
57
  logging,
58
- replace_return_docstrings,
59
  )
60
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
61
- from .configuration_gpt2 import GPT2Config
62
-
63
-
64
- if is_flash_attn_2_available():
65
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
66
-
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
- _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
71
- _CONFIG_FOR_DOC = "GPT2Config"
72
-
73
-
74
- def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
75
- """Load tf checkpoints in a pytorch model"""
76
- try:
77
- import re
78
-
79
- import tensorflow as tf
80
- except ImportError:
81
- logger.error(
82
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
83
- "https://www.tensorflow.org/install/ for installation instructions."
84
- )
85
- raise
86
- tf_path = os.path.abspath(gpt2_checkpoint_path)
87
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
88
- # Load weights from TF model
89
- init_vars = tf.train.list_variables(tf_path)
90
- names = []
91
- arrays = []
92
- for name, shape in init_vars:
93
- logger.info(f"Loading TF weight {name} with shape {shape}")
94
- array = tf.train.load_variable(tf_path, name)
95
- names.append(name)
96
- arrays.append(array.squeeze())
97
-
98
- for name, array in zip(names, arrays):
99
- name = name[6:] # skip "model/"
100
- name = name.split("/")
101
- pointer = model
102
- for m_name in name:
103
- if re.fullmatch(r"[A-Za-z]+\d+", m_name):
104
- scope_names = re.split(r"(\d+)", m_name)
105
- else:
106
- scope_names = [m_name]
107
- if scope_names[0] == "w" or scope_names[0] == "g":
108
- pointer = getattr(pointer, "weight")
109
- elif scope_names[0] == "b":
110
- pointer = getattr(pointer, "bias")
111
- elif scope_names[0] == "wpe" or scope_names[0] == "wte":
112
- pointer = getattr(pointer, scope_names[0])
113
- pointer = getattr(pointer, "weight")
114
- else:
115
- pointer = getattr(pointer, scope_names[0])
116
- if len(scope_names) >= 2:
117
- num = int(scope_names[1])
118
- pointer = pointer[num]
119
- try:
120
- if pointer.shape != array.shape:
121
- raise ValueError(
122
- f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
123
- )
124
- except ValueError as e:
125
- e.args += (pointer.shape, array.shape)
126
- raise
127
- logger.info(f"Initialize PyTorch weight {name}")
128
- pointer.data = torch.from_numpy(array)
129
- return model
130
-
131
 
132
  class GPT2Attention(nn.Module):
133
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
@@ -136,9 +40,9 @@ class GPT2Attention(nn.Module):
136
  max_positions = config.max_position_embeddings
137
  self.register_buffer(
138
  "bias",
139
- torch.tril(
140
- torch.ones((max_positions, max_positions), dtype=torch.bool)
141
- ).view(1, 1, max_positions, max_positions),
142
  persistent=False,
143
  )
144
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
@@ -177,88 +81,25 @@ class GPT2Attention(nn.Module):
177
  def prune_heads(self, heads):
178
  if len(heads) == 0:
179
  return
180
- heads, index = find_pruneable_heads_and_indices(
181
- heads, self.num_heads, self.head_dim, self.pruned_heads
182
- )
183
- index_attn = torch.cat(
184
- [index, index + self.split_size, index + (2 * self.split_size)]
185
- )
186
 
187
  # Prune conv1d layers
188
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
189
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
190
 
191
  # Update hyper params
192
- self.split_size = (self.split_size // self.num_heads) * (
193
- self.num_heads - len(heads)
194
- )
195
  self.num_heads = self.num_heads - len(heads)
196
  self.pruned_heads = self.pruned_heads.union(heads)
197
 
198
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
199
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
200
-
201
- if self.scale_attn_weights:
202
- attn_weights = attn_weights / torch.full(
203
- [],
204
- value.size(-1) ** 0.5,
205
- dtype=attn_weights.dtype,
206
- device=attn_weights.device,
207
- )
208
-
209
- # Layer-wise attention scaling
210
- if self.scale_attn_by_inverse_layer_idx:
211
- attn_weights = attn_weights / float(self.layer_idx + 1)
212
-
213
- if not self.is_cross_attention:
214
- # if only "normal" attention layer implements causal mask
215
- query_length, key_length = query.size(-2), key.size(-2)
216
- causal_mask = self.bias[
217
- :, :, key_length - query_length : key_length, :key_length
218
- ]
219
- mask_value = torch.finfo(attn_weights.dtype).min
220
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
221
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
222
- mask_value = torch.full(
223
- [], mask_value, dtype=attn_weights.dtype, device=attn_weights.device
224
- )
225
- attn_weights = torch.where(
226
- causal_mask, attn_weights.to(attn_weights.dtype), mask_value
227
- )
228
-
229
- if attention_mask is not None:
230
- # Apply the attention mask
231
- attn_weights = attn_weights + attention_mask
232
-
233
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
234
-
235
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
236
- attn_weights = attn_weights.type(value.dtype)
237
- attn_weights = self.attn_dropout(attn_weights)
238
-
239
- # Mask heads if we want to
240
- if head_mask is not None:
241
- attn_weights = attn_weights * head_mask
242
-
243
- attn_output = torch.matmul(attn_weights, value)
244
-
245
- return attn_output, attn_weights
246
-
247
- def _upcast_and_reordered_attn(
248
- self, query, key, value, attention_mask=None, head_mask=None
249
- ):
250
  # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
251
  bsz, num_heads, q_seq_len, dk = query.size()
252
  _, _, k_seq_len, _ = key.size()
253
 
254
  # Preallocate attn_weights for `baddbmm`
255
- attn_weights = torch.empty(
256
- bsz * num_heads,
257
- q_seq_len,
258
- k_seq_len,
259
- dtype=torch.float32,
260
- device=query.device,
261
- )
262
 
263
  # Compute Scale Factor
264
  scale_factor = 1.0
@@ -270,26 +111,18 @@ class GPT2Attention(nn.Module):
270
 
271
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
272
  with torch.amp.autocast(query.device.type, enabled=False):
273
- q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
274
- -1, dk, k_seq_len
275
- )
276
- attn_weights = torch.baddbmm(
277
- attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
278
- )
279
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
280
 
281
  if not self.is_cross_attention:
282
  # if only "normal" attention layer implements causal mask
283
  query_length, key_length = query.size(-2), key.size(-2)
284
- causal_mask = self.bias[
285
- :, :, key_length - query_length : key_length, :key_length
286
- ]
287
  mask_value = torch.finfo(attn_weights.dtype).min
288
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
289
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
290
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
291
- attn_weights.device
292
- )
293
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
294
 
295
  if attention_mask is not None:
@@ -300,9 +133,7 @@ class GPT2Attention(nn.Module):
300
 
301
  # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
302
  if attn_weights.dtype != torch.float32:
303
- raise RuntimeError(
304
- "Error with upcasting, attn_weights does not have dtype torch.float32"
305
- )
306
  attn_weights = attn_weights.type(value.dtype)
307
  attn_weights = self.attn_dropout(attn_weights)
308
 
@@ -311,1744 +142,102 @@ class GPT2Attention(nn.Module):
311
  attn_weights = attn_weights * head_mask
312
 
313
  attn_output = torch.matmul(attn_weights, value)
 
314
 
315
  return attn_output, attn_weights
316
 
317
- def _split_heads(self, tensor, num_heads, attn_head_size):
318
- """
319
- Splits hidden_size dim into attn_head_size and num_heads
320
- """
321
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
322
- tensor = tensor.view(new_shape)
323
- return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
324
-
325
- def _merge_heads(self, tensor, num_heads, attn_head_size):
326
- """
327
- Merges attn_head_size dim and num_attn_heads dim into hidden_size
328
- """
329
- tensor = tensor.permute(0, 2, 1, 3).contiguous()
330
- new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
331
- return tensor.view(new_shape)
332
-
333
  def forward(
334
  self,
335
  hidden_states: Optional[Tuple[torch.FloatTensor]],
336
- layer_past: Optional[Tuple[torch.Tensor]] = None,
 
337
  attention_mask: Optional[torch.FloatTensor] = None,
338
  head_mask: Optional[torch.FloatTensor] = None,
339
  encoder_hidden_states: Optional[torch.Tensor] = None,
340
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
341
- use_cache: Optional[bool] = False,
342
  output_attentions: Optional[bool] = False,
 
343
  ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
344
- if encoder_hidden_states is not None:
 
345
  if not hasattr(self, "q_attn"):
346
  raise ValueError(
347
  "If class is used as cross attention, the weights `q_attn` have to be defined. "
348
  "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
349
  )
350
 
351
- query = self.q_attn(hidden_states)
352
- key, value = self.c_attn(encoder_hidden_states).split(
353
- self.split_size, dim=2
354
- )
355
  attention_mask = encoder_attention_mask
356
  else:
357
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
358
-
359
- query = self._split_heads(query, self.num_heads, self.head_dim)
360
- key = self._split_heads(key, self.num_heads, self.head_dim)
361
- value = self._split_heads(value, self.num_heads, self.head_dim)
362
 
363
- if layer_past is not None:
364
- past_key, past_value = layer_past
365
- key = torch.cat((past_key, key), dim=-2)
366
- value = torch.cat((past_value, value), dim=-2)
367
 
368
- if use_cache is True:
369
- present = (key, value)
370
- else:
371
- present = None
372
 
373
- if self.reorder_and_upcast_attn:
374
- attn_output, attn_weights = self._upcast_and_reordered_attn(
375
- query, key, value, attention_mask, head_mask
376
- )
377
- else:
378
- attn_output, attn_weights = self._attn(
379
- query, key, value, attention_mask, head_mask
 
 
380
  )
381
 
382
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
383
- attn_output = self.c_proj(attn_output)
384
- attn_output = self.resid_dropout(attn_output)
385
-
386
- outputs = (attn_output, present)
387
- if output_attentions:
388
- outputs += (attn_weights,)
389
-
390
- return outputs # a, present, (attentions)
391
-
392
-
393
- class GPT2FlashAttention2(GPT2Attention):
394
- """
395
- GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
396
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
397
- flash attention and deal with padding tokens in case the input contains any of them.
398
- """
399
 
400
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
401
- def __init__(self, *args, **kwargs):
402
- super().__init__(*args, **kwargs)
403
-
404
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
405
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
406
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
407
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
408
-
409
- def forward(
410
- self,
411
- hidden_states: Optional[Tuple[torch.FloatTensor]],
412
- layer_past: Optional[Tuple[torch.Tensor]] = None,
413
- attention_mask: Optional[torch.FloatTensor] = None,
414
- head_mask: Optional[torch.FloatTensor] = None,
415
- encoder_hidden_states: Optional[torch.Tensor] = None,
416
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
417
- use_cache: Optional[bool] = False,
418
- output_attentions: Optional[bool] = False,
419
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
420
- bsz, _, _ = hidden_states.size()
421
- if encoder_hidden_states is not None:
422
- if not hasattr(self, "q_attn"):
423
- raise ValueError(
424
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
425
- "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
426
  )
427
-
428
- query = self.q_attn(hidden_states)
429
- key, value = self.c_attn(encoder_hidden_states).split(
430
- self.split_size, dim=2
431
- )
432
- attention_mask = encoder_attention_mask
433
- else:
434
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
435
-
436
- query = self._split_heads(query, self.num_heads, self.head_dim)
437
- key = self._split_heads(key, self.num_heads, self.head_dim)
438
- value = self._split_heads(value, self.num_heads, self.head_dim)
439
-
440
- if layer_past is not None:
441
- past_key = layer_past[0]
442
- past_value = layer_past[1]
443
- key = torch.cat((past_key, key), dim=-2)
444
- value = torch.cat((past_value, value), dim=-2)
445
-
446
- present = None
447
- if use_cache is True:
448
- present = (key, value)
449
-
450
- query_length = query.shape[2]
451
- tgt_len = key.shape[2]
452
-
453
- # Flash attention requires the input to have the shape
454
- # batch_size x seq_length x head_dim x hidden_dim
455
- query = query.transpose(1, 2).view(
456
- bsz, query_length, self.num_heads, self.head_dim
457
- )
458
- key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
459
- value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
460
-
461
- attn_dropout = self.attn_dropout.p if self.training else 0.0
462
-
463
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
464
- # therefore the input hidden states gets silently casted in float32. Hence, we need
465
- # cast them back in the correct dtype just to be sure everything works as expected.
466
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
467
- # in fp32. (LlamaRMSNorm handles it correctly)
468
-
469
- if query.dtype == torch.float32:
470
- if torch.is_autocast_enabled():
471
- target_dtype = torch.get_autocast_gpu_dtype()
472
- # Handle the case where the model is quantized
473
- elif hasattr(self.config, "_pre_quantization_dtype"):
474
- target_dtype = self.config._pre_quantization_dtype
475
  else:
476
- target_dtype = self.c_proj.weight.dtype
477
-
478
- logger.warning_once(
479
- f"The input hidden states seems to be silently casted in float32, this might be related to"
480
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
481
- f" {target_dtype}."
482
- )
483
-
484
- query = query.to(target_dtype)
485
- key = key.to(target_dtype)
486
- value = value.to(target_dtype)
487
-
488
- attn_output = _flash_attention_forward(
489
- query,
490
- key,
491
- value,
492
- attention_mask,
493
- query_length,
494
- dropout=attn_dropout,
495
- is_causal=self.is_causal,
496
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
497
- )
498
-
499
- attn_weights_reshaped = attn_output.reshape(
500
- bsz, query_length, self.num_heads * self.head_dim
501
- )
502
- attn_output = self.c_proj(attn_weights_reshaped)
503
- attn_output = self.resid_dropout(attn_output)
504
 
505
- outputs = (attn_output, present)
506
- if output_attentions:
507
- outputs += (attn_weights_reshaped,)
508
-
509
- return outputs
510
-
511
-
512
- class GPT2SdpaAttention(GPT2Attention):
513
- """
514
- GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
515
- `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
516
- to adapt to the SDPA API.
517
- """
518
-
519
- def __init__(self, *args, **kwargs):
520
- super().__init__(*args, **kwargs)
521
-
522
- # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
523
- # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
524
- # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
525
- # Reference: https://github.com/pytorch/pytorch/issues/112577
526
- self.require_contiguous_qkv = version.parse(
527
- get_torch_version()
528
- ) < version.parse("2.2.0")
529
-
530
- def forward(
531
- self,
532
- hidden_states: Optional[Tuple[torch.FloatTensor]],
533
- layer_past: Optional[Tuple[torch.Tensor]] = None,
534
- attention_mask: Optional[torch.FloatTensor] = None,
535
- head_mask: Optional[torch.FloatTensor] = None,
536
- encoder_hidden_states: Optional[torch.Tensor] = None,
537
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
538
- use_cache: Optional[bool] = False,
539
- output_attentions: Optional[bool] = False,
540
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
541
- if output_attentions or head_mask is not None:
542
- logger.warning_once(
543
- "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
544
- "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
545
- "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
546
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
547
  )
548
- return super().forward(
549
- hidden_states=hidden_states,
550
- layer_past=layer_past,
551
- attention_mask=attention_mask,
 
 
 
552
  head_mask=head_mask,
553
- encoder_hidden_states=encoder_hidden_states,
554
- encoder_attention_mask=encoder_attention_mask,
555
- use_cache=use_cache,
556
- output_attentions=output_attentions,
557
- )
558
-
559
- bsz, q_len, _ = hidden_states.size()
560
-
561
- # Initial attention projections
562
- is_cross_attention = encoder_hidden_states is not None
563
- if is_cross_attention:
564
- if not hasattr(self, "q_attn"):
565
- raise ValueError(
566
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
567
- "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
568
- )
569
-
570
- query = self.q_attn(hidden_states)
571
- key, value = self.c_attn(encoder_hidden_states).split(
572
- self.split_size, dim=2
573
  )
574
- attention_mask = encoder_attention_mask
575
- else:
576
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
577
-
578
- query = self._split_heads(query, self.num_heads, self.head_dim)
579
- key = self._split_heads(key, self.num_heads, self.head_dim)
580
- value = self._split_heads(value, self.num_heads, self.head_dim)
581
-
582
- # Optional kv caching
583
- if layer_past is not None:
584
- past_key = layer_past[0]
585
- past_value = layer_past[1]
586
- key = torch.cat((past_key, key), dim=-2)
587
- value = torch.cat((past_value, value), dim=-2)
588
-
589
- present = None
590
- if use_cache is True:
591
- present = (key, value)
592
-
593
- # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
594
- if (
595
- self.require_contiguous_qkv
596
- and query.device.type == "cuda"
597
- and attention_mask is not None
598
- ):
599
- query = query.contiguous()
600
- key = key.contiguous()
601
- value = value.contiguous()
602
-
603
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
604
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
605
- is_causal = (
606
- True
607
- if attention_mask is None and q_len > 1 and not is_cross_attention
608
- else False
609
- )
610
 
611
- attn_output = torch.nn.functional.scaled_dot_product_attention(
612
- query,
613
- key,
614
- value,
615
- attn_mask=attention_mask,
616
- dropout_p=self.attn_dropout.p if self.training else 0.0,
617
- is_causal=is_causal,
618
- )
619
-
620
- # Reshape outputs
621
- attn_output = attn_output.transpose(1, 2).contiguous()
622
- attn_output = attn_output.view(bsz, q_len, self.embed_dim)
623
-
624
- # Final projection
625
  attn_output = self.c_proj(attn_output)
626
  attn_output = self.resid_dropout(attn_output)
627
 
628
- return attn_output, present, None
629
-
630
-
631
- class GPT2MLP(nn.Module):
632
- def __init__(self, intermediate_size, config):
633
- super().__init__()
634
- embed_dim = config.hidden_size
635
- self.c_fc = Conv1D(intermediate_size, embed_dim)
636
- self.c_proj = Conv1D(embed_dim, intermediate_size)
637
- self.act = ACT2FN[config.activation_function]
638
- self.dropout = nn.Dropout(config.resid_pdrop)
639
-
640
- def forward(
641
- self, hidden_states: Optional[Tuple[torch.FloatTensor]]
642
- ) -> torch.FloatTensor:
643
- hidden_states = self.c_fc(hidden_states)
644
- hidden_states = self.act(hidden_states)
645
- hidden_states = self.c_proj(hidden_states)
646
- hidden_states = self.dropout(hidden_states)
647
- return hidden_states
648
-
649
-
650
- GPT2_ATTENTION_CLASSES = {
651
- "eager": GPT2Attention,
652
- "flash_attention_2": GPT2FlashAttention2,
653
- "sdpa": GPT2SdpaAttention,
654
- }
655
-
656
-
657
- class GPT2Block(nn.Module):
658
- def __init__(self, config, layer_idx=None):
659
- super().__init__()
660
- hidden_size = config.hidden_size
661
- inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
662
- attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
663
-
664
- self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
665
- self.attn = attention_class(config=config, layer_idx=layer_idx)
666
- self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
667
-
668
- if config.add_cross_attention:
669
- self.crossattention = attention_class(
670
- config=config, is_cross_attention=True, layer_idx=layer_idx
671
- )
672
- self.ln_cross_attn = nn.LayerNorm(
673
- hidden_size, eps=config.layer_norm_epsilon
674
- )
675
-
676
- self.mlp = GPT2MLP(inner_dim, config)
677
-
678
- def forward(
679
- self,
680
- hidden_states: Optional[Tuple[torch.FloatTensor]],
681
- layer_past: Optional[Tuple[torch.Tensor]] = None,
682
- attention_mask: Optional[torch.FloatTensor] = None,
683
- head_mask: Optional[torch.FloatTensor] = None,
684
- encoder_hidden_states: Optional[torch.Tensor] = None,
685
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
686
- use_cache: Optional[bool] = False,
687
- output_attentions: Optional[bool] = False,
688
- ) -> Union[
689
- Tuple[torch.Tensor],
690
- Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
691
- ]:
692
- residual = hidden_states
693
- hidden_states = self.ln_1(hidden_states)
694
- attn_outputs = self.attn(
695
- hidden_states,
696
- layer_past=layer_past,
697
- attention_mask=attention_mask,
698
- head_mask=head_mask,
699
- use_cache=use_cache,
700
- output_attentions=output_attentions,
701
- )
702
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
703
- outputs = attn_outputs[1:]
704
- # residual connection
705
- hidden_states = attn_output + residual
706
-
707
- if encoder_hidden_states is not None:
708
- # add one self-attention block for cross-attention
709
- if not hasattr(self, "crossattention"):
710
- raise ValueError(
711
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
712
- "cross-attention layers by setting `config.add_cross_attention=True`"
713
- )
714
- residual = hidden_states
715
- hidden_states = self.ln_cross_attn(hidden_states)
716
- cross_attn_outputs = self.crossattention(
717
- hidden_states,
718
- attention_mask=attention_mask,
719
- head_mask=head_mask,
720
- encoder_hidden_states=encoder_hidden_states,
721
- encoder_attention_mask=encoder_attention_mask,
722
- output_attentions=output_attentions,
723
- )
724
- attn_output = cross_attn_outputs[0]
725
- # residual connection
726
- hidden_states = residual + attn_output
727
- outputs = (
728
- outputs + cross_attn_outputs[2:]
729
- ) # add cross attentions if we output attention weights
730
-
731
- residual = hidden_states
732
- hidden_states = self.ln_2(hidden_states)
733
- feed_forward_hidden_states = self.mlp(hidden_states)
734
- # residual connection
735
- hidden_states = residual + feed_forward_hidden_states
736
-
737
- if use_cache:
738
- outputs = (hidden_states,) + outputs
739
- else:
740
- outputs = (hidden_states,) + outputs[1:]
741
-
742
- return outputs # hidden_states, present, (attentions, cross_attentions)
743
-
744
-
745
- class GPT2PreTrainedModel(PreTrainedModel):
746
- """
747
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
748
- models.
749
- """
750
-
751
- config_class = GPT2Config
752
- load_tf_weights = load_tf_weights_in_gpt2
753
- base_model_prefix = "transformer"
754
- is_parallelizable = True
755
- supports_gradient_checkpointing = True
756
- _no_split_modules = ["GPT2Block"]
757
- _skip_keys_device_placement = "past_key_values"
758
- _supports_flash_attn_2 = True
759
- _supports_sdpa = True
760
-
761
- def __init__(self, *inputs, **kwargs):
762
- super().__init__(*inputs, **kwargs)
763
-
764
- def _init_weights(self, module):
765
- """Initialize the weights."""
766
- if isinstance(module, (nn.Linear, Conv1D)):
767
- # Slightly different from the TF version which uses truncated_normal for initialization
768
- # cf https://github.com/pytorch/pytorch/pull/5617
769
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
770
- if module.bias is not None:
771
- module.bias.data.zero_()
772
- elif isinstance(module, nn.Embedding):
773
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
774
- if module.padding_idx is not None:
775
- module.weight.data[module.padding_idx].zero_()
776
- elif isinstance(module, nn.LayerNorm):
777
- module.bias.data.zero_()
778
- module.weight.data.fill_(1.0)
779
-
780
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
781
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
782
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
783
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
784
- #
785
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
786
- for name, p in module.named_parameters():
787
- if name == "c_proj.weight":
788
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
789
- p.data.normal_(
790
- mean=0.0,
791
- std=(
792
- self.config.initializer_range
793
- / math.sqrt(2 * self.config.n_layer)
794
- ),
795
- )
796
-
797
-
798
- @dataclass
799
- class GPT2DoubleHeadsModelOutput(ModelOutput):
800
- """
801
- Base class for outputs of models predicting if two sentences are consecutive or not.
802
-
803
- Args:
804
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
805
- Language modeling loss.
806
- mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
807
- Multiple choice classification loss.
808
- logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
809
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
810
- mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
811
- Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
812
- past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
813
- Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
814
- sequence_length, embed_size_per_head)`).
815
-
816
- Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
817
- `past_key_values` input) to speed up sequential decoding.
818
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
819
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
820
- shape `(batch_size, sequence_length, hidden_size)`.
821
-
822
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
823
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
824
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
825
- sequence_length)`.
826
-
827
- GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
828
- self-attention heads.
829
- """
830
-
831
- loss: Optional[torch.FloatTensor] = None
832
- mc_loss: Optional[torch.FloatTensor] = None
833
- logits: torch.FloatTensor = None
834
- mc_logits: torch.FloatTensor = None
835
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
836
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
837
- attentions: Optional[Tuple[torch.FloatTensor]] = None
838
-
839
-
840
- GPT2_START_DOCSTRING = r"""
841
-
842
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
843
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
844
- etc.)
845
-
846
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
847
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
848
- and behavior.
849
-
850
- Parameters:
851
- config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
852
- Initializing with a config file does not load the weights associated with the model, only the
853
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
854
- """
855
-
856
- GPT2_INPUTS_DOCSTRING = r"""
857
- Args:
858
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
859
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
860
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
861
- sequence tokens in the vocabulary.
862
-
863
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
864
- `input_ids`.
865
-
866
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
867
- [`PreTrainedTokenizer.__call__`] for details.
868
-
869
- [What are input IDs?](../glossary#input-ids)
870
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
871
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
872
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
873
- their past given to this model should not be passed as `input_ids` as they have already been computed.
874
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
875
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
876
-
877
- - 1 for tokens that are **not masked**,
878
- - 0 for tokens that are **masked**.
879
-
880
- If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
881
- `past_key_values`. In other words, the `attention_mask` always has to have the length:
882
- `len(past_key_values) + len(input_ids)`
883
-
884
- [What are attention masks?](../glossary#attention-mask)
885
- token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
886
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
887
- 1]`:
888
-
889
- - 0 corresponds to a *sentence A* token,
890
- - 1 corresponds to a *sentence B* token.
891
-
892
- [What are token type IDs?](../glossary#token-type-ids)
893
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
894
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
895
- config.max_position_embeddings - 1]`.
896
-
897
- [What are position IDs?](../glossary#position-ids)
898
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
899
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
900
-
901
- - 1 indicates the head is **not masked**,
902
- - 0 indicates the head is **masked**.
903
-
904
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
905
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
906
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
907
- model's internal embedding lookup matrix.
908
-
909
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
910
- `past_key_values`).
911
- use_cache (`bool`, *optional*):
912
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
913
- `past_key_values`).
914
- output_attentions (`bool`, *optional*):
915
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
916
- tensors for more detail.
917
- output_hidden_states (`bool`, *optional*):
918
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
919
- more detail.
920
- return_dict (`bool`, *optional*):
921
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
922
- """
923
- PARALLELIZE_DOCSTRING = r"""
924
- This is an experimental feature and is a subject to change at a moment's notice.
925
-
926
- Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
927
- it will evenly distribute blocks across all devices.
928
-
929
- Args:
930
- device_map (`Dict[int, list]`, *optional*):
931
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
932
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
933
- have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
934
- following number of attention modules:
935
-
936
- - openai-community/gpt2: 12
937
- - openai-community/gpt2-medium: 24
938
- - openai-community/gpt2-large: 36
939
- - openai-community/gpt2-xl: 48
940
-
941
- Example:
942
-
943
- ```python
944
- # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
945
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
946
- device_map = {
947
- 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
948
- 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
949
- 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
950
- 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
951
- }
952
- model.parallelize(device_map)
953
- ```
954
- """
955
- DEPARALLELIZE_DOCSTRING = r"""
956
- Moves the model to cpu from a model parallel state.
957
-
958
- Example:
959
-
960
- ```python
961
- # On a 4 GPU machine with openai-community/gpt2-large:
962
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
963
- device_map = {
964
- 0: [0, 1, 2, 3, 4, 5, 6, 7],
965
- 1: [8, 9, 10, 11, 12, 13, 14, 15],
966
- 2: [16, 17, 18, 19, 20, 21, 22, 23],
967
- 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
968
- }
969
- model.parallelize(device_map) # Splits the model across several devices
970
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
971
- ```
972
- """
973
-
974
-
975
- @add_start_docstrings(
976
- "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
977
- GPT2_START_DOCSTRING,
978
- )
979
- class GPT2Model(GPT2PreTrainedModel):
980
- _supports_param_buffer_assignment = False
981
-
982
- def __init__(self, config):
983
- super().__init__(config)
984
-
985
- self.embed_dim = config.hidden_size
986
-
987
- self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
988
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
989
-
990
- self.drop = nn.Dropout(config.embd_pdrop)
991
- self.h = nn.ModuleList(
992
- [GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
993
- )
994
- self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
995
-
996
- # Model parallel
997
- self.model_parallel = False
998
- self.device_map = None
999
- self.gradient_checkpointing = False
1000
- self._attn_implementation = config._attn_implementation
1001
-
1002
- # Initialize weights and apply final processing
1003
- self.post_init()
1004
-
1005
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1006
- def parallelize(self, device_map=None):
1007
- # Check validity of device_map
1008
- warnings.warn(
1009
- "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
1010
- " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1011
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
1012
- " ...}",
1013
- FutureWarning,
1014
- )
1015
- self.device_map = (
1016
- get_device_map(len(self.h), range(torch.cuda.device_count()))
1017
- if device_map is None
1018
- else device_map
1019
- )
1020
- assert_device_map(self.device_map, len(self.h))
1021
- self.model_parallel = True
1022
- self.first_device = (
1023
- "cpu"
1024
- if "cpu" in self.device_map.keys()
1025
- else "cuda:" + str(min(self.device_map.keys()))
1026
- )
1027
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
1028
- self.wte = self.wte.to(self.first_device)
1029
- self.wpe = self.wpe.to(self.first_device)
1030
- # Load onto devices
1031
- for k, v in self.device_map.items():
1032
- for block in v:
1033
- cuda_device = "cuda:" + str(k)
1034
- self.h[block] = self.h[block].to(cuda_device)
1035
- # ln_f to last
1036
- self.ln_f = self.ln_f.to(self.last_device)
1037
-
1038
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1039
- def deparallelize(self):
1040
- warnings.warn(
1041
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1042
- FutureWarning,
1043
- )
1044
- self.model_parallel = False
1045
- self.device_map = None
1046
- self.first_device = "cpu"
1047
- self.last_device = "cpu"
1048
- self.wte = self.wte.to("cpu")
1049
- self.wpe = self.wpe.to("cpu")
1050
- for index in range(len(self.h)):
1051
- self.h[index] = self.h[index].to("cpu")
1052
- self.ln_f = self.ln_f.to("cpu")
1053
- torch.cuda.empty_cache()
1054
-
1055
- def get_input_embeddings(self):
1056
- return self.wte
1057
-
1058
- def set_input_embeddings(self, new_embeddings):
1059
- self.wte = new_embeddings
1060
-
1061
- def _prune_heads(self, heads_to_prune):
1062
- """
1063
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
1064
- """
1065
- for layer, heads in heads_to_prune.items():
1066
- self.h[layer].attn.prune_heads(heads)
1067
-
1068
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1069
- @add_code_sample_docstrings(
1070
- checkpoint=_CHECKPOINT_FOR_DOC,
1071
- output_type=BaseModelOutputWithPastAndCrossAttentions,
1072
- config_class=_CONFIG_FOR_DOC,
1073
- )
1074
- def forward(
1075
- self,
1076
- input_ids: Optional[torch.LongTensor] = None,
1077
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1078
- attention_mask: Optional[torch.FloatTensor] = None,
1079
- token_type_ids: Optional[torch.LongTensor] = None,
1080
- position_ids: Optional[torch.LongTensor] = None,
1081
- head_mask: Optional[torch.FloatTensor] = None,
1082
- inputs_embeds: Optional[torch.FloatTensor] = None,
1083
- encoder_hidden_states: Optional[torch.Tensor] = None,
1084
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1085
- use_cache: Optional[bool] = None,
1086
- output_attentions: Optional[bool] = None,
1087
- output_hidden_states: Optional[bool] = None,
1088
- return_dict: Optional[bool] = None,
1089
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
1090
- output_attentions = (
1091
- output_attentions
1092
- if output_attentions is not None
1093
- else self.config.output_attentions
1094
- )
1095
- output_hidden_states = (
1096
- output_hidden_states
1097
- if output_hidden_states is not None
1098
- else self.config.output_hidden_states
1099
- )
1100
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1101
- return_dict = (
1102
- return_dict if return_dict is not None else self.config.use_return_dict
1103
- )
1104
-
1105
- if input_ids is not None and inputs_embeds is not None:
1106
- raise ValueError(
1107
- "You cannot specify both input_ids and inputs_embeds at the same time"
1108
- )
1109
- elif input_ids is not None:
1110
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1111
- input_shape = input_ids.size()
1112
- input_ids = input_ids.view(-1, input_shape[-1])
1113
- batch_size = input_ids.shape[0]
1114
- elif inputs_embeds is not None:
1115
- input_shape = inputs_embeds.size()[:-1]
1116
- batch_size = inputs_embeds.shape[0]
1117
- else:
1118
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1119
-
1120
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1121
-
1122
- if token_type_ids is not None:
1123
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
1124
-
1125
- if past_key_values is None:
1126
- past_length = 0
1127
- past_key_values = tuple([None] * len(self.h))
1128
- else:
1129
- past_length = past_key_values[0][0].size(-2)
1130
- if position_ids is None:
1131
- position_ids = torch.arange(
1132
- past_length,
1133
- input_shape[-1] + past_length,
1134
- dtype=torch.long,
1135
- device=device,
1136
- )
1137
- position_ids = position_ids.unsqueeze(0)
1138
-
1139
- if inputs_embeds is None:
1140
- inputs_embeds = self.wte(input_ids)
1141
- position_embeds = self.wpe(position_ids)
1142
- hidden_states = inputs_embeds + position_embeds
1143
-
1144
- # Attention mask.
1145
- _use_sdpa = (
1146
- self._attn_implementation == "sdpa"
1147
- and output_attentions is False
1148
- and head_mask is None
1149
- )
1150
- attention_mask = (
1151
- attention_mask.view(batch_size, -1) if attention_mask is not None else None
1152
- )
1153
- if self._attn_implementation == "flash_attention_2":
1154
- attention_mask = (
1155
- attention_mask
1156
- if (attention_mask is not None and 0 in attention_mask)
1157
- else None
1158
- )
1159
- elif _use_sdpa:
1160
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1161
- attention_mask=attention_mask,
1162
- input_shape=(batch_size, input_shape[-1]),
1163
- inputs_embeds=inputs_embeds,
1164
- past_key_values_length=past_length,
1165
- )
1166
- else:
1167
- if attention_mask is not None:
1168
- # We create a 3D attention mask from a 2D tensor mask.
1169
- # Sizes are [batch_size, 1, 1, to_seq_length]
1170
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1171
- # this attention mask is more simple than the triangular masking of causal attention
1172
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1173
- attention_mask = attention_mask[:, None, None, :]
1174
-
1175
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1176
- # masked positions, this operation will create a tensor which is 0.0 for
1177
- # positions we want to attend and the dtype's smallest value for masked positions.
1178
- # Since we are adding it to the raw scores before the softmax, this is
1179
- # effectively the same as removing these entirely.
1180
- attention_mask = attention_mask.to(
1181
- dtype=self.dtype
1182
- ) # fp16 compatibility
1183
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1184
-
1185
- # If a 2D or 3D attention mask is provided for the cross-attention
1186
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1187
- if self.config.add_cross_attention and encoder_hidden_states is not None:
1188
- encoder_batch_size, encoder_sequence_length, _ = (
1189
- encoder_hidden_states.size()
1190
- )
1191
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1192
- if encoder_attention_mask is None:
1193
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1194
- if _use_sdpa:
1195
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1196
- mask=encoder_attention_mask,
1197
- dtype=inputs_embeds.dtype,
1198
- tgt_len=input_shape[-1],
1199
- )
1200
- elif not self._attn_implementation == "flash_attention_2":
1201
- encoder_attention_mask = self.invert_attention_mask(
1202
- encoder_attention_mask
1203
- )
1204
- else:
1205
- encoder_attention_mask = None
1206
-
1207
- # Prepare head mask if needed
1208
- # 1.0 in head_mask indicate we keep the head
1209
- # attention_probs has shape bsz x n_heads x N x N
1210
- # head_mask has shape n_layer x batch x n_heads x N x N
1211
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1212
-
1213
- if token_type_ids is not None:
1214
- token_type_embeds = self.wte(token_type_ids)
1215
- hidden_states = hidden_states + token_type_embeds
1216
-
1217
- hidden_states = self.drop(hidden_states)
1218
-
1219
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1220
-
1221
- if self.gradient_checkpointing and self.training:
1222
- if use_cache:
1223
- logger.warning_once(
1224
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1225
- )
1226
- use_cache = False
1227
-
1228
- presents = () if use_cache else None
1229
- all_self_attentions = () if output_attentions else None
1230
- all_cross_attentions = (
1231
- () if output_attentions and self.config.add_cross_attention else None
1232
- )
1233
- all_hidden_states = () if output_hidden_states else None
1234
- for i in range(len(self.h)):
1235
- block, layer_past = self.h[i], past_key_values[i]
1236
- # Model parallel
1237
- if self.model_parallel:
1238
- torch.cuda.set_device(hidden_states.device)
1239
- # Ensure layer_past is on same device as hidden_states (might not be correct)
1240
- if layer_past is not None:
1241
- layer_past = tuple(
1242
- past_state.to(hidden_states.device) for past_state in layer_past
1243
- )
1244
- # Ensure that attention_mask is always on the same device as hidden_states
1245
- if attention_mask is not None:
1246
- attention_mask = attention_mask.to(hidden_states.device)
1247
- if isinstance(head_mask, torch.Tensor):
1248
- head_mask = head_mask.to(hidden_states.device)
1249
- if output_hidden_states:
1250
- all_hidden_states = all_hidden_states + (hidden_states,)
1251
-
1252
- if self.gradient_checkpointing and self.training:
1253
- outputs = self._gradient_checkpointing_func(
1254
- block.__call__,
1255
- hidden_states,
1256
- None,
1257
- attention_mask,
1258
- head_mask[i],
1259
- encoder_hidden_states,
1260
- encoder_attention_mask,
1261
- use_cache,
1262
- output_attentions,
1263
- )
1264
- else:
1265
- outputs = block(
1266
- hidden_states,
1267
- layer_past=layer_past,
1268
- attention_mask=attention_mask,
1269
- head_mask=head_mask[i],
1270
- encoder_hidden_states=encoder_hidden_states,
1271
- encoder_attention_mask=encoder_attention_mask,
1272
- use_cache=use_cache,
1273
- output_attentions=output_attentions,
1274
- )
1275
-
1276
- hidden_states = outputs[0]
1277
- if use_cache is True:
1278
- presents = presents + (outputs[1],)
1279
-
1280
- if output_attentions:
1281
- all_self_attentions = all_self_attentions + (
1282
- outputs[2 if use_cache else 1],
1283
- )
1284
- if self.config.add_cross_attention:
1285
- all_cross_attentions = all_cross_attentions + (
1286
- outputs[3 if use_cache else 2],
1287
- )
1288
-
1289
- # Model Parallel: If it's the last layer for that device, put things on the next device
1290
- if self.model_parallel:
1291
- for k, v in self.device_map.items():
1292
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
1293
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
1294
-
1295
- hidden_states = self.ln_f(hidden_states)
1296
-
1297
- hidden_states = hidden_states.view(output_shape)
1298
- # Add last hidden state
1299
- if output_hidden_states:
1300
- all_hidden_states = all_hidden_states + (hidden_states,)
1301
-
1302
- if not return_dict:
1303
- return tuple(
1304
- v
1305
- for v in [
1306
- hidden_states,
1307
- presents,
1308
- all_hidden_states,
1309
- all_self_attentions,
1310
- all_cross_attentions,
1311
- ]
1312
- if v is not None
1313
- )
1314
-
1315
- return BaseModelOutputWithPastAndCrossAttentions(
1316
- last_hidden_state=hidden_states,
1317
- past_key_values=presents,
1318
- hidden_states=all_hidden_states,
1319
- attentions=all_self_attentions,
1320
- cross_attentions=all_cross_attentions,
1321
- )
1322
-
1323
-
1324
- @add_start_docstrings(
1325
- """
1326
- The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1327
- embeddings).
1328
- """,
1329
- GPT2_START_DOCSTRING,
1330
- )
1331
- class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1332
- _tied_weights_keys = ["lm_head.weight"]
1333
-
1334
- def __init__(self, config):
1335
- super().__init__(config)
1336
- self.transformer = GPT2Model(config)
1337
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1338
-
1339
- # Model parallel
1340
- self.model_parallel = False
1341
- self.device_map = None
1342
-
1343
- # Initialize weights and apply final processing
1344
- self.post_init()
1345
-
1346
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1347
- def parallelize(self, device_map=None):
1348
- warnings.warn(
1349
- "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1350
- " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1351
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1352
- " 0, 'transformer.h.1': 1, ...}",
1353
- FutureWarning,
1354
- )
1355
- self.device_map = (
1356
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1357
- if device_map is None
1358
- else device_map
1359
- )
1360
- assert_device_map(self.device_map, len(self.transformer.h))
1361
- self.transformer.parallelize(self.device_map)
1362
- self.lm_head = self.lm_head.to(self.transformer.first_device)
1363
- self.model_parallel = True
1364
-
1365
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1366
- def deparallelize(self):
1367
- warnings.warn(
1368
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1369
- FutureWarning,
1370
- )
1371
- self.transformer.deparallelize()
1372
- self.transformer = self.transformer.to("cpu")
1373
- self.lm_head = self.lm_head.to("cpu")
1374
- self.model_parallel = False
1375
- torch.cuda.empty_cache()
1376
-
1377
- def get_output_embeddings(self):
1378
- return self.lm_head
1379
-
1380
- def set_output_embeddings(self, new_embeddings):
1381
- self.lm_head = new_embeddings
1382
-
1383
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1384
- @add_code_sample_docstrings(
1385
- checkpoint=_CHECKPOINT_FOR_DOC,
1386
- output_type=CausalLMOutputWithCrossAttentions,
1387
- config_class=_CONFIG_FOR_DOC,
1388
- )
1389
- def forward(
1390
- self,
1391
- input_ids: Optional[torch.LongTensor] = None,
1392
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1393
- attention_mask: Optional[torch.FloatTensor] = None,
1394
- token_type_ids: Optional[torch.LongTensor] = None,
1395
- position_ids: Optional[torch.LongTensor] = None,
1396
- head_mask: Optional[torch.FloatTensor] = None,
1397
- inputs_embeds: Optional[torch.FloatTensor] = None,
1398
- encoder_hidden_states: Optional[torch.Tensor] = None,
1399
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1400
- labels: Optional[torch.LongTensor] = None,
1401
- use_cache: Optional[bool] = None,
1402
- output_attentions: Optional[bool] = None,
1403
- output_hidden_states: Optional[bool] = None,
1404
- return_dict: Optional[bool] = None,
1405
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1406
- r"""
1407
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1408
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1409
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1410
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1411
- """
1412
- return_dict = (
1413
- return_dict if return_dict is not None else self.config.use_return_dict
1414
- )
1415
-
1416
- transformer_outputs = self.transformer(
1417
- input_ids,
1418
- past_key_values=past_key_values,
1419
- attention_mask=attention_mask,
1420
- token_type_ids=token_type_ids,
1421
- position_ids=position_ids,
1422
- head_mask=head_mask,
1423
- inputs_embeds=inputs_embeds,
1424
- encoder_hidden_states=encoder_hidden_states,
1425
- encoder_attention_mask=encoder_attention_mask,
1426
- use_cache=use_cache,
1427
- output_attentions=output_attentions,
1428
- output_hidden_states=output_hidden_states,
1429
- return_dict=return_dict,
1430
- )
1431
- hidden_states = transformer_outputs[0]
1432
-
1433
- # Set device for model parallelism
1434
- if self.model_parallel:
1435
- torch.cuda.set_device(self.transformer.first_device)
1436
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1437
-
1438
- lm_logits = self.lm_head(hidden_states)
1439
-
1440
- loss = None
1441
- if labels is not None:
1442
- # move labels to correct device to enable model parallelism
1443
- labels = labels.to(lm_logits.device)
1444
- # Shift so that tokens < n predict n
1445
- shift_logits = lm_logits[..., :-1, :].contiguous()
1446
- shift_labels = labels[..., 1:].contiguous()
1447
- # Flatten the tokens
1448
- loss_fct = CrossEntropyLoss()
1449
- loss = loss_fct(
1450
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1451
- )
1452
-
1453
- if not return_dict:
1454
- output = (lm_logits,) + transformer_outputs[1:]
1455
- return ((loss,) + output) if loss is not None else output
1456
-
1457
- return CausalLMOutputWithCrossAttentions(
1458
- loss=loss,
1459
- logits=lm_logits,
1460
- past_key_values=transformer_outputs.past_key_values,
1461
- hidden_states=transformer_outputs.hidden_states,
1462
- attentions=transformer_outputs.attentions,
1463
- cross_attentions=transformer_outputs.cross_attentions,
1464
- )
1465
-
1466
- @staticmethod
1467
- def _reorder_cache(
1468
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1469
- ) -> Tuple[Tuple[torch.Tensor]]:
1470
- """
1471
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1472
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1473
- beam_idx at every generation step.
1474
- """
1475
- return tuple(
1476
- tuple(
1477
- past_state.index_select(0, beam_idx.to(past_state.device))
1478
- for past_state in layer_past
1479
- )
1480
- for layer_past in past_key_values
1481
- )
1482
-
1483
-
1484
- @add_start_docstrings(
1485
- """
1486
- The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1487
- RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1488
- input embeddings, the classification head takes as input the input of a specified classification token index in the
1489
- input sequence).
1490
- """,
1491
- GPT2_START_DOCSTRING,
1492
- )
1493
- class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1494
- _tied_weights_keys = ["lm_head.weight"]
1495
-
1496
- def __init__(self, config):
1497
- super().__init__(config)
1498
- config.num_labels = 1
1499
- self.transformer = GPT2Model(config)
1500
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1501
- self.multiple_choice_head = SequenceSummary(config)
1502
-
1503
- # Model parallel
1504
- self.model_parallel = False
1505
- self.device_map = None
1506
-
1507
- # Initialize weights and apply final processing
1508
- self.post_init()
1509
-
1510
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1511
- def parallelize(self, device_map=None):
1512
- warnings.warn(
1513
- "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1514
- " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1515
- " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1516
- " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1517
- FutureWarning,
1518
- )
1519
- self.device_map = (
1520
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1521
- if device_map is None
1522
- else device_map
1523
- )
1524
- assert_device_map(self.device_map, len(self.transformer.h))
1525
- self.transformer.parallelize(self.device_map)
1526
- self.lm_head = self.lm_head.to(self.transformer.first_device)
1527
- self.multiple_choice_head = self.multiple_choice_head.to(
1528
- self.transformer.first_device
1529
- )
1530
- self.model_parallel = True
1531
-
1532
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1533
- def deparallelize(self):
1534
- warnings.warn(
1535
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1536
- FutureWarning,
1537
- )
1538
- self.transformer.deparallelize()
1539
- self.transformer = self.transformer.to("cpu")
1540
- self.lm_head = self.lm_head.to("cpu")
1541
- self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1542
- self.model_parallel = False
1543
- torch.cuda.empty_cache()
1544
-
1545
- def get_output_embeddings(self):
1546
- return self.lm_head
1547
-
1548
- def set_output_embeddings(self, new_embeddings):
1549
- self.lm_head = new_embeddings
1550
-
1551
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1552
- @replace_return_docstrings(
1553
- output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1554
- )
1555
- def forward(
1556
- self,
1557
- input_ids: Optional[torch.LongTensor] = None,
1558
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1559
- attention_mask: Optional[torch.FloatTensor] = None,
1560
- token_type_ids: Optional[torch.LongTensor] = None,
1561
- position_ids: Optional[torch.LongTensor] = None,
1562
- head_mask: Optional[torch.FloatTensor] = None,
1563
- inputs_embeds: Optional[torch.FloatTensor] = None,
1564
- mc_token_ids: Optional[torch.LongTensor] = None,
1565
- labels: Optional[torch.LongTensor] = None,
1566
- mc_labels: Optional[torch.LongTensor] = None,
1567
- use_cache: Optional[bool] = None,
1568
- output_attentions: Optional[bool] = None,
1569
- output_hidden_states: Optional[bool] = None,
1570
- return_dict: Optional[bool] = None,
1571
- **kwargs,
1572
- ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1573
- r"""
1574
- mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1575
- Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1576
- 1]`.
1577
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1578
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1579
- `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1580
- `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1581
- mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1582
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1583
- where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1584
-
1585
- Return:
1586
-
1587
- Example:
1588
-
1589
- ```python
1590
- >>> import torch
1591
- >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1592
-
1593
- >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1594
- >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1595
-
1596
- >>> # Add a [CLS] to the vocabulary (we should train it also!)
1597
- >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1598
- >>> # Update the model embeddings with the new vocabulary size
1599
- >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1600
-
1601
- >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1602
- >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1603
- >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1604
-
1605
- >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1606
- >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1607
-
1608
- >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1609
- >>> lm_logits = outputs.logits
1610
- >>> mc_logits = outputs.mc_logits
1611
- ```"""
1612
- return_dict = (
1613
- return_dict if return_dict is not None else self.config.use_return_dict
1614
- )
1615
-
1616
- transformer_outputs = self.transformer(
1617
- input_ids,
1618
- past_key_values=past_key_values,
1619
- attention_mask=attention_mask,
1620
- token_type_ids=token_type_ids,
1621
- position_ids=position_ids,
1622
- head_mask=head_mask,
1623
- inputs_embeds=inputs_embeds,
1624
- use_cache=use_cache,
1625
- output_attentions=output_attentions,
1626
- output_hidden_states=output_hidden_states,
1627
- return_dict=return_dict,
1628
- )
1629
-
1630
- hidden_states = transformer_outputs[0]
1631
-
1632
- # Set device for model parallelism
1633
- if self.model_parallel:
1634
- torch.cuda.set_device(self.transformer.first_device)
1635
- hidden_states = hidden_states.to(self.lm_head.weight.device)
1636
-
1637
- lm_logits = self.lm_head(hidden_states)
1638
- mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1639
-
1640
- mc_loss = None
1641
- if mc_labels is not None:
1642
- loss_fct = CrossEntropyLoss()
1643
- mc_loss = loss_fct(
1644
- mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1645
- )
1646
- lm_loss = None
1647
- if labels is not None:
1648
- labels = labels.to(lm_logits.device)
1649
- shift_logits = lm_logits[..., :-1, :].contiguous()
1650
- shift_labels = labels[..., 1:].contiguous()
1651
- loss_fct = CrossEntropyLoss()
1652
- lm_loss = loss_fct(
1653
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1654
- )
1655
-
1656
- if not return_dict:
1657
- output = (lm_logits, mc_logits) + transformer_outputs[1:]
1658
- if mc_loss is not None:
1659
- output = (mc_loss,) + output
1660
- return ((lm_loss,) + output) if lm_loss is not None else output
1661
-
1662
- return GPT2DoubleHeadsModelOutput(
1663
- loss=lm_loss,
1664
- mc_loss=mc_loss,
1665
- logits=lm_logits,
1666
- mc_logits=mc_logits,
1667
- past_key_values=transformer_outputs.past_key_values,
1668
- hidden_states=transformer_outputs.hidden_states,
1669
- attentions=transformer_outputs.attentions,
1670
- )
1671
-
1672
- @staticmethod
1673
- def _reorder_cache(
1674
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1675
- ) -> Tuple[Tuple[torch.Tensor]]:
1676
- """
1677
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1678
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1679
- beam_idx at every generation step.
1680
- """
1681
- return tuple(
1682
- tuple(
1683
- past_state.index_select(0, beam_idx.to(past_state.device))
1684
- for past_state in layer_past
1685
- )
1686
- for layer_past in past_key_values
1687
- )
1688
-
1689
-
1690
- @add_start_docstrings(
1691
- """
1692
- The GPT2 Model transformer with a sequence classification head on top (linear layer).
1693
-
1694
- [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1695
- (e.g. GPT-1) do.
1696
-
1697
- Since it does classification on the last token, it requires to know the position of the last token. If a
1698
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1699
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1700
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1701
- each row of the batch).
1702
- """,
1703
- GPT2_START_DOCSTRING,
1704
- )
1705
- class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1706
- def __init__(self, config):
1707
- super().__init__(config)
1708
- self.num_labels = config.num_labels
1709
- self.transformer = GPT2Model(config)
1710
- self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1711
-
1712
- # Model parallel
1713
- self.model_parallel = False
1714
- self.device_map = None
1715
-
1716
- # Initialize weights and apply final processing
1717
- self.post_init()
1718
-
1719
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1720
- @add_code_sample_docstrings(
1721
- checkpoint="microsoft/DialogRPT-updown",
1722
- output_type=SequenceClassifierOutputWithPast,
1723
- config_class=_CONFIG_FOR_DOC,
1724
- )
1725
- def forward(
1726
- self,
1727
- input_ids: Optional[torch.LongTensor] = None,
1728
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1729
- attention_mask: Optional[torch.FloatTensor] = None,
1730
- token_type_ids: Optional[torch.LongTensor] = None,
1731
- position_ids: Optional[torch.LongTensor] = None,
1732
- head_mask: Optional[torch.FloatTensor] = None,
1733
- inputs_embeds: Optional[torch.FloatTensor] = None,
1734
- labels: Optional[torch.LongTensor] = None,
1735
- use_cache: Optional[bool] = None,
1736
- output_attentions: Optional[bool] = None,
1737
- output_hidden_states: Optional[bool] = None,
1738
- return_dict: Optional[bool] = None,
1739
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1740
- r"""
1741
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1742
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1743
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1744
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1745
- """
1746
- return_dict = (
1747
- return_dict if return_dict is not None else self.config.use_return_dict
1748
- )
1749
-
1750
- transformer_outputs = self.transformer(
1751
- input_ids,
1752
- past_key_values=past_key_values,
1753
- attention_mask=attention_mask,
1754
- token_type_ids=token_type_ids,
1755
- position_ids=position_ids,
1756
- head_mask=head_mask,
1757
- inputs_embeds=inputs_embeds,
1758
- use_cache=use_cache,
1759
- output_attentions=output_attentions,
1760
- output_hidden_states=output_hidden_states,
1761
- return_dict=return_dict,
1762
- )
1763
- hidden_states = transformer_outputs[0]
1764
- logits = self.score(hidden_states)
1765
-
1766
- if input_ids is not None:
1767
- batch_size, sequence_length = input_ids.shape[:2]
1768
- else:
1769
- batch_size, sequence_length = inputs_embeds.shape[:2]
1770
-
1771
- assert (
1772
- self.config.pad_token_id is not None or batch_size == 1
1773
- ), "Cannot handle batch sizes > 1 if no padding token is defined."
1774
- if self.config.pad_token_id is None:
1775
- sequence_lengths = -1
1776
- else:
1777
- if input_ids is not None:
1778
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1779
- sequence_lengths = (
1780
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1781
- )
1782
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1783
- sequence_lengths = sequence_lengths.to(logits.device)
1784
- else:
1785
- sequence_lengths = -1
1786
- logger.warning_once(
1787
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1788
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1789
- )
1790
-
1791
- pooled_logits = logits[
1792
- torch.arange(batch_size, device=logits.device), sequence_lengths
1793
- ]
1794
-
1795
- loss = None
1796
- if labels is not None:
1797
- if self.config.problem_type is None:
1798
- if self.num_labels == 1:
1799
- self.config.problem_type = "regression"
1800
- elif self.num_labels > 1 and (
1801
- labels.dtype == torch.long or labels.dtype == torch.int
1802
- ):
1803
- self.config.problem_type = "single_label_classification"
1804
- else:
1805
- self.config.problem_type = "multi_label_classification"
1806
-
1807
- if self.config.problem_type == "regression":
1808
- loss_fct = MSELoss()
1809
- if self.num_labels == 1:
1810
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1811
- else:
1812
- loss = loss_fct(pooled_logits, labels)
1813
- elif self.config.problem_type == "single_label_classification":
1814
- loss_fct = CrossEntropyLoss()
1815
- loss = loss_fct(
1816
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1817
- )
1818
- elif self.config.problem_type == "multi_label_classification":
1819
- loss_fct = BCEWithLogitsLoss()
1820
- loss = loss_fct(pooled_logits, labels)
1821
- if not return_dict:
1822
- output = (pooled_logits,) + transformer_outputs[1:]
1823
- return ((loss,) + output) if loss is not None else output
1824
-
1825
- return SequenceClassifierOutputWithPast(
1826
- loss=loss,
1827
- logits=pooled_logits,
1828
- past_key_values=transformer_outputs.past_key_values,
1829
- hidden_states=transformer_outputs.hidden_states,
1830
- attentions=transformer_outputs.attentions,
1831
- )
1832
-
1833
-
1834
- @add_start_docstrings(
1835
- """
1836
- GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1837
- Named-Entity-Recognition (NER) tasks.
1838
- """,
1839
- GPT2_START_DOCSTRING,
1840
- )
1841
- class GPT2ForTokenClassification(GPT2PreTrainedModel):
1842
- def __init__(self, config):
1843
- super().__init__(config)
1844
- self.num_labels = config.num_labels
1845
-
1846
- self.transformer = GPT2Model(config)
1847
- if (
1848
- hasattr(config, "classifier_dropout")
1849
- and config.classifier_dropout is not None
1850
- ):
1851
- classifier_dropout = config.classifier_dropout
1852
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1853
- classifier_dropout = config.hidden_dropout
1854
- else:
1855
- classifier_dropout = 0.1
1856
- self.dropout = nn.Dropout(classifier_dropout)
1857
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1858
-
1859
- # Model parallel
1860
- self.model_parallel = False
1861
- self.device_map = None
1862
-
1863
- # Initialize weights and apply final processing
1864
- self.post_init()
1865
-
1866
- @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1867
- # fmt: off
1868
- @add_code_sample_docstrings(
1869
- checkpoint="brad1141/gpt2-finetuned-comp2",
1870
- output_type=TokenClassifierOutput,
1871
- config_class=_CONFIG_FOR_DOC,
1872
- expected_loss=0.25,
1873
- expected_output=[
1874
- "Lead",
1875
- "Lead",
1876
- "Lead",
1877
- "Position",
1878
- "Lead",
1879
- "Lead",
1880
- "Lead",
1881
- "Lead",
1882
- "Lead",
1883
- "Lead",
1884
- "Lead",
1885
- "Lead",
1886
- ],
1887
- )
1888
- # fmt: on
1889
- def forward(
1890
- self,
1891
- input_ids: Optional[torch.LongTensor] = None,
1892
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1893
- attention_mask: Optional[torch.FloatTensor] = None,
1894
- token_type_ids: Optional[torch.LongTensor] = None,
1895
- position_ids: Optional[torch.LongTensor] = None,
1896
- head_mask: Optional[torch.FloatTensor] = None,
1897
- inputs_embeds: Optional[torch.FloatTensor] = None,
1898
- labels: Optional[torch.LongTensor] = None,
1899
- use_cache: Optional[bool] = None,
1900
- output_attentions: Optional[bool] = None,
1901
- output_hidden_states: Optional[bool] = None,
1902
- return_dict: Optional[bool] = None,
1903
- ) -> Union[Tuple, TokenClassifierOutput]:
1904
- r"""
1905
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1906
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1907
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1908
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1909
- """
1910
- return_dict = (
1911
- return_dict if return_dict is not None else self.config.use_return_dict
1912
- )
1913
-
1914
- transformer_outputs = self.transformer(
1915
- input_ids,
1916
- past_key_values=past_key_values,
1917
- attention_mask=attention_mask,
1918
- token_type_ids=token_type_ids,
1919
- position_ids=position_ids,
1920
- head_mask=head_mask,
1921
- inputs_embeds=inputs_embeds,
1922
- use_cache=use_cache,
1923
- output_attentions=output_attentions,
1924
- output_hidden_states=output_hidden_states,
1925
- return_dict=return_dict,
1926
- )
1927
-
1928
- hidden_states = transformer_outputs[0]
1929
- hidden_states = self.dropout(hidden_states)
1930
- logits = self.classifier(hidden_states)
1931
-
1932
- loss = None
1933
- if labels is not None:
1934
- labels = labels.to(logits.device)
1935
- loss_fct = CrossEntropyLoss()
1936
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1937
-
1938
- if not return_dict:
1939
- output = (logits,) + transformer_outputs[2:]
1940
- return ((loss,) + output) if loss is not None else output
1941
-
1942
- return TokenClassifierOutput(
1943
- loss=loss,
1944
- logits=logits,
1945
- hidden_states=transformer_outputs.hidden_states,
1946
- attentions=transformer_outputs.attentions,
1947
- )
1948
-
1949
-
1950
- @add_start_docstrings(
1951
- """
1952
- The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
1953
- SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1954
- """,
1955
- GPT2_START_DOCSTRING,
1956
- )
1957
- class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1958
- def __init__(self, config):
1959
- super().__init__(config)
1960
- self.num_labels = config.num_labels
1961
- self.transformer = GPT2Model(config)
1962
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1963
-
1964
- # Model parallel
1965
- self.model_parallel = False
1966
- self.device_map = None
1967
-
1968
- # Initialize weights and apply final processing
1969
- self.post_init()
1970
-
1971
- @add_start_docstrings_to_model_forward(
1972
- GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1973
- )
1974
- @add_code_sample_docstrings(
1975
- checkpoint=_CHECKPOINT_FOR_DOC,
1976
- output_type=QuestionAnsweringModelOutput,
1977
- config_class=_CONFIG_FOR_DOC,
1978
- real_checkpoint=_CHECKPOINT_FOR_DOC,
1979
- )
1980
- def forward(
1981
- self,
1982
- input_ids: Optional[torch.LongTensor] = None,
1983
- attention_mask: Optional[torch.FloatTensor] = None,
1984
- token_type_ids: Optional[torch.LongTensor] = None,
1985
- position_ids: Optional[torch.LongTensor] = None,
1986
- head_mask: Optional[torch.FloatTensor] = None,
1987
- inputs_embeds: Optional[torch.FloatTensor] = None,
1988
- start_positions: Optional[torch.LongTensor] = None,
1989
- end_positions: Optional[torch.LongTensor] = None,
1990
- output_attentions: Optional[bool] = None,
1991
- output_hidden_states: Optional[bool] = None,
1992
- return_dict: Optional[bool] = None,
1993
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1994
- r"""
1995
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1996
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1997
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1998
- are not taken into account for computing the loss.
1999
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2000
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
2001
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
2002
- are not taken into account for computing the loss.
2003
- """
2004
- return_dict = (
2005
- return_dict if return_dict is not None else self.config.use_return_dict
2006
- )
2007
-
2008
- outputs = self.transformer(
2009
- input_ids,
2010
- attention_mask=attention_mask,
2011
- token_type_ids=token_type_ids,
2012
- position_ids=position_ids,
2013
- head_mask=head_mask,
2014
- inputs_embeds=inputs_embeds,
2015
- output_attentions=output_attentions,
2016
- output_hidden_states=output_hidden_states,
2017
- return_dict=return_dict,
2018
- )
2019
-
2020
- sequence_output = outputs[0]
2021
-
2022
- logits = self.qa_outputs(sequence_output)
2023
- start_logits, end_logits = logits.split(1, dim=-1)
2024
- start_logits = start_logits.squeeze(-1).contiguous()
2025
- end_logits = end_logits.squeeze(-1).contiguous()
2026
-
2027
- total_loss = None
2028
- if start_positions is not None and end_positions is not None:
2029
- # If we are on multi-GPU, split add a dimension
2030
- if len(start_positions.size()) > 1:
2031
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
2032
- if len(end_positions.size()) > 1:
2033
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
2034
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
2035
- ignored_index = start_logits.size(1)
2036
- start_positions = start_positions.clamp(0, ignored_index)
2037
- end_positions = end_positions.clamp(0, ignored_index)
2038
-
2039
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2040
- start_loss = loss_fct(start_logits, start_positions)
2041
- end_loss = loss_fct(end_logits, end_positions)
2042
- total_loss = (start_loss + end_loss) / 2
2043
-
2044
- if not return_dict:
2045
- output = (start_logits, end_logits) + outputs[2:]
2046
- return ((total_loss,) + output) if total_loss is not None else output
2047
 
2048
- return QuestionAnsweringModelOutput(
2049
- loss=total_loss,
2050
- start_logits=start_logits,
2051
- end_logits=end_logits,
2052
- hidden_states=outputs.hidden_states,
2053
- attentions=outputs.attentions,
2054
- )
 
 
 
 
15
  # limitations under the License.
16
  """PyTorch OpenAI GPT-2 model."""
17
 
 
 
 
18
  from dataclasses import dataclass
19
+ from typing import Callable, Optional, Tuple, Union
20
 
21
  import torch
 
 
22
  from torch import nn
 
23
 
24
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
26
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from transformers.utils import (
 
 
 
 
 
 
 
28
  logging,
 
29
  )
30
+ from transformers.utils.deprecation import deprecate_kwarg
31
+ from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, eager_attention_forward, GPT2Block, GPT2MLP, GPT2SequenceSummary,GPT2PreTrainedModel,GPT2DoubleHeadsModelOuptut,GPT2DoubleHeadsModel, GPT2Model,GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification,GPT2ForTokenClassification,GPT2ForQuestionAnswering
 
 
 
 
 
32
 
33
  logger = logging.get_logger(__name__)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class GPT2Attention(nn.Module):
37
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
 
40
  max_positions = config.max_position_embeddings
41
  self.register_buffer(
42
  "bias",
43
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
44
+ 1, 1, max_positions, max_positions
45
+ ),
46
  persistent=False,
47
  )
48
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
 
81
  def prune_heads(self, heads):
82
  if len(heads) == 0:
83
  return
84
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
85
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
 
 
 
 
86
 
87
  # Prune conv1d layers
88
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
89
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
90
 
91
  # Update hyper params
92
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
 
 
93
  self.num_heads = self.num_heads - len(heads)
94
  self.pruned_heads = self.pruned_heads.union(heads)
95
 
96
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
98
  bsz, num_heads, q_seq_len, dk = query.size()
99
  _, _, k_seq_len, _ = key.size()
100
 
101
  # Preallocate attn_weights for `baddbmm`
102
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
 
 
 
 
 
 
103
 
104
  # Compute Scale Factor
105
  scale_factor = 1.0
 
111
 
112
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
113
  with torch.amp.autocast(query.device.type, enabled=False):
114
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
115
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
 
 
 
 
116
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
117
 
118
  if not self.is_cross_attention:
119
  # if only "normal" attention layer implements causal mask
120
  query_length, key_length = query.size(-2), key.size(-2)
121
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
 
 
122
  mask_value = torch.finfo(attn_weights.dtype).min
123
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
124
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
125
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
 
 
126
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
127
 
128
  if attention_mask is not None:
 
133
 
134
  # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
135
  if attn_weights.dtype != torch.float32:
136
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
 
 
137
  attn_weights = attn_weights.type(value.dtype)
138
  attn_weights = self.attn_dropout(attn_weights)
139
 
 
142
  attn_weights = attn_weights * head_mask
143
 
144
  attn_output = torch.matmul(attn_weights, value)
145
+ attn_output = attn_output.transpose(1, 2)
146
 
147
  return attn_output, attn_weights
148
 
149
+ @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def forward(
151
  self,
152
  hidden_states: Optional[Tuple[torch.FloatTensor]],
153
+ past_key_value: Optional[Cache] = None,
154
+ cache_position: Optional[torch.LongTensor] = None,
155
  attention_mask: Optional[torch.FloatTensor] = None,
156
  head_mask: Optional[torch.FloatTensor] = None,
157
  encoder_hidden_states: Optional[torch.Tensor] = None,
158
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
 
159
  output_attentions: Optional[bool] = False,
160
+ **kwargs,
161
  ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
162
+ is_cross_attention = encoder_hidden_states is not None
163
+ if is_cross_attention:
164
  if not hasattr(self, "q_attn"):
165
  raise ValueError(
166
  "If class is used as cross attention, the weights `q_attn` have to be defined. "
167
  "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
168
  )
169
 
170
+ query_states = self.q_attn(hidden_states)
171
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
 
 
172
  attention_mask = encoder_attention_mask
173
  else:
174
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
 
 
 
 
175
 
176
+ shape_q = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
177
+ shape_kv = (query_states.shape[0],query_states.shape[1], -1, self.head_dim)
 
 
178
 
179
+ query_states = query_states.view(shape_q).transpose(1, 2)
180
+ key_states = key_states.view(shape_kv).transpose(1, 2)
181
+ value_states = value_states.view(shape_kv).transpose(1, 2)
 
182
 
183
+ if past_key_value is not None:
184
+ if isinstance(past_key_value, EncoderDecoderCache):
185
+ if is_cross_attention:
186
+ past_key_value = past_key_value.cross_attention_cache
187
+ else:
188
+ past_key_value = past_key_value.self_attention_cache
189
+ cache_kwargs = {"cache_position": cache_position}
190
+ key_states, value_states = past_key_value.update(
191
+ key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
192
  )
193
 
194
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ using_eager = self.config._attn_implementation == "eager"
197
+ attention_interface: Callable = eager_attention_forward
198
+ if self.config._attn_implementation != "eager":
199
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
200
+ using_eager = True
201
+ logger.warning_once(
202
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
203
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  else:
206
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
207
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
208
+ # not necessarily to eager (if mentioned options are provided).
209
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ if using_eager and self.reorder_and_upcast_attn:
212
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
213
+ query_states, key_states, value_states, attention_mask, head_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
+ else:
216
+ attn_output, attn_weights = attention_interface(
217
+ self,
218
+ query_states,
219
+ key_states,
220
+ value_states,
221
+ attention_mask,
222
  head_mask=head_mask,
223
+ dropout=self.attn_dropout.p if self.training else 0.0,
224
+ is_causal=is_causal,
225
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ attn_output = attn_output.reshape(attn_output.shape[0],attn_output.shape[1], -1).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  attn_output = self.c_proj(attn_output)
230
  attn_output = self.resid_dropout(attn_output)
231
 
232
+ return attn_output, attn_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ __all__ = [
235
+ "GPT2DoubleHeadsModel",
236
+ "GPT2ForQuestionAnswering",
237
+ "GPT2ForSequenceClassification",
238
+ "GPT2ForTokenClassification",
239
+ "GPT2LMHeadModel",
240
+ "GPT2Model",
241
+ "GPT2PreTrainedModel",
242
+ "load_tf_weights_in_gpt2",
243
+ ]