jbochi commited on
Commit
389b01f
1 Parent(s): 14afdcf

It's kind of working!

Browse files

Still unclear how to set input ids

Files changed (1) hide show
  1. decoder_only_t5/modeling.py +129 -291
decoder_only_t5/modeling.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
  from typing import Optional, Tuple, Union
3
 
4
  import torch
@@ -19,6 +20,39 @@ logger = logging.get_logger(__name__)
19
  _CONFIG_FOR_DOC = "DecoderOnlyT5Config"
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
23
  def __init__(self, config: DecoderOnlyT5Config):
24
  super(modeling_t5.T5LayerFF, self).__init__()
@@ -28,7 +62,7 @@ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
28
  self.DenseReluDense = modeling_t5.T5DenseActDense(config)
29
 
30
  if not config.parallel_layers:
31
- self.layer_norm = modeling_t5.T5LayerNorm(
32
  config.d_model, eps=config.layer_norm_epsilon
33
  )
34
  else:
@@ -37,7 +71,7 @@ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
37
 
38
 
39
  # LlamaRotaryEmbedding
40
- class T5DecoderOnlyRotaryEmbedding(nn.Module):
41
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
42
  super().__init__()
43
 
@@ -139,25 +173,21 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
139
  def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
140
  super(modeling_t5.T5Attention, self).__init__()
141
  self.is_decoder = config.is_decoder
142
- self.has_relative_attention_bias = has_relative_attention_bias
143
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
144
- self.relative_attention_max_distance = config.relative_attention_max_distance
145
  self.d_model = config.d_model
146
- self.key_value_proj_dim = config.d_kv
147
- self.n_heads = config.num_heads
148
- self.n_kv_heads = 1 if config.multi_query_attention else self.n_heads
149
- self.n_kv_groups = self.n_heads // self.n_kv_heads
150
- self.dropout = config.dropout_rate
151
- self.inner_dim = self.n_heads * self.key_value_proj_dim
152
- self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim
153
- if config.use_rotary_embedding:
154
- self.rotary_embedding = T5DecoderOnlyRotaryEmbedding(
155
- self.key_value_proj_dim,
156
- max_position_embeddings=config.relative_attention_max_distance,
157
- base=config.rotary_embedding_max_timescale,
158
- )
159
- else:
160
- self.rotary_embedding = None
161
 
162
  # Mesh TensorFlow initialization to avoid scaling before softmax
163
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -165,179 +195,79 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
165
  self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
166
  self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
167
 
168
- if self.has_relative_attention_bias:
169
- self.relative_attention_bias = nn.Embedding(
170
- self.relative_attention_num_buckets, self.n_heads
171
- )
172
  self.pruned_heads = set()
173
  self.gradient_checkpointing = False
174
 
175
  def forward(
176
  self,
177
- hidden_states,
178
- mask=None,
179
  key_value_states=None,
180
  position_bias=None,
181
- position_ids=None,
182
- past_key_value=None,
183
  layer_head_mask=None,
184
- query_length=None,
185
- use_cache=False,
186
- output_attentions=False,
187
- ):
188
- """
189
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
190
- """
191
- # Input is (batch_size, seq_length, dim)
192
- # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
193
- # past_key_value[0] is (batch_size, n_kv_heads, q_len - 1, dim_per_head)
194
- batch_size, seq_length = hidden_states.shape[:2]
195
 
196
- real_seq_length = seq_length
 
 
197
 
 
 
 
 
 
198
  if past_key_value is not None:
199
- if len(past_key_value) != 2:
200
- raise ValueError(
201
- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
202
- )
203
- real_seq_length += (
204
- past_key_value[0].shape[2] if query_length is None else query_length
205
- )
206
 
207
- key_length = (
208
- real_seq_length if key_value_states is None else key_value_states.shape[1]
209
- )
 
210
 
