Make tests work
Browse files- tests/__init__.py +0 -0
- tests/padding.py +53 -0
- tests/test_flash_attn.py +12 -15
- tests/test_util.py +348 -0
tests/__init__.py
ADDED
|
File without changes
|
tests/padding.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
| 9 |
+
"""
|
| 10 |
+
Arguments:
|
| 11 |
+
hidden_states: (batch, seqlen, ...)
|
| 12 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
| 13 |
+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
| 14 |
+
Return:
|
| 15 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
| 16 |
+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
| 17 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 18 |
+
max_seqlen_in_batch: int
|
| 19 |
+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
| 20 |
+
"""
|
| 21 |
+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
| 22 |
+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
| 23 |
+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 24 |
+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
| 25 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 26 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 27 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 28 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 29 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 30 |
+
# index with integer indices.
|
| 31 |
+
return (
|
| 32 |
+
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
| 33 |
+
indices,
|
| 34 |
+
cu_seqlens,
|
| 35 |
+
max_seqlen_in_batch,
|
| 36 |
+
used_seqlens_in_batch,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
| 41 |
+
"""
|
| 42 |
+
Arguments:
|
| 43 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 44 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
| 45 |
+
batch: int, batch size for the padded sequence.
|
| 46 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
| 47 |
+
Return:
|
| 48 |
+
hidden_states: (batch, seqlen, ...)
|
| 49 |
+
"""
|
| 50 |
+
dim = hidden_states.shape[1:]
|
| 51 |
+
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 52 |
+
output[indices] = hidden_states
|
| 53 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
tests/test_flash_attn.py
CHANGED
|
@@ -8,10 +8,7 @@ import torch.nn.functional as F
|
|
| 8 |
from torch._C import parse_schema
|
| 9 |
|
| 10 |
from einops import rearrange, repeat
|
| 11 |
-
|
| 12 |
-
from flash_attn.layers.rotary import apply_rotary_emb
|
| 13 |
-
except ImportError:
|
| 14 |
-
apply_rotary_emb = None
|
| 15 |
|
| 16 |
from padding import pad_input, unpad_input
|
| 17 |
from test_util import (
|
|
@@ -20,10 +17,10 @@ from test_util import (
|
|
| 20 |
generate_random_padding_mask,
|
| 21 |
)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
from flash_attn3 import flash_attn_with_kvcache, get_scheduler_metadata
|
| 25 |
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
|
@@ -195,7 +192,7 @@ def test_flash_attn_output(
|
|
| 195 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
| 196 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
| 197 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
| 198 |
-
out, lse = flash_attn_func(
|
| 199 |
q,
|
| 200 |
k,
|
| 201 |
v,
|
|
@@ -462,7 +459,7 @@ def test_flash_attn_varlen_output(
|
|
| 462 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
| 463 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
| 464 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
| 465 |
-
out_unpad, lse = flash_attn_varlen_func(
|
| 466 |
q_unpad,
|
| 467 |
k_unpad,
|
| 468 |
v_unpad,
|
|
@@ -856,7 +853,7 @@ def test_flash_attn_kvcache(
|
|
| 856 |
precompute_metadata_vals = [False, True]
|
| 857 |
for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
|
| 858 |
if precompute_metadata:
|
| 859 |
-
scheduler_metadata = get_scheduler_metadata(
|
| 860 |
batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
|
| 861 |
cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
|
| 862 |
cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
|
|
@@ -874,7 +871,7 @@ def test_flash_attn_kvcache(
|
|
| 874 |
else:
|
| 875 |
k_cache_paged.copy_(k_cache_saved)
|
| 876 |
v_cache_paged.copy_(v_cache_saved)
|
| 877 |
-
out, lse, *rest = flash_attn_with_kvcache(
|
| 878 |
q if not varlen_q else q_unpad,
|
| 879 |
k_cache if page_size is None else k_cache_paged,
|
| 880 |
v_cache if page_size is None else v_cache_paged,
|
|
@@ -1008,7 +1005,7 @@ def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
|
|
| 1008 |
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
| 1009 |
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
| 1010 |
for _ in range(100):
|
| 1011 |
-
flash_attn_func(q, k, v, causal=causal)
|
| 1012 |
|
| 1013 |
|
| 1014 |
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
@@ -1052,7 +1049,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
|
| 1052 |
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
| 1053 |
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
| 1054 |
torch.random.manual_seed(42)
|
| 1055 |
-
out0, lse0 = flash_attn_func(q, k, v, causal=causal)
|
| 1056 |
g = torch.randn_like(out0)
|
| 1057 |
dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
|
| 1058 |
# Numerical error if we just do any arithmetic on dq
|
|
@@ -1060,7 +1057,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
|
| 1060 |
|
| 1061 |
for i in range(1000):
|
| 1062 |
torch.random.manual_seed(42)
|
| 1063 |
-
out, lse = flash_attn_func(q, k, v, causal=causal)
|
| 1064 |
assert torch.equal(out, out0)
|
| 1065 |
assert torch.equal(lse, lse0)
|
| 1066 |
|
|
@@ -1111,7 +1108,7 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
|
|
| 1111 |
lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
|
| 1112 |
# To test short-circuiting based on num_splits
|
| 1113 |
lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
|
| 1114 |
-
out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
|
| 1115 |
out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
|
| 1116 |
out_pt = out_ref.to(dtype)
|
| 1117 |
|
|
|
|
| 8 |
from torch._C import parse_schema
|
| 9 |
|
| 10 |
from einops import rearrange, repeat
|
| 11 |
+
apply_rotary_emb = None
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from padding import pad_input, unpad_input
|
| 14 |
from test_util import (
|
|
|
|
| 17 |
generate_random_padding_mask,
|
| 18 |
)
|
| 19 |
|
| 20 |
+
import kernels
|
|
|
|
| 21 |
|
| 22 |
+
flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3")
|
| 23 |
+
ops = flash_attn3._ops
|
| 24 |
|
| 25 |
|
| 26 |
DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
|
|
|
|
| 192 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
| 193 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
| 194 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
| 195 |
+
out, lse = flash_attn3.flash_attn_func(
|
| 196 |
q,
|
| 197 |
k,
|
| 198 |
v,
|
|
|
|
| 459 |
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
| 460 |
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
| 461 |
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
| 462 |
+
out_unpad, lse = flash_attn3.flash_attn_varlen_func(
|
| 463 |
q_unpad,
|
| 464 |
k_unpad,
|
| 465 |
v_unpad,
|
|
|
|
| 853 |
precompute_metadata_vals = [False, True]
|
| 854 |
for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):
|
| 855 |
if precompute_metadata:
|
| 856 |
+
scheduler_metadata = flash_attn3.get_scheduler_metadata(
|
| 857 |
batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
|
| 858 |
cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
|
| 859 |
cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
|
|
|
|
| 871 |
else:
|
| 872 |
k_cache_paged.copy_(k_cache_saved)
|
| 873 |
v_cache_paged.copy_(v_cache_saved)
|
| 874 |
+
out, lse, *rest = flash_attn3.flash_attn_with_kvcache(
|
| 875 |
q if not varlen_q else q_unpad,
|
| 876 |
k_cache if page_size is None else k_cache_paged,
|
| 877 |
v_cache if page_size is None else v_cache_paged,
|
|
|
|
| 1005 |
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
| 1006 |
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
|
| 1007 |
for _ in range(100):
|
| 1008 |
+
flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
| 1009 |
|
| 1010 |
|
| 1011 |
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
|
| 1049 |
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
| 1050 |
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
| 1051 |
torch.random.manual_seed(42)
|
| 1052 |
+
out0, lse0 = flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
| 1053 |
g = torch.randn_like(out0)
|
| 1054 |
dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
|
| 1055 |
# Numerical error if we just do any arithmetic on dq
|
|
|
|
| 1057 |
|
| 1058 |
for i in range(1000):
|
| 1059 |
torch.random.manual_seed(42)
|
| 1060 |
+
out, lse = flash_attn3.flash_attn_func(q, k, v, causal=causal)
|
| 1061 |
assert torch.equal(out, out0)
|
| 1062 |
assert torch.equal(lse, lse0)
|
| 1063 |
|
|
|
|
| 1108 |
lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
|
| 1109 |
# To test short-circuiting based on num_splits
|
| 1110 |
lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
|
| 1111 |
+
out, lse = flash_attn3.flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
|
| 1112 |
out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
|
| 1113 |
out_pt = out_ref.to(dtype)
|
| 1114 |
|
tests/test_util.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
|
| 6 |
+
from padding import pad_input, unpad_input
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
| 10 |
+
assert mode in ["full", "random", "third"]
|
| 11 |
+
if mode == "full":
|
| 12 |
+
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
| 13 |
+
elif mode == "random":
|
| 14 |
+
lengths = torch.randint(
|
| 15 |
+
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
| 16 |
+
)
|
| 17 |
+
elif mode == "third":
|
| 18 |
+
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
| 19 |
+
|
| 20 |
+
if zero_lengths:
|
| 21 |
+
# Generate zero-lengths every 5 batches and the last batch.
|
| 22 |
+
for i in range(batch_size):
|
| 23 |
+
if i % 5 == 0:
|
| 24 |
+
lengths[i] = 0
|
| 25 |
+
lengths[-1] = 0
|
| 26 |
+
padding_mask = (
|
| 27 |
+
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
| 28 |
+
)
|
| 29 |
+
return padding_mask
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def generate_qkv(
|
| 33 |
+
q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False,
|
| 34 |
+
query_unused_mask=None, key_unused_mask=None,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Arguments:
|
| 38 |
+
q: (batch_size, seqlen_q, nheads, d)
|
| 39 |
+
k: (batch_size, seqlen_k, nheads_k, d)
|
| 40 |
+
v: (batch_size, seqlen_k, nheads_k, d_v)
|
| 41 |
+
query_padding_mask: (batch_size, seqlen), bool
|
| 42 |
+
key_padding_mask: (batch_size, seqlen), bool
|
| 43 |
+
"""
|
| 44 |
+
assert not (kvpacked and qkvpacked)
|
| 45 |
+
batch_size, seqlen_q, nheads, d = q.shape
|
| 46 |
+
d_v = v.shape[-1]
|
| 47 |
+
_, seqlen_k, nheads_k, _ = k.shape
|
| 48 |
+
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
| 49 |
+
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
|
| 50 |
+
if query_unused_mask is not None or key_unused_mask is not None:
|
| 51 |
+
assert not kvpacked
|
| 52 |
+
assert not qkvpacked
|
| 53 |
+
|
| 54 |
+
if query_padding_mask is not None:
|
| 55 |
+
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
| 56 |
+
q, query_padding_mask, query_unused_mask
|
| 57 |
+
)
|
| 58 |
+
output_pad_fn = lambda output_unpad: pad_input(
|
| 59 |
+
output_unpad, indices_q, batch_size, seqlen_q
|
| 60 |
+
)
|
| 61 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
| 62 |
+
else:
|
| 63 |
+
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
| 64 |
+
cu_seqlens_q = torch.arange(
|
| 65 |
+
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
| 66 |
+
)
|
| 67 |
+
seqused_q = None
|
| 68 |
+
max_seqlen_q = seqlen_q
|
| 69 |
+
output_pad_fn = lambda output_unpad: rearrange(
|
| 70 |
+
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
| 71 |
+
)
|
| 72 |
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
| 73 |
+
|
| 74 |
+
if key_padding_mask is not None:
|
| 75 |
+
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
| 76 |
+
k, key_padding_mask, key_unused_mask
|
| 77 |
+
)
|
| 78 |
+
v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
|
| 79 |
+
else:
|
| 80 |
+
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
| 81 |
+
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
| 82 |
+
cu_seqlens_k = torch.arange(
|
| 83 |
+
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
| 84 |
+
)
|
| 85 |
+
seqused_k = None
|
| 86 |
+
max_seqlen_k = seqlen_k
|
| 87 |
+
|
| 88 |
+
if qkvpacked:
|
| 89 |
+
assert (query_padding_mask == key_padding_mask).all()
|
| 90 |
+
assert nheads == nheads_k
|
| 91 |
+
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
| 92 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 93 |
+
if query_padding_mask is not None:
|
| 94 |
+
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
|
| 95 |
+
else:
|
| 96 |
+
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
| 97 |
+
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
| 98 |
+
)
|
| 99 |
+
return (
|
| 100 |
+
qkv_unpad.detach().requires_grad_(),
|
| 101 |
+
cu_seqlens_q,
|
| 102 |
+
max_seqlen_q,
|
| 103 |
+
qkv.detach().requires_grad_(),
|
| 104 |
+
output_pad_fn,
|
| 105 |
+
dqkv_pad_fn,
|
| 106 |
+
)
|
| 107 |
+
elif kvpacked:
|
| 108 |
+
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
| 109 |
+
kv = torch.stack([k, v], dim=2)
|
| 110 |
+
dq_pad_fn = output_pad_fn
|
| 111 |
+
if key_padding_mask is not None:
|
| 112 |
+
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
|
| 113 |
+
else:
|
| 114 |
+
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
| 115 |
+
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
| 116 |
+
)
|
| 117 |
+
return (
|
| 118 |
+
q_unpad.detach().requires_grad_(),
|
| 119 |
+
kv_unpad.detach().requires_grad_(),
|
| 120 |
+
cu_seqlens_q,
|
| 121 |
+
cu_seqlens_k,
|
| 122 |
+
max_seqlen_q,
|
| 123 |
+
max_seqlen_k,
|
| 124 |
+
q.detach().requires_grad_(),
|
| 125 |
+
kv.detach().requires_grad_(),
|
| 126 |
+
output_pad_fn,
|
| 127 |
+
dq_pad_fn,
|
| 128 |
+
dkv_pad_fn,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
dq_pad_fn = output_pad_fn
|
| 132 |
+
if key_padding_mask is not None:
|
| 133 |
+
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
|
| 134 |
+
else:
|
| 135 |
+
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
|
| 136 |
+
return (
|
| 137 |
+
q_unpad.detach().requires_grad_(),
|
| 138 |
+
k_unpad.detach().requires_grad_(),
|
| 139 |
+
v_unpad.detach().requires_grad_(),
|
| 140 |
+
qv_unpad.detach() if qv is not None else None,
|
| 141 |
+
cu_seqlens_q,
|
| 142 |
+
cu_seqlens_k,
|
| 143 |
+
seqused_q,
|
| 144 |
+
seqused_k,
|
| 145 |
+
max_seqlen_q,
|
| 146 |
+
max_seqlen_k,
|
| 147 |
+
q.detach().requires_grad_(),
|
| 148 |
+
k.detach().requires_grad_(),
|
| 149 |
+
v.detach().requires_grad_(),
|
| 150 |
+
qv.detach() if qv is not None else None,
|
| 151 |
+
output_pad_fn,
|
| 152 |
+
dq_pad_fn,
|
| 153 |
+
dk_pad_fn,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def construct_local_mask(
|
| 158 |
+
seqlen_q,
|
| 159 |
+
seqlen_k,
|
| 160 |
+
window_size=(-1, -1), # -1 means infinite window size
|
| 161 |
+
sink_token_length=0,
|
| 162 |
+
query_padding_mask=None,
|
| 163 |
+
key_padding_mask=None,
|
| 164 |
+
key_leftpad=None,
|
| 165 |
+
device=None,
|
| 166 |
+
):
|
| 167 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
| 168 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
| 169 |
+
if key_leftpad is not None:
|
| 170 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
| 171 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
| 172 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
| 173 |
+
sk = (
|
| 174 |
+
seqlen_k
|
| 175 |
+
if key_padding_mask is None
|
| 176 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 177 |
+
)
|
| 178 |
+
sq = (
|
| 179 |
+
seqlen_q
|
| 180 |
+
if query_padding_mask is None
|
| 181 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 182 |
+
)
|
| 183 |
+
if window_size[0] < 0:
|
| 184 |
+
return col_idx > row_idx + sk - sq + window_size[1]
|
| 185 |
+
else:
|
| 186 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
| 187 |
+
return torch.logical_or(
|
| 188 |
+
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
|
| 189 |
+
torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def construct_chunk_mask(
|
| 194 |
+
seqlen_q,
|
| 195 |
+
seqlen_k,
|
| 196 |
+
attention_chunk,
|
| 197 |
+
query_padding_mask=None,
|
| 198 |
+
key_padding_mask=None,
|
| 199 |
+
key_leftpad=None,
|
| 200 |
+
device=None,
|
| 201 |
+
):
|
| 202 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
| 203 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
| 204 |
+
if key_leftpad is not None:
|
| 205 |
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
| 206 |
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
| 207 |
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
| 208 |
+
sk = (
|
| 209 |
+
seqlen_k
|
| 210 |
+
if key_padding_mask is None
|
| 211 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 212 |
+
)
|
| 213 |
+
sq = (
|
| 214 |
+
seqlen_q
|
| 215 |
+
if query_padding_mask is None
|
| 216 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 217 |
+
)
|
| 218 |
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
| 219 |
+
# Subtract remainder instead of divide and then multiply to take care of negative values
|
| 220 |
+
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
|
| 221 |
+
return torch.logical_or(
|
| 222 |
+
col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def attention_ref(
|
| 227 |
+
q,
|
| 228 |
+
k,
|
| 229 |
+
v,
|
| 230 |
+
query_padding_mask=None,
|
| 231 |
+
key_padding_mask=None,
|
| 232 |
+
key_leftpad=None,
|
| 233 |
+
attn_bias=None,
|
| 234 |
+
dropout_p=0.0,
|
| 235 |
+
dropout_mask=None,
|
| 236 |
+
causal=False,
|
| 237 |
+
qv=None,
|
| 238 |
+
q_descale=None, k_descale=None, v_descale=None,
|
| 239 |
+
window_size=(-1, -1), # -1 means infinite window size
|
| 240 |
+
attention_chunk=0,
|
| 241 |
+
sink_token_length=0,
|
| 242 |
+
softcap=0.0,
|
| 243 |
+
upcast=True,
|
| 244 |
+
reorder_ops=False,
|
| 245 |
+
intermediate_dtype=None,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Arguments:
|
| 249 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 250 |
+
k: (batch_size, seqlen_k, nheads, head_dim)
|
| 251 |
+
v: (batch_size, seqlen_k, nheads, head_dim_v)
|
| 252 |
+
qv: (batch_size, seqlen_q, nheads, head_dim_v)
|
| 253 |
+
query_padding_mask: (batch_size, seqlen_q)
|
| 254 |
+
key_padding_mask: (batch_size, seqlen_k)
|
| 255 |
+
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
| 256 |
+
dropout_p: float
|
| 257 |
+
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
| 258 |
+
causal: whether to apply causal masking
|
| 259 |
+
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
|
| 260 |
+
output back to fp16/bf16.
|
| 261 |
+
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
|
| 262 |
+
without changing the math. This is to estimate the numerical error from operation
|
| 263 |
+
reordering.
|
| 264 |
+
Output:
|
| 265 |
+
output: (batch_size, seqlen_q, nheads, head_dim_v)
|
| 266 |
+
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
|
| 267 |
+
"""
|
| 268 |
+
if causal:
|
| 269 |
+
window_size = (window_size[0], 0)
|
| 270 |
+
dtype_og = q.dtype
|
| 271 |
+
if upcast:
|
| 272 |
+
q, k, v = q.float(), k.float(), v.float()
|
| 273 |
+
qv = qv.float() if qv is not None else None
|
| 274 |
+
if q_descale is not None:
|
| 275 |
+
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
|
| 276 |
+
q = (q.float() * q_descale).to(q.dtype)
|
| 277 |
+
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
|
| 278 |
+
if k_descale is not None:
|
| 279 |
+
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
|
| 280 |
+
if v_descale is not None:
|
| 281 |
+
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
|
| 282 |
+
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
| 283 |
+
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
| 284 |
+
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
| 285 |
+
d = q.shape[-1]
|
| 286 |
+
dv = v.shape[-1]
|
| 287 |
+
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
|
| 288 |
+
if not reorder_ops:
|
| 289 |
+
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
| 290 |
+
else:
|
| 291 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 292 |
+
if qv is not None:
|
| 293 |
+
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
|
| 294 |
+
if softcap > 0:
|
| 295 |
+
scores = torch.tanh(scores / softcap) * softcap
|
| 296 |
+
if key_padding_mask is not None:
|
| 297 |
+
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
| 298 |
+
local_mask = None
|
| 299 |
+
if window_size[0] >= 0 or window_size[1] >= 0:
|
| 300 |
+
local_mask = construct_local_mask(
|
| 301 |
+
seqlen_q,
|
| 302 |
+
seqlen_k,
|
| 303 |
+
window_size,
|
| 304 |
+
sink_token_length,
|
| 305 |
+
query_padding_mask,
|
| 306 |
+
key_padding_mask,
|
| 307 |
+
key_leftpad=key_leftpad,
|
| 308 |
+
device=q.device,
|
| 309 |
+
)
|
| 310 |
+
if attention_chunk > 0:
|
| 311 |
+
chunk_mask = construct_chunk_mask(
|
| 312 |
+
seqlen_q,
|
| 313 |
+
seqlen_k,
|
| 314 |
+
attention_chunk,
|
| 315 |
+
query_padding_mask,
|
| 316 |
+
key_padding_mask,
|
| 317 |
+
key_leftpad=key_leftpad,
|
| 318 |
+
device=q.device,
|
| 319 |
+
)
|
| 320 |
+
local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
| 321 |
+
if local_mask is not None:
|
| 322 |
+
scores.masked_fill_(local_mask, float("-inf"))
|
| 323 |
+
if attn_bias is not None:
|
| 324 |
+
scores = scores + attn_bias
|
| 325 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 326 |
+
# We want to mask here so that the attention matrix doesn't have any NaNs
|
| 327 |
+
# Otherwise we'll get NaN in dV
|
| 328 |
+
if query_padding_mask is not None:
|
| 329 |
+
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
| 330 |
+
# Without this we might get NaN in dv
|
| 331 |
+
if key_padding_mask is not None:
|
| 332 |
+
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
| 333 |
+
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
| 334 |
+
if local_mask is not None:
|
| 335 |
+
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
|
| 336 |
+
dropout_scaling = 1.0 / (1 - dropout_p)
|
| 337 |
+
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
| 338 |
+
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
| 339 |
+
if dropout_mask is not None:
|
| 340 |
+
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
| 341 |
+
else:
|
| 342 |
+
attention_drop = attention
|
| 343 |
+
if intermediate_dtype is not None:
|
| 344 |
+
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
|
| 345 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
| 346 |
+
if query_padding_mask is not None:
|
| 347 |
+
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
| 348 |
+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|