Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import triton | |
import triton.language as tl | |
def self_extend_flash_forward_triton( | |
model_self, | |
query_position, | |
group_size_2, | |
neighbor_query_states, | |
neighbor_key_states, | |
group_query_states, | |
group_key_states, | |
value_states, | |
attention_mask, | |
bsz, | |
q_len, | |
kv_seq_len, | |
attn_dropout, | |
): | |
o = _self_extend_flash_forward_triton(q=neighbor_query_states, | |
k=neighbor_key_states, | |
q1=group_query_states, | |
k1=group_key_states, | |
v=value_states, | |
causal=(q_len == kv_seq_len), | |
sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]), | |
window=group_size_2) | |
o = o.transpose(1, 2).contiguous() | |
# print("o", o.shape) | |
return o | |
def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window): | |
# shape constraints | |
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] | |
assert Lq == Lk and Lk == Lv | |
assert Lk in {16, 32, 64, 128} | |
device = torch.cuda.device_of(q) | |
with torch.cuda.device(device): | |
o = torch.empty_like(q) | |
BLOCK_M = 128 | |
BLOCK_N = 32 | |
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) | |
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) | |
_fwd_kernel[grid]( | |
q, | |
k, | |
q1, | |
k1, | |
v, | |
sm_scale, | |
L, | |
o, | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), | |
q.shape[0], | |
q.shape[1], | |
q.shape[2], | |
k.shape[2], | |
BLOCK_M=BLOCK_M, | |
BLOCK_N=BLOCK_N, | |
BLOCK_DMODEL=Lk, | |
IS_CAUSAL=causal, | |
WINDOW=window, | |
num_warps=8, | |
num_stages=2) | |
return o | |
def _fwd_kernel( | |
Q, | |
K, | |
Q1, | |
K1, | |
V, | |
sm_scale, | |
L, | |
Out, | |
stride_qz, stride_qh, stride_qm, stride_qk, | |
stride_kz, stride_kh, stride_kn, stride_kk, | |
stride_vz, stride_vh, stride_vn, stride_vk, | |
stride_oz, stride_oh, stride_om, stride_on, | |
Z, | |
H, | |
Q_CTX, | |
KV_CTX, | |
BLOCK_M: tl.constexpr, | |
BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
IS_CAUSAL: tl.constexpr, | |
WINDOW: tl.constexpr, | |
EVEN_M: tl.constexpr, | |
EVEN_N: tl.constexpr | |
): | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
# qvk_offset = off_hz * stride_qh | |
q_offset = off_hz * stride_qh | |
vk_offset = off_hz * stride_kh | |
# vk_offset = q_offset | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + q_offset, | |
shape=(Q_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + vk_offset, | |
shape=(KV_CTX, BLOCK_DMODEL), | |
strides=(stride_kn, stride_kk), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
Q1_block_ptr = tl.make_block_ptr( | |
base=Q1 + q_offset, | |
shape=(Q_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
K1_block_ptr = tl.make_block_ptr( | |
base=K1 + vk_offset, | |
shape=(KV_CTX, BLOCK_DMODEL), | |
strides=(stride_kn, stride_kk), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + vk_offset, | |
shape=(KV_CTX, BLOCK_DMODEL), | |
strides=(stride_vn, stride_vk), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
# scale sm_scale by log_2(e) and use | |
# 2^x instead of exp in the loop because CSE and LICM | |
# don't work as expected with `exp` in the loop | |
qk_scale = sm_scale * 1.4426950408889634 | |
# load q: it will stay in SRAM throughout | |
if EVEN_M: | |
q = tl.load(Q_block_ptr) | |
q1 = tl.load(Q1_block_ptr) | |
else: | |
q = tl.load(Q_block_ptr, boundary_check=(1,0)) | |
q1 = tl.load(Q1_block_ptr, boundary_check=(1,0)) | |
q = (q * qk_scale).to(tl.bfloat16) | |
q1 = (q1 * qk_scale).to(tl.bfloat16) | |
# Dot I trick: it converts q1, q2 into mma layout and saves shared memory | |
# better way to generate a eye matrix. avoid casting from bool | |
offs_k = tl.arange(0, BLOCK_DMODEL) | |
I = tl.where(offs_k[:, None] == offs_k, | |
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16), | |
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16)) | |
q = tl.dot(q, I).to(tl.bfloat16) | |
q1 = tl.dot(q1, I).to(tl.bfloat16) | |
# loop over k, v and update accumulator | |
lo = 0 | |
if IS_CAUSAL: | |
hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M) | |
else: | |
hi = KV_CTX | |
for start_n in range(lo, hi, BLOCK_N): | |
# -- load k, v -- | |
if EVEN_N: | |
k = tl.load(K_block_ptr) | |
k1 = tl.load(K1_block_ptr) | |
v = tl.load(V_block_ptr) | |
else: | |
k = tl.load(K_block_ptr, boundary_check=(1,0)) | |
k1 = tl.load(K1_block_ptr, boundary_check=(1,0)) | |
v = tl.load(V_block_ptr, boundary_check=(1,0)) | |
# -- compute qk --- | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
# Window masking | |
mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW) | |
qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k))) | |
# if not EVEN_N: | |
# mask = (start_n + offs_n) < KV_CTX | |
# qk = tl.where(mask, qk, float("-inf")) | |
if IS_CAUSAL: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = tl.where(mask, qk, float("-inf")) | |
# qk += tl.dot(q, k) | |
# -- compute scaling constant --- | |
m_i_new = tl.maximum(m_i, tl.max(qk, 1)) | |
alpha = tl.math.exp2(m_i - m_i_new) | |
p = tl.math.exp2(qk - m_i_new[:, None]) | |
# -- scale and update acc -- | |
acc_scale = l_i * 0 + alpha # workaround some compiler bug | |
acc *= acc_scale[:, None] | |
acc += tl.dot(p.to(tl.bfloat16), v) | |
# -- update m_i and l_i -- | |
l_i = l_i * alpha + tl.sum(p, 1) | |
m_i = m_i_new | |
# update pointers | |
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) | |
K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0)) | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
# write back l and m | |
acc = acc * (1.0 / l_i[:, None]) | |
l_ptrs = L + off_hz * Q_CTX + offs_m | |
mask_m = offs_m < Q_CTX | |
l_i = m_i + tl.math.log2(l_i) | |
if EVEN_M: | |
tl.store(l_ptrs, l_i) | |
else: | |
tl.store(l_ptrs, l_i, mask=mask_m) | |
# write back O | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + q_offset, | |
shape=(Q_CTX, BLOCK_DMODEL), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
if EVEN_M: | |
tl.store(O_block_ptr, acc.to(tl.bfloat16)) | |
else: | |
tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0)) | |