211
- def shape(states, n_heads):
212
- """projection"""
213
- return states.view(
214
- batch_size, -1, n_heads, self.key_value_proj_dim
215
- ).transpose(1, 2)
216
 
217
- def unshape(states):
218
- """reshape"""
219
- return (
220
- states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
221
- )
222
 
223
- def project(hidden_states, proj_layer, key_value_states, past_key_value):
224
- """projects hidden states correctly to key/query states"""
225
- if key_value_states is None:
226
- # self-attn
227
- # (batch_size, n_kv_heads, seq_length, dim_per_head)
228
- hidden_states = shape(proj_layer(hidden_states), self.n_kv_heads)
229
- elif past_key_value is None:
230
- # cross-attn
231
- # (batch_size, n_kv_heads, seq_length, dim_per_head)
232
- hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
233
- return hidden_states
234
-
235
- def concat_past_key_value(hidden_states, past_key_value, key_value_states):
236
- if key_value_states is None:
237
- # self-attn
238
- # (batch_size, n_kv_heads, key_length, dim_per_head)
239
- hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
240
- elif past_key_value.shape[2] != key_value_states.shape[1]:
241
- # checking that the `sequence_length` of the `past_key_value` is the same as
242
- # the provided `key_value_states` to support prefix tuning
243
- # cross-attn
244
- # (batch_size, n_kv_heads, seq_length, dim_per_head)
245
- raise NotImplementedError(
246
- "cross attention with RoPE and past KV is not implemented"
247
- )
248
- # hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
249
- else:
250
- # cross-attn
251
- hidden_states = past_key_value
252
- return hidden_states
253
-
254
- # get query states
255
- query_states = shape(
256
- self.q(hidden_states), self.n_heads
257
- ) # (batch_size, n_heads, seq_length, dim_per_head)
258
-
259
- # get key/value states
260
- key_states = project(hidden_states, self.k, key_value_states, past_key_value)
261
- value_states = project(hidden_states, self.v, key_value_states, past_key_value)
262
-
263
- # RoPE
264
- if self.rotary_embedding is not None:
265
- kv_seq_len = key_states.shape[-2]
266
- if past_key_value:
267
- kv_seq_len += past_key_value[0].shape[-2]
268
- cos, sin = self.rotary_embedding(query_states, seq_len=kv_seq_len)
269
- query_states, key_states = apply_rotary_pos_emb(
270
- query_states, key_states, cos, sin, position_ids
271
- )
272
 
273
- # concat past
274
- if past_key_value is not None:
275
- key_states = concat_past_key_value(
276
- key_states,
277
- past_key_value[0],
278
- key_value_states,
279
- )
280
- value_states = concat_past_key_value(
281
- value_states,
282
- past_key_value[1],
283
- key_value_states,
284
  )
285
 
286
- # MultiQueryDotProductAttention
287
- key_states = repeat_kv(key_states, self.n_kv_groups)
288
- value_states = repeat_kv(value_states, self.n_kv_groups)
289
-
290
- # compute scores
291
- scores = torch.matmul(
292
- query_states, key_states.transpose(3, 2)
293
- ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
294
-
295
- if position_bias is None:
296
- if not self.has_relative_attention_bias:
297
- position_bias = torch.zeros(
298
- (1, self.n_heads, real_seq_length, key_length),
299
- device=scores.device,
300
- dtype=scores.dtype,
301
- )
302
- if self.gradient_checkpointing and self.training:
303
- position_bias.requires_grad = True
304
- else:
305
- position_bias = self.compute_bias(
306
- real_seq_length, key_length, device=scores.device
307
  )
 
308
 
309
- # if key and values are already calculated
310
- # we want only the last query position bias
311
- if past_key_value is not None:
312
- position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
313
 
314
- if mask is not None:
315
- position_bias = (
316
- position_bias + mask
317
- ) # (batch_size, n_heads, seq_length, key_length)
 
318
 
319
- if self.pruned_heads:
320
- mask = torch.ones(position_bias.shape[1])
321
- mask[list(self.pruned_heads)] = 0
322
- position_bias_masked = position_bias[:, mask.bool()]
323
- else:
324
- position_bias_masked = position_bias
325
-
326
- scores += position_bias_masked
327
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
328
- scores
329
- ) # (batch_size, n_heads, seq_length, key_length)
330
- attn_weights = nn.functional.dropout(
331
- attn_weights, p=self.dropout, training=self.training
332
- ) # (batch_size, n_heads, seq_length, key_length)
333
-
334
- # Mask heads if we want to
335
- if layer_head_mask is not None:
336
- attn_weights = attn_weights * layer_head_mask
337
-
338
- attn_output = unshape(
339
- torch.matmul(attn_weights, value_states)
340
- ) # (batch_size, seq_length, dim)
341
  attn_output = self.o(attn_output)
