Alex Birch
commited on
gradient checkpointing for multi-query attention
Browse files- attention.py +64 -5
attention.py
CHANGED
@@ -316,7 +316,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
316 |
False, # multiquery
|
317 |
)
|
318 |
return custom_forward
|
319 |
-
|
320 |
create_custom_forward(self.attn_fn),
|
321 |
query,
|
322 |
key,
|
@@ -332,7 +332,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
332 |
**ckpt_kwargs,
|
333 |
)
|
334 |
else:
|
335 |
-
|
336 |
query,
|
337 |
key,
|
338 |
value,
|
@@ -345,7 +345,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
345 |
training=self.training,
|
346 |
needs_weights=needs_weights,
|
347 |
)
|
348 |
-
context, attn_weights =
|
349 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
350 |
|
351 |
class MultiQueryAttention(nn.Module, Attn):
|
@@ -413,8 +413,67 @@ class MultiQueryAttention(nn.Module, Attn):
|
|
413 |
past_key_value = PastKeyValue(key, value)
|
414 |
if attn_bias is not None:
|
415 |
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
419 |
|
420 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
|
|
316 |
False, # multiquery
|
317 |
)
|
318 |
return custom_forward
|
319 |
+
attn_fn_out: AttnFnOutput = checkpoint(
|
320 |
create_custom_forward(self.attn_fn),
|
321 |
query,
|
322 |
key,
|
|
|
332 |
**ckpt_kwargs,
|
333 |
)
|
334 |
else:
|
335 |
+
attn_fn_out: AttnFnOutput = self.attn_fn(
|
336 |
query,
|
337 |
key,
|
338 |
value,
|
|
|
345 |
training=self.training,
|
346 |
needs_weights=needs_weights,
|
347 |
)
|
348 |
+
context, attn_weights = attn_fn_out
|
349 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
350 |
|
351 |
class MultiQueryAttention(nn.Module, Attn):
|
|
|
413 |
past_key_value = PastKeyValue(key, value)
|
414 |
if attn_bias is not None:
|
415 |
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
416 |
+
if self.training and self.gradient_checkpointing:
|
417 |
+
ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
|
418 |
+
def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
|
419 |
+
def custom_forward(
|
420 |
+
query: torch.Tensor,
|
421 |
+
key: torch.Tensor,
|
422 |
+
value: torch.Tensor,
|
423 |
+
n_heads: int,
|
424 |
+
softmax_scale: Optional[float],
|
425 |
+
attn_bias: Optional[torch.Tensor],
|
426 |
+
key_padding_mask: Optional[torch.ByteTensor],
|
427 |
+
is_causal: bool,
|
428 |
+
dropout_p: float,
|
429 |
+
training: bool,
|
430 |
+
needs_weights: bool,
|
431 |
+
):
|
432 |
+
return attn_fn(
|
433 |
+
query,
|
434 |
+
key,
|
435 |
+
value,
|
436 |
+
n_heads,
|
437 |
+
softmax_scale,
|
438 |
+
attn_bias,
|
439 |
+
key_padding_mask,
|
440 |
+
is_causal,
|
441 |
+
dropout_p,
|
442 |
+
training,
|
443 |
+
needs_weights,
|
444 |
+
True, # multiquery
|
445 |
+
)
|
446 |
+
return custom_forward
|
447 |
+
attn_fn_out: AttnFnOutput = checkpoint(
|
448 |
+
create_custom_forward(self.attn_fn),
|
449 |
+
query,
|
450 |
+
key,
|
451 |
+
value,
|
452 |
+
self.n_heads,
|
453 |
+
self.softmax_scale,
|
454 |
+
attn_bias,
|
455 |
+
key_padding_mask,
|
456 |
+
is_causal,
|
457 |
+
self.attn_dropout_p,
|
458 |
+
self.training,
|
459 |
+
needs_weights,
|
460 |
+
**ckpt_kwargs,
|
461 |
+
)
|
462 |
+
else:
|
463 |
+
attn_fn_out: AttnFnOutput = self.attn_fn(
|
464 |
+
query,
|
465 |
+
key,
|
466 |
+
value,
|
467 |
+
self.n_heads,
|
468 |
+
softmax_scale=self.softmax_scale,
|
469 |
+
attn_bias=attn_bias,
|
470 |
+
key_padding_mask=key_padding_mask,
|
471 |
+
is_causal=is_causal,
|
472 |
+
dropout_p=self.attn_dropout_p,
|
473 |
+
training=self.training,
|
474 |
+
needs_weights=needs_weights,
|
475 |
+
)
|
476 |
+
context, attn_weights = attn_fn_out
|
477 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
478 |
|
479 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|