davda54 commited on
Commit
34b6faa
1 Parent(s): 805a541

Key-value caching

Browse files
Files changed (1) hide show
  1. modeling_nort5.py +209 -97
modeling_nort5.py CHANGED
@@ -1,19 +1,17 @@
1
- from __future__ import absolute_import, division, print_function, unicode_literals
2
-
3
  import math
4
  from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from torch import _softmax_backward_data as _softmax_backward_data
10
  from torch.utils import checkpoint
11
 
12
  from configuration_nort5 import NorT5Config
13
  from transformers.modeling_utils import PreTrainedModel
14
  from transformers.activations import gelu_new
15
  from transformers.modeling_outputs import (
16
- Seq2SeqModelOutput, Seq2SeqLMOutput, BaseModelOutput
17
  )
18
 
19
 
@@ -58,18 +56,37 @@ class Decoder(nn.Module):
58
  layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
59
  layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
60
 
61
- def forward(self, x, encoder_output, encoder_padding_mask):
 
 
62
  self_relative_embedding = self.self_relative_embedding()
63
  cross_relative_embedding = self.cross_relative_embedding()
64
 
65
- autoreg_mask = torch.triu(
66
- torch.full((x.size(0), x.size(0)), True, device=x.device),
67
- diagonal=1
68
- )
 
 
 
69
 
70
- for layer in self.layers:
71
- x = layer(x, autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding)
72
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  class MaskClassifier(nn.Module):
@@ -95,11 +112,11 @@ class MaskClassifier(nn.Module):
95
  class EncoderLayer(nn.Module):
96
  def __init__(self, config):
97
  super().__init__()
98
- self.attention = Attention(config)
99
  self.mlp = FeedForward(config)
100
 
101
  def forward(self, x, padding_mask, relative_embedding):
102
- attention_output, attention_probs = self.attention(x, x, padding_mask, relative_embedding)
103
  x = x + attention_output
104
  x = x + self.mlp(x)
105
  return x, attention_probs
@@ -108,15 +125,26 @@ class EncoderLayer(nn.Module):
108
  class DecoderLayer(nn.Module):
109
  def __init__(self, config):
110
  super().__init__()
111
- self.self_attention = Attention(config)
112
- self.cross_attention = Attention(config)
113
  self.mlp = FeedForward(config)
114
 
115
- def forward(self, x, autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding):
116
- x = x + self.self_attention(x, x, autoreg_mask, self_relative_embedding)[0]
117
- x = x + self.cross_attention(x, encoder_output, encoder_padding_mask, cross_relative_embedding)[0]
 
 
 
 
 
 
 
 
 
 
118
  x = x + self.mlp(x)
119
- return x
 
120
 
121
 
122
  class GeGLU(nn.Module):
@@ -152,24 +180,27 @@ class MaskedSoftmax(torch.autograd.Function):
152
  @staticmethod
153
  def forward(self, x, mask, dim):
154
  self.dim = dim
155
- x.masked_fill_(mask, float('-inf'))
 
156
  x = torch.softmax(x, self.dim)
157
- x.masked_fill_(mask, 0.0)
 
158
  self.save_for_backward(x)
159
  return x
160
 
161
  @staticmethod
162
  def backward(self, grad_output):
163
  output, = self.saved_tensors
164
- inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
165
- return inputGrad, None, None
166
 
167
 
168
  class Attention(nn.Module):
169
- def __init__(self, config):
170
  super().__init__()
171
 
172
  self.config = config
 
173
 
174
  if config.hidden_size % config.num_attention_heads != 0:
175
  raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
@@ -186,9 +217,9 @@ class Attention(nn.Module):
186
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
187
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
188
 
189
- position_indices = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
190
- - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
191
- position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
192
  position_indices = config.position_bucket_size - 1 + position_indices
193
  self.register_buffer("position_indices", position_indices, persistent=True)
194
 
@@ -215,59 +246,67 @@ class Attention(nn.Module):
215
  self.in_proj_v.bias.data.zero_()
