ghost-8b-beta-128k / self_extend_patch /selfextend_flash_attn_triton.py
lamhieu's picture
chore: initialize the app
7a58a7d
raw
history blame
8.38 kB
import math
import torch
import triton
import triton.language as tl
def self_extend_flash_forward_triton(
model_self,
query_position,
group_size_2,
neighbor_query_states,
neighbor_key_states,
group_query_states,
group_key_states,
value_states,
attention_mask,
bsz,
q_len,
kv_seq_len,
attn_dropout,
):
o = _self_extend_flash_forward_triton(q=neighbor_query_states,
k=neighbor_key_states,
q1=group_query_states,
k1=group_key_states,
v=value_states,
causal=(q_len == kv_seq_len),
sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]),
window=group_size_2)
o = o.transpose(1, 2).contiguous()
# print("o", o.shape)
return o
def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
device = torch.cuda.device_of(q)
with torch.cuda.device(device):
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 32
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q,
k,
q1,
k1,
v,
sm_scale,
L,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
k.shape[2],
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
WINDOW=window,
num_warps=8,
num_stages=2)
return o
@triton.heuristics(
{
"EVEN_M": lambda args: args["Q_CTX"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["KV_CTX"] % args["BLOCK_N"] == 0,
}
)
@triton.jit
def _fwd_kernel(
Q,
K,
Q1,
K1,
V,
sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z,
H,
Q_CTX,
KV_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
WINDOW: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# qvk_offset = off_hz * stride_qh
q_offset = off_hz * stride_qh
vk_offset = off_hz * stride_kh
# vk_offset = q_offset
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(Q_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + vk_offset,
shape=(KV_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
Q1_block_ptr = tl.make_block_ptr(
base=Q1 + q_offset,
shape=(Q_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K1_block_ptr = tl.make_block_ptr(
base=K1 + vk_offset,
shape=(KV_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
V_block_ptr = tl.make_block_ptr(
base=V + vk_offset,
shape=(KV_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.4426950408889634
# load q: it will stay in SRAM throughout
if EVEN_M:
q = tl.load(Q_block_ptr)
q1 = tl.load(Q1_block_ptr)
else:
q = tl.load(Q_block_ptr, boundary_check=(1,0))
q1 = tl.load(Q1_block_ptr, boundary_check=(1,0))
q = (q * qk_scale).to(tl.bfloat16)
q1 = (q1 * qk_scale).to(tl.bfloat16)
# Dot I trick: it converts q1, q2 into mma layout and saves shared memory
# better way to generate a eye matrix. avoid casting from bool
offs_k = tl.arange(0, BLOCK_DMODEL)
I = tl.where(offs_k[:, None] == offs_k,
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16),
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16))
q = tl.dot(q, I).to(tl.bfloat16)
q1 = tl.dot(q1, I).to(tl.bfloat16)
# loop over k, v and update accumulator
lo = 0
if IS_CAUSAL:
hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M)
else:
hi = KV_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
if EVEN_N:
k = tl.load(K_block_ptr)
k1 = tl.load(K1_block_ptr)
v = tl.load(V_block_ptr)
else:
k = tl.load(K_block_ptr, boundary_check=(1,0))
k1 = tl.load(K1_block_ptr, boundary_check=(1,0))
v = tl.load(V_block_ptr, boundary_check=(1,0))
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# Window masking
mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW)
qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k)))
# if not EVEN_N:
# mask = (start_n + offs_n) < KV_CTX
# qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = tl.where(mask, qk, float("-inf"))
# qk += tl.dot(q, k)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.bfloat16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0))
K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc * (1.0 / l_i[:, None])
l_ptrs = L + off_hz * Q_CTX + offs_m
mask_m = offs_m < Q_CTX
l_i = m_i + tl.math.log2(l_i)
if EVEN_M:
tl.store(l_ptrs, l_i)
else:
tl.store(l_ptrs, l_i, mask=mask_m)
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + q_offset,
shape=(Q_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
if EVEN_M:
tl.store(O_block_ptr, acc.to(tl.bfloat16))
else:
tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0))