|
#pragma once |
|
|
|
#include "common.cuh" |
|
#include "convert.cuh" |
|
#include "vecdotq.cuh" |
|
|
|
#include <cstdint> |
|
|
|
#define FATTN_KQ_STRIDE 256 |
|
#define HALF_MAX_HALF __float2half(65504.0f/2) |
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f |
|
|
|
typedef void (* fattn_kernel_t)( |
|
const char * __restrict__ Q, |
|
const char * __restrict__ K, |
|
const char * __restrict__ V, |
|
const char * __restrict__ mask, |
|
float * __restrict__ dst, |
|
float2 * __restrict__ dst_meta, |
|
const float scale, |
|
const float max_bias, |
|
const float m0, |
|
const float m1, |
|
const uint32_t n_head_log2, |
|
const float logit_softcap, |
|
const int ne00, |
|
const int ne01, |
|
const int ne02, |
|
const int ne03, |
|
const int ne10, |
|
const int ne11, |
|
const int ne12, |
|
const int ne13, |
|
const int ne31, |
|
const int nb31, |
|
const int nb01, |
|
const int nb02, |
|
const int nb03, |
|
const int nb11, |
|
const int nb12, |
|
const int nb13, |
|
const int nb21, |
|
const int nb22, |
|
const int nb23, |
|
const int ne0, |
|
const int ne1, |
|
const int ne2, |
|
const int ne3); |
|
|
|
typedef half (*vec_dot_KQ_f16_t)( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); |
|
typedef float (*vec_dot_KQ_f32_t)( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); |
|
|
|
template<typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |
|
|
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; |
|
GGML_UNUSED(Q_v); |
|
|
|
T sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const int ib = k_KQ / QI8_1; |
|
const int iqs4 = k_KQ % QI4_0; |
|
const int shift = k_KQ & (QI8_1/2); |
|
|
|
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; |
|
const int u = Q_q8[k_KQ_0/WARP_SIZE]; |
|
|
|
const int sumi = ggml_cuda_dp4a(v, u, 0); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_ds = (const half2 *) Q_ds_v; |
|
|
|
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; |
|
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) ); |
|
} else |
|
#endif |
|
{ |
|
const float2 * Q_ds = (const float2 *) Q_ds_v; |
|
|
|
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); |
|
} |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template<typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |
|
|
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; |
|
GGML_UNUSED(Q_v); |
|
|
|
T sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const int ib = k_KQ / QI8_1; |
|
const int iqs4 = k_KQ % QI4_1; |
|
const int shift = k_KQ & (QI8_1/2); |
|
|
|
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; |
|
const int u = Q_q8[k_KQ_0/WARP_SIZE]; |
|
|
|
const int sumi = ggml_cuda_dp4a(v, u, 0); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_ds = (const half2 *) Q_ds_v; |
|
|
|
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; |
|
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); |
|
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); |
|
} else |
|
#endif |
|
{ |
|
const float2 * Q_ds = (const float2 *) Q_ds_v; |
|
|
|
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; |
|
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; |
|
|
|
sum += (T) (sumid4d8 + m4s8scaled); |
|
} |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template<typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |
|
|
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; |
|
GGML_UNUSED(Q_v); |
|
|
|
T sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const int ib = k_KQ / QI8_1; |
|
const int iqs4 = k_KQ % QI5_0; |
|
const int iqs8 = k_KQ % QI8_1; |
|
const int shift = k_KQ & (QI8_1/2); |
|
|
|
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; |
|
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); |
|
v |= (vh << 4) & 0x00000010; |
|
v |= (vh << 11) & 0x00001000; |
|
v |= (vh << 18) & 0x00100000; |
|
v |= (vh << 25) & 0x10000000; |
|
|
|
const int u = Q_q8[k_KQ_0/WARP_SIZE]; |
|
|
|
const int sumi = ggml_cuda_dp4a(v, u, 0); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_ds = (const half2 *) Q_ds_v; |
|
|
|
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; |
|
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) ; |
|
} else |
|
#endif |
|
{ |
|
const float2 * Q_ds = (const float2 *) Q_ds_v; |
|
|
|
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); |
|
} |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template<typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |
|
|
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; |
|
GGML_UNUSED(Q_v); |
|
|
|
T sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const int ib = k_KQ / QI8_1; |
|
const int iqs4 = k_KQ % QI5_1; |
|
const int iqs8 = k_KQ % QI8_1; |
|
const int shift = k_KQ & (QI8_1/2); |
|
|
|
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; |
|
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); |
|
v |= (vh << 4) & 0x00000010; |
|
v |= (vh << 11) & 0x00001000; |
|
v |= (vh << 18) & 0x00100000; |
|
v |= (vh << 25) & 0x10000000; |
|
|
|
const int u = Q_q8[k_KQ_0/WARP_SIZE]; |
|
|
|
const int sumi = ggml_cuda_dp4a(v, u, 0); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_ds = (const half2 *) Q_ds_v; |
|
|
|
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; |
|
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); |
|
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); |
|
} else |
|
#endif |
|
{ |
|
const float2 * Q_ds = (const float2 *) Q_ds_v; |
|
|
|
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; |
|
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; |
|
|
|
sum += (T) (sumid5d8 + m5s8scaled); |
|
} |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template <typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |
|
|
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; |
|
GGML_UNUSED(Q_v); |
|
|
|
T sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const int ib = k_KQ / QI8_0; |
|
const int iqs = k_KQ % QI8_0; |
|
|
|
const int v = get_int_b2(K_q8_0[ib].qs, iqs); |
|
|
|
T Q_d; |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_ds = (const half2 *) Q_ds_v; |
|
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); |
|
} else { |
|
const float2 * Q_ds = (const float2 *) Q_ds_v; |
|
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; |
|
} |
|
|
|
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template <typename T, int D> |
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( |
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { |
|
|
|
const half2 * K_h2 = (const half2 *) K_c; |
|
GGML_UNUSED(Q_q8); |
|
GGML_UNUSED(Q_ds_v); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
const half2 * Q_h2 = (const half2 *) Q_v; |
|
|
|
half2 sum2 = make_half2(0.0f, 0.0f); |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const half2 K_ik = K_h2[k_KQ]; |
|
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; |
|
} |
|
|
|
return __low2half(sum2) + __high2half(sum2); |
|
} |
|
#endif |
|
|
|
const float2 * Q_f2 = (const float2 *) Q_v; |
|
|
|
float sum = 0.0f; |
|
|
|
#pragma unroll |
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { |
|
const int k_KQ = k_KQ_0 + threadIdx.x; |
|
|
|
const half2 K_ik = K_h2[k_KQ]; |
|
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; |
|
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; |
|
} |
|
|
|
return sum; |
|
} |
|
|
|
template <typename Tds> |
|
static __device__ __forceinline__ void quantize_q8_1_to_shared( |
|
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { |
|
|
|
float vals[sizeof(int)] = {0.0f}; |
|
#pragma unroll |
|
for (int l = 0; l < sizeof(int); ++l) { |
|
vals[l] = scale * x[4*threadIdx.x + l]; |
|
} |
|
|
|
float amax = fabsf(vals[0]); |
|
float sum = vals[0]; |
|
#pragma unroll |
|
for (int l = 1; l < sizeof(int); ++l) { |
|
amax = fmaxf(amax, fabsf(vals[l])); |
|
sum += vals[l]; |
|
} |
|
#pragma unroll |
|
for (int mask = QI8_1/2; mask > 0; mask >>= 1) { |
|
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); |
|
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); |
|
} |
|
|
|
const float d = amax / 127; |
|
int q32 = 0; |
|
int8_t * q8 = (int8_t *) &q32; |
|
|
|
if (d != 0.0f) { |
|
#pragma unroll |
|
for (int l = 0; l < sizeof(int); ++l) { |
|
q8[l] = roundf(vals[l] / d); |
|
} |
|
} |
|
|
|
yq32[threadIdx.x] = q32; |
|
if (threadIdx.x % QI8_1 == 0) { |
|
if (std::is_same<Tds, half2>::value) { |
|
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); |
|
} else { |
|
((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); |
|
} |
|
} |
|
} |
|
|
|
typedef half (*dequantize_1_f16_t)(const void *, const int64_t); |
|
typedef float (*dequantize_1_f32_t)(const void *, const int64_t); |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { |
|
const block_q4_0 * x = (const block_q4_0 *) vx; |
|
|
|
const int64_t ib = i / QK4_0; |
|
const int iqs = i % (QK4_0/2); |
|
const int shift = (i % QK4_0) / (QK4_0/2); |
|
|
|
const T d = x[ib].d; |
|
const int q0 = x[ib].qs[iqs]; |
|
const int q = ((q0 >> (4*shift)) & 0x0F) - 8; |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
return ((half) d)*((half) q); |
|
} |
|
#endif |
|
|
|
return ((float) d)*((float) q); |
|
} |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { |
|
const block_q4_1 * x = (const block_q4_1 *) vx; |
|
|
|
const int64_t ib = i / QK4_1; |
|
const int iqs = i % (QK4_1/2); |
|
const int shift = (i % QK4_1) / (QK4_1/2); |
|
|
|
const half2 dm = x[ib].dm; |
|
const int q0 = x[ib].qs[iqs]; |
|
const int q = ((q0 >> (4*shift)) & 0x0F); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
return __low2half(dm)*((half) q) + __high2half(dm); |
|
} |
|
#endif |
|
|
|
return __low2float(dm)*((float) q) + __high2float(dm); |
|
} |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { |
|
const block_q5_0 * x = (const block_q5_0 *) vx; |
|
|
|
const int64_t ib = i / QK5_0; |
|
const int idq = i % QK5_0; |
|
const int iqs = i % (QK5_0/2); |
|
const int shift = (i % QK5_0) / (QK5_0/2); |
|
|
|
const T d = x[ib].d; |
|
const int ql0 = x[ib].qs[iqs]; |
|
const int qh0 = get_int_b2(x[ib].qh, 0); |
|
const int ql = ((ql0 >> (4*shift)) & 0x0F); |
|
const int qh = ((qh0 >> idq) << 4) & 0x10; |
|
const int q = (ql | qh) - 16; |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
return ((half) d)*((half) q); |
|
} |
|
#endif |
|
|
|
return ((float) d)*((float) q); |
|
} |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { |
|
const block_q5_1 * x = (const block_q5_1 *) vx; |
|
|
|
const int64_t ib = i / QK5_1; |
|
const int idq = i % QK5_1; |
|
const int iqs = i % (QK5_1/2); |
|
const int shift = (i % QK5_1) / (QK5_1/2); |
|
|
|
const half2 dm = x[ib].dm; |
|
const int ql0 = x[ib].qs[iqs]; |
|
const int qh0 = get_int_b4(x[ib].qh, 0); |
|
const int ql = ((ql0 >> (4*shift)) & 0x0F); |
|
const int qh = ((qh0 >> idq) << 4) & 0x10; |
|
const int q = (ql | qh); |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
return __low2half(dm)*((half) q) + __high2half(dm); |
|
} |
|
#endif |
|
|
|
return __low2float(dm)*((float) q) + __high2float(dm); |
|
} |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { |
|
const block_q8_0 * x = (const block_q8_0 *) vx; |
|
|
|
const int64_t ib = i / QK8_0; |
|
const int iqs = i % QK8_0; |
|
|
|
const T d = x[ib].d; |
|
const int q = x[ib].qs[iqs]; |
|
|
|
#ifdef FP16_AVAILABLE |
|
if (std::is_same<T, half>::value) { |
|
return ((half) d)*((half) q); |
|
} |
|
#endif |
|
|
|
return ((float) d)*((float) q); |
|
} |
|
|
|
template <typename T> |
|
static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { |
|
const half * x = (const half *) vx; |
|
|
|
return x[i]; |
|
} |
|
|
|
template <int D> |
|
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { |
|
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> : |
|
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> : |
|
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> : |
|
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> : |
|
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> : |
|
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> : |
|
nullptr; |
|
} |
|
|
|
template <int D> |
|
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { |
|
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> : |
|
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> : |
|
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> : |
|
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> : |
|
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> : |
|
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> : |
|
nullptr; |
|
} |
|
|
|
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { |
|
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> : |
|
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> : |
|
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> : |
|
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> : |
|
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> : |
|
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> : |
|
nullptr; |
|
} |
|
|
|
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { |
|
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> : |
|
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> : |
|
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> : |
|
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> : |
|
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> : |
|
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> : |
|
nullptr; |
|
} |
|
|
|
|
|
#ifdef __clang__ |
|
#pragma clang diagnostic push |
|
#pragma clang diagnostic ignored "-Wpass-failed" |
|
#endif |
|
|
|
template<int D, int ncols, int KQ_stride> |
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
|
__launch_bounds__(D, 1) |
|
#endif |
|
static __global__ void flash_attn_stream_k_fixup( |
|
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { |
|
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); |
|
|
|
const int iter_k = ne11 / KQ_stride; |
|
const int iter_j = (ne01 + (ncols - 1)) / ncols; |
|
|
|
const int bidx0 = blockIdx.x; |
|
|
|
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; |
|
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; |
|
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop; |
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; |
|
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; |
|
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { |
|
return; |
|
} |
|
|
|
const int channel = kbc0 / (iter_k*iter_j); |
|
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; |
|
|
|
dst += jt*ncols*ne02*D + channel*D; |
|
|
|
|
|
float dst_val[ncols] = {0.0f}; |
|
float max_val[ncols] = {0.0f}; |
|
float rowsum[ncols] = {0.0f}; |
|
#pragma unroll |
|
for (int j = 0; j < ncols; ++j) { |
|
if (jt*ncols + j >= ne01) { |
|
break; |
|
} |
|
dst_val[j] = dst[j*ne02*D + threadIdx.x]; |
|
|
|
const float2 tmp = dst_fixup[bidx0*ncols + j]; |
|
max_val[j] = tmp.x; |
|
rowsum[j] = tmp.y; |
|
} |
|
|
|
|
|
|
|
int bidx = bidx0 - 1; |
|
int kbc_stop = kbc0; |
|
while(true) { |
|
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; |
|
if (kbc == kbc_stop) { |
|
bidx--; |
|
kbc_stop = kbc; |
|
continue; |
|
} |
|
|
|
#pragma unroll |
|
for (int j = 0; j < ncols; ++j) { |
|
if (jt*ncols + j >= ne01) { |
|
break; |
|
} |
|
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; |
|
|
|
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; |
|
|
|
|
|
const float max_val_new = fmaxf(max_val[j], tmp.x); |
|
|
|
const float diff_val = max_val[j] - max_val_new; |
|
const float diff_add = tmp.x - max_val_new; |
|
|
|
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; |
|
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; |
|
|
|
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; |
|
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; |
|
|
|
max_val[j] = max_val_new; |
|
} |
|
|
|
|
|
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { |
|
break; |
|
} |
|
bidx--; |
|
kbc_stop = kbc; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < ncols; ++j) { |
|
if (jt*ncols + j >= ne01) { |
|
return; |
|
} |
|
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; |
|
} |
|
} |
|
|
|
#ifdef __clang__ |
|
#pragma clang diagnostic pop |
|
#endif |
|
|
|
template<int D, int parallel_blocks> |
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
|
__launch_bounds__(D, 1) |
|
#endif |
|
static __global__ void flash_attn_combine_results( |
|
const float * __restrict__ VKQ_parts, |
|
const float2 * __restrict__ VKQ_meta, |
|
float * __restrict__ dst) { |
|
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; |
|
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; |
|
dst += D * gridDim.y*blockIdx.x; |
|
|
|
const int tid = threadIdx.x; |
|
__builtin_assume(tid < D); |
|
|
|
__shared__ float2 meta[parallel_blocks]; |
|
if (tid < 2*parallel_blocks) { |
|
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
float kqmax = meta[0].x; |
|
#pragma unroll |
|
for (int l = 1; l < parallel_blocks; ++l) { |
|
kqmax = max(kqmax, meta[l].x); |
|
} |
|
|
|
float VKQ_numerator = 0.0f; |
|
float VKQ_denominator = 0.0f; |
|
#pragma unroll |
|
for (int l = 0; l < parallel_blocks; ++l) { |
|
const float diff = meta[l].x - kqmax; |
|
const float KQ_max_scale = expf(diff); |
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); |
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask; |
|
|
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; |
|
VKQ_denominator += KQ_max_scale * meta[l].y; |
|
} |
|
|
|
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; |
|
} |
|
|
|
static void on_no_fattn_vec_case(const int D) { |
|
if (D == 64) { |
|
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); |
|
fprintf(stderr, "By default only f16 KV cache is supported.\n"); |
|
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); |
|
GGML_ABORT("fatal error"); |
|
} else if (D == 128) { |
|
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); |
|
fprintf(stderr, "Supported combinations:\n"); |
|
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); |
|
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); |
|
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); |
|
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); |
|
GGML_ABORT("fatal error"); |
|
} else { |
|
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); |
|
fprintf(stderr, "Only f16 is supported.\n"); |
|
GGML_ABORT("fatal error"); |
|
} |
|
} |
|
|
|
|
|
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride> |
|
void launch_fattn( |
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, |
|
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V |
|
) { |
|
const ggml_tensor * Q = dst->src[0]; |
|
const ggml_tensor * K = dst->src[1]; |
|
const ggml_tensor * V = dst->src[2]; |
|
|
|
const ggml_tensor * mask = dst->src[3]; |
|
|
|
ggml_tensor * KQV = dst; |
|
|
|
GGML_ASSERT(Q->type == GGML_TYPE_F32); |
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32); |
|
|
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); |
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && |
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); |
|
|
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); |
|
|
|
GGML_ASSERT(Q->ne[3] == 1); |
|
|
|
ggml_cuda_pool & pool = ctx.pool(); |
|
cudaStream_t main_stream = ctx.stream(); |
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; |
|
|
|
ggml_cuda_pool_alloc<half> K_f16(pool); |
|
ggml_cuda_pool_alloc<half> V_f16(pool); |
|
ggml_cuda_pool_alloc<float> dst_tmp(pool); |
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); |
|
|
|
const char * K_data = (const char *) K->data; |
|
size_t nb11 = K->nb[1]; |
|
size_t nb12 = K->nb[2]; |
|
size_t nb13 = K->nb[3]; |
|
|
|
const char * V_data = (const char *) V->data; |
|
size_t nb21 = V->nb[1]; |
|
size_t nb22 = V->nb[2]; |
|
size_t nb23 = V->nb[3]; |
|
|
|
if (need_f16_K && K->type != GGML_TYPE_F16) { |
|
K_f16.alloc(ggml_nelements(K)); |
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); |
|
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); |
|
K_data = (char *) K_f16.ptr; |
|
|
|
const size_t bs = ggml_blck_size(K->type); |
|
const size_t ts = ggml_type_size(K->type); |
|
|
|
nb11 = nb11*bs*sizeof(half)/ts; |
|
nb12 = nb12*bs*sizeof(half)/ts; |
|
nb13 = nb13*bs*sizeof(half)/ts; |
|
} |
|
|
|
if (need_f16_V && V->type != GGML_TYPE_F16) { |
|
V_f16.alloc(ggml_nelements(V)); |
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); |
|
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); |
|
V_data = (char *) V_f16.ptr; |
|
|
|
const size_t bs = ggml_blck_size(V->type); |
|
const size_t ts = ggml_type_size(V->type); |
|
|
|
nb21 = nb21*bs*sizeof(half)/ts; |
|
nb22 = nb22*bs*sizeof(half)/ts; |
|
nb23 = nb23*bs*sizeof(half)/ts; |
|
} |
|
|
|
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); |
|
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; |
|
|
|
const dim3 block_dim(WARP_SIZE, nwarps, 1); |
|
dim3 blocks_num; |
|
if (parallel_blocks == 0) { |
|
|
|
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; |
|
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; |
|
const bool short_context = K->ne[1] < 4096; |
|
|
|
const int nblocks_stream_k = 2*nsm; |
|
|
|
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; |
|
blocks_num.y = 1; |
|
blocks_num.z = 1; |
|
|
|
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); |
|
} else { |
|
blocks_num.x = parallel_blocks*ntiles_x; |
|
blocks_num.y = Q->ne[2]; |
|
blocks_num.z = Q->ne[3]; |
|
|
|
if (parallel_blocks > 1) { |
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); |
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); |
|
} |
|
} |
|
|
|
|
|
float scale = 1.0f; |
|
float max_bias = 0.0f; |
|
float logit_softcap = 0.0f; |
|
|
|
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); |
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); |
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |
|
|
|
if (logit_softcap != 0.0f) { |
|
scale /= logit_softcap; |
|
} |
|
|
|
const uint32_t n_head = Q->ne[2]; |
|
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); |
|
|
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); |
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |
|
|
|
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( |
|
(const char *) Q->data, |
|
K_data, |
|
V_data, |
|
mask ? ((const char *) mask->data) : nullptr, |
|
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, |
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap, |
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], |
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3], |
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, |
|
Q->nb[1], Q->nb[2], Q->nb[3], |
|
nb11, nb12, nb13, |
|
nb21, nb22, nb23, |
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] |
|
); |
|
CUDA_CHECK(cudaGetLastError()); |
|
|
|
if constexpr (parallel_blocks == 0) { |
|
if (blocks_num.x % ntiles_total != 0) { |
|
const dim3 block_dim_combine(D, 1, 1); |
|
const dim3 blocks_num_combine = blocks_num; |
|
|
|
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride> |
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>> |
|
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); |
|
} |
|
} else if constexpr (parallel_blocks > 1) { |
|
const dim3 block_dim_combine(D, 1, 1); |
|
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); |
|
|
|
flash_attn_combine_results<D, parallel_blocks> |
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>> |
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); |
|
} |
|
CUDA_CHECK(cudaGetLastError()); |
|
} |
|
|