342
 
343
  present_key_value_state = (
@@ -356,8 +286,11 @@ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
356
  self.SelfAttention = DecoderOnlyT5Attention(
357
  config, has_relative_attention_bias=has_relative_attention_bias
358
  )
359
- self.layer_norm = modeling_t5.T5LayerNorm(
360
- config.d_model, eps=config.layer_norm_epsilon
 
 
 
361
  )
362
  self.dropout = nn.Dropout(config.dropout_rate)
363
  self.parallel_layers = config.parallel_layers
@@ -425,20 +358,19 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
425
  position_bias=None,
426
  position_ids=None,
427
  encoder_hidden_states=None,
428
- encoder_attention_mask=None,
429
- encoder_decoder_position_bias=None,
430
  layer_head_mask=None,
431
- cross_attn_layer_head_mask=None,
432
  past_key_value=None,
433
  use_cache=False,
434
  output_attentions=False,
 
 
 
435
  return_dict=True,
436
  ):
 
 
 
437
  if past_key_value is not None:
438
- if not self.is_decoder:
439
- logger.warning(
440
- "`past_key_values` is passed to the encoder. Please make sure this is intended."
441
- )
442
  expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
443
 
444
  if len(past_key_value) != expected_num_past_key_values:
@@ -447,11 +379,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
447
  f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
448
  f"Got {len(past_key_value)} past key / value states"
449
  )
450
-
451
  self_attn_past_key_value = past_key_value[:2]
452
- cross_attn_past_key_value = past_key_value[2:]
453
  else:
454
- self_attn_past_key_value, cross_attn_past_key_value = None, None
455
 
456
  ff_layer = self.layer[-1]
457
  if self.parallel_layers:
@@ -490,45 +420,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
490
  and not self.is_decoder_only
491
  and encoder_hidden_states is not None
492
  )
493
- if do_cross_attention:
494
- # the actual query length is unknown for cross attention
495
- # if using past key value states. Need to inject it here
496
- if present_key_value_state is not None:
497
- query_length = present_key_value_state[0].shape[2]
498
- else:
499
- query_length = None
500
-
501
- cross_attention_outputs = self.layer[1](
502
- x,
503
- key_value_states=encoder_hidden_states,
504
- attention_mask=encoder_attention_mask,
505
- position_bias=encoder_decoder_position_bias,
506
- # position_ids ?
507
- layer_head_mask=cross_attn_layer_head_mask,
508
- past_key_value=cross_attn_past_key_value,
509
- query_length=query_length,
510
- use_cache=use_cache,
511
- output_attentions=output_attentions,
512
- )
513
- x = cross_attention_outputs[0]
514
-
515
- # clamp inf values to enable fp16 training
516
- if x.dtype == torch.float16:
517
- clamp_value = torch.where(
518
- torch.isinf(x).any(),
519
- torch.finfo(x.dtype).max - 1000,
520
- torch.finfo(x.dtype).max,
521
- )
522
- x = torch.clamp(x, min=-clamp_value, max=clamp_value)
523
-
524
- # Combine self attn and cross attn key value states
525
- if present_key_value_state is not None:
526
- present_key_value_state = (
527
- present_key_value_state + cross_attention_outputs[1]
528
- )
529
-
530
- # Keep cross-attention outputs and relative position weights
531
- attention_outputs = attention_outputs + cross_attention_outputs[2:]
532
 
