d-Matrix commited on
Commit
5042902
·
verified ·
1 Parent(s): d16c53a

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +614 -313
modeling_opt.py CHANGED
@@ -13,15 +13,16 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """ PyTorch OPT model."""
16
- import random
17
  from typing import List, Optional, Tuple, Union
18
 
19
  import torch
 
20
  import torch.utils.checkpoint
21
  from torch import nn
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
 
24
  from transformers.activations import ACT2FN
 
25
  from transformers.modeling_outputs import (
26
  BaseModelOutputWithPast,
27
  CausalLMOutputWithPast,
@@ -33,18 +34,23 @@ from transformers.utils import (
33
  add_code_sample_docstrings,
34
  add_start_docstrings,
35
  add_start_docstrings_to_model_forward,
 
 
36
  logging,
37
  replace_return_docstrings,
38
  )
39
  from .configuration_opt import OPTConfig
40
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
 
 
 
 
41
 
42
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
46
  _CONFIG_FOR_DOC = "OPTConfig"
47
- _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
48
 
49
  # Base model docstring
50
  _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
@@ -65,36 +71,45 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
  # See all OPT models at https://huggingface.co/models?filter=opt
66
  ]
67
 
68
- def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
69
- """
70
- Make causal mask used for bi-directional self-attention.
71
- """
72
- bsz, tgt_len = input_ids_shape
73
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
74
- mask_cond = torch.arange(mask.size(-1))
75
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
76
- mask = mask.to(dtype)
77
 
78
- if past_key_values_length > 0:
79
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
80
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
 
 
 
 
 
 
 
81
 
82
 
83
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
- """
85
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
86
- """
87
- bsz, src_len = mask.size()
88
- tgt_len = tgt_len if tgt_len is not None else src_len
 
 
 
 
 
 
 
 
89
 
90
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
91
 
92
- inverted_mask = 1.0 - expanded_mask
 
93
 
94
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
95
 
96
 
97
- class OPTLearnedPositionalEmbedding(nn.Embedding):
98
  """
99
  This module learns positional embeddings up to a fixed maximum size.
100
  """
@@ -102,20 +117,25 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
102
  def __init__(self, num_embeddings: int, embedding_dim: int):
103
  # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
104
  # and adjust num_embeddings appropriately. Other models don't have this hack
 
105
  self.offset = 2
106
- super().__init__(num_embeddings + self.offset, embedding_dim)
107
 
108
- def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
 
 
109
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
110
  attention_mask = attention_mask.long()
111
 
112
  # create positions depending on attention_mask
113
- positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
 
 
114
 
115
  # cut positions if `past_key_values_length` is > 0
116
  positions = positions[:, past_key_values_length:]
117
 
118
- return super().forward(positions + self.offset)
119
 
120
 
121
  class OPTAttention(nn.Module):
@@ -123,36 +143,64 @@ class OPTAttention(nn.Module):
123
 
124
  def __init__(
125
  self,
126
- embed_dim: int,
127
- num_heads: int,
128
- dropout: float = 0.0,
129
  is_decoder: bool = False,
130
- bias: bool = True,
131
  ):
132
  super().__init__()
133
- self.embed_dim = embed_dim
134
- self.num_heads = num_heads
135
- # self.dropout = dropout
136
- self.head_dim = embed_dim // num_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- if (self.head_dim * num_heads) != self.embed_dim:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  raise ValueError(
140
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
141
- f" and `num_heads`: {num_heads})."
142
  )
143
  self.scaling = self.head_dim**-0.5
144
  self.is_decoder = is_decoder
145
 
146
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
148
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
149
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
150
-
151
- self.softmax = nn.Softmax(dim=-1)
152
- self.dropout = nn.Dropout(p=dropout)
153
 
154
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
155
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
156
 
157
  def forward(
158
  self,
@@ -222,15 +270,25 @@ class OPTAttention(nn.Module):
222
  raise ValueError(
223
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
224
  )
225
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask.to(attn_weights.device)
226
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 
 
 
 
 
 
 
 
227
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
228
 
229
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
230
  if attn_weights.dtype == torch.float16:
231
- attn_weights = self.softmax(attn_weights.float()).to(torch.float16)
 
 
232
  else:
233
- attn_weights = self.softmax(attn_weights)
234
 
235
  if layer_head_mask is not None:
236
  if layer_head_mask.size() != (self.num_heads,):
@@ -238,7 +296,9 @@ class OPTAttention(nn.Module):
238
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
239
  f" {layer_head_mask.size()}"
240
  )
241
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
 
 
242
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
243
 
244
  if output_attentions:
@@ -246,12 +306,19 @@ class OPTAttention(nn.Module):
246
  # make sure that attn_weights keeps its gradient.
247
  # In order to do so, attn_weights have to be reshaped
248
  # twice and have to be reused in the following
249
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
250
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 
 
 
 
251
  else:
252
  attn_weights_reshaped = None
253
 
254
- attn_probs = self.dropout(attn_weights)
 
 
 
255
  attn_output = torch.bmm(attn_probs, value_states)
256
 
257
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
@@ -272,36 +339,296 @@ class OPTAttention(nn.Module):
272
  return attn_output, attn_weights_reshaped, past_key_value
273
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  class OPTDecoderLayer(nn.Module):
276
  def __init__(self, config: OPTConfig):
277
  super().__init__()
278
  self.embed_dim = config.hidden_size
279
- self.self_attn = OPTAttention(
280
- embed_dim=self.embed_dim,
281
- num_heads=config.num_attention_heads,
282
- dropout=config.attention_dropout,
283
- is_decoder=True,
284
  )
 
285
  self.do_layer_norm_before = config.do_layer_norm_before
286
- # self.dropout = config.dropout
287
  self.activation_fn = ACT2FN[config.activation_function]
288
 
289
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
290
- self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
291
- self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
292
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
293
-
294
- self.dropout = nn.Dropout(p=config.dropout)
 
 
295
 
296
  def forward(
297
  self,
298
  hidden_states: torch.Tensor,
299
  attention_mask: Optional[torch.Tensor] = None,
300
  layer_head_mask: Optional[torch.Tensor] = None,
 
301
  output_attentions: Optional[bool] = False,
302
  use_cache: Optional[bool] = False,
303
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
304
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
305
  """
306
  Args:
307
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -332,7 +659,9 @@ class OPTDecoderLayer(nn.Module):
332
  layer_head_mask=layer_head_mask,
333
  output_attentions=output_attentions,
334
  )
335
- hidden_states = self.dropout(hidden_states)
 
 
336
  hidden_states = residual + hidden_states
337
 
338
  # 350m applies layer norm AFTER attention
@@ -352,7 +681,9 @@ class OPTDecoderLayer(nn.Module):
352
  hidden_states = self.activation_fn(hidden_states)
353
 
354
  hidden_states = self.fc2(hidden_states)
355
- hidden_states = self.dropout(hidden_states)
 
 
356
 
357
  hidden_states = (residual + hidden_states).view(hidden_states_shape)
358
 
@@ -375,11 +706,9 @@ OPT_START_DOCSTRING = r"""
375
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
376
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
377
  etc.)
378
-
379
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
380
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
381
  and behavior.
382
-
383
  Parameters:
384
  config ([`OPTConfig`]):
385
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -397,7 +726,7 @@ class OPTPreTrainedModel(PreTrainedModel):
397
  base_model_prefix = "model"
398
  supports_gradient_checkpointing = True
399
  _no_split_modules = ["OPTDecoderLayer"]
400
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
401
 
402
  def _init_weights(self, module):
403
  std = self.config.init_std
@@ -410,52 +739,37 @@ class OPTPreTrainedModel(PreTrainedModel):
410
  if module.padding_idx is not None:
411
  module.weight.data[module.padding_idx].zero_()
412
 
