It's kind of working!
Browse filesStill unclear how to set input ids
- decoder_only_t5/modeling.py +129 -291
decoder_only_t5/modeling.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import copy
|
|
|
2 |
from typing import Optional, Tuple, Union
|
3 |
|
4 |
import torch
|
@@ -19,6 +20,39 @@ logger = logging.get_logger(__name__)
|
|
19 |
_CONFIG_FOR_DOC = "DecoderOnlyT5Config"
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
|
23 |
def __init__(self, config: DecoderOnlyT5Config):
|
24 |
super(modeling_t5.T5LayerFF, self).__init__()
|
@@ -28,7 +62,7 @@ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
|
|
28 |
self.DenseReluDense = modeling_t5.T5DenseActDense(config)
|
29 |
|
30 |
if not config.parallel_layers:
|
31 |
-
self.layer_norm = modeling_t5.
|
32 |
config.d_model, eps=config.layer_norm_epsilon
|
33 |
)
|
34 |
else:
|
@@ -37,7 +71,7 @@ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
|
|
37 |
|
38 |
|
39 |
# LlamaRotaryEmbedding
|
40 |
-
class
|
41 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
42 |
super().__init__()
|
43 |
|
@@ -139,25 +173,21 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
|
|
139 |
def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
|
140 |
super(modeling_t5.T5Attention, self).__init__()
|
141 |
self.is_decoder = config.is_decoder
|
142 |
-
|
143 |
-
|
144 |
-
self.relative_attention_max_distance = config.relative_attention_max_distance
|
145 |
self.d_model = config.d_model
|
146 |
-
self.
|
147 |
-
self.
|
148 |
-
self.
|
149 |
-
self.
|
150 |
-
self.
|
151 |
-
self.inner_dim = self.
|
152 |
-
self.kv_inner_dim = self.
|
153 |
-
|
154 |
-
self.
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
)
|
159 |
-
else:
|
160 |
-
self.rotary_embedding = None
|
161 |
|
162 |
# Mesh TensorFlow initialization to avoid scaling before softmax
|
163 |
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
@@ -165,179 +195,79 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
|
|
165 |
self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
|
166 |
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
167 |
|
168 |
-
if self.has_relative_attention_bias:
|
169 |
-
self.relative_attention_bias = nn.Embedding(
|
170 |
-
self.relative_attention_num_buckets, self.n_heads
|
171 |
-
)
|
172 |
self.pruned_heads = set()
|
173 |
self.gradient_checkpointing = False
|
174 |
|
175 |
def forward(
|
176 |
self,
|
177 |
-
hidden_states,
|
178 |
-
mask=None,
|
179 |
key_value_states=None,
|
180 |
position_bias=None,
|
181 |
-
|
182 |
-
past_key_value=None,
|
183 |
layer_head_mask=None,
|
184 |
-
|
185 |
-
|
186 |
-
output_attentions=False,
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
batch_size, seq_length = hidden_states.shape[:2]
|
195 |
|
196 |
-
|
|
|
|
|
197 |
|
|
|
|
|
|
|
|
|
|
|
198 |
if past_key_value is not None:
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
)
|
203 |
-
real_seq_length += (
|
204 |
-
past_key_value[0].shape[2] if query_length is None else query_length
|
205 |
-
)
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
210 |
|
211 |
-
|
212 |
-
"""projection"""
|
213 |
-
return states.view(
|
214 |
-
batch_size, -1, n_heads, self.key_value_proj_dim
|
215 |
-
).transpose(1, 2)
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
return (
|
220 |
-
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
221 |
-
)
|
222 |
|
223 |
-
|
224 |
-
"""projects hidden states correctly to key/query states"""
|
225 |
-
if key_value_states is None:
|
226 |
-
# self-attn
|
227 |
-
# (batch_size, n_kv_heads, seq_length, dim_per_head)
|
228 |
-
hidden_states = shape(proj_layer(hidden_states), self.n_kv_heads)
|
229 |
-
elif past_key_value is None:
|
230 |
-
# cross-attn
|
231 |
-
# (batch_size, n_kv_heads, seq_length, dim_per_head)
|
232 |
-
hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
|
233 |
-
return hidden_states
|
234 |
-
|
235 |
-
def concat_past_key_value(hidden_states, past_key_value, key_value_states):
|
236 |
-
if key_value_states is None:
|
237 |
-
# self-attn
|
238 |
-
# (batch_size, n_kv_heads, key_length, dim_per_head)
|
239 |
-
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
240 |
-
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
241 |
-
# checking that the `sequence_length` of the `past_key_value` is the same as
|
242 |
-
# the provided `key_value_states` to support prefix tuning
|
243 |
-
# cross-attn
|
244 |
-
# (batch_size, n_kv_heads, seq_length, dim_per_head)
|
245 |
-
raise NotImplementedError(
|
246 |
-
"cross attention with RoPE and past KV is not implemented"
|
247 |
-
)
|
248 |
-
# hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
|
249 |
-
else:
|
250 |
-
# cross-attn
|
251 |
-
hidden_states = past_key_value
|
252 |
-
return hidden_states
|
253 |
-
|
254 |
-
# get query states
|
255 |
-
query_states = shape(
|
256 |
-
self.q(hidden_states), self.n_heads
|
257 |
-
) # (batch_size, n_heads, seq_length, dim_per_head)
|
258 |
-
|
259 |
-
# get key/value states
|
260 |
-
key_states = project(hidden_states, self.k, key_value_states, past_key_value)
|
261 |
-
value_states = project(hidden_states, self.v, key_value_states, past_key_value)
|
262 |
-
|
263 |
-
# RoPE
|
264 |
-
if self.rotary_embedding is not None:
|
265 |
-
kv_seq_len = key_states.shape[-2]
|
266 |
-
if past_key_value:
|
267 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
268 |
-
cos, sin = self.rotary_embedding(query_states, seq_len=kv_seq_len)
|
269 |
-
query_states, key_states = apply_rotary_pos_emb(
|
270 |
-
query_states, key_states, cos, sin, position_ids
|
271 |
-
)
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
past_key_value[0],
|
278 |
-
key_value_states,
|
279 |
-
)
|
280 |
-
value_states = concat_past_key_value(
|
281 |
-
value_states,
|
282 |
-
past_key_value[1],
|
283 |
-
key_value_states,
|
284 |
)
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
# compute scores
|
291 |
-
scores = torch.matmul(
|
292 |
-
query_states, key_states.transpose(3, 2)
|
293 |
-
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
294 |
-
|
295 |
-
if position_bias is None:
|
296 |
-
if not self.has_relative_attention_bias:
|
297 |
-
position_bias = torch.zeros(
|
298 |
-
(1, self.n_heads, real_seq_length, key_length),
|
299 |
-
device=scores.device,
|
300 |
-
dtype=scores.dtype,
|
301 |
-
)
|
302 |
-
if self.gradient_checkpointing and self.training:
|
303 |
-
position_bias.requires_grad = True
|
304 |
-
else:
|
305 |
-
position_bias = self.compute_bias(
|
306 |
-
real_seq_length, key_length, device=scores.device
|
307 |
)
|
|
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
mask[list(self.pruned_heads)] = 0
|
322 |
-
position_bias_masked = position_bias[:, mask.bool()]
|
323 |
-
else:
|
324 |
-
position_bias_masked = position_bias
|
325 |
-
|
326 |
-
scores += position_bias_masked
|
327 |
-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
328 |
-
scores
|
329 |
-
) # (batch_size, n_heads, seq_length, key_length)
|
330 |
-
attn_weights = nn.functional.dropout(
|
331 |
-
attn_weights, p=self.dropout, training=self.training
|
332 |
-
) # (batch_size, n_heads, seq_length, key_length)
|
333 |
-
|
334 |
-
# Mask heads if we want to
|
335 |
-
if layer_head_mask is not None:
|
336 |
-
attn_weights = attn_weights * layer_head_mask
|
337 |
-
|
338 |
-
attn_output = unshape(
|
339 |
-
torch.matmul(attn_weights, value_states)
|
340 |
-
) # (batch_size, seq_length, dim)
|
341 |
attn_output = self.o(attn_output)
|
342 |
|
343 |
present_key_value_state = (
|
@@ -356,8 +286,11 @@ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
|
|
356 |
self.SelfAttention = DecoderOnlyT5Attention(
|
357 |
config, has_relative_attention_bias=has_relative_attention_bias
|
358 |
)
|
359 |
-
self.layer_norm =
|
360 |
-
config.d_model,
|
|
|
|
|
|
|
361 |
)
|
362 |
self.dropout = nn.Dropout(config.dropout_rate)
|
363 |
self.parallel_layers = config.parallel_layers
|
@@ -425,20 +358,19 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
425 |
position_bias=None,
|
426 |
position_ids=None,
|
427 |
encoder_hidden_states=None,
|
428 |
-
encoder_attention_mask=None,
|
429 |
-
encoder_decoder_position_bias=None,
|
430 |
layer_head_mask=None,
|
431 |
-
cross_attn_layer_head_mask=None,
|
432 |
past_key_value=None,
|
433 |
use_cache=False,
|
434 |
output_attentions=False,
|
|
|
|
|
|
|
435 |
return_dict=True,
|
436 |
):
|
|
|
|
|
|
|
437 |
if past_key_value is not None:
|
438 |
-
if not self.is_decoder:
|
439 |
-
logger.warning(
|
440 |
-
"`past_key_values` is passed to the encoder. Please make sure this is intended."
|
441 |
-
)
|
442 |
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
443 |
|
444 |
if len(past_key_value) != expected_num_past_key_values:
|
@@ -447,11 +379,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
447 |
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
448 |
f"Got {len(past_key_value)} past key / value states"
|
449 |
)
|
450 |
-
|
451 |
self_attn_past_key_value = past_key_value[:2]
|
452 |
-
cross_attn_past_key_value = past_key_value[2:]
|
453 |
else:
|
454 |
-
self_attn_past_key_value
|
455 |
|
456 |
ff_layer = self.layer[-1]
|
457 |
if self.parallel_layers:
|
@@ -490,45 +420,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
490 |
and not self.is_decoder_only
|
491 |
and encoder_hidden_states is not None
|
492 |
)
|
493 |
-
|
494 |
-
# the actual query length is unknown for cross attention
|
495 |
-
# if using past key value states. Need to inject it here
|
496 |
-
if present_key_value_state is not None:
|
497 |
-
query_length = present_key_value_state[0].shape[2]
|
498 |
-
else:
|
499 |
-
query_length = None
|
500 |
-
|
501 |
-
cross_attention_outputs = self.layer[1](
|
502 |
-
x,
|
503 |
-
key_value_states=encoder_hidden_states,
|
504 |
-
attention_mask=encoder_attention_mask,
|
505 |
-
position_bias=encoder_decoder_position_bias,
|
506 |
-
# position_ids ?
|
507 |
-
layer_head_mask=cross_attn_layer_head_mask,
|
508 |
-
past_key_value=cross_attn_past_key_value,
|
509 |
-
query_length=query_length,
|
510 |
-
use_cache=use_cache,
|
511 |
-
output_attentions=output_attentions,
|
512 |
-
)
|
513 |
-
x = cross_attention_outputs[0]
|
514 |
-
|
515 |
-
# clamp inf values to enable fp16 training
|
516 |
-
if x.dtype == torch.float16:
|
517 |
-
clamp_value = torch.where(
|
518 |
-
torch.isinf(x).any(),
|
519 |
-
torch.finfo(x.dtype).max - 1000,
|
520 |
-
torch.finfo(x.dtype).max,
|
521 |
-
)
|
522 |
-
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
523 |
-
|
524 |
-
# Combine self attn and cross attn key value states
|
525 |
-
if present_key_value_state is not None:
|
526 |
-
present_key_value_state = (
|
527 |
-
present_key_value_state + cross_attention_outputs[1]
|
528 |
-
)
|
529 |
-
|
530 |
-
# Keep cross-attention outputs and relative position weights
|
531 |
-
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
532 |
|
533 |
if self.parallel_layers:
|
534 |
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
@@ -577,12 +469,12 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
577 |
for i in range(config.num_layers)
|
578 |
]
|
579 |
)
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
self.dropout = nn.Dropout(config.dropout_rate)
|
587 |
|
588 |
# Initialize weights and apply final processing
|
@@ -654,8 +546,7 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
654 |
seq_length + past_key_values_length,
|
655 |
dtype=torch.long,
|
656 |
device=device,
|
657 |
-
)
|
658 |
-
position_ids = position_ids.unsqueeze(0)
|
659 |
|
660 |
if inputs_embeds is None:
|
661 |
if self.embed_tokens is None:
|
@@ -683,18 +574,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
683 |
attention_mask = torch.ones(
|
684 |
batch_size, mask_seq_length, device=inputs_embeds.device
|
685 |
)
|
686 |
-
if (
|
687 |
-
self.is_decoder
|
688 |
-
and encoder_attention_mask is None
|
689 |
-
and encoder_hidden_states is not None
|
690 |
-
):
|
691 |
-
encoder_seq_length = encoder_hidden_states.shape[1]
|
692 |
-
encoder_attention_mask = torch.ones(
|
693 |
-
batch_size,
|
694 |
-
encoder_seq_length,
|
695 |
-
device=inputs_embeds.device,
|
696 |
-
dtype=torch.long,
|
697 |
-
)
|
698 |
|
699 |
# initialize past_key_values with `None` if past does not exist
|
700 |
if past_key_values is None:
|
@@ -706,25 +585,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
706 |
attention_mask, input_shape
|
707 |
)
|
708 |
|
709 |
-
# If a 2D or 3D attention mask is provided for the cross-attention
|
710 |
-
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
711 |
-
if self.is_decoder and encoder_hidden_states is not None:
|
712 |
-
(
|
713 |
-
encoder_batch_size,
|
714 |
-
encoder_sequence_length,
|
715 |
-
_,
|
716 |
-
) = encoder_hidden_states.size()
|
717 |
-
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
718 |
-
if encoder_attention_mask is None:
|
719 |
-
encoder_attention_mask = torch.ones(
|
720 |
-
encoder_hidden_shape, device=inputs_embeds.device
|
721 |
-
)
|
722 |
-
encoder_extended_attention_mask = self.invert_attention_mask(
|
723 |
-
encoder_attention_mask
|
724 |
-
)
|
725 |
-
else:
|
726 |
-
encoder_extended_attention_mask = None
|
727 |
-
|
728 |
if self.gradient_checkpointing and self.training:
|
729 |
if use_cache:
|
730 |
logger.warning_once(
|
@@ -742,7 +602,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
742 |
all_attentions = () if output_attentions else None
|
743 |
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
744 |
position_bias = None
|
745 |
-
encoder_decoder_position_bias = None
|
746 |
|
747 |
hidden_states = self.dropout(inputs_embeds)
|
748 |
|
@@ -758,25 +617,10 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
758 |
if attention_mask is not None:
|
759 |
attention_mask = attention_mask.to(hidden_states.device)
|
760 |
if position_bias is not None:
|
761 |
-
position_bias = position_bias.to(hidden_states.device)
|
762 |
-
if encoder_hidden_states is not None:
|
763 |
-
encoder_hidden_states = encoder_hidden_states.to(
|
764 |
-
hidden_states.device
|
765 |
-
)
|
766 |
-
if encoder_extended_attention_mask is not None:
|
767 |
-
encoder_extended_attention_mask = (
|
768 |
-
encoder_extended_attention_mask.to(hidden_states.device)
|
769 |
-
)
|
770 |
-
if encoder_decoder_position_bias is not None:
|
771 |
-
encoder_decoder_position_bias = encoder_decoder_position_bias.to(
|
772 |
-
hidden_states.device
|
773 |
-
)
|
774 |
if layer_head_mask is not None:
|
775 |
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
776 |
-
|
777 |
-
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
|
778 |
-
hidden_states.device
|
779 |
-
)
|
780 |
if output_hidden_states:
|
781 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
782 |
|
@@ -786,9 +630,9 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
786 |
hidden_states,
|
787 |
extended_attention_mask,
|
788 |
position_bias,
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
layer_head_mask,
|
793 |
cross_attn_layer_head_mask,
|
794 |
None, # past_key_value is always None with gradient checkpointing
|
@@ -801,9 +645,9 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
801 |
attention_mask=extended_attention_mask,
|
802 |
position_bias=position_bias,
|
803 |
position_ids=position_ids,
|
804 |
-
encoder_hidden_states=
|
805 |
-
encoder_attention_mask=
|
806 |
-
encoder_decoder_position_bias=
|
807 |
layer_head_mask=layer_head_mask,
|
808 |
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
809 |
past_key_value=past_key_value,
|
@@ -822,10 +666,6 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
|
822 |
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
823 |
# (cross-attention position bias), (cross-attention weights)
|
824 |
position_bias = layer_outputs[2]
|
825 |
-
if self.is_decoder and encoder_hidden_states is not None:
|
826 |
-
encoder_decoder_position_bias = layer_outputs[
|
827 |
-
4 if output_attentions else 3
|
828 |
-
]
|
829 |
# append next layer key value states
|
830 |
if use_cache:
|
831 |
present_key_value_states = present_key_value_states + (
|
@@ -900,8 +740,6 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
|
|
900 |
def _tie_weights(self):
|
901 |
if not self.config.tie_word_embeddings:
|
902 |
return
|
903 |
-
if self.encoder:
|
904 |
-
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
905 |
if self.decoder:
|
906 |
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
907 |
|
|
|
1 |
import copy
|
2 |
+
import math
|
3 |
from typing import Optional, Tuple, Union
|
4 |
|
5 |
import torch
|
|
|
20 |
_CONFIG_FOR_DOC = "DecoderOnlyT5Config"
|
21 |
|
22 |
|
23 |
+
class DecoderOnlyT5LayerNorm(nn.Module):
|
24 |
+
def __init__(self, hidden_size, eps=1e-6, use_scale=True, center_scale_at_zero=False):
|
25 |
+
"""
|
26 |
+
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
if use_scale:
|
30 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
31 |
+
else:
|
32 |
+
assert not center_scale_at_zero
|
33 |
+
self.weight = None
|
34 |
+
self.center_scale_at_zero = center_scale_at_zero
|
35 |
+
self.variance_epsilon = eps
|
36 |
+
|
37 |
+
def forward(self, hidden_states):
|
38 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/layer_norm.py#L30
|
39 |
+
|
40 |
+
# layer norm should always be calculated in float32
|
41 |
+
mean2 = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
42 |
+
hidden_states = hidden_states * torch.rsqrt(mean2 + self.variance_epsilon)
|
43 |
+
|
44 |
+
# convert into float16 if necessary
|
45 |
+
if self.weight is None:
|
46 |
+
return hidden_states
|
47 |
+
if self.weight.dtype == torch.float16:
|
48 |
+
hidden_states = hidden_states.to(torch.float16)
|
49 |
+
if self.center_scale_at_zero:
|
50 |
+
return (self.weight + 1.0) * hidden_states
|
51 |
+
else:
|
52 |
+
return self.weight * hidden_states
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
|
57 |
def __init__(self, config: DecoderOnlyT5Config):
|
58 |
super(modeling_t5.T5LayerFF, self).__init__()
|
|
|
62 |
self.DenseReluDense = modeling_t5.T5DenseActDense(config)
|
63 |
|
64 |
if not config.parallel_layers:
|
65 |
+
self.layer_norm = modeling_t5.DecoderOnlyT5LayerNorm(
|
66 |
config.d_model, eps=config.layer_norm_epsilon
|
67 |
)
|
68 |
else:
|
|
|
71 |
|
72 |
|
73 |
# LlamaRotaryEmbedding
|
74 |
+
class DecoderOnlyT5RotaryEmbedding(nn.Module):
|
75 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
76 |
super().__init__()
|
77 |
|
|
|
173 |
def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
|
174 |
super(modeling_t5.T5Attention, self).__init__()
|
175 |
self.is_decoder = config.is_decoder
|
176 |
+
assert not has_relative_attention_bias
|
177 |
+
assert config.use_rotary_embedding
|
|
|
178 |
self.d_model = config.d_model
|
179 |
+
self.head_dim = config.d_kv
|
180 |
+
self.num_heads = config.num_heads
|
181 |
+
self.num_key_value_heads = 1 if config.multi_query_attention else self.n_heads
|
182 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
183 |
+
self.attention_dropout = config.dropout_rate
|
184 |
+
self.inner_dim = self.num_heads * self.head_dim
|
185 |
+
self.kv_inner_dim = self.num_key_value_heads * self.head_dim
|
186 |
+
self.rotary_emb = DecoderOnlyT5RotaryEmbedding(
|
187 |
+
self.head_dim,
|
188 |
+
max_position_embeddings=config.relative_attention_max_distance,
|
189 |
+
base=config.rotary_embedding_max_timescale,
|
190 |
+
)
|
|
|
|
|
|
|
191 |
|
192 |
# Mesh TensorFlow initialization to avoid scaling before softmax
|
193 |
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
|
|
195 |
self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
|
196 |
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
self.pruned_heads = set()
|
199 |
self.gradient_checkpointing = False
|
200 |
|
201 |
def forward(
|
202 |
self,
|
203 |
+
hidden_states: torch.Tensor,
|
|
|
204 |
key_value_states=None,
|
205 |
position_bias=None,
|
206 |
+
mask: Optional[torch.Tensor] = None,
|
|
|
207 |
layer_head_mask=None,
|
208 |
+
position_ids: Optional[torch.LongTensor] = None,
|
209 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
210 |
+
output_attentions: bool = False,
|
211 |
+
use_cache: bool = False,
|
212 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
213 |
+
assert key_value_states is None
|
214 |
+
assert position_bias is None
|
215 |
+
assert layer_head_mask is None
|
216 |
+
|
217 |
+
bsz, q_len, _ = hidden_states.size()
|
|
|
218 |
|
219 |
+
query_states = self.q(hidden_states)
|
220 |
+
key_states = self.k(hidden_states)
|
221 |
+
value_states = self.v(hidden_states)
|
222 |
|
223 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
224 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
225 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
226 |
+
|
227 |
+
kv_seq_len = key_states.shape[-2]
|
228 |
if past_key_value is not None:
|
229 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
230 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
231 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
|
|
|
|
|
232 |
|
233 |
+
if past_key_value is not None:
|
234 |
+
# reuse k, v, self_attention
|
235 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
236 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
237 |
|
238 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
|
|
|
|
|
239 |
|
240 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
241 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
|
|
|
242 |
|
243 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
246 |
+
raise ValueError(
|
247 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
248 |
+
f" {attn_weights.size()}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
)
|
250 |
|
251 |
+
if mask is not None:
|
252 |
+
if mask.size() != (bsz, 1, q_len, kv_seq_len):
|
253 |
+
raise ValueError(
|
254 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {mask.size()}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
)
|
256 |
+
attn_weights = attn_weights + mask
|
257 |
|
258 |
+
# upcast attention to fp32
|
259 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
260 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout)
|
261 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
262 |
|
263 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
264 |
+
raise ValueError(
|
265 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
266 |
+
f" {attn_output.size()}"
|
267 |
+
)
|
268 |
|
269 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
270 |
+
attn_output = attn_output.reshape(bsz, q_len, self.inner_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
attn_output = self.o(attn_output)
|
272 |
|
273 |
present_key_value_state = (
|
|
|
286 |
self.SelfAttention = DecoderOnlyT5Attention(
|
287 |
config, has_relative_attention_bias=has_relative_attention_bias
|
288 |
)
|
289 |
+
self.layer_norm = DecoderOnlyT5LayerNorm(
|
290 |
+
config.d_model,
|
291 |
+
eps=config.layer_norm_epsilon,
|
292 |
+
use_scale=True,
|
293 |
+
center_scale_at_zero=True,
|
294 |
)
|
295 |
self.dropout = nn.Dropout(config.dropout_rate)
|
296 |
self.parallel_layers = config.parallel_layers
|
|
|
358 |
position_bias=None,
|
359 |
position_ids=None,
|
360 |
encoder_hidden_states=None,
|
|
|
|
|
361 |
layer_head_mask=None,
|
|
|
362 |
past_key_value=None,
|
363 |
use_cache=False,
|
364 |
output_attentions=False,
|
365 |
+
encoder_attention_mask=None,
|
366 |
+
encoder_decoder_position_bias=None,
|
367 |
+
cross_attn_layer_head_mask=None,
|
368 |
return_dict=True,
|
369 |
):
|
370 |
+
assert encoder_attention_mask is None
|
371 |
+
assert encoder_decoder_position_bias is None
|
372 |
+
assert cross_attn_layer_head_mask is None
|
373 |
if past_key_value is not None:
|
|
|
|
|
|
|
|
|
374 |
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
375 |
|
376 |
if len(past_key_value) != expected_num_past_key_values:
|
|
|
379 |
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
380 |
f"Got {len(past_key_value)} past key / value states"
|
381 |
)
|
|
|
382 |
self_attn_past_key_value = past_key_value[:2]
|
|
|
383 |
else:
|
384 |
+
self_attn_past_key_value = None
|
385 |
|
386 |
ff_layer = self.layer[-1]
|
387 |
if self.parallel_layers:
|
|
|
420 |
and not self.is_decoder_only
|
421 |
and encoder_hidden_states is not None
|
422 |
)
|
423 |
+
assert not do_cross_attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
if self.parallel_layers:
|
426 |
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
|
|
469 |
for i in range(config.num_layers)
|
470 |
]
|
471 |
)
|
472 |
+
self.final_layer_norm = DecoderOnlyT5LayerNorm(
|
473 |
+
config.d_model,
|
474 |
+
eps=config.layer_norm_epsilon,
|
475 |
+
use_scale=False,
|
476 |
+
center_scale_at_zero=False,
|
477 |
+
)
|
478 |
self.dropout = nn.Dropout(config.dropout_rate)
|
479 |
|
480 |
# Initialize weights and apply final processing
|
|
|
546 |
seq_length + past_key_values_length,
|
547 |
dtype=torch.long,
|
548 |
device=device,
|
549 |
+
).unsqueeze(0)
|
|
|
550 |
|
551 |
if inputs_embeds is None:
|
552 |
if self.embed_tokens is None:
|
|
|
574 |
attention_mask = torch.ones(
|
575 |
batch_size, mask_seq_length, device=inputs_embeds.device
|
576 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
|
578 |
# initialize past_key_values with `None` if past does not exist
|
579 |
if past_key_values is None:
|
|
|
585 |
attention_mask, input_shape
|
586 |
)
|
587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
if self.gradient_checkpointing and self.training:
|
589 |
if use_cache:
|
590 |
logger.warning_once(
|
|
|
602 |
all_attentions = () if output_attentions else None
|
603 |
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
604 |
position_bias = None
|
|
|
605 |
|
606 |
hidden_states = self.dropout(inputs_embeds)
|
607 |
|
|
|
617 |
if attention_mask is not None:
|
618 |
attention_mask = attention_mask.to(hidden_states.device)
|
619 |
if position_bias is not None:
|
620 |
+
position_bias = position_bias.to(hidden_states.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
if layer_head_mask is not None:
|
622 |
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
623 |
+
|
|
|
|
|
|
|
624 |
if output_hidden_states:
|
625 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
626 |
|
|
|
630 |
hidden_states,
|
631 |
extended_attention_mask,
|
632 |
position_bias,
|
633 |
+
None,
|
634 |
+
None,
|
635 |
+
None,
|
636 |
layer_head_mask,
|
637 |
cross_attn_layer_head_mask,
|
638 |
None, # past_key_value is always None with gradient checkpointing
|
|
|
645 |
attention_mask=extended_attention_mask,
|
646 |
position_bias=position_bias,
|
647 |
position_ids=position_ids,
|
648 |
+
encoder_hidden_states=None,
|
649 |
+
encoder_attention_mask=None,
|
650 |
+
encoder_decoder_position_bias=None,
|
651 |
layer_head_mask=layer_head_mask,
|
652 |
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
653 |
past_key_value=past_key_value,
|
|
|
666 |
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
667 |
# (cross-attention position bias), (cross-attention weights)
|
668 |
position_bias = layer_outputs[2]
|
|
|
|
|
|
|
|
|
669 |
# append next layer key value states
|
670 |
if use_cache:
|
671 |
present_key_value_states = present_key_value_states + (
|
|
|
740 |
def _tie_weights(self):
|
741 |
if not self.config.tie_word_embeddings:
|
742 |
return
|
|
|
|
|
743 |
if self.decoder:
|
744 |
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
745 |
|