|
#include "common.cuh" |
|
#include "mma.cuh" |
|
#include "fattn-common.cuh" |
|
|
|
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup> |
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( |
|
const float2 * const __restrict__ Q_f2, |
|
const half2 * const __restrict__ K_h2, |
|
const half2 * const __restrict__ V_h2, |
|
const half * const __restrict__ maskh, |
|
float2 * const __restrict__ dstk, |
|
float2 * const __restrict__ dstk_fixup, |
|
const float scale, |
|
const float slope, |
|
const float logit_softcap, |
|
const int ne00, |
|
const int ne01, |
|
const int ne02, |
|
const int ne03, |
|
const int ne10, |
|
const int ne11, |
|
const int ne12, |
|
const int ne13, |
|
const int ne31, |
|
const int nb31, |
|
const int nb01, |
|
const int nb02, |
|
const int nb03, |
|
const int nb11, |
|
const int nb12, |
|
const int nb13, |
|
const int nb21, |
|
const int nb22, |
|
const int nb23, |
|
const int ne0, |
|
const int ne1, |
|
const int ne2, |
|
const int ne3, |
|
const int jt, |
|
const int kb0_start, |
|
const int kb0_stop) { |
|
#ifdef NEW_MMA_AVAILABLE |
|
|
|
|
|
typedef mma_A_I16K8<half2> mma_A; |
|
typedef mma_B_J8K8<half2> mma_B; |
|
typedef mma_C_I16J8<float> mma_C_KQ; |
|
typedef mma_C_I16J8<half2> mma_C_VKQ; |
|
|
|
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); |
|
constexpr int np = nwarps*mma_B::J / ncols; |
|
|
|
static_assert(D % nwarps == 0, "bad D"); |
|
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); |
|
|
|
constexpr int D2_padded = D/2 + 4; |
|
extern __shared__ half2 tile_KV[]; |
|
|
|
const int stride_Q = nb01 / sizeof(float2); |
|
const int stride_KV = nb11 / sizeof(half2); |
|
const int stride_mask = nb31 / sizeof(half); |
|
|
|
mma_B Q_B[D/(2*mma_B::K)]; |
|
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; |
|
|
|
float2 KQ_rowsum = {0.0f, 0.0f}; |
|
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; |
|
float2 KQ_max_scale = {0.0f, 0.0f}; |
|
|
|
|
|
|
|
const half2 scale_h2 = make_half2(scale, scale); |
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_j = WARP_SIZE / stride_k; |
|
|
|
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { |
|
break; |
|
} |
|
|
|
#pragma unroll |
|
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { |
|
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
|
|
if (jt*ncols + j < ne01) { |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; |
|
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); |
|
} |
|
} else { |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); |
|
} |
|
} |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
{ |
|
const int j0 = (threadIdx.y / np) * mma_B::J; |
|
|
|
#pragma unroll |
|
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { |
|
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { |
|
const int k_VKQ_0 = kb0*KQ_stride; |
|
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; |
|
|
|
|
|
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); |
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_i = WARP_SIZE / stride_k; |
|
|
|
#pragma unroll |
|
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { |
|
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { |
|
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; |
|
} |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
#pragma unroll |
|
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { |
|
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; |
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { |
|
mma_A K_A; |
|
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); |
|
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
if (use_logit_softcap) { |
|
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { |
|
#pragma unroll |
|
for (int l = 0; l < mma_C_KQ::ne; ++l) { |
|
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); |
|
} |
|
} |
|
} |
|
|
|
if (maskh) { |
|
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); |
|
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { |
|
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; |
|
#pragma unroll |
|
for (int l = 0; l < mma_C_KQ::ne; ++l) { |
|
const int i = i0 + mma_C_KQ::get_i(l); |
|
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); |
|
|
|
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
float2 KQ_max_new = KQ_max; |
|
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { |
|
#pragma unroll |
|
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { |
|
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); |
|
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int offset = 16; offset > 2; offset >>= 1) { |
|
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); |
|
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); |
|
} |
|
|
|
{ |
|
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); |
|
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); |
|
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { |
|
KQ_max_scale.x = 0.0f; |
|
} |
|
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { |
|
KQ_max_scale.y = 0.0f; |
|
} |
|
KQ_max = KQ_max_new; |
|
} |
|
|
|
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); |
|
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { |
|
#pragma unroll |
|
for (int l = 0; l < mma_C_KQ::ne; ++l) { |
|
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; |
|
const float diff = KQ_C[k].x[l] - KQ_max_l; |
|
KQ_C[k].x[l] = expf(diff); |
|
if (diff <= SOFTMAX_FTZ_THRESHOLD) { |
|
KQ_C[k].x[l] = 0.0f; |
|
} |
|
|
|
if (l % 2 == 0) { |
|
KQ_rowsum_add.x += KQ_C[k].x[l]; |
|
} else { |
|
KQ_rowsum_add.y += KQ_C[k].x[l]; |
|
} |
|
} |
|
} |
|
|
|
|
|
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; |
|
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; |
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); |
|
#pragma unroll |
|
for (int i = 0; i < D/mma_C_VKQ::I; ++i) { |
|
#pragma unroll |
|
for (int l = 0; l < mma_C_VKQ::ne; ++l) { |
|
VKQ_C[i].x[l] *= KQ_max_scale_h2; |
|
} |
|
} |
|
|
|
|
|
mma_B B[KQ_stride/(np*2*mma_B::K)]; |
|
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { |
|
B[k] = KQ_C[k].to_mma_B(); |
|
} |
|
|
|
|
|
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); |
|
#pragma unroll |
|
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); |
|
const int i0_stop = D/2 - (D/2) % (1*stride_i); |
|
const int stride_k = WARP_SIZE / stride_i; |
|
|
|
#pragma unroll |
|
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { |
|
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); |
|
|
|
#pragma unroll |
|
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { |
|
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); |
|
|
|
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; |
|
} |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
#pragma unroll |
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { |
|
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { |
|
const int k0 = k00 + (threadIdx.y % np)*mma_A::K; |
|
|
|
mma_A A; |
|
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); |
|
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
for (int offset = 16; offset > 2; offset >>= 1) { |
|
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); |
|
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); |
|
} |
|
|
|
|
|
|
|
|
|
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); |
|
#pragma unroll |
|
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { |
|
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); |
|
|
|
#pragma unroll |
|
for (int l = 0; l < mma_B::ne; ++l) { |
|
const int k = k0 + mma_B::get_k(l); |
|
|
|
tile_KV[j_cwd*D2_padded + k] = B.x[l]; |
|
} |
|
} |
|
|
|
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; |
|
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; |
|
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); |
|
|
|
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { |
|
|
|
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
static_assert(np == 1 || np == 2 || np == 4, "bad np"); |
|
if (np == 1) { |
|
|
|
if (needs_fixup && threadIdx.x < mma_B::J) { |
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
|
dstk_fixup_meta[j_cwm] = KQ_cmr; |
|
} |
|
if (is_fixup && threadIdx.x < mma_B::J) { |
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
|
dstk_fixup_meta[j_cwm] = KQ_cmr; |
|
} |
|
} else if (threadIdx.y % np == 0) { |
|
|
|
|
|
|
|
|
|
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; |
|
|
|
float KQ_cm = -FLT_MAX/2; |
|
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { |
|
KQ_cm = meta_j[0]; |
|
} |
|
|
|
float KQ_cmn = KQ_cm; |
|
#pragma unroll |
|
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { |
|
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); |
|
} |
|
|
|
const float KQ_cms = expf(KQ_cm - KQ_cmn); |
|
float KQ_crs = 0.0f; |
|
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { |
|
KQ_crs = KQ_cms*meta_j[1]; |
|
} |
|
#pragma unroll |
|
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { |
|
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); |
|
} |
|
|
|
|
|
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { |
|
meta_j[0] = KQ_cmn; |
|
meta_j[1] = KQ_crs; |
|
meta_j[2] = KQ_cms; |
|
} |
|
if (needs_fixup && threadIdx.x < mma_B::J) { |
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
|
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
|
} |
|
if (is_fixup && threadIdx.x < mma_B::J) { |
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
|
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
|
} |
|
} |
|
|
|
if (np > 1) { |
|
__syncthreads(); |
|
} |
|
|
|
if (np == 1 || threadIdx.y % np == 0) { |
|
|
|
|
|
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); |
|
|
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_j = WARP_SIZE / stride_k; |
|
|
|
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { |
|
break; |
|
} |
|
|
|
#pragma unroll |
|
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { |
|
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; |
|
|
|
if (!is_fixup && jt*ncols + j_dst >= ne01) { |
|
continue; |
|
} |
|
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
float2 dstk_val = make_float2(0.0f, 0.0f); |
|
#pragma unroll |
|
for (int ip = 0; ip < np; ++ip) { |
|
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; |
|
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); |
|
dstk_val.x += dstk_val_add.x*KQ_crs; |
|
dstk_val.y += dstk_val_add.y*KQ_crs; |
|
} |
|
|
|
if (!needs_fixup && !is_fixup) { |
|
const float KQ_rowsum_j = meta_j[1]; |
|
dstk_val.x /= KQ_rowsum_j; |
|
dstk_val.y /= KQ_rowsum_j; |
|
} |
|
|
|
if (is_fixup) { |
|
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; |
|
} else { |
|
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (np > 1) { |
|
__syncthreads(); |
|
} |
|
#else |
|
NO_DEVICE_CODE; |
|
#endif |
|
} |
|
|
|
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap> |
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
|
__launch_bounds__(nwarps*WARP_SIZE, 2) |
|
#endif |
|
static __global__ void flash_attn_ext_f16( |
|
const char * __restrict__ Q, |
|
const char * __restrict__ K, |
|
const char * __restrict__ V, |
|
const char * __restrict__ mask, |
|
float * __restrict__ dst, |
|
float2 * __restrict__ dst_meta, |
|
const float scale, |
|
const float max_bias, |
|
const float m0, |
|
const float m1, |
|
const uint32_t n_head_log2, |
|
const float logit_softcap, |
|
const int ne00, |
|
const int ne01, |
|
const int ne02, |
|
const int ne03, |
|
const int ne10, |
|
const int ne11, |
|
const int ne12, |
|
const int ne13, |
|
const int ne31, |
|
const int nb31, |
|
const int nb01, |
|
const int nb02, |
|
const int nb03, |
|
const int nb11, |
|
const int nb12, |
|
const int nb13, |
|
const int nb21, |
|
const int nb22, |
|
const int nb23, |
|
const int ne0, |
|
const int ne1, |
|
const int ne2, |
|
const int ne3) { |
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) { |
|
NO_DEVICE_CODE; |
|
return; |
|
} |
|
|
|
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); |
|
|
|
const int gqa_ratio = ne02 / ne12; |
|
|
|
const int iter_k = ne11 / KQ_stride; |
|
const int iter_j = (ne01 + (ncols - 1)) / ncols; |
|
|
|
|
|
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; |
|
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; |
|
|
|
|
|
|
|
|
|
|
|
|
|
int kb0_start = kbc % iter_k; |
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); |
|
while (kbc < kbc_stop && kb0_stop == iter_k) { |
|
const int channel = kbc / (iter_k*iter_j); |
|
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; |
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); |
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); |
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); |
|
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; |
|
float2 * dstk = ((float2 *) dst) + channel*(D/2); |
|
|
|
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); |
|
|
|
constexpr bool is_fixup = false; |
|
if (kb0_start == 0) { |
|
constexpr bool needs_fixup = false; |
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, |
|
jt, kb0_start, kb0_stop); |
|
} else { |
|
constexpr bool needs_fixup = true; |
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, |
|
jt, kb0_start, kb0_stop); |
|
} |
|
|
|
kbc += iter_k; |
|
kbc -= kbc % iter_k; |
|
|
|
kb0_start = 0; |
|
kb0_stop = min(iter_k, kbc_stop - kbc); |
|
} |
|
|
|
if (kbc >= kbc_stop) { |
|
return; |
|
} |
|
|
|
const int channel = kbc / (iter_k*iter_j); |
|
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; |
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); |
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); |
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); |
|
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; |
|
float2 * dstk = ((float2 *) dst) + channel*(D/2); |
|
|
|
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); |
|
|
|
constexpr bool is_fixup = true; |
|
constexpr bool needs_fixup = false; |
|
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, |
|
jt, kb0_start, kb0_stop); |
|
} |
|
|
|
template <int D, int cols_per_block> |
|
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
typedef mma_A_I16K8<half2> mma_A; |
|
typedef mma_B_J8K8<half2> mma_B; |
|
|
|
static_assert(D % mma_B::K == 0, "bad D"); |
|
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); |
|
|
|
const ggml_tensor * KQV = dst; |
|
|
|
constexpr int KQ_stride = D <= 128 ? 64 : 32; |
|
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? |
|
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); |
|
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); |
|
|
|
float logit_softcap; |
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |
|
|
|
fattn_kernel_t fattn_kernel; |
|
if (logit_softcap == 0.0f) { |
|
constexpr bool use_logit_softcap = false; |
|
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; |
|
} else { |
|
constexpr bool use_logit_softcap = true; |
|
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; |
|
} |
|
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); |
|
} |
|
|
|
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ |
|
template void ggml_cuda_flash_attn_ext_mma_f16_case \ |
|
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ |
|
|
|
extern DECL_FATTN_MMA_F16_CASE( 64, 8); |
|
extern DECL_FATTN_MMA_F16_CASE( 80, 8); |
|
extern DECL_FATTN_MMA_F16_CASE( 96, 8); |
|
extern DECL_FATTN_MMA_F16_CASE(112, 8); |
|
extern DECL_FATTN_MMA_F16_CASE(128, 8); |
|
extern DECL_FATTN_MMA_F16_CASE(256, 8); |
|
|
|
extern DECL_FATTN_MMA_F16_CASE( 64, 16); |
|
extern DECL_FATTN_MMA_F16_CASE( 80, 16); |
|
extern DECL_FATTN_MMA_F16_CASE( 96, 16); |
|
extern DECL_FATTN_MMA_F16_CASE(112, 16); |
|
extern DECL_FATTN_MMA_F16_CASE(128, 16); |
|
extern DECL_FATTN_MMA_F16_CASE(256, 16); |
|
|
|
extern DECL_FATTN_MMA_F16_CASE( 64, 32); |
|
extern DECL_FATTN_MMA_F16_CASE( 80, 32); |
|
extern DECL_FATTN_MMA_F16_CASE( 96, 32); |
|
extern DECL_FATTN_MMA_F16_CASE(112, 32); |
|
extern DECL_FATTN_MMA_F16_CASE(128, 32); |
|
extern DECL_FATTN_MMA_F16_CASE(256, 32); |
|
|
|
extern DECL_FATTN_MMA_F16_CASE( 64, 64); |
|
extern DECL_FATTN_MMA_F16_CASE( 80, 64); |
|
extern DECL_FATTN_MMA_F16_CASE( 96, 64); |
|
extern DECL_FATTN_MMA_F16_CASE(112, 64); |
|
extern DECL_FATTN_MMA_F16_CASE(128, 64); |
|
extern DECL_FATTN_MMA_F16_CASE(256, 64); |
|
|