413
- def _set_gradient_checkpointing(self, module, value=False):
414
- if isinstance(module, (OPTDecoder)):
415
- module.gradient_checkpointing = value
416
-
417
 
418
  OPT_INPUTS_DOCSTRING = r"""
419
  Args:
420
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
421
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
422
  it.
423
-
424
- Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
425
  [`PreTrainedTokenizer.__call__`] for details.
426
-
427
  [What are input IDs?](../glossary#input-ids)
428
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
429
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
430
-
431
  - 1 for tokens that are **not masked**,
432
  - 0 for tokens that are **masked**.
433
-
434
  [What are attention masks?](../glossary#attention-mask)
435
-
436
- Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
437
  [`PreTrainedTokenizer.__call__`] for details.
438
-
439
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
440
  `past_key_values`).
441
-
442
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
443
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
444
  information on the default strategy.
445
  head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
446
  Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
447
-
448
  - 1 indicates the head is **not masked**,
449
  - 0 indicates the head is **masked**.
450
-
451
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
452
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
453
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
454
  `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
455
-
456
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
457
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
458
-
459
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
460
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
461
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
@@ -480,7 +794,6 @@ OPT_INPUTS_DOCSTRING = r"""
480
  class OPTDecoder(OPTPreTrainedModel):
481
  """
482
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
483
-
484
  Args:
485
  config: OPTConfig
486
  """
@@ -493,16 +806,25 @@ class OPTDecoder(OPTPreTrainedModel):
493
  self.max_target_positions = config.max_position_embeddings
494
  self.vocab_size = config.vocab_size
495
 
496
- self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
497
- self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
 
 
 
 
 
498
 
499
  if config.word_embed_proj_dim != config.hidden_size:
500
- self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
 
 
501
  else:
502
  self.project_out = None
503
 
504
  if config.word_embed_proj_dim != config.hidden_size:
505
- self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
 
 
506
  else:
507
  self.project_in = None
508
 
@@ -510,11 +832,17 @@ class OPTDecoder(OPTPreTrainedModel):
510
  # with checkpoints that have been fine-tuned before transformers v4.20.1
511
  # see https://github.com/facebookresearch/metaseq/pull/164
512
  if config.do_layer_norm_before and not config._remove_final_layer_norm:
513
- self.final_layer_norm = nn.LayerNorm(config.hidden_size)
 
 
 
514
  else:
515
  self.final_layer_norm = None
516
 
517
- self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
518
 
519
  self.gradient_checkpointing = False
520
  # Initialize weights and apply final processing
@@ -526,29 +854,6 @@ class OPTDecoder(OPTPreTrainedModel):
526
  def set_input_embeddings(self, value):
527
  self.embed_tokens = value
528
 
