|
#include "common.cuh" |
|
#include "cp-async.cuh" |
|
#include "mma.cuh" |
|
#include "fattn-common.cuh" |
|
|
|
using namespace ggml_cuda_mma; |
|
|
|
typedef tile<16, 8, half2> tile_A; |
|
typedef tile< 8, 8, half2> tile_B; |
|
typedef tile<16, 8, half2> tile_B_16; |
|
typedef tile<16, 8, float> tile_C_KQ; |
|
typedef tile<16, 16, float> tile_C_KQ_16; |
|
typedef tile<16, 4, half2> tile_C_VKQ; |
|
typedef tile<16, 8, half2> tile_C_VKQ_16; |
|
|
|
template<int D, int nwarps, int KQ_per_iter> |
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( |
|
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { |
|
constexpr int D2_padded = D/2 + 4; |
|
|
|
|
|
#ifdef CP_ASYNC_AVAILABLE |
|
static_assert(D >= 64 && D < 512, "bad D"); |
|
constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); |
|
|
|
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); |
|
|
|
constexpr int preload = 64; |
|
constexpr int h2_per_chunk = 16/sizeof(half2); |
|
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; |
|
constexpr int stride_i = WARP_SIZE / chunks_per_row; |
|
#pragma unroll |
|
for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { |
|
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); |
|
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; |
|
|
|
cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); |
|
} |
|
#else |
|
constexpr int k0_sync_start = 0; |
|
#endif |
|
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); |
|
|
|
|
|
|
|
static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); |
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_i = WARP_SIZE / stride_k; |
|
|
|
if (k0_start == k0_stop || k0_stop <= k0_sync_start) { |
|
continue; |
|
} |
|
|
|
#pragma unroll |
|
for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { |
|
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
|
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template<int ncols1, int nwarps, int KQ_per_iter> |
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( |
|
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { |
|
static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); |
|
#ifdef CP_ASYNC_AVAILABLE |
|
constexpr int preload = KQ_per_iter * sizeof(half); |
|
constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; |
|
constexpr int stride_j = nwarps * cols_per_warp; |
|
|
|
const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); |
|
|
|
#pragma unroll |
|
for (int j0 = 0; j0 < ncols1; j0 += stride_j) { |
|
const int j = j0 + threadIdx.y*cols_per_warp + |
|
(KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); |
|
|
|
if (j0 + stride_j > ncols1 && j >= ncols1) { |
|
break; |
|
} |
|
|
|
const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); |
|
|
|
cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); |
|
} |
|
#else |
|
constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; |
|
constexpr int stride_j = nwarps * cols_per_warp; |
|
#pragma unroll |
|
for (int j0 = 0; j0 < ncols1; j0 += stride_j) { |
|
const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); |
|
|
|
if (j0 + stride_j > ncols1 && j >= ncols1) { |
|
break; |
|
} |
|
|
|
const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); |
|
|
|
tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; |
|
} |
|
#endif |
|
} |
|
|
|
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter> |
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter( |
|
const float2 * const __restrict__ Q_f2, |
|
const half2 * const __restrict__ K_h2, |
|
const half2 * const __restrict__ V_h2, |
|
const half2 * const __restrict__ mask_h2, |
|
float2 * const __restrict__ dstk, |
|
float2 * const __restrict__ dstk_fixup, |
|
const float scale, |
|
const float slope, |
|
const float logit_softcap, |
|
const int ne01, |
|
const int ne02, |
|
const int stride_KV, |
|
const int stride_mask, |
|
const int jt, |
|
half2 * const __restrict__ tile_K, |
|
half2 * const __restrict__ tile_V, |
|
half2 * const __restrict__ tile_mask, |
|
const tile_B * const __restrict__ Q_B, |
|
tile_C_VKQ * const __restrict__ VKQ_C, |
|
float * const __restrict__ KQ_max, |
|
float * const __restrict__ KQ_rowsum, |
|
const int kb0) { |
|
#ifdef NEW_MMA_AVAILABLE |
|
constexpr int cols_per_warp = ntiles * tile_B::I; |
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; |
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; |
|
constexpr int D2_padded = D/2 + 4; |
|
|
|
const int k_VKQ_0 = kb0 * KQ_per_iter; |
|
tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; |
|
|
|
|
|
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; |
|
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; |
|
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; |
|
|
|
#ifdef CP_ASYNC_AVAILABLE |
|
cp_async_wait_all(); |
|
__syncthreads(); |
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); |
|
#else |
|
if (ncols2 > 1 || mask_h2) { |
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); |
|
} |
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); |
|
__syncthreads(); |
|
#endif |
|
|
|
|
|
#pragma unroll |
|
for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { |
|
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; |
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { |
|
tile_A K_A; |
|
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); |
|
if (ntiles == 1) { |
|
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); |
|
} else { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
|
|
mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); |
|
} |
|
} |
|
} |
|
} |
|
|
|
#ifndef CP_ASYNC_AVAILABLE |
|
__syncthreads(); |
|
#endif |
|
|
|
if (use_logit_softcap) { |
|
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ::ne; ++l) { |
|
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); |
|
} |
|
} |
|
} |
|
|
|
float KQ_max_new[cols_per_thread]; |
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
KQ_max_new[col] = KQ_max[col]; |
|
} |
|
float KQ_rowsum_add[cols_per_thread] = {0.0f}; |
|
|
|
if (ntiles == 1) { |
|
if (ncols2 > 1 || mask_h2) { |
|
#pragma unroll |
|
for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { |
|
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ::ne; ++l) { |
|
const int i = i0 + tile_C_KQ::get_i(l); |
|
const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; |
|
|
|
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * |
|
__half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ::ne; ++l) { |
|
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
#pragma unroll |
|
for (int offset = 16; offset >= 4; offset >>= 1) { |
|
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); |
|
} |
|
} |
|
|
|
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); |
|
|
|
#pragma unroll |
|
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ::ne; ++l) { |
|
KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); |
|
|
|
KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; |
|
} |
|
} |
|
} else { |
|
if (ncols2 > 1 || mask_h2) { |
|
#pragma unroll |
|
for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { |
|
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
#pragma unroll |
|
for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { |
|
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; |
|
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; |
|
|
|
const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); |
|
const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; |
|
KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; |
|
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ_16::ne; ++l) { |
|
const int KQ_index = 2*t + (l/2) % 2; |
|
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); |
|
} |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
#pragma unroll |
|
for (int offset = 2; offset >= 1; offset >>= 1) { |
|
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); |
|
} |
|
} |
|
|
|
static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_KQ_16::ne; ++l) { |
|
const int KQ_index = 2*t + (l/2) % 2; |
|
|
|
KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); |
|
|
|
KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
{ |
|
float KQ_max_scale[cols_per_thread]; |
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); |
|
KQ_max[col] = KQ_max_new[col]; |
|
|
|
|
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; |
|
} |
|
|
|
if (ntiles == 1) { |
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); |
|
#pragma unroll |
|
for (int i = 0; i < D/tile_C_VKQ::I; ++i) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_VKQ::ne; ++l) { |
|
VKQ_C[i].x[l] *= KQ_max_scale_h2; |
|
} |
|
} |
|
} else { |
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); |
|
#pragma unroll |
|
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { |
|
#pragma unroll |
|
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { |
|
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; |
|
tile_B_16 * B_16 = (tile_B_16 *) B; |
|
static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); |
|
if (ntiles == 1) { |
|
#pragma unroll |
|
for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { |
|
B[k] = get_transposed(get_half2(KQ_C[k])); |
|
} |
|
} else { |
|
for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); |
|
} |
|
} |
|
} |
|
|
|
#ifdef CP_ASYNC_AVAILABLE |
|
|
|
cp_async_wait_all(); |
|
__syncthreads(); |
|
if (!last_iter) { |
|
if (ncols2 > 1 || mask_h2) { |
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); |
|
} |
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); |
|
} |
|
#else |
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); |
|
__syncthreads(); |
|
#endif |
|
|
|
|
|
#pragma unroll |
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { |
|
static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); |
|
#pragma unroll |
|
for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { |
|
const int k0 = k00 + (threadIdx.y % np)*tile_A::J; |
|
|
|
tile_A A; |
|
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); |
|
if (ntiles == 1) { |
|
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); |
|
} else { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
|
|
mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); |
|
} |
|
} |
|
} |
|
} |
|
|
|
#ifndef CP_ASYNC_AVAILABLE |
|
__syncthreads(); |
|
#endif |
|
|
|
#else |
|
NO_DEVICE_CODE; |
|
#endif |
|
} |
|
|
|
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup> |
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( |
|
const float2 * const __restrict__ Q_f2, |
|
const half2 * const __restrict__ K_h2, |
|
const half2 * const __restrict__ V_h2, |
|
const half2 * const __restrict__ mask_h2, |
|
float2 * const __restrict__ dstk, |
|
float2 * const __restrict__ dstk_fixup, |
|
const float scale, |
|
const float slope, |
|
const float logit_softcap, |
|
const int ne01, |
|
const int ne02, |
|
const int stride_Q1, |
|
const int stride_Q2, |
|
const int stride_KV, |
|
const int stride_mask, |
|
const int jt, |
|
const int kb0_start, |
|
const int kb0_stop) { |
|
#ifdef NEW_MMA_AVAILABLE |
|
|
|
|
|
constexpr int ncols = ncols1 * ncols2; |
|
constexpr int cols_per_warp = ntiles * tile_B::I; |
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; |
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; |
|
|
|
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); |
|
|
|
static_assert(D % nwarps == 0, "bad D"); |
|
static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); |
|
|
|
constexpr int D2_padded = D/2 + 4; |
|
|
|
|
|
extern __shared__ half2 tile_K[]; |
|
#ifdef CP_ASYNC_AVAILABLE |
|
half2 * tile_V = tile_K + KQ_per_iter*D2_padded; |
|
#else |
|
half2 * tile_V = tile_K; |
|
#endif |
|
half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; |
|
|
|
tile_B Q_B[D/(2*tile_B::J) * ntiles]; |
|
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; |
|
|
|
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; |
|
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; |
|
|
|
float KQ_rowsum[cols_per_thread] = {0.0f}; |
|
float KQ_max[cols_per_thread]; |
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
KQ_max[col] = -FLT_MAX/2.0f; |
|
} |
|
|
|
|
|
|
|
const half2 scale_h2 = make_half2(scale, scale); |
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_jc = WARP_SIZE / stride_k; |
|
|
|
if (k0_start == k0_stop) { |
|
continue; |
|
} |
|
|
|
#pragma unroll |
|
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { |
|
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
|
|
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { |
|
break; |
|
} |
|
|
|
const int j = jc / ncols2; |
|
const int c = jc % ncols2; |
|
|
|
if (jt*ncols1 + j < ne01) { |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; |
|
tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); |
|
} |
|
} else { |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); |
|
} |
|
} |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
{ |
|
const int j0 = (threadIdx.y / np) * cols_per_warp; |
|
|
|
#pragma unroll |
|
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { |
|
if (ntiles == 1) { |
|
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); |
|
} else { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], |
|
tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); |
|
} |
|
} |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
#ifdef CP_ASYNC_AVAILABLE |
|
if (ncols2 > 1 || mask_h2) { |
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); |
|
} |
|
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); |
|
#endif |
|
|
|
|
|
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { |
|
constexpr bool last_iter = false; |
|
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> |
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |
|
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); |
|
} |
|
{ |
|
constexpr bool last_iter = true; |
|
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> |
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |
|
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); |
|
} |
|
|
|
|
|
|
|
#ifdef CP_ASYNC_AVAILABLE |
|
if (nwarps*cols_per_warp > KQ_per_iter) { |
|
__syncthreads(); |
|
} |
|
#endif |
|
|
|
|
|
|
|
{ |
|
constexpr int offset_first = ntiles == 1 ? 16 : 2; |
|
constexpr int offset_last = ntiles == 1 ? 4 : 1; |
|
#pragma unroll |
|
for (int col = 0; col < cols_per_thread; ++col) { |
|
#pragma unroll |
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) { |
|
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
if (ntiles == 1) { |
|
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); |
|
#pragma unroll |
|
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { |
|
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); |
|
|
|
#pragma unroll |
|
for (int l = 0; l < tile_B::ne; ++l) { |
|
const int k = k0 + tile_B::get_j(l); |
|
|
|
tile_K[jc_cwd*D2_padded + k] = B.x[l]; |
|
} |
|
} |
|
} else { |
|
#pragma unroll |
|
for (int t = 0; t < ntiles/2; ++t) { |
|
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; |
|
#pragma unroll |
|
for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { |
|
#pragma unroll |
|
for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { |
|
const int j = j0 + tile_C_VKQ_16::get_i(l); |
|
const int k = k0 + tile_C_VKQ_16::get_j(l); |
|
|
|
tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
if constexpr (ntiles == 1) { |
|
const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; |
|
const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; |
|
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); |
|
|
|
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { |
|
|
|
((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
if (np == 1) { |
|
|
|
if (needs_fixup && threadIdx.x < tile_B::I) { |
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
|
dstk_fixup_meta[jc_cwm] = KQ_cmr; |
|
} |
|
if (is_fixup && threadIdx.x < tile_B::I) { |
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
|
dstk_fixup_meta[jc_cwm] = KQ_cmr; |
|
} |
|
} |
|
} else { |
|
static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); |
|
const int jc_cwm = threadIdx.y*cols_per_warp |
|
+ (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) |
|
+ tile_C_VKQ_16::get_i(threadIdx.x % 4); |
|
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); |
|
|
|
if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { |
|
|
|
((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
if (np == 1) { |
|
|
|
if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { |
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
|
dstk_fixup_meta[jc_cwm] = KQ_cmr; |
|
} |
|
if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { |
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
|
dstk_fixup_meta[jc_cwm] = KQ_cmr; |
|
} |
|
} |
|
} |
|
|
|
static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); |
|
if (np > 1 && threadIdx.y % np == 0) { |
|
|
|
|
|
|
|
|
|
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; |
|
|
|
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); |
|
float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; |
|
float2 meta[nmeta]; |
|
#pragma unroll |
|
for (int imeta = 0; imeta < nmeta; ++imeta) { |
|
meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; |
|
} |
|
|
|
float KQ_cmn = meta[0].x; |
|
#pragma unroll |
|
for (int imeta = 1; imeta < nmeta; ++imeta) { |
|
KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); |
|
} |
|
#pragma unroll |
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { |
|
if (offset >= WARP_SIZE) { |
|
continue; |
|
} |
|
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); |
|
} |
|
|
|
float KQ_cms[nmeta]; |
|
#pragma unroll |
|
for (int imeta = 0; imeta < nmeta; ++imeta) { |
|
KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); |
|
} |
|
|
|
float KQ_crs = KQ_cms[0]*meta[0].y; |
|
#pragma unroll |
|
for (int imeta = 1; imeta < nmeta; ++imeta) { |
|
KQ_crs += KQ_cms[imeta]*meta[imeta].y; |
|
} |
|
#pragma unroll |
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { |
|
if (offset >= WARP_SIZE) { |
|
continue; |
|
} |
|
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int imeta = 0; imeta < nmeta; ++imeta) { |
|
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { |
|
|
|
meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); |
|
} |
|
} |
|
|
|
|
|
static_assert(cols_per_warp <= WARP_SIZE); |
|
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { |
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; |
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
|
} |
|
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { |
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; |
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); |
|
} |
|
} |
|
|
|
if (np > 1) { |
|
__syncthreads(); |
|
} |
|
|
|
if (np == 1 || threadIdx.y % np == 0) { |
|
|
|
|
|
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); |
|
|
|
#pragma unroll |
|
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { |
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); |
|
const int k0_stop = D/2 - (D/2) % (1*stride_k); |
|
const int stride_jc = WARP_SIZE / stride_k; |
|
|
|
if (k0_start == k0_stop) { |
|
continue; |
|
} |
|
|
|
#pragma unroll |
|
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { |
|
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); |
|
|
|
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { |
|
break; |
|
} |
|
|
|
const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; |
|
|
|
const int j_dst = jc_dst / ncols2; |
|
const int c_dst = jc_dst % ncols2; |
|
|
|
if (!is_fixup && jt*ncols1 + j_dst >= ne01) { |
|
continue; |
|
} |
|
|
|
const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; |
|
#pragma unroll |
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { |
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); |
|
|
|
float2 dstk_val = make_float2(0.0f, 0.0f); |
|
#pragma unroll |
|
for (int ip = 0; ip < np; ++ip) { |
|
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; |
|
const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); |
|
dstk_val.x += dstk_val_add.x*KQ_crs; |
|
dstk_val.y += dstk_val_add.y*KQ_crs; |
|
} |
|
|
|
if (!needs_fixup && !is_fixup) { |
|
const float KQ_rowsum_j = meta_j[1]; |
|
dstk_val.x /= KQ_rowsum_j; |
|
dstk_val.y /= KQ_rowsum_j; |
|
} |
|
|
|
if (is_fixup) { |
|
dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; |
|
} else { |
|
dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (np > 1) { |
|
__syncthreads(); |
|
} |
|
#else |
|
NO_DEVICE_CODE; |
|
#endif |
|
} |
|
|
|
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap> |
|
__launch_bounds__(nwarps*WARP_SIZE, 2) |
|
static __global__ void flash_attn_ext_f16( |
|
const char * __restrict__ Q, |
|
const char * __restrict__ K, |
|
const char * __restrict__ V, |
|
const char * __restrict__ mask, |
|
float * __restrict__ dst, |
|
float2 * __restrict__ dst_meta, |
|
const float scale, |
|
const float max_bias, |
|
const float m0, |
|
const float m1, |
|
const uint32_t n_head_log2, |
|
const float logit_softcap, |
|
const int ne00, |
|
const int ne01, |
|
const int ne02, |
|
const int ne03, |
|
const int ne10, |
|
const int ne11, |
|
const int ne12, |
|
const int ne13, |
|
const int ne31, |
|
const int nb31, |
|
const int nb01, |
|
const int nb02, |
|
const int nb03, |
|
const int nb11, |
|
const int nb12, |
|
const int nb13, |
|
const int nb21, |
|
const int nb22, |
|
const int nb23, |
|
const int ne0, |
|
const int ne1, |
|
const int ne2, |
|
const int ne3) { |
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) |
|
|
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) { |
|
NO_DEVICE_CODE; |
|
return; |
|
} |
|
|
|
static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); |
|
|
|
const int gqa_ratio = ne02 / ne12; |
|
|
|
const int stride_Q1 = nb01 / sizeof(float2); |
|
const int stride_Q2 = nb02 / sizeof(float2); |
|
const int stride_KV = nb11 / sizeof(half2); |
|
const int stride_mask = nb31 / sizeof(half2); |
|
|
|
const int iter_k = ne11 / FATTN_KQ_STRIDE; |
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; |
|
|
|
constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; |
|
|
|
|
|
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; |
|
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; |
|
|
|
|
|
|
|
|
|
|
|
|
|
int kb0_start = kbc % iter_k; |
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); |
|
while (kbc < kbc_stop && kb0_stop == iter_k) { |
|
const int channel = kbc / (iter_k*iter_j); |
|
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; |
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); |
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); |
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); |
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; |
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); |
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; |
|
|
|
const int kb0_start_kernel = kb0_start * kb_niter; |
|
const int kb0_stop_kernel = kb0_stop * kb_niter; |
|
|
|
constexpr bool is_fixup = false; |
|
if (kb0_start == 0) { |
|
constexpr bool needs_fixup = false; |
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
|
} else { |
|
constexpr bool needs_fixup = true; |
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
|
} |
|
|
|
kbc += iter_k; |
|
kbc -= kbc % iter_k; |
|
|
|
kb0_start = 0; |
|
kb0_stop = min(iter_k, kbc_stop - kbc); |
|
} |
|
|
|
if (kbc >= kbc_stop) { |
|
return; |
|
} |
|
|
|
const int channel = kbc / (iter_k*iter_j); |
|
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; |
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); |
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); |
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); |
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; |
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); |
|
|
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; |
|
|
|
const int kb0_start_kernel = kb0_start * kb_niter; |
|
const int kb0_stop_kernel = kb0_stop * kb_niter; |
|
|
|
constexpr bool is_fixup = true; |
|
constexpr bool needs_fixup = false; |
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> |
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, |
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |
|
#else |
|
NO_DEVICE_CODE; |
|
#endif |
|
} |
|
|
|
template <int D, int ncols1, int ncols2> |
|
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
constexpr int ncols = ncols1 * ncols2; |
|
constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; |
|
constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; |
|
constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); |
|
constexpr int cols_per_warp = ntiles * tile_B::I; |
|
|
|
static_assert(D % tile_B::J == 0, "bad D"); |
|
static_assert(ncols % cols_per_warp == 0, "bad ncols"); |
|
|
|
const ggml_tensor * KQV = dst; |
|
const int id = ggml_cuda_get_device(); |
|
const int cc = ggml_cuda_info().devices[id].cc; |
|
|
|
const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; |
|
|
|
const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); |
|
const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); |
|
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); |
|
|
|
const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); |
|
|
|
float logit_softcap; |
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |
|
|
|
fattn_kernel_t fattn_kernel; |
|
if (logit_softcap == 0.0f) { |
|
constexpr bool use_logit_softcap = false; |
|
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>; |
|
} else { |
|
constexpr bool use_logit_softcap = true; |
|
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>; |
|
} |
|
|
|
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true); |
|
} |
|
|
|
|
|
#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ |
|
template void ggml_cuda_flash_attn_ext_mma_f16_case \ |
|
<D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ |
|
|
|
#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ |
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ |
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ |
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ |
|
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ |
|
|
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8); |
|
|
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16); |
|
|
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32); |
|
|
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64); |
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|