|
import torch
|
|
from triton_flash_atn import _attention
|
|
|
|
|
|
batch_size = 2
|
|
num_heads = 4
|
|
seq_len = 128
|
|
head_dim = 64
|
|
|
|
|
|
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
|
dtype=torch.float16, device='cuda')
|
|
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
|
dtype=torch.float16, device='cuda')
|
|
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
|
|
dtype=torch.float16, device='cuda')
|
|
|
|
|
|
causal = False
|
|
sm_scale = 1.0 / (head_dim ** 0.5)
|
|
|
|
|
|
attention = _attention.apply
|
|
output = attention(q, k, v, causal, sm_scale)
|
|
|
|
print(output)
|
|
|