529
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
530
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
531
- # create causal mask
532
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
533
- combined_attention_mask = None
534
- if input_shape[-1] > 1:
535
- combined_attention_mask = _make_causal_mask(
536
- input_shape,
537
- inputs_embeds.dtype,
538
- past_key_values_length=past_key_values_length,
539
- )
540
-
541
- if attention_mask is not None:
542
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
543
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
544
- inputs_embeds.device
545
- )
546
- combined_attention_mask = (
547
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask.to(expanded_attn_mask.device)
548
- )
549
-
550
- return combined_attention_mask
551
-
552
  def forward(
553
  self,
554
  input_ids: torch.LongTensor = None,
@@ -566,35 +871,26 @@ class OPTDecoder(OPTPreTrainedModel):
566
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
567
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
568
  provide it.
569
-
570
- Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
571
  [`PreTrainedTokenizer.__call__`] for details.
572
-
573
  [What are input IDs?](../glossary#input-ids)
574
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
575
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
576
-
577
  - 1 for tokens that are **not masked**,
578
  - 0 for tokens that are **masked**.
579
-
580
  [What are attention masks?](../glossary#attention-mask)
581
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
582
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
583
-
584
  - 1 indicates the head is **not masked**,
585
  - 0 indicates the head is **masked**.
586
-
587
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
588
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
589
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
590
-
591
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
592
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
593
-
594
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
595
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
596
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
597
-
598
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
599
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
600
  This is useful if you want more control over how to convert `input_ids` indices into associated vectors
@@ -608,44 +904,89 @@ class OPTDecoder(OPTPreTrainedModel):
608
  return_dict (`bool`, *optional*):
609
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
610
  """
611
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
612
  output_hidden_states = (
613
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
614
  )
615
  use_cache = use_cache if use_cache is not None else self.config.use_cache
616
 
617
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
618
 
619
  # retrieve input_ids and inputs_embeds
620
  if input_ids is not None and inputs_embeds is not None:
621
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
 
622
  elif input_ids is not None:
623
  input_shape = input_ids.size()
624
  input_ids = input_ids.view(-1, input_shape[-1])
625
  elif inputs_embeds is not None:
626
  input_shape = inputs_embeds.size()[:-1]
627
  else:
628
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
629
-
630
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
631
 
632
  if inputs_embeds is None:
633
  inputs_embeds = self.embed_tokens(input_ids)
634
 
 
 
 
 
 
 
 
635
  # embed positions
636
- if attention_mask is None:
637
- attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
638
- pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
- attention_mask = self._prepare_decoder_attention_mask(
641
- attention_mask, input_shape, inputs_embeds, past_key_values_length
642
- )
643
 
644
  if self.project_in is not None:
645
  inputs_embeds = self.project_in(inputs_embeds)
646
 
647
  hidden_states = inputs_embeds + pos_embeds
648
 
 
 
 
 
 
 
 
649
  # decoder layers
650
  all_hidden_states = () if output_hidden_states else None
651
  all_self_attns = () if output_attentions else None
@@ -665,39 +1006,29 @@ class OPTDecoder(OPTPreTrainedModel):
665
  if output_hidden_states:
666
  all_hidden_states += (hidden_states,)
667
 
668
- dropout_probability = random.uniform(0, 1)
669
- if self.training and (dropout_probability < self.layerdrop):
670
- continue
 
671
 
672
- past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
673
 
674
  if self.gradient_checkpointing and self.training:
675
-
676
- if use_cache:
677
- logger.warning(
678
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
679
- )
680
- use_cache = False
681
-
682
- def create_custom_forward(module):
683
- def custom_forward(*inputs):
684
- # None for past_key_value
685
- return module(*inputs, output_attentions, None)
686
-
687
- return custom_forward
688
-
689
- layer_outputs = torch.utils.checkpoint.checkpoint(
690
- create_custom_forward(decoder_layer),
691
  hidden_states,
692
- attention_mask,
693
  head_mask[idx] if head_mask is not None else None,
694
  None,
 
 
695
  )
696
  else:
697
-
698
  layer_outputs = decoder_layer(
699
  hidden_states,
700
- attention_mask=attention_mask,
701
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
702
  past_key_value=past_key_value,
703
  output_attentions=output_attentions,
@@ -712,12 +1043,6 @@ class OPTDecoder(OPTPreTrainedModel):
712
  if output_attentions:
713
  all_self_attns += (layer_outputs[1],)
714
 
715
- # Model Parallel: If it's the last layer for that device, put things on the next device
716
- if self.model_parallel:
717
- for k, v in self.device_map.items():
718
- if idx == v[-1] and "cuda:" + str(k) != self.last_device:
719
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
720
-
721
  if self.final_layer_norm is not None:
722
  hidden_states = self.final_layer_norm(hidden_states)
723
 
@@ -730,7 +1055,11 @@ class OPTDecoder(OPTPreTrainedModel):
730
 
731
  next_cache = next_decoder_cache if use_cache else None
732
  if not return_dict:
733
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
734
  return BaseModelOutputWithPast(
735
  last_hidden_state=hidden_states,
736
  past_key_values=next_cache,
@@ -747,46 +1076,9 @@ class OPTModel(OPTPreTrainedModel):
747
  def __init__(self, config: OPTConfig):
748
  super().__init__(config)
749
  self.decoder = OPTDecoder(config)
750
-
751
- # Model parallel
752
- self.decoder.model_parallel = False
753
- self.decoder.device_map = None
754
- self.decoder.gradient_checkpointing = False
755
-
756
  # Initialize weights and apply final processing
757
  self.post_init()
758
 
759
- def parallelize(self, device_map=None):
760
- # Check validity of device_map
761
- self.decoder.device_map = (
762
- get_device_map(len(self.decoder.layers), range(torch.cuda.device_count())) if device_map is None else device_map
763
- )
764
- assert_device_map(self.decoder.device_map, len(self.decoder.layers))
765
- self.decoder.model_parallel = True
766
- self.decoder.first_device = "cpu" if "cpu" in self.decoder.device_map.keys() else "cuda:" + str(min(self.decoder.device_map.keys()))
767
- self.decoder.last_device = "cuda:" + str(max(self.decoder.device_map.keys()))
768
- self.decoder.embed_tokens = self.decoder.embed_tokens.to(self.decoder.first_device)
769
- self.decoder.embed_positions = self.decoder.embed_positions.to(self.decoder.first_device)
770
- # Load onto devices
771
- for k, v in self.decoder.device_map.items():
772
- for block in v:
773
- cuda_device = "cuda:" + str(k)
774
- self.decoder.layers[block] = self.decoder.layers[block].to(cuda_device)
775
- # final_layer_norm to last
776
- self.decoder.final_layer_norm = self.decoder.final_layer_norm.to(self.decoder.last_device)
777
-
778
- def deparallelize(self):
779
- self.decoder.model_parallel = False
780
- self.decoder.device_map = None
781
- self.decoder.first_device = "cpu"
782
- self.decoder.last_device = "cpu"
783
- self.decoder.embed_tokens = self.decoder.embed_tokens.to("cpu")
784
- self.decoder.embed_positions = self.decoder.embed_positions.to("cpu")
785
- for index in range(len(self.decoder)):
786
- self.decoder.layers[index] = self.decoder.layers[index].to("cpu")
787
- self.decoder.final_layer_norm = self.decoder.final_layer_norm.to("cpu")
788
- torch.cuda.empty_cache()
789
-
790
  def get_input_embeddings(self):
791
  return self.decoder.embed_tokens
792
 
@@ -798,7 +1090,6 @@ class OPTModel(OPTPreTrainedModel):
798
 
799
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
800
  @add_code_sample_docstrings(
801
- processor_class=_TOKENIZER_FOR_DOC,
802
  checkpoint=_CHECKPOINT_FOR_DOC,
803
  output_type=BaseModelOutputWithPast,
804
  config_class=_CONFIG_FOR_DOC,
@@ -816,13 +1107,20 @@ class OPTModel(OPTPreTrainedModel):
816
  output_hidden_states: Optional[bool] = None,
817
  return_dict: Optional[bool] = None,
818
  ) -> Union[Tuple, BaseModelOutputWithPast]:
819
-
820
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
821
  output_hidden_states = (
822
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
823
  )
824
  use_cache = use_cache if use_cache is not None else self.config.use_cache
825
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
826
 
827
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
828
  decoder_outputs = self.decoder(
@@ -849,40 +1147,20 @@ class OPTModel(OPTPreTrainedModel):
849
 
850
 
851
  class OPTForCausalLM(OPTPreTrainedModel):
852
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
853
 
854
  def __init__(self, config):
855
  super().__init__(config)
856
  self.model = OPTModel(config)
857
 
858
  # the lm_head weight is automatically tied to the embed tokens weight
859
- self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
860
-
861
- # Model parallel
862
- self.model_parallel = False
863
- self.device_map = None
864
 
865
  # Initialize weights and apply final processing
866
  self.post_init()
867
 
868
- def parallelize(self, device_map=None):
869
- self.model.decoder.device_map = (
870
- get_device_map(len(self.model.decoder.layers), range(torch.cuda.device_count()))
871
- if device_map is None
872
- else device_map
873
- )
874
- assert_device_map(self.model.decoder.device_map, len(self.model.decoder.layers))
875
- self.model.parallelize(self.model.decoder.device_map)
876
- self.lm_head = self.lm_head.to(self.model.decoder.first_device)
877
- self.model_parallel = True
878
-
879
- def deparallelize(self):
880
- self.model.deparallelize()
881
- self.model = self.model.to("cpu")
882
- self.lm_head = self.lm_head.to("cpu")
883
- self.model_parallel = False
884
- torch.cuda.empty_cache()
885
-
886
  def get_input_embeddings(self):
887
  return self.model.decoder.embed_tokens
888
 
@@ -901,7 +1179,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
901
  def get_decoder(self):
902
  return self.model.decoder
903
 
904
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
905
  def forward(
906
  self,
907
  input_ids: torch.LongTensor = None,
@@ -920,33 +1200,25 @@ class OPTForCausalLM(OPTPreTrainedModel):
920
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
921
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
922
  provide it.
923
-
924
- Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
925
  [`PreTrainedTokenizer.__call__`] for details.
926
-
927
  [What are input IDs?](../glossary#input-ids)
928
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
929
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
930
-
931
  - 1 for tokens that are **not masked**,
932
  - 0 for tokens that are **masked**.
933
-
934
  [What are attention masks?](../glossary#attention-mask)
935
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
936
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
937
-
938
  - 1 indicates the head is **not masked**,
939
  - 0 indicates the head is **masked**.
940
-
941
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
942
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
943
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
944
  shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
945
  tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
946
-
947
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
948
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
949
-
950
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
951
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
952
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
@@ -969,31 +1241,33 @@ class OPTForCausalLM(OPTPreTrainedModel):
969
  for more detail.
970
  return_dict (`bool`, *optional*):
971
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
972
-
973
  Returns:
974
-
975
  Example:
976
-
977
  ```python
978
- >>> from transformers import GPT2Tokenizer, OPTForCausalLM
979
-
980
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
981
- >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
982
-
983
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
984
  >>> inputs = tokenizer(prompt, return_tensors="pt")
985
-
986
  >>> # Generate
987
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
988
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
989
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
990
  ```"""
991
 
992
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
993
  output_hidden_states = (
994
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
995
  )
996
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
 
998
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
999
  outputs = self.model.decoder(
@@ -1008,11 +1282,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
1008
  return_dict=return_dict,
1009
  )
1010
 
1011
- # Set device for model parallelism
1012
- if self.model.decoder.model_parallel:
1013
- torch.cuda.set_device(self.model.decoder.first_device)
1014
-
1015
- logits = self.lm_head(outputs[0].to(self.lm_head.weight.device)).contiguous()
1016
 
1017
  loss = None
1018
  if labels is not None:
@@ -1023,7 +1293,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
1023
  shift_labels = labels[..., 1:].contiguous()
1024
  # Flatten the tokens
1025
  loss_fct = CrossEntropyLoss()
1026
- loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
 
 
1027
 
1028
  if not return_dict:
1029
  output = (logits,) + outputs[1:]
@@ -1037,36 +1309,59 @@ class OPTForCausalLM(OPTPreTrainedModel):
1037
  attentions=outputs.attentions,
1038
  )
1039
 
1040
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1041
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1042
- if attention_mask is None:
1043
- attention_mask = input_ids.new_ones(input_ids.shape)
1044
-
1045
- if past:
1046
- input_ids = input_ids[:, -1:]
1047
- # first step, decoder_cached_states are empty
1048
- return {
1049
- "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1050
- "attention_mask": attention_mask,
1051
- "past_key_values": past,
1052
- "use_cache": use_cache,
1053
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1054
 
1055
  @staticmethod
1056
- def _reorder_cache(past, beam_idx):
1057
  reordered_past = ()
1058
- for layer_past in past:
1059
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
 
 
 
 
1060
  return reordered_past
1061
 
1062
 
1063
  @add_start_docstrings(
1064
  """
1065
  The OPT Model transformer with a sequence classification head on top (linear layer).
1066
-
1067
  [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1068
  (e.g. GPT-2) do.
1069
-
1070
  Since it does classification on the last token, it requires to know the position of the last token. If a
1071
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1072
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
@@ -1076,8 +1371,6 @@ class OPTForCausalLM(OPTPreTrainedModel):
1076
  OPT_START_DOCSTRING,
1077
  )
1078
  class OPTForSequenceClassification(OPTPreTrainedModel):
1079
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1080
-
1081
  def __init__(self, config: OPTConfig):
1082
  super().__init__(config)
1083
  self.num_labels = config.num_labels
@@ -1089,7 +1382,6 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1089
 
1090
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1091
  @add_code_sample_docstrings(
1092
- processor_class=_TOKENIZER_FOR_DOC,
1093
  checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1094
  output_type=SequenceClassifierOutputWithPast,
1095
  config_class=_CONFIG_FOR_DOC,
@@ -1115,7 +1407,9 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1115
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1116
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1117
  """
1118
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1119
 
1120
  transformer_outputs = self.model(
1121
  input_ids,
@@ -1140,7 +1434,12 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1140
  sequence_lengths = -1
1141
  else:
1142
  if input_ids is not None:
1143
- sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
 
 
 
 
 
1144
  else:
1145
  sequence_lengths = -1
1146
  logger.warning(
@@ -1148,14 +1447,18 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1148
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1149
  )
1150
 
1151
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1152
 
1153
  loss = None
1154
  if labels is not None:
1155
  if self.config.problem_type is None:
1156
  if self.num_labels == 1:
1157
  self.config.problem_type = "regression"
1158
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1159
  self.config.problem_type = "single_label_classification"
1160
  else:
1161
  self.config.problem_type = "multi_label_classification"
@@ -1168,7 +1471,9 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1168
  loss = loss_fct(pooled_logits, labels)
1169
  elif self.config.problem_type == "single_label_classification":
1170
  loss_fct = CrossEntropyLoss()
1171
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1172
  elif self.config.problem_type == "multi_label_classification":
1173
  loss_fct = BCEWithLogitsLoss()
1174
  loss = loss_fct(pooled_logits, labels)
@@ -1199,8 +1504,6 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1199
  OPT_START_DOCSTRING,
1200
  )
1201
  class OPTForQuestionAnswering(OPTPreTrainedModel):
1202
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1203
-
1204
  def __init__(self, config: OPTConfig):
1205
  super().__init__(config)
1206
  self.model = OPTModel(config)
@@ -1210,7 +1513,9 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
1210
  self.post_init()
1211
 
1212
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1213
- @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
 
 
1214
  def forward(
1215
  self,
1216
  input_ids: Optional[torch.LongTensor] = None,
@@ -1234,37 +1539,33 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
1234
  Labels for position (index) of the end of the labelled span for computing the token classification loss.
1235
  Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1236
  are not taken into account for computing the loss.
1237
-
1238
  Returns:
1239
-
1240
  Example:
1241
-
1242
  ```python
1243
- >>> from transformers import GPT2Tokenizer, OPTForQuestionAnswering
1244
  >>> import torch
1245
-
1246
  >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1247
- >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
1248
-
1249
  >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1250
  >>> # so the head will be randomly initialized, hence the predictions will be random
1251
  >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
1252
-
1253
  >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1254
-
1255
  >>> inputs = tokenizer(question, text, return_tensors="pt")
1256
  >>> with torch.no_grad():
1257
  ... outputs = model(**inputs)
1258
-
1259
  >>> answer_start_index = outputs.start_logits.argmax()
1260
  >>> answer_end_index = outputs.end_logits.argmax()
1261
-
1262
- >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
 
 
1263
  >>> predicted = tokenizer.decode(predict_answer_tokens)
1264
  >>> predicted
1265
- ' Henson?'
1266
  ```"""
1267
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1268
 
1269
  transformer_outputs = self.model(
1270
  input_ids,
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """ PyTorch OPT model."""
 
16
  from typing import List, Optional, Tuple, Union
17
 
18
  import torch
19
+ import torch.nn.functional as F
20
  import torch.utils.checkpoint
21
  from torch import nn
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
 
24
  from transformers.activations import ACT2FN
25
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
26
  from transformers.modeling_outputs import (
27
  BaseModelOutputWithPast,
28
  CausalLMOutputWithPast,
 
34
  add_code_sample_docstrings,
35
  add_start_docstrings,
36
  add_start_docstrings_to_model_forward,
37
+ is_flash_attn_2_available,
38
+ is_flash_attn_greater_or_equal_2_10,
39
  logging,
40
  replace_return_docstrings,
41
  )
42
  from .configuration_opt import OPTConfig
43
+
44
+
45
+ if is_flash_attn_2_available():
46
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
47
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
48
 
49
 
50
  logger = logging.get_logger(__name__)
51
 
52
  _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
53
  _CONFIG_FOR_DOC = "OPTConfig"
 
54
 
55
  # Base model docstring
56
  _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
 
71
  # See all OPT models at https://huggingface.co/models?filter=opt
72
  ]
73
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
76
+ def _get_unpad_data(attention_mask):
77
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
78
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
79
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
80
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
81
+ return (
82
+ indices,
83
+ cu_seqlens,
84
+ max_seqlen_in_batch,
85
+ )
86
 
87
 
88
+ # class OPTLearnedPositionalEmbedding(nn.Embedding):
89
+ # """
90
+ # This module learns positional embeddings up to a fixed maximum size.
91
+ # """
92
+
93
+ # def __init__(self, num_embeddings: int, embedding_dim: int):
94
+ # # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
95
+ # # and adjust num_embeddings appropriately. Other models don't have this hack
96
+ # self.offset = 2
97
+ # super().__init__(num_embeddings + self.offset, embedding_dim)
98
+
99
+ # def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
100
+ # """`input_ids_shape` is expected to be [bsz x seqlen]."""
101
+ # attention_mask = attention_mask.long()
102
 
103
+ # # create positions depending on attention_mask
104
+ # positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
105
 
106
+ # # cut positions if `past_key_values_length` is > 0
107
+ # positions = positions[:, past_key_values_length:]
108
 
109
+ # return super().forward(positions + self.offset)
110
 
111
 
112
+ class OPTLearnedPositionalEmbedding(nn.Module):
113
  """
114
  This module learns positional embeddings up to a fixed maximum size.
115
  """
 
117
  def __init__(self, num_embeddings: int, embedding_dim: int):
118
  # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
119
  # and adjust num_embeddings appropriately. Other models don't have this hack
120
+ super().__init__()
121
  self.offset = 2
122
+ self.embeddings = nn.Embedding(num_embeddings + self.offset, embedding_dim)
123
 
124
+ def forward(
125
+ self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
126
+ ):
127
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
128
  attention_mask = attention_mask.long()
129
 
130
  # create positions depending on attention_mask
131
+ positions = (
132
+ torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
133
+ ).long() - 1
134
 
135
  # cut positions if `past_key_values_length` is > 0
136
  positions = positions[:, past_key_values_length:]
137
 
138
+ return self.embeddings(positions + self.offset)
139
 
140
 
141
  class OPTAttention(nn.Module):
 
143
 
144
  def __init__(
145
  self,
146
+ config: OPTConfig,
 
 
147
  is_decoder: bool = False,
148
+ **kwargs,
149
  ):
150
  super().__init__()
151
+ self.config = config
152
+
153
+ def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
154
+ """
155
+ If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
156
+ warning and return that value, otherwise take the equivalent config.config_arg_name
157
+ """
158
+ val = None
159
+ if fn_arg_name in kwargs:
160
+ logging.warning(
161
+ "Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
162
+ "v4.39. Please set it in the config instead"
163
+ )
164
+ val = kwargs.pop(fn_arg_name)
165
+ else:
166
+ val = getattr(config, config_arg_name)
167
+ return val
168
 
169
+ self.embed_dim = _handle_deprecated_argument(
170
+ "hidden_size", config, "embed_dim", kwargs
171
+ )
172
+ self.num_heads = _handle_deprecated_argument(
173
+ "num_attention_heads", config, "num_heads", kwargs
174
+ )
175
+ self.dropout = _handle_deprecated_argument(
176
+ "attention_dropout", config, "dropout", kwargs
177
+ )
178
+ self.enable_bias = _handle_deprecated_argument(
179
+ "enable_bias", config, "bias", kwargs
180
+ )
181
+
182
+ self.head_dim = self.embed_dim // self.num_heads
183
+ self.is_causal = True
184
+
185
+ if (self.head_dim * self.num_heads) != self.embed_dim:
186
  raise ValueError(
187
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
188
+ f" and `num_heads`: {self.num_heads})."
189
  )
190
  self.scaling = self.head_dim**-0.5
191
  self.is_decoder = is_decoder
192
 
193
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
194
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
195
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
196
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
 
 
 
197
 
198
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
199
+ return (
200
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
201
+ .transpose(1, 2)
202
+ .contiguous()
203
+ )
204
 
205
  def forward(
206
  self,
 
270
  raise ValueError(
271
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
272
  )
273
+ attn_weights = (
274
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
275
+ + attention_mask
276
+ )
277
+ attn_weights = torch.max(
278
+ attn_weights,
279
+ torch.tensor(
280
+ torch.finfo(attn_weights.dtype).min, device=attn_weights.device
281
+ ),
282
+ )
283
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
284
 
285
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
286
  if attn_weights.dtype == torch.float16:
287
+ attn_weights = nn.functional.softmax(
288
+ attn_weights, dim=-1, dtype=torch.float32
289
+ ).to(torch.float16)
290
  else:
291
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
292
 
293
  if layer_head_mask is not None:
294
  if layer_head_mask.size() != (self.num_heads,):
 
296
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
297
  f" {layer_head_mask.size()}"
298
  )
299
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
300
+ bsz, self.num_heads, tgt_len, src_len
301
+ )
302
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
303
 
304
  if output_attentions:
 
306
  # make sure that attn_weights keeps its gradient.
307
  # In order to do so, attn_weights have to be reshaped
308
  # twice and have to be reused in the following
309
+ attn_weights_reshaped = attn_weights.view(
310
+ bsz, self.num_heads, tgt_len, src_len
311
+ )
312
+ attn_weights = attn_weights_reshaped.view(
313
+ bsz * self.num_heads, tgt_len, src_len
314
+ )
315
  else:
316
  attn_weights_reshaped = None
317
 
318
+ attn_probs = nn.functional.dropout(
319
+ attn_weights, p=self.dropout, training=self.training
320
+ )
321
+
322
  attn_output = torch.bmm(attn_probs, value_states)
323
 
324
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
 
339
  return attn_output, attn_weights_reshaped, past_key_value
340
 
341
 
342
+ class OptFlashAttention2(OPTAttention):
343
+ """
344
+ OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
345
+ The only required change would be on the forward pass where it needs to correctly call the public API of flash
346
+ attention and deal with padding tokens in case the input contains any of them.
347
+ """
348
+
349
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
350
+ def __init__(self, *args, **kwargs):
351
+ super().__init__(*args, **kwargs)
352
+
353
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
354
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
355
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
356
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ key_value_states: Optional[torch.Tensor] = None,
362
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ layer_head_mask: Optional[torch.Tensor] = None,
365
+ output_attentions: bool = False,
366
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
367
+ """Input shape: Batch x Time x Channel"""
368
+
369
+ # if key_value_states are provided this layer is used as a cross-attention layer
370
+ # for the decoder
371
+ is_cross_attention = key_value_states is not None
372
+
373
+ bsz, _, _ = hidden_states.size()
374
+
375
+ # get query proj
376
+ query_states = self.q_proj(hidden_states)
377
+ # get key, value proj
378
+ if is_cross_attention and past_key_value is not None:
379
+ # reuse k,v, cross_attentions
380
+ key_states = past_key_value[0]
381
+ value_states = past_key_value[1]
382
+ elif is_cross_attention:
383
+ # cross_attentions
384
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
385
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
386
+ elif past_key_value is not None:
387
+ # reuse k, v, self_attention
388
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
389
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
390
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
391
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
392
+ else:
393
+ # self_attention
394
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
395
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
396
+
397
+ if self.is_decoder:
398
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
399
+ # Further calls to cross_attention layer can then reuse all cross-attention
400
+ # key/value_states (first "if" case)
401
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
402
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
403
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
404
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
405
+ past_key_value = (key_states, value_states)
406
+
407
+ query_length = query_states.shape[1]
408
+ tgt_len = key_states.shape[-2]
409
+
410
+ # Flash attention requires the input to have the shape
411
+ # batch_size x seq_length x head_dim x hidden_dim
412
+ query_states = query_states.view(
413
+ bsz, query_length, self.num_heads, self.head_dim
414
+ )
415
+ key_states = key_states.transpose(1, 2).view(
416
+ bsz, tgt_len, self.num_heads, self.head_dim
417
+ )
418
+ value_states = value_states.transpose(1, 2).view(
419
+ bsz, tgt_len, self.num_heads, self.head_dim
420
+ )
421
+
422
+ attn_dropout = self.dropout if self.training else 0.0
423
+
424
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
425
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
426
+ # cast them back in float16 just to be sure everything works as expected.
427
+ input_dtype = query_states.dtype
428
+ if input_dtype == torch.float32:
429
+ if torch.is_autocast_enabled():
430
+ target_dtype = torch.get_autocast_gpu_dtype()
431
+ # Handle the case where the model is quantized
432
+ elif hasattr(self.config, "_pre_quantization_dtype"):
433
+ target_dtype = self.config._pre_quantization_dtype
434
+ else:
435
+ target_dtype = self.q_proj.weight.dtype
436
+
437
+ logger.warning_once(
438
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
439
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
440
+ f" {target_dtype}."
441
+ )
442
+
443
+ query_states = query_states.to(target_dtype)
444
+ key_states = key_states.to(target_dtype)
445
+ value_states = value_states.to(target_dtype)
446
+
447
+ attn_output = self._flash_attention_forward(
448
+ query_states,
449
+ key_states,
450
+ value_states,
451
+ attention_mask,
452
+ query_length,
453
+ dropout=attn_dropout,
454
+ )
455
+
456
+ attn_weights_reshaped = attn_output.reshape(
457
+ bsz, query_length, self.num_heads * self.head_dim
458
+ )
459
+ attn_output = self.out_proj(attn_weights_reshaped)
460
+
461
+ if not output_attentions:
462
+ attn_weights_reshaped = None
463
+
464
+ return attn_output, attn_weights_reshaped, past_key_value
465
+
466
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
467
+ def _flash_attention_forward(
468
+ self,
469
+ query_states,
470
+ key_states,
471
+ value_states,
472
+ attention_mask,
473
+ query_length,
474
+ dropout=0.0,
475
+ softmax_scale=None,
476
+ ):
477
+ """
478
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
479
+ first unpad the input, then computes the attention scores and pad the final attention scores.
480
+ Args:
481
+ query_states (`torch.Tensor`):
482
+ Input query states to be passed to Flash Attention API
483
+ key_states (`torch.Tensor`):
484
+ Input key states to be passed to Flash Attention API
485
+ value_states (`torch.Tensor`):
486
+ Input value states to be passed to Flash Attention API
487
+ attention_mask (`torch.Tensor`):
488
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
489
+ position of padding tokens and 1 for the position of non-padding tokens.
490
+ dropout (`int`, *optional*):
491
+ Attention dropout
492
+ softmax_scale (`float`, *optional*):
493
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
494
+ """
495
+ if not self._flash_attn_uses_top_left_mask:
496
+ causal = self.is_causal
497
+ else:
498
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
499
+ causal = self.is_causal and query_length != 1
500
+
501
+ # Contains at least one padding token in the sequence
502
+ if attention_mask is not None:
503
+ batch_size = query_states.shape[0]
504
+ (
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ indices_q,
509
+ cu_seq_lens,
510
+ max_seq_lens,
511
+ ) = self._upad_input(
512
+ query_states, key_states, value_states, attention_mask, query_length
513
+ )
514
+
515
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
516
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
517
+
518
+ attn_output_unpad = flash_attn_varlen_func(
519
+ query_states,
520
+ key_states,
521
+ value_states,
522
+ cu_seqlens_q=cu_seqlens_q,
523
+ cu_seqlens_k=cu_seqlens_k,
524
+ max_seqlen_q=max_seqlen_in_batch_q,
525
+ max_seqlen_k=max_seqlen_in_batch_k,
526
+ dropout_p=dropout,
527
+ softmax_scale=softmax_scale,
528
+ causal=causal,
529
+ )
530
+
531
+ attn_output = pad_input(
532
+ attn_output_unpad, indices_q, batch_size, query_length
533
+ )
534
+ else:
535
+ attn_output = flash_attn_func(
536
+ query_states,
537
+ key_states,
538
+ value_states,
539
+ dropout,
540
+ softmax_scale=softmax_scale,
541
+ causal=causal,
542
+ )
543
+
544
+ return attn_output
545
+
546
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
547
+ def _upad_input(
548
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
549
+ ):
550
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
551
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
552
+
553
+ key_layer = index_first_axis(
554
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
555
+ indices_k,
556
+ )
557
+ value_layer = index_first_axis(
558
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
559
+ indices_k,
560
+ )
561
+ if query_length == kv_seq_len:
562
+ query_layer = index_first_axis(
563
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
564
+ indices_k,
565
+ )
566
+ cu_seqlens_q = cu_seqlens_k
567
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
568
+ indices_q = indices_k
569
+ elif query_length == 1:
570
+ max_seqlen_in_batch_q = 1
571
+ cu_seqlens_q = torch.arange(
572
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
573
+ ) # There is a memcpy here, that is very bad.
574
+ indices_q = cu_seqlens_q[:-1]
575
+ query_layer = query_layer.squeeze(1)
576
+ else:
577
+ # The -q_len: slice assumes left padding.
578
+ attention_mask = attention_mask[:, -query_length:]
579
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
580
+ query_layer, attention_mask
581
+ )
582
+
583
+ return (
584
+ query_layer,
585
+ key_layer,
586
+ value_layer,
587
+ indices_q,
588
+ (cu_seqlens_q, cu_seqlens_k),
589
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
590
+ )
591
+
592
+
593
+ OPT_ATTENTION_CLASSES = {
594
+ "eager": OPTAttention,
595
+ "flash_attention_2": OptFlashAttention2,
596
+ }
597
+
598
+
599
  class OPTDecoderLayer(nn.Module):
600
  def __init__(self, config: OPTConfig):
601
  super().__init__()
602
  self.embed_dim = config.hidden_size
603
+
604
+ self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
605
+ config=config, is_decoder=True
 
 
606
  )
607
+
608
  self.do_layer_norm_before = config.do_layer_norm_before
609
+ self.dropout = config.dropout
610
  self.activation_fn = ACT2FN[config.activation_function]
611
 
612
+ self.self_attn_layer_norm = nn.LayerNorm(
613
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
614
+ )
615
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
616
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
617
+ self.final_layer_norm = nn.LayerNorm(
618
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
619
+ )
620
 
621
  def forward(
622
  self,
623
  hidden_states: torch.Tensor,
624
  attention_mask: Optional[torch.Tensor] = None,
625
  layer_head_mask: Optional[torch.Tensor] = None,
626
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
627
  output_attentions: Optional[bool] = False,
628
  use_cache: Optional[bool] = False,
629
+ ) -> Tuple[
630
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
631
+ ]:
632
  """
633
  Args:
634
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
659
  layer_head_mask=layer_head_mask,
660
  output_attentions=output_attentions,
661
  )
662
+ hidden_states = nn.functional.dropout(
663
+ hidden_states, p=self.dropout, training=self.training
664
+ )
665
  hidden_states = residual + hidden_states
666
 
667
  # 350m applies layer norm AFTER attention
 
681
  hidden_states = self.activation_fn(hidden_states)
682
 
683
  hidden_states = self.fc2(hidden_states)
684
+ hidden_states = nn.functional.dropout(
685
+ hidden_states, p=self.dropout, training=self.training
686
+ )
687
 
688
  hidden_states = (residual + hidden_states).view(hidden_states_shape)
689
 
 
706
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
707
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
708
  etc.)
 
709
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
710
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
711
  and behavior.
 
712
  Parameters:
713
  config ([`OPTConfig`]):
714
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
726
  base_model_prefix = "model"
727
  supports_gradient_checkpointing = True
728
  _no_split_modules = ["OPTDecoderLayer"]
729
+ _supports_flash_attn_2 = True
730
 
731
  def _init_weights(self, module):
732
  std = self.config.init_std
 
739
  if module.padding_idx is not None:
740
  module.weight.data[module.padding_idx].zero_()
741
 
 
 
 
 
742
 
743
  OPT_INPUTS_DOCSTRING = r"""
744
  Args:
745
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
746
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
747
  it.
748
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
 
749
  [`PreTrainedTokenizer.__call__`] for details.
 
750
  [What are input IDs?](../glossary#input-ids)
751
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
752
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
753
  - 1 for tokens that are **not masked**,
754
  - 0 for tokens that are **masked**.
 
755
  [What are attention masks?](../glossary#attention-mask)
756
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
 
757
  [`PreTrainedTokenizer.__call__`] for details.
 
758
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
759
  `past_key_values`).
 
760
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
761
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
762
  information on the default strategy.
763
  head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
764
  Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
 
765
  - 1 indicates the head is **not masked**,
766
  - 0 indicates the head is **masked**.
 
767
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
768
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
769
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
770
  `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
 
771
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
772
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
773
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
774
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
775
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
794
  class OPTDecoder(OPTPreTrainedModel):
795
  """
796
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
 
797
  Args:
798
  config: OPTConfig
799
  """
 
806
  self.max_target_positions = config.max_position_embeddings
807
  self.vocab_size = config.vocab_size
808
 
809
+ self.embed_tokens = nn.Embedding(
810
+ config.vocab_size, config.word_embed_proj_dim, self.padding_idx
811
+ )
812
+ self._embed_positions = OPTLearnedPositionalEmbedding(
813
+ config.max_position_embeddings, config.hidden_size
814
+ )
815
+ self.embed_positions = self._embed_positions.embeddings
816
 
817
  if config.word_embed_proj_dim != config.hidden_size:
818
+ self.project_out = nn.Linear(
819
+ config.hidden_size, config.word_embed_proj_dim, bias=False
820
+ )
821
  else:
822
  self.project_out = None
823
 
824
  if config.word_embed_proj_dim != config.hidden_size:
825
+ self.project_in = nn.Linear(
826
+ config.word_embed_proj_dim, config.hidden_size, bias=False
827
+ )
828
  else:
829
  self.project_in = None
830
 
 
832
  # with checkpoints that have been fine-tuned before transformers v4.20.1
833
  # see https://github.com/facebookresearch/metaseq/pull/164
834
  if config.do_layer_norm_before and not config._remove_final_layer_norm:
835
+ self.final_layer_norm = nn.LayerNorm(
836
+ config.hidden_size,
837
+ elementwise_affine=config.layer_norm_elementwise_affine,
838
+ )
839
  else:
840
  self.final_layer_norm = None
841
 
842
+ self.layers = nn.ModuleList(
843
+ [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
844
+ )
845
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
846
 
847
  self.gradient_checkpointing = False
848
  # Initialize weights and apply final processing
 
854
  def set_input_embeddings(self, value):
855
  self.embed_tokens = value
856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
  def forward(
858
  self,
859
  input_ids: torch.LongTensor = None,
 
871
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
872
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
873
  provide it.
874
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
 
875
  [`PreTrainedTokenizer.__call__`] for details.
 
876
  [What are input IDs?](../glossary#input-ids)
877
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
878
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
879
  - 1 for tokens that are **not masked**,
880
  - 0 for tokens that are **masked**.
 
881
  [What are attention masks?](../glossary#attention-mask)
882
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
883
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
 
884
  - 1 indicates the head is **not masked**,
885
  - 0 indicates the head is **masked**.
 
886
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
887
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
888
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
 
889
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
890
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
891
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
892
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
893
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
894
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
895
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
896
  This is useful if you want more control over how to convert `input_ids` indices into associated vectors
 
904
  return_dict (`bool`, *optional*):
905
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
906
  """
907
+ output_attentions = (
908
+ output_attentions
909
+ if output_attentions is not None
910
+ else self.config.output_attentions
911
+ )
912
  output_hidden_states = (
913
+ output_hidden_states
914
+ if output_hidden_states is not None
915
+ else self.config.output_hidden_states
916
  )
917
  use_cache = use_cache if use_cache is not None else self.config.use_cache
918
 
919
+ return_dict = (
920
+ return_dict if return_dict is not None else self.config.use_return_dict
921
+ )
922
 
923
  # retrieve input_ids and inputs_embeds
924
  if input_ids is not None and inputs_embeds is not None:
925
+ raise ValueError(
926
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
927
+ )
928
  elif input_ids is not None:
929
  input_shape = input_ids.size()
930
  input_ids = input_ids.view(-1, input_shape[-1])
931
  elif inputs_embeds is not None:
932
  input_shape = inputs_embeds.size()[:-1]
933
  else:
934
+ raise ValueError(
935
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
936
+ )
937
 
938
  if inputs_embeds is None:
939
  inputs_embeds = self.embed_tokens(input_ids)
940
 
941
+ batch_size, seq_length = input_shape
942
+ past_key_values_length = (
943
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
944
+ )
945
+ # required mask seq length can be calculated via length of past
946
+ mask_seq_length = past_key_values_length + seq_length
947
+
948
  # embed positions
949
+ if self._use_flash_attention_2:
950
+ # 2d mask is passed through the layers
951
+ causal_attention_mask = (
952
+ attention_mask
953
+ if (attention_mask is not None and 0 in attention_mask)
954
+ else None
955
+ )
956
+ attention_mask = (
957
+ torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
958
+ if attention_mask is None
959
+ else attention_mask
960
+ )
961
+ else:
962
+ # 4d mask is passed through the layers
963
+ if attention_mask is None:
964
+ attention_mask = torch.ones(
965
+ batch_size, mask_seq_length, device=inputs_embeds.device
966
+ )
967
+ elif attention_mask.shape[1] != mask_seq_length:
968
+ raise ValueError(
969
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
970
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)"
971
+ )
972
+ causal_attention_mask = _prepare_4d_causal_attention_mask(
973
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
974
+ )
975
 
976
+ pos_embeds = self._embed_positions(attention_mask, past_key_values_length)
 
 
977
 
978
  if self.project_in is not None:
979
  inputs_embeds = self.project_in(inputs_embeds)
980
 
981
  hidden_states = inputs_embeds + pos_embeds
982
 
983
+ if self.gradient_checkpointing and self.training:
984
+ if use_cache:
985
+ logger.warning_once(
986
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
987
+ )
988
+ use_cache = False
989
+
990
  # decoder layers
991
  all_hidden_states = () if output_hidden_states else None
992
  all_self_attns = () if output_attentions else None
 
1006
  if output_hidden_states:
1007
  all_hidden_states += (hidden_states,)
1008
 
1009
+ if self.training:
1010
+ dropout_probability = torch.rand([])
1011
+ if dropout_probability < self.layerdrop:
1012
+ continue
1013
 
1014
+ past_key_value = (
1015
+ past_key_values[idx] if past_key_values is not None else None
1016
+ )
1017
 
1018
  if self.gradient_checkpointing and self.training:
1019
+ layer_outputs = self._gradient_checkpointing_func(
1020
+ decoder_layer.__call__,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  hidden_states,
1022
+ causal_attention_mask,
1023
  head_mask[idx] if head_mask is not None else None,
1024
  None,
1025
+ output_attentions,
1026
+ use_cache,
1027
  )
1028
  else:
 
1029
  layer_outputs = decoder_layer(
1030
  hidden_states,
1031
+ attention_mask=causal_attention_mask,
1032
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1033
  past_key_value=past_key_value,
1034
  output_attentions=output_attentions,
 
1043
  if output_attentions:
1044
  all_self_attns += (layer_outputs[1],)
1045
 
 
 
 
 
 
 
1046
  if self.final_layer_norm is not None:
1047
  hidden_states = self.final_layer_norm(hidden_states)
1048
 
 
1055
 
1056
  next_cache = next_decoder_cache if use_cache else None
1057
  if not return_dict:
1058
+ return tuple(
1059
+ v
1060
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1061
+ if v is not None
1062
+ )
1063
  return BaseModelOutputWithPast(
1064
  last_hidden_state=hidden_states,
1065
  past_key_values=next_cache,
 
1076
  def __init__(self, config: OPTConfig):
1077
  super().__init__(config)
1078
  self.decoder = OPTDecoder(config)
 
 
 
 
 
 
1079
  # Initialize weights and apply final processing
1080
  self.post_init()
1081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1082
  def get_input_embeddings(self):
1083
  return self.decoder.embed_tokens
1084
 
 
1090
 
1091
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1092
  @add_code_sample_docstrings(
 
1093
  checkpoint=_CHECKPOINT_FOR_DOC,
1094
  output_type=BaseModelOutputWithPast,
1095
  config_class=_CONFIG_FOR_DOC,
 
1107
  output_hidden_states: Optional[bool] = None,
1108
  return_dict: Optional[bool] = None,
1109
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1110
+ output_attentions = (
1111
+ output_attentions
1112
+ if output_attentions is not None
1113
+ else self.config.output_attentions
1114
+ )
1115
  output_hidden_states = (
1116
+ output_hidden_states
1117
+ if output_hidden_states is not None
1118
+ else self.config.output_hidden_states
1119
  )
1120
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1121
+ return_dict = (
1122
+ return_dict if return_dict is not None else self.config.use_return_dict
1123
+ )
1124
 
1125
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1126
  decoder_outputs = self.decoder(
 
1147
 
1148
 
1149
  class OPTForCausalLM(OPTPreTrainedModel):
1150
+ _tied_weights_keys = ["lm_head.weight"]
1151
 
1152
  def __init__(self, config):
1153
  super().__init__(config)
1154
  self.model = OPTModel(config)
1155
 
1156
  # the lm_head weight is automatically tied to the embed tokens weight
1157
+ self.lm_head = nn.Linear(
1158
+ config.word_embed_proj_dim, config.vocab_size, bias=False
1159
+ )
 
 
1160
 
1161
  # Initialize weights and apply final processing
1162
  self.post_init()
1163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1164
  def get_input_embeddings(self):
1165
  return self.model.decoder.embed_tokens
1166
 
 
1179
  def get_decoder(self):
1180
  return self.model.decoder
1181
 
1182
+ @replace_return_docstrings(
1183
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1184
+ )
1185
  def forward(
1186
  self,
1187
  input_ids: torch.LongTensor = None,
 
1200
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1201
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1202
  provide it.
1203
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
 
1204
  [`PreTrainedTokenizer.__call__`] for details.
 
1205
  [What are input IDs?](../glossary#input-ids)
1206
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1207
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
1208
  - 1 for tokens that are **not masked**,
1209
  - 0 for tokens that are **masked**.
 
1210
  [What are attention masks?](../glossary#attention-mask)
1211
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
1212
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
 
1213
  - 1 indicates the head is **not masked**,
1214
  - 0 indicates the head is **masked**.
 
1215
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1216
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1217
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1218
  shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1219
  tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
 
1220
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1221
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
1222
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1223
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1224
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
1241
  for more detail.
1242
  return_dict (`bool`, *optional*):
1243
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
1244
  Returns:
 
1245
  Example:
 
1246
  ```python
1247
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
 
1248
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
1249
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1250
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
 
1251
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1252
  >>> # Generate
1253
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1254
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1255
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
1256
  ```"""
1257
 
1258
+ output_attentions = (
1259
+ output_attentions
1260
+ if output_attentions is not None
1261
+ else self.config.output_attentions
1262
+ )
1263
  output_hidden_states = (
1264
+ output_hidden_states
1265
+ if output_hidden_states is not None
1266
+ else self.config.output_hidden_states
1267
+ )
1268
+ return_dict = (
1269
+ return_dict if return_dict is not None else self.config.use_return_dict
1270
  )
 
1271
 
1272
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1273
  outputs = self.model.decoder(
 
1282
  return_dict=return_dict,
1283
  )
1284
 
1285
+ logits = self.lm_head(outputs[0]).contiguous()
 
 
 
 
1286
 
1287
  loss = None
1288
  if labels is not None:
 
1293
  shift_labels = labels[..., 1:].contiguous()
1294
  # Flatten the tokens
1295
  loss_fct = CrossEntropyLoss()
1296
+ loss = loss_fct(
1297
+ shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
1298
+ )
1299
 
1300
  if not return_dict:
1301
  output = (logits,) + outputs[1:]
 
1309
  attentions=outputs.attentions,
1310
  )
1311
 
1312
+ def prepare_inputs_for_generation(
1313
+ self,
1314
+ input_ids,
1315
+ past_key_values=None,
1316
+ attention_mask=None,
1317
+ inputs_embeds=None,
1318
+ **kwargs,
1319
+ ):
1320
+ if past_key_values is not None:
1321
+ past_length = past_key_values[0][0].shape[2]
1322
+
1323
+ # Some generation methods already pass only the last input ID
1324
+ if input_ids.shape[1] > past_length:
1325
+ remove_prefix_length = past_length
1326
+ else:
1327
+ # Default to old behavior: keep only final ID
1328
+ remove_prefix_length = input_ids.shape[1] - 1
1329
+
1330
+ input_ids = input_ids[:, remove_prefix_length:]
1331
+
1332
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1333
+ if inputs_embeds is not None and past_key_values is None:
1334
+ model_inputs = {"inputs_embeds": inputs_embeds}
1335
+ else:
1336
+ model_inputs = {"input_ids": input_ids}
1337
+
1338
+ model_inputs.update(
1339
+ {
1340
+ "past_key_values": past_key_values,
1341
+ "use_cache": kwargs.get("use_cache"),
1342
+ "attention_mask": attention_mask,
1343
+ }
1344
+ )
1345
+ return model_inputs
1346
 
1347
  @staticmethod
1348
+ def _reorder_cache(past_key_values, beam_idx):
1349
  reordered_past = ()
1350
+ for layer_past in past_key_values:
1351
+ reordered_past += (
1352
+ tuple(
1353
+ past_state.index_select(0, beam_idx.to(past_state.device))
1354
+ for past_state in layer_past
1355
+ ),
1356
+ )
1357
  return reordered_past
1358
 
1359
 
1360
  @add_start_docstrings(
1361
  """
1362
  The OPT Model transformer with a sequence classification head on top (linear layer).
 
1363
  [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1364
  (e.g. GPT-2) do.
 
1365
  Since it does classification on the last token, it requires to know the position of the last token. If a
1366
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1367
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
1371
  OPT_START_DOCSTRING,
1372
  )
1373
  class OPTForSequenceClassification(OPTPreTrainedModel):
 
 
1374
  def __init__(self, config: OPTConfig):
1375
  super().__init__(config)
1376
  self.num_labels = config.num_labels
 
1382
 
1383
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1384
  @add_code_sample_docstrings(
 
1385
  checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1386
  output_type=SequenceClassifierOutputWithPast,
1387
  config_class=_CONFIG_FOR_DOC,
 
1407
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1408
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1409
  """
1410
+ return_dict = (
1411
+ return_dict if return_dict is not None else self.config.use_return_dict
1412
+ )
1413
 
1414
  transformer_outputs = self.model(
1415
  input_ids,
 
1434
  sequence_lengths = -1
1435
  else:
1436
  if input_ids is not None:
1437
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1438
+ sequence_lengths = (
1439
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1440
+ )
1441
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1442
+ sequence_lengths = sequence_lengths.to(logits.device)
1443
  else:
1444
  sequence_lengths = -1
1445
  logger.warning(
 
1447
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1448
  )
1449
 
1450
+ pooled_logits = logits[
1451
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1452
+ ]
1453
 
1454
  loss = None
1455
  if labels is not None:
1456
  if self.config.problem_type is None:
1457
  if self.num_labels == 1:
1458
  self.config.problem_type = "regression"
1459
+ elif self.num_labels > 1 and (
1460
+ labels.dtype == torch.long or labels.dtype == torch.int
1461
+ ):
1462
  self.config.problem_type = "single_label_classification"
1463
  else:
1464
  self.config.problem_type = "multi_label_classification"
 
1471
  loss = loss_fct(pooled_logits, labels)
1472
  elif self.config.problem_type == "single_label_classification":
1473
  loss_fct = CrossEntropyLoss()
1474
+ loss = loss_fct(
1475
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1476
+ )
1477
  elif self.config.problem_type == "multi_label_classification":
1478
  loss_fct = BCEWithLogitsLoss()
1479
  loss = loss_fct(pooled_logits, labels)
 
1504
  OPT_START_DOCSTRING,
1505
  )
1506
  class OPTForQuestionAnswering(OPTPreTrainedModel):
 
 
1507
  def __init__(self, config: OPTConfig):
1508
  super().__init__(config)
1509
  self.model = OPTModel(config)
 
1513
  self.post_init()
1514
 
1515
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1516
+ @replace_return_docstrings(
1517
+ output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC
1518
+ )
1519
  def forward(
1520
  self,
1521
  input_ids: Optional[torch.LongTensor] = None,
 
1539
  Labels for position (index) of the end of the labelled span for computing the token classification loss.
1540
  Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1541
  are not taken into account for computing the loss.
 
1542
  Returns:
 
1543
  Example:
 
1544
  ```python
1545
+ >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
1546
  >>> import torch
 
1547
  >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1548
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
 
1549
  >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1550
  >>> # so the head will be randomly initialized, hence the predictions will be random
1551
  >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
 
1552
  >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
 
1553
  >>> inputs = tokenizer(question, text, return_tensors="pt")
1554
  >>> with torch.no_grad():
1555
  ... outputs = model(**inputs)
 
1556
  >>> answer_start_index = outputs.start_logits.argmax()
1557
  >>> answer_end_index = outputs.end_logits.argmax()
1558
+ >>> answer_offset = len(tokenizer(question)[0])
1559
+ >>> predict_answer_tokens = inputs.input_ids[
1560
+ ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
1561
+ ... ]
1562
  >>> predicted = tokenizer.decode(predict_answer_tokens)
1563
  >>> predicted
1564
+ ' a nice puppet'
1565
  ```"""
1566
+ return_dict = (
1567
+ return_dict if return_dict is not None else self.config.use_return_dict
1568
+ )
1569
 
1570
  transformer_outputs = self.model(
1571
  input_ids,