216
  self.out_proj.bias.data.zero_()
217
 
218
- def compute_attention_scores(self, q, kv, relative_embedding):
219
  key_len, batch_size, _ = kv.size()
220
  query_len, _, _ = q.size()
221
 
222
- if self.position_indices.size(0) < query_len:
223
- position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
224
- - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
226
  position_indices = self.config.position_bucket_size - 1 + position_indices
227
  self.register_buffer("position_indices", position_indices.to(q.device), persistent=True)
228
 
229
- kv = self.pre_layer_norm(kv)
230
  q = self.pre_layer_norm(q)
231
-
232
  query = self.in_proj_q(q) # shape: [T, B, D]
233
- key = self.in_proj_k(kv) # shape: [T, B, D]
234
- value = self.in_proj_v(kv) # shape: [T, B, D]
235
-
236
- query_pos = self.in_proj_q(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
237
- query_pos = F.embedding(self.position_indices[:query_len, :key_len], query_pos) # shape: [T, T, 2D]
238
- query_pos = query_pos.view(query_len, key_len, self.num_heads, self.head_size)
239
-
240
- key_pos = self.in_proj_k(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
241
- key_pos = F.embedding(self.position_indices[:query_len, :key_len], key_pos) # shape: [T, T, 2D]
242
- key_pos = key_pos.view(query_len, key_len, self.num_heads, self.head_size)
243
-
244
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
245
- key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
246
- value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
247
 
248
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
249
-
250
- query = query.view(batch_size, self.num_heads, query_len, self.head_size)
251
- key = key.view(batch_size, self.num_heads, key_len, self.head_size)
 
 
 
 
 
 
 
 
 
 
 
 
252
  attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
253
- attention_scores.add_(torch.einsum("bhqd,qkhd->bhqk", query, key_pos * self.scale))
254
- attention_scores.add_(torch.einsum("bhkd,qkhd->bhqk", key * self.scale, query_pos))
255
 
256
- return attention_scores, value
257
 
258
- def compute_output(self, attention_probs, value):
259
  attention_probs = self.dropout(attention_probs)
260
  context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
261
  context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
262
  context = self.out_proj(context)
263
  context = self.post_layer_norm(context)
264
  context = self.dropout(context)
265
- return context
266
 
267
- def forward(self, q, kv, attention_mask, relative_embedding):
268
- attention_scores, value = self.compute_attention_scores(q, kv, relative_embedding)
269
- attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
270
- return self.compute_output(attention_probs, value), attention_probs.detach()
271
 
272
 
273
  class WordEmbedding(nn.Module):
@@ -348,8 +387,8 @@ class NorT5Model(NorT5PreTrainedModel):
348
  return self.get_encoder_output
349
 
350
  def get_decoder(self):
351
- return self.decoder
352
-
353
  def set_decoder_special_tokens(self, target_id):
354
  target_id.masked_fill_(target_id == self.cls_token_id, self.bos_token_id)
355
  target_id.masked_fill_(target_id == self.sep_token_id, self.eos_token_id)
@@ -359,12 +398,13 @@ class NorT5Model(NorT5PreTrainedModel):
359
  shifted_input_ids = input_ids.new_zeros(input_ids.shape)
360
  shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
361
  shifted_input_ids[..., 0] = self.bos_token_id
 
362
 
363
  return shifted_input_ids
364
 
365
  def get_encoder_output(
366
  self,
367
- input_ids: Optional[torch.Tensor] = None,
368
  attention_mask: Optional[torch.Tensor] = None,
369
  output_hidden_states: Optional[bool] = None,
370
  output_attentions: Optional[bool] = None,
@@ -394,16 +434,28 @@ class NorT5Model(NorT5PreTrainedModel):
394
  ]
395
 
396
  if not return_dict:
397
- return last_layer, contextualized_embeddings, attention_probs
398
-
 
 
 
 
399
  return BaseModelOutput(
400
  last_hidden_state=last_layer,
401
- hidden_states=contextualized_embeddings,
402
- attentions=attention_probs
403
  )
404
 
405
  def get_decoder_output(
406
- self, target_ids, encoder_output, attention_mask
 
 
 
 
 
 
 
 
407
  ):
408
  batch_size, seq_length, _ = encoder_output.shape
409
  device = target_ids.device
@@ -414,11 +466,37 @@ class NorT5Model(NorT5PreTrainedModel):
414
  attention_mask = ~attention_mask.bool()
415
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
 
417
- return self.decoder(
418
  self.embedding(target_ids.t()),
419
  encoder_output.transpose(0, 1),
420
- attention_mask
421
- ).transpose(0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  def forward(
424
  self,
@@ -426,28 +504,45 @@ class NorT5Model(NorT5PreTrainedModel):
426
  attention_mask: Optional[torch.FloatTensor] = None,
427
  decoder_input_ids: Optional[torch.LongTensor] = None,
428
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
429
- return_dict: Optional[bool] = None,
 
 
 
 
 
430
  ):
431
 
432
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
433
 
434
  decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
435
 
436
- encoder_outputs, encoder_contextualized_embeddings, encoder_attention_probs = self.get_encoder_output(input_ids, attention_mask)
437
- decoder_outputs = self.get_decoder_output(decoder_input_ids, encoder_outputs, attention_mask)
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  if not return_dict:
440
- return (decoder_outputs, encoder_outputs)
441
-
442
  return Seq2SeqModelOutput(
443
- last_hidden_state=decoder_outputs,
444
- past_key_values=None,
445
- decoder_hidden_states=None,
446
- decoder_attentions=None,
447
- cross_attentions=None,
448
- encoder_last_hidden_state=encoder_outputs,
449
- encoder_hidden_states=encoder_contextualized_embeddings,
450
- encoder_attentions=encoder_attention_probs,
451
  )
452
 
453
 
@@ -475,12 +570,19 @@ class NorT5ForConditionalGeneration(NorT5Model):
475
  output_hidden_states: Optional[bool] = None,
476
  return_dict: Optional[bool] = None,
477
  ):
478
-
479
- use_cache = False
480
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
481
 
482
  if encoder_outputs is None:
483
- encoder_outputs = self.get_encoder_output(input_ids, attention_mask, return_dict=True)
 
 
 
 
 
 
 
 
484
 
485
  if labels is not None:
486
  labels = self.set_decoder_special_tokens(labels)
@@ -490,24 +592,28 @@ class NorT5ForConditionalGeneration(NorT5Model):
490
  elif decoder_input_ids is not None:
491
  decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
492
 
493
- decoder_outputs = self.get_decoder_output(decoder_input_ids, encoder_outputs.last_hidden_state, attention_mask)
494
- lm_logits = self.classifier(decoder_outputs)
 
 
495
 
496
  loss = None
497
  if labels is not None:
498
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
 
499
  loss = loss_fct(lm_logits.flatten(0, 1), labels.flatten())
500
 
501
  if not return_dict:
502
- output = (lm_logits,) + encoder_outputs
503
  return ((loss,) + output) if loss is not None else output
504
 
505
  return Seq2SeqLMOutput(
506
  loss=loss,
507
  logits=lm_logits,
508
- decoder_hidden_states=decoder_outputs,
509
- decoder_attentions=None,
510
- cross_attentions=None,
 
511
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
512
  encoder_hidden_states=encoder_outputs.hidden_states,
513
  encoder_attentions=encoder_outputs.attentions,
@@ -525,6 +631,9 @@ class NorT5ForConditionalGeneration(NorT5Model):
525
  encoder_outputs=None,
526
  **kwargs,
527
  ):
 
 
 
528
  return {
529
  "decoder_input_ids": input_ids,
530
  "past_key_values": past_key_values,
@@ -553,9 +662,10 @@ class NorT5ForConditionalGeneration(NorT5Model):
553
  reordered_layer_past_states = ()
554
  for layer_past_state in layer_past_states:
555
  # need to set correct `past` for each of the four key / value states
556
- reordered_layer_past_states = reordered_layer_past_states + (
557
- layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
558
- )
 
559
 
560
  assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
561
  assert len(reordered_layer_past_states) == len(layer_past_states)
@@ -578,4 +688,6 @@ class NorT5Encoder(NorT5Model):
578
  ):
579
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
580
 
581
- return self.get_encoder_output(input_ids, attention_mask, return_dict=return_dict)
 
 
 
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from transformers.pytorch_utils import softmax_backward_data
8
  from torch.utils import checkpoint
9
 
10
  from configuration_nort5 import NorT5Config
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.activations import gelu_new
13
  from transformers.modeling_outputs import (
14
+ Seq2SeqModelOutput, Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
15
  )
16
 
17
 
 
56
  layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
57
  layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
58
 
59
+ self.activation_checkpointing = activation_checkpointing
60
+
61
+ def forward(self, x, encoder_output, encoder_padding_mask, past_key_values=None):
62
  self_relative_embedding = self.self_relative_embedding()
63
  cross_relative_embedding = self.cross_relative_embedding()
64
 
65
+ if past_key_values is not None:
66
+ autoreg_mask = torch.triu(
67
+ torch.full((x.size(0), x.size(0)), True, device=x.device),
68
+ diagonal=1
69
+ )
70
+ else:
71
+ autoreg_mask = None
72
 
73
+ # initialize past_key_values with `None` if past does not exist
74
+ if past_key_values is None:
75
+ past_key_values = [None] * len(self.layers)
76
+
77
+ hidden_states, self_attention_probs, cross_attention_probs, key_value_states = [x], [], [], []
78
+ for layer, past_key_value in zip(self.layers, past_key_values):
79
+ if self.activation_checkpointing:
80
+ hidden_state, self_attention_p, cross_attention_p, key_value_state = checkpoint.checkpoint(layer, hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None)
81
+ else:
82
+ hidden_state, self_attention_p, cross_attention_p, key_value_state = layer(hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=past_key_value)
83
+
84
+ hidden_states.append(hidden_state)
85
+ self_attention_probs.append(self_attention_p)
86
+ cross_attention_probs.append(cross_attention_p)
87
+ key_value_states.append(key_value_state)
88
+
89
+ return hidden_states, self_attention_probs, cross_attention_probs, key_value_states
90
 
91
 
92
  class MaskClassifier(nn.Module):
 
112
  class EncoderLayer(nn.Module):
113
  def __init__(self, config):
114
  super().__init__()
115
+ self.attention = Attention(config, is_cross_attention=False)
116
  self.mlp = FeedForward(config)
117
 
118
  def forward(self, x, padding_mask, relative_embedding):
119
+ attention_output, attention_probs, _ = self.attention(x, x, padding_mask, relative_embedding)
120
  x = x + attention_output
121
  x = x + self.mlp(x)
122
  return x, attention_probs
 
125
  class DecoderLayer(nn.Module):
126
  def __init__(self, config):
127
  super().__init__()
128
+ self.self_attention = Attention(config, is_cross_attention=False)
129
+ self.cross_attention = Attention(config, is_cross_attention=True)
130
  self.mlp = FeedForward(config)
131
 
132
+ def forward(self, x, autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None):
133
+ query_offset = 0
134
+ if past_key_value is not None:
135
+ self_attn_past_key_value = past_key_value[:2]
136
+ cross_attn_past_key_value = past_key_value[2:]
137
+ query_offset = self_attn_past_key_value[0].size(1)
138
+ else:
139
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
140
+
141
+ x_, self_attention_probs, self_key_value_state = self.self_attention(x, x, autoreg_mask, self_relative_embedding, past_key_value=self_attn_past_key_value, query_offset=query_offset)
142
+ x = x + x_
143
+ x_, cross_attention_probs, cross_key_value_state = self.cross_attention(x, encoder_output, encoder_padding_mask, cross_relative_embedding, past_key_value=cross_attn_past_key_value, query_offset=query_offset)
144
+ x = x + x_
145
  x = x + self.mlp(x)
146
+
147
+ return x, self_attention_probs, cross_attention_probs, self_key_value_state + cross_key_value_state
148
 
149
 
150
  class GeGLU(nn.Module):
 
180
  @staticmethod
181
  def forward(self, x, mask, dim):
182
  self.dim = dim
183
+ if mask is not None:
184
+ x.masked_fill_(mask, float('-inf'))
185
  x = torch.softmax(x, self.dim)
186
+ if mask is not None:
187
+ x.masked_fill_(mask, 0.0)
188
  self.save_for_backward(x)
189
  return x
190
 
191
  @staticmethod
192
  def backward(self, grad_output):
193
  output, = self.saved_tensors
194
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
195
+ return input_grad, None, None
196
 
197
 
198
  class Attention(nn.Module):
199
+ def __init__(self, config, is_cross_attention=False):
200
  super().__init__()
201
 
202
  self.config = config
203
+ self.is_cross_attention = is_cross_attention
204
 
205
  if config.hidden_size % config.num_attention_heads != 0:
206
  raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
 
217
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
218
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
219
 
220
+ position_indices = torch.arange(512, dtype=torch.long).unsqueeze(1) \
221
+ - torch.arange(512, dtype=torch.long).unsqueeze(0)
222
+ position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
223
  position_indices = config.position_bucket_size - 1 + position_indices
224
  self.register_buffer("position_indices", position_indices, persistent=True)
225
 
 
246
  self.in_proj_v.bias.data.zero_()
247
  self.out_proj.bias.data.zero_()
248
 
249
+ def forward(self, q, kv, attention_mask, relative_embedding, past_key_value=None, query_offset=0):
250
  key_len, batch_size, _ = kv.size()
251
  query_len, _, _ = q.size()
252
 
253
+ if not self.is_cross_attention or past_key_value is None or past_key_value[0].size(1) != kv.size(0):
254
+ kv = self.pre_layer_norm(kv)
255
+ key = self.in_proj_k(kv) # shape: [T, B, D]
256
+ value = self.in_proj_v(kv) # shape: [T, B, D]
257
+ key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
258
+ value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
259
+
260
+ if past_key_value is not None:
261
+ if not self.is_cross_attention:
262
+ key = torch.cat([past_key_value[0], key], dim=1)
263
+ value = torch.cat([past_key_value[1], value], dim=1)
264
+ key_len = key.size(1)
265
+ elif past_key_value[0].size(1) == kv.size(0):
266
+ key = past_key_value[0]
267
+ value = past_key_value[1]
268
+
269
+ if self.position_indices.size(0) < max(query_len, key_len):
270
+ position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
271
+ - torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
272
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
273
  position_indices = self.config.position_bucket_size - 1 + position_indices
274
  self.register_buffer("position_indices", position_indices.to(q.device), persistent=True)
275
 
 
276
  q = self.pre_layer_norm(q)
 
277
  query = self.in_proj_q(q) # shape: [T, B, D]
 
 
 
 
 
 
 
 
 
 
 
278
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
 
 
279
 
280
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
281
+
282
+ query_pos = self.in_proj_q(self.dropout(relative_embedding)) # shape: [2T-1, D]
283
+ query_pos = query_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
284
+ key_pos = self.in_proj_k(self.dropout(relative_embedding)) # shape: [2T-1, D]
285
+ key_pos = key_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
286
+
287
+ query_ = query.view(batch_size, self.num_heads, query_len, self.head_size)
288
+ key_ = key.view(batch_size, self.num_heads, key_len, self.head_size)
289
+
290
+ attention_c_p = torch.einsum("bhqd,khd->bhqk", query_, key_pos.squeeze(1) * self.scale)
291
+ attention_p_c = torch.einsum("bhkd,qhd->bhqk", key_ * self.scale, query_pos.squeeze(1))
292
+ position_indices = self.position_indices[query_offset:query_offset+query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
293
+ attention_c_p = attention_c_p.gather(3, position_indices)
294
+ attention_p_c = attention_p_c.gather(2, position_indices)
295
+
296
  attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
297
+ attention_scores.add_(attention_c_p)
298
+ attention_scores.add_(attention_p_c)
299
 
300
+ attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
301
 
 
302
  attention_probs = self.dropout(attention_probs)
303
  context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
304
  context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
305
  context = self.out_proj(context)
306
  context = self.post_layer_norm(context)
307
  context = self.dropout(context)
 
308
 
309
+ return context, attention_probs.detach(), (key.detach(), value.detach())
 
 
 
310
 
311
 
312
  class WordEmbedding(nn.Module):
 
387
  return self.get_encoder_output
388
 
389
  def get_decoder(self):
390
+ return self.get_decoder_output
391
+
392
  def set_decoder_special_tokens(self, target_id):
393
  target_id.masked_fill_(target_id == self.cls_token_id, self.bos_token_id)
394
  target_id.masked_fill_(target_id == self.sep_token_id, self.eos_token_id)
 
398
  shifted_input_ids = input_ids.new_zeros(input_ids.shape)
399
  shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
400
  shifted_input_ids[..., 0] = self.bos_token_id
401
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
402
 
403
  return shifted_input_ids
404
 
405
  def get_encoder_output(
406
  self,
407
+ input_ids: torch.Tensor = None,
408
  attention_mask: Optional[torch.Tensor] = None,
409
  output_hidden_states: Optional[bool] = None,
410
  output_attentions: Optional[bool] = None,
 
434
  ]
435
 
436
  if not return_dict:
437
+ return (
438
+ last_layer,
439
+ *([contextualized_embeddings] if output_hidden_states else []),
440
+ *([attention_probs] if output_attentions else [])
441
+ )
442
+
443
  return BaseModelOutput(
444
  last_hidden_state=last_layer,
445
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
446
+ attentions=attention_probs if output_attentions else None
447
  )
448
 
449
  def get_decoder_output(
450
+ self,
451
+ target_ids: torch.Tensor = None,
452
+ encoder_output: torch.Tensor = None,
453
+ attention_mask: Optional[torch.Tensor] = None,
454
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
455
+ use_cache: Optional[bool] = None,
456
+ output_hidden_states: Optional[bool] = None,
457
+ output_attentions: Optional[bool] = None,
458
+ return_dict = False
459
  ):
460
  batch_size, seq_length, _ = encoder_output.shape
461
  device = target_ids.device
 
466
  attention_mask = ~attention_mask.bool()
467
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
468
 
469
+ hidden_states, self_attention_p, cross_attention_p, key_value_states = self.decoder(
470
  self.embedding(target_ids.t()),
471
  encoder_output.transpose(0, 1),
472
+ attention_mask,
473
+ past_key_values
474
+ )
475
+
476
+ hidden_states = [e.transpose(0, 1) for e in hidden_states]
477
+ last_layer = hidden_states[-1]
478
+ hidden_states = [hidden_states[0]] + [
479
+ hidden_states[i] - hidden_states[i - 1]
480
+ for i in range(1, len(hidden_states))
481
+ ]
482
+
483
+ if not return_dict:
484
+ return (
485
+ last_layer,
486
+ *([key_value_states] if use_cache else []),
487
+ *([hidden_states] if output_hidden_states else []),
488
+ *([self_attention_p] if output_attentions else []),
489
+ *([cross_attention_p] if output_attentions else []),
490
+ )
491
+
492
+ return BaseModelOutputWithPastAndCrossAttentions(
493
+ last_hidden_state=last_layer,
494
+ past_key_values=key_value_states if use_cache else None,
495
+ hidden_states=hidden_states if output_hidden_states else None,
496
+ attentions=self_attention_p if output_attentions else None,
497
+ cross_attentions=cross_attention_p if output_attentions else None
498
+ )
499
+
500
 
501
  def forward(
502
  self,
 
504
  attention_mask: Optional[torch.FloatTensor] = None,
505
  decoder_input_ids: Optional[torch.LongTensor] = None,
506
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
507
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
508
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
509
+ use_cache: Optional[bool] = None,
510
+ output_attentions: Optional[bool] = None,
511
+ output_hidden_states: Optional[bool] = None,
512
+ return_dict: Optional[bool] = None
513
  ):
514
 
515
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
516
 
517
  decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
518
 
519
+ if encoder_outputs is None:
520
+ encoder_outputs = self.get_encoder_output(
521
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
522
+ )
523
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
524
+ encoder_outputs = BaseModelOutput(
525
+ last_hidden_state=encoder_outputs[0],
526
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
527
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
528
+ )
529
+
530
+ decoder_outputs = self.get_decoder_output(
531
+ decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
532
+ )
533
 
534
  if not return_dict:
535
+ return decoder_outputs + encoder_outputs
536
+
537
  return Seq2SeqModelOutput(
538
+ last_hidden_state=decoder_outputs.last_hidden_state,
539
+ past_key_values=decoder_outputs.past_key_values,
540
+ decoder_hidden_states=decoder_outputs.hidden_states,
541
+ decoder_attentions=decoder_outputs.attentions,
542
+ cross_attentions=decoder_outputs.cross_attentions,
543
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
544
+ encoder_hidden_states=encoder_outputs.hidden_states,
545
+ encoder_attentions=encoder_outputs.attentions,
546
  )
547
 
548
 
 
570
  output_hidden_states: Optional[bool] = None,
571
  return_dict: Optional[bool] = None,
572
  ):
573
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
 
574
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
575
 
576
  if encoder_outputs is None:
577
+ encoder_outputs = self.get_encoder_output(
578
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
579
+ )
580
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
581
+ encoder_outputs = BaseModelOutput(
582
+ last_hidden_state=encoder_outputs[0],
583
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
584
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
585
+ )
586
 
587
  if labels is not None:
588
  labels = self.set_decoder_special_tokens(labels)
 
592
  elif decoder_input_ids is not None:
593
  decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
594
 
595
+ decoder_outputs = self.get_decoder_output(
596
+ decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
597
+ )
598
+ lm_logits = self.classifier(decoder_outputs[0])
599
 
600
  loss = None
601
  if labels is not None:
602
+ labels.masked_fill_(labels == self.pad_token_id, -100)
603
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
604
  loss = loss_fct(lm_logits.flatten(0, 1), labels.flatten())
605
 
606
  if not return_dict:
607
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
608
  return ((loss,) + output) if loss is not None else output
609
 
610
  return Seq2SeqLMOutput(
611
  loss=loss,
612
  logits=lm_logits,
613
+ past_key_values=decoder_outputs.past_key_values,
614
+ decoder_hidden_states=decoder_outputs.hidden_states,
615
+ decoder_attentions=decoder_outputs.attentions,
616
+ cross_attentions=decoder_outputs.cross_attentions,
617
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
618
  encoder_hidden_states=encoder_outputs.hidden_states,
619
  encoder_attentions=encoder_outputs.attentions,
 
631
  encoder_outputs=None,
632
  **kwargs,
633
  ):
634
+ if past_key_values is not None:
635
+ input_ids = input_ids[:, -1:]
636
+
637
  return {
638
  "decoder_input_ids": input_ids,
639
  "past_key_values": past_key_values,
 
662
  reordered_layer_past_states = ()
663
  for layer_past_state in layer_past_states:
664
  # need to set correct `past` for each of the four key / value states
665
+ layer_past_state = layer_past_state.unflatten(0, (-1, self.config.num_attention_heads))
666
+ layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
667
+ layer_past_state = layer_past_state.flatten(0, 1)
668
+ reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
669
 
670
  assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
671
  assert len(reordered_layer_past_states) == len(layer_past_states)
 
688
  ):
689
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690
 
691
+ return self.get_encoder_output(
692
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
693
+ )