533
  if self.parallel_layers:
534
  # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
@@ -577,12 +469,12 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
577
  for i in range(config.num_layers)
578
  ]
579
  )
580
- if not config.parallel_layers:
581
- self.final_layer_norm = modeling_t5.T5LayerNorm(
582
- config.d_model, eps=config.layer_norm_epsilon
583
- )
584
- else:
585
- self.final_layer_norm = nn.Identity()
586
  self.dropout = nn.Dropout(config.dropout_rate)
587
 
588
  # Initialize weights and apply final processing
@@ -654,8 +546,7 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
654
  seq_length + past_key_values_length,
655
  dtype=torch.long,
656
  device=device,
657
- )
658
- position_ids = position_ids.unsqueeze(0)
659
 
660
  if inputs_embeds is None:
661
  if self.embed_tokens is None:
@@ -683,18 +574,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
683
  attention_mask = torch.ones(
684
  batch_size, mask_seq_length, device=inputs_embeds.device
685
  )
686
- if (
687
- self.is_decoder
688
- and encoder_attention_mask is None
689
- and encoder_hidden_states is not None
690
- ):
691
- encoder_seq_length = encoder_hidden_states.shape[1]
692
- encoder_attention_mask = torch.ones(
693
- batch_size,
694
- encoder_seq_length,
695
- device=inputs_embeds.device,
696
- dtype=torch.long,
697
- )
698
 
699
  # initialize past_key_values with `None` if past does not exist
700
  if past_key_values is None:
@@ -706,25 +585,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
706
  attention_mask, input_shape
707
  )
708
 
709
- # If a 2D or 3D attention mask is provided for the cross-attention
710
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
711
- if self.is_decoder and encoder_hidden_states is not None:
712
- (
713
- encoder_batch_size,
714
- encoder_sequence_length,
715
- _,
716
- ) = encoder_hidden_states.size()
717
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
718
- if encoder_attention_mask is None:
719
- encoder_attention_mask = torch.ones(
720
- encoder_hidden_shape, device=inputs_embeds.device
721
- )
722
- encoder_extended_attention_mask = self.invert_attention_mask(
723
- encoder_attention_mask
724
- )
725
- else:
726
- encoder_extended_attention_mask = None
727
-
728
  if self.gradient_checkpointing and self.training:
729
  if use_cache:
730
  logger.warning_once(
@@ -742,7 +602,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
742
  all_attentions = () if output_attentions else None
743
  all_cross_attentions = () if (output_attentions and self.is_decoder) else None
744
  position_bias = None
745
- encoder_decoder_position_bias = None
746
 
747
  hidden_states = self.dropout(inputs_embeds)
748
 
@@ -758,25 +617,10 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
758
  if attention_mask is not None:
759
  attention_mask = attention_mask.to(hidden_states.device)
760
  if position_bias is not None:
761
- position_bias = position_bias.to(hidden_states.device)
762
- if encoder_hidden_states is not None:
763
- encoder_hidden_states = encoder_hidden_states.to(
764
- hidden_states.device
765
- )
766
- if encoder_extended_attention_mask is not None:
767
- encoder_extended_attention_mask = (
768
- encoder_extended_attention_mask.to(hidden_states.device)
769
- )
770
- if encoder_decoder_position_bias is not None:
771
- encoder_decoder_position_bias = encoder_decoder_position_bias.to(
772
- hidden_states.device
773
- )
774
  if layer_head_mask is not None:
775
  layer_head_mask = layer_head_mask.to(hidden_states.device)
776
- if cross_attn_layer_head_mask is not None:
777
- cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
778
- hidden_states.device
779
- )
780
  if output_hidden_states:
