|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
|
|
try: |
|
from .triton_flash_atn import _attention |
|
from .triton_bert_pading import pad_input, unpad_input |
|
except: |
|
print("FlashAttention is not installed.") |
|
|
|
|
|
class FlashAttention(nn.Module): |
|
"""Implement the scaled dot product attention with softmax. |
|
Arguments |
|
--------- |
|
softmax_scale: The temperature to use for the softmax attention. |
|
(default: 1/sqrt(d_keys) where d_keys is computed at |
|
runtime) |
|
attention_dropout: The dropout rate to apply to the attention |
|
(default: 0.0) |
|
""" |
|
|
|
def __init__( |
|
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None |
|
): |
|
super().__init__() |
|
self.softmax_scale = softmax_scale |
|
self.dropout_p = attention_dropout |
|
|
|
def forward( |
|
self, |
|
qkv, |
|
key_padding_mask=None, |
|
causal=False, |
|
cu_seqlens=None, |
|
max_s=None, |
|
need_weights=False, |
|
): |
|
"""Implements the multihead softmax attention. |
|
Arguments |
|
--------- |
|
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None |
|
if unpadded: (nnz, 3, h, d) |
|
key_padding_mask: a bool tensor of shape (B, S) |
|
""" |
|
assert not need_weights |
|
assert qkv.dtype in [torch.float16, torch.bfloat16] |
|
assert qkv.is_cuda |
|
|
|
if cu_seqlens is None: |
|
batch_size = qkv.shape[0] |
|
seqlen = qkv.shape[1] |
|
if key_padding_mask is None: |
|
qkv = rearrange(qkv, "b s ... -> (b s) ...") |
|
max_s = seqlen |
|
cu_seqlens = torch.arange( |
|
0, |
|
(batch_size + 1) * seqlen, |
|
step=seqlen, |
|
dtype=torch.int32, |
|
device=qkv.device, |
|
) |
|
output = _attention.apply( |
|
qkv, |
|
cu_seqlens, |
|
max_s, |
|
self.dropout_p if self.training else 0.0, |
|
self.softmax_scale, |
|
causal |
|
) |
|
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) |
|
else: |
|
nheads = qkv.shape[-2] |
|
x = rearrange(qkv, "b s three h d -> b s (three h d)") |
|
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) |
|
x_unpad = rearrange( |
|
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads |
|
) |
|
output_unpad = _attention.apply( |
|
x_unpad, |
|
cu_seqlens, |
|
max_s, |
|
self.dropout_p if self.training else 0.0, |
|
self.softmax_scale, |
|
causal |
|
) |
|
output = rearrange( |
|
pad_input( |
|
rearrange(output_unpad, "nnz h d -> nnz (h d)"), |
|
indices, |
|
batch_size, |
|
seqlen, |
|
), |
|
"b s (h d) -> b s h d", |
|
h=nheads, |
|
) |
|
else: |
|
assert max_s is not None |
|
output = _attention.apply( |
|
qkv, |
|
cu_seqlens, |
|
max_s, |
|
self.dropout_p if self.training else 0.0, |
|
self.softmax_scale, |
|
causal |
|
) |
|
|
|
return output, None |
|
|
|
|
|
|