781
  all_hidden_states = all_hidden_states + (hidden_states,)
782
 
@@ -786,9 +630,9 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
786
  hidden_states,
787
  extended_attention_mask,
788
  position_bias,
789
- encoder_hidden_states,
790
- encoder_extended_attention_mask,
791
- encoder_decoder_position_bias,
792
  layer_head_mask,
793
  cross_attn_layer_head_mask,
794
  None, # past_key_value is always None with gradient checkpointing
@@ -801,9 +645,9 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
801
  attention_mask=extended_attention_mask,
802
  position_bias=position_bias,
803
  position_ids=position_ids,
804
- encoder_hidden_states=encoder_hidden_states,
805
- encoder_attention_mask=encoder_extended_attention_mask,
806
- encoder_decoder_position_bias=encoder_decoder_position_bias,
807
  layer_head_mask=layer_head_mask,
808
  cross_attn_layer_head_mask=cross_attn_layer_head_mask,
809
  past_key_value=past_key_value,
@@ -822,10 +666,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
822
  # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
823
  # (cross-attention position bias), (cross-attention weights)
824
  position_bias = layer_outputs[2]
825
- if self.is_decoder and encoder_hidden_states is not None:
826
- encoder_decoder_position_bias = layer_outputs[
827
- 4 if output_attentions else 3
828
- ]
829
  # append next layer key value states
830
  if use_cache:
831
  present_key_value_states = present_key_value_states + (
@@ -900,8 +740,6 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
900
  def _tie_weights(self):
901
  if not self.config.tie_word_embeddings:
902
  return
903
- if self.encoder:
904
- self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
905
  if self.decoder:
906
  self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
907
 
 
1
  import copy
2
+ import math
3
  from typing import Optional, Tuple, Union
4
 
5
  import torch
 
20
  _CONFIG_FOR_DOC = "DecoderOnlyT5Config"
21
 
22
 
23
+ class DecoderOnlyT5LayerNorm(nn.Module):
24
+ def __init__(self, hidden_size, eps=1e-6, use_scale=True, center_scale_at_zero=False):
25
+ """
26
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
27
+ """
28
+ super().__init__()
29
+ if use_scale:
30
+ self.weight = nn.Parameter(torch.ones(hidden_size))
31
+ else:
32
+ assert not center_scale_at_zero
33
+ self.weight = None
34
+ self.center_scale_at_zero = center_scale_at_zero
35
+ self.variance_epsilon = eps
36
+
37
+ def forward(self, hidden_states):
38
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/layer_norm.py#L30
39
+
40
+ # layer norm should always be calculated in float32
41
+ mean2 = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
42
+ hidden_states = hidden_states * torch.rsqrt(mean2 + self.variance_epsilon)
43
+
44
+ # convert into float16 if necessary
45
+ if self.weight is None:
46
+ return hidden_states
47
+ if self.weight.dtype == torch.float16:
48
+ hidden_states = hidden_states.to(torch.float16)
49
+ if self.center_scale_at_zero:
50
+ return (self.weight + 1.0) * hidden_states
51
+ else:
52
+ return self.weight * hidden_states
53
+
54
+
55
+
56
  class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
57
  def __init__(self, config: DecoderOnlyT5Config):
58
  super(modeling_t5.T5LayerFF, self).__init__()
 
62
  self.DenseReluDense = modeling_t5.T5DenseActDense(config)
63
 
64
  if not config.parallel_layers:
65
+ self.layer_norm = modeling_t5.DecoderOnlyT5LayerNorm(
66
  config.d_model, eps=config.layer_norm_epsilon
67
  )
68
  else:
 
71
 
72
 
73
  # LlamaRotaryEmbedding
74
+ class DecoderOnlyT5RotaryEmbedding(nn.Module):
75
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
76
  super().__init__()
77
 
 
173
  def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
174
  super(modeling_t5.T5Attention, self).__init__()
175
  self.is_decoder = config.is_decoder
176
+ assert not has_relative_attention_bias
177
+ assert config.use_rotary_embedding
 
178
  self.d_model = config.d_model
179
+ self.head_dim = config.d_kv
180
+ self.num_heads = config.num_heads
181
+ self.num_key_value_heads = 1 if config.multi_query_attention else self.n_heads
182
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
183
+ self.attention_dropout = config.dropout_rate
184
+ self.inner_dim = self.num_heads * self.head_dim
185
+ self.kv_inner_dim = self.num_key_value_heads * self.head_dim
186
+ self.rotary_emb = DecoderOnlyT5RotaryEmbedding(
187
+ self.head_dim,
188
+ max_position_embeddings=config.relative_attention_max_distance,
189
+ base=config.rotary_embedding_max_timescale,
190
+ )
 
 
 
191
 
192
  # Mesh TensorFlow initialization to avoid scaling before softmax
193
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
 
195
  self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
196
  self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
197
 
 
 
 
 
198
  self.pruned_heads = set()
199
  self.gradient_checkpointing = False
200
 
201
  def forward(
202
  self,
203
+ hidden_states: torch.Tensor,
 
204
  key_value_states=None,
205
  position_bias=None,
206
+ mask: Optional[torch.Tensor] = None,
 
207
  layer_head_mask=None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
210
+ output_attentions: bool = False,
211
+ use_cache: bool = False,
212
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
213
+ assert key_value_states is None
214
+ assert position_bias is None
215
+ assert layer_head_mask is None
216
+
217
+ bsz, q_len, _ = hidden_states.size()
 
218
 
219
+ query_states = self.q(hidden_states)
220
+ key_states = self.k(hidden_states)
221
+ value_states = self.v(hidden_states)
222
 
223
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
224
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
225
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
226
+
227
+ kv_seq_len = key_states.shape[-2]
228
  if past_key_value is not None:
229
+ kv_seq_len += past_key_value[0].shape[-2]
230
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
231
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
 
232
 
233
+ if past_key_value is not None:
234
+ # reuse k, v, self_attention
235
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
236
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
237
 
238
+ past_key_value = (key_states, value_states) if use_cache else None
 
 
 
 
239
 
240
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
241
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
 
 
 
242
 
243
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
246
+ raise ValueError(
247
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
248
+ f" {attn_weights.size()}"
 
 
 
 
 
 
 
249
  )
250
 
251
+ if mask is not None:
252
+ if mask.size() != (bsz, 1, q_len, kv_seq_len):
253
+ raise ValueError(
254
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {mask.size()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
256
+ attn_weights = attn_weights + mask
257
 
258
+ # upcast attention to fp32
259
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
260
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout)
261
+ attn_output = torch.matmul(attn_weights, value_states)
262
 
263
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
264
+ raise ValueError(
265
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
266
+ f" {attn_output.size()}"
267
+ )
268
 
269
+ attn_output = attn_output.transpose(1, 2).contiguous()
270
+ attn_output = attn_output.reshape(bsz, q_len, self.inner_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  attn_output = self.o(attn_output)
272
 
273
  present_key_value_state = (
 
286
  self.SelfAttention = DecoderOnlyT5Attention(
287
  config, has_relative_attention_bias=has_relative_attention_bias
288
  )
289
+ self.layer_norm = DecoderOnlyT5LayerNorm(
290
+ config.d_model,
291
+ eps=config.layer_norm_epsilon,
292
+ use_scale=True,
293
+ center_scale_at_zero=True,
294
  )
295
  self.dropout = nn.Dropout(config.dropout_rate)
296
  self.parallel_layers = config.parallel_layers
 
358
  position_bias=None,
359
  position_ids=None,
360
  encoder_hidden_states=None,
 
 
361
  layer_head_mask=None,
 
362
  past_key_value=None,
363
  use_cache=False,
364
  output_attentions=False,
365
+ encoder_attention_mask=None,
366
+ encoder_decoder_position_bias=None,
367
+ cross_attn_layer_head_mask=None,
368
  return_dict=True,
369
  ):
370
+ assert encoder_attention_mask is None
371
+ assert encoder_decoder_position_bias is None
372
+ assert cross_attn_layer_head_mask is None
373
  if past_key_value is not None:
 
 
 
 
374
  expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
375
 
376
  if len(past_key_value) != expected_num_past_key_values:
 
379
  f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
380
  f"Got {len(past_key_value)} past key / value states"
381
  )
 
382
  self_attn_past_key_value = past_key_value[:2]
 
383
  else:
384
+ self_attn_past_key_value = None
385
 
386
  ff_layer = self.layer[-1]
387
  if self.parallel_layers:
 
420
  and not self.is_decoder_only
421
  and encoder_hidden_states is not None
422
  )
423
+ assert not do_cross_attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  if self.parallel_layers:
426
  # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
 
469
  for i in range(config.num_layers)
470
  ]
471
  )
472
+ self.final_layer_norm = DecoderOnlyT5LayerNorm(
473
+ config.d_model,
474
+ eps=config.layer_norm_epsilon,
475
+ use_scale=False,
476
+ center_scale_at_zero=False,
477
+ )
478
  self.dropout = nn.Dropout(config.dropout_rate)
479
 
480
  # Initialize weights and apply final processing
 
546
  seq_length + past_key_values_length,
547
  dtype=torch.long,
548
  device=device,
549
+ ).unsqueeze(0)
 
550
 
551
  if inputs_embeds is None:
552
  if self.embed_tokens is None:
 
574
  attention_mask = torch.ones(
575
  batch_size, mask_seq_length, device=inputs_embeds.device
576
  )
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
  # initialize past_key_values with `None` if past does not exist
579
  if past_key_values is None:
 
585
  attention_mask, input_shape
586
  )
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  if self.gradient_checkpointing and self.training:
589
  if use_cache:
590
  logger.warning_once(
 
602
  all_attentions = () if output_attentions else None
603
  all_cross_attentions = () if (output_attentions and self.is_decoder) else None
604
  position_bias = None
 
605
 
606
  hidden_states = self.dropout(inputs_embeds)
607
 
 
617
  if attention_mask is not None:
618
  attention_mask = attention_mask.to(hidden_states.device)
619
  if position_bias is not None:
620
+ position_bias = position_bias.to(hidden_states.device)
 
 
 
 
 
 
 
 
 
 
 
 
621
  if layer_head_mask is not None:
622
  layer_head_mask = layer_head_mask.to(hidden_states.device)
623
+
 
 
 
624
  if output_hidden_states:
625
  all_hidden_states = all_hidden_states + (hidden_states,)
626
 
 
630
  hidden_states,
631
  extended_attention_mask,
632
  position_bias,
633
+ None,
634
+ None,
635
+ None,
636
  layer_head_mask,
637
  cross_attn_layer_head_mask,
638
  None, # past_key_value is always None with gradient checkpointing
 
645
  attention_mask=extended_attention_mask,
646
  position_bias=position_bias,
647
  position_ids=position_ids,
648
+ encoder_hidden_states=None,
649
+ encoder_attention_mask=None,
650
+ encoder_decoder_position_bias=None,
651
  layer_head_mask=layer_head_mask,
652
  cross_attn_layer_head_mask=cross_attn_layer_head_mask,
653
  past_key_value=past_key_value,
 
666
  # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
667
  # (cross-attention position bias), (cross-attention weights)
668
  position_bias = layer_outputs[2]
 
 
 
 
669
  # append next layer key value states
670
  if use_cache:
671
  present_key_value_states = present_key_value_states + (
 
740
  def _tie_weights(self):
741
  if not self.config.tie_word_embeddings:
742
  return
 
 
743
  if self.decoder:
744
  self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
745