Spaces:
Sleeping
Sleeping
typedef at::Half fp16; | |
__half *cast(fp16 *ptr) { | |
return reinterpret_cast<__half *>(ptr); | |
} | |
template <typename F> | |
__global__ void kernel_wkv_forward(const int B, const int T, const int C, | |
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, | |
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { | |
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
const int _b = idx / C; | |
const int _c = idx % C; | |
const int _offset = _b * T * C + _c; | |
const int _state_offset = _b * C + _c; | |
float u = _u[_c]; | |
float w = _w[_c]; | |
const F *__restrict__ const k = _k + _offset; | |
const F *__restrict__ const v = _v + _offset; | |
F *__restrict__ const y = _y + _offset; | |
float aa = _aa[_state_offset]; | |
float bb = _bb[_state_offset]; | |
float pp = _pp[_state_offset]; | |
for (int i = 0; i < T; i++) { | |
const int ii = i * C; | |
const float kk = float(k[ii]); | |
const float vv = float(v[ii]); | |
float ww = u + kk; | |
float p = max(pp, ww); | |
float e1 = exp(pp - p); | |
float e2 = exp(ww - p); | |
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); | |
ww = w + pp; | |
p = max(ww, kk); | |
e1 = exp(ww - p); | |
e2 = exp(kk - p); | |
aa = e1 * aa + e2 * vv; | |
bb = e1 * bb + e2; | |
pp = p; | |
} | |
_aa[_state_offset] = aa; | |
_bb[_state_offset] = bb; | |
_pp[_state_offset] = pp; | |
} | |
template <typename F> | |
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { | |
dim3 threadsPerBlock( min(C, 32) ); | |
assert(B * C % threadsPerBlock.x == 0); | |
dim3 numBlocks(B * C / threadsPerBlock.x); | |
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp); | |
} | |
template void cuda_wkv_forward<fp16>( | |
int B, int T, int C, | |
float *w, float *u, fp16 *k, fp16 *v, fp16 *y, | |
float *aa, float *bb, float *pp); | |
template void cuda_wkv_forward<float>( | |
int B, int T, int C, | |
float *w, float *u, float *k, float *v, float *y, | |
float *aa, float *bb, float *pp); | |
__global__ void kernel_mm_seq_fp32i8( | |
const int B, const int N, const int M, | |
const float *__restrict__ const x, const int x_stride, | |
const uint8_t *__restrict__ const w, const int w_stride, | |
const float *__restrict__ const mx, | |
const float *__restrict__ const rx, | |
const float *__restrict__ const my, | |
const float *__restrict__ const ry, | |
float *__restrict__ const y, const int y_stride) { | |
const int i = blockIdx.x * blockDim.x + threadIdx.x; | |
const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
if (i < B && k < M) { | |
float y_local = 0; | |
for (int j = 0; j < N; ++j) { | |
y_local += x[i * x_stride + j] * ( | |
(float(w[j * w_stride + k]) + 0.5f) | |
* rx[k] * ry[j] + mx[k] + my[j] | |
); | |
} | |
y[i * y_stride + k] = y_local; | |
} | |
} | |
template <typename F> | |
void cuda_mm8_seq(int B, int N, int M, | |
F *x, int x_stride, | |
uint8_t *w, int w_stride, | |
F *mx, F *rx, | |
F *my, F *ry, | |
F *y, int y_stride); | |
template <> | |
void cuda_mm8_seq<float>(int B, int N, int M, | |
float *x, int x_stride, | |
uint8_t *w, int w_stride, | |
float *mx, float *rx, | |
float *my, float *ry, | |
float *y, int y_stride) { | |
dim3 blockSize(1, 128); | |
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); | |
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>( | |
B, N, M, x, x_stride, w, w_stride, | |
mx, rx, my, ry, y, y_stride); | |
} | |
__global__ void kernel_mm_seq_fp16i8( | |
const int B, const int N, const int M, | |
const __half *__restrict__ const x, const int x_stride, | |
const uint8_t *__restrict__ const w, const int w_stride, | |
const __half *__restrict__ const mx, | |
const __half *__restrict__ const rx, | |
const __half *__restrict__ const my, | |
const __half *__restrict__ const ry, | |
__half *__restrict__ const y, const int y_stride) { | |
const int i = blockIdx.x * blockDim.x + threadIdx.x; | |
const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
if (i < B && k < M) { | |
float y_local = 0; | |
for (int j = 0; j < N; ++j) { | |
y_local += __half2float(x[i * x_stride + j]) * ( | |
(float(w[j * w_stride + k]) + 0.5f) | |
* __half2float(rx[k]) * __half2float(ry[j]) | |
+ __half2float(mx[k]) + __half2float(my[j]) | |
); | |
} | |
y[i * y_stride + k] = __float2half(y_local); | |
} | |
} | |
template <> | |
void cuda_mm8_seq<fp16>(int B, int N, int M, | |
fp16 *x, int x_stride, | |
uint8_t *w, int w_stride, | |
fp16 *mx, fp16 *rx, | |
fp16 *my, fp16 *ry, | |
fp16 *y, int y_stride) { | |
dim3 blockSize(1, 128); | |
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); | |
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>( | |
B, N, M, cast(x), x_stride, w, w_stride, | |
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); | |
} | |
__global__ void kernel_mm_one_fp32i8( | |
const int N, const int M, | |
const float *__restrict__ const x, | |
const uint8_t *__restrict__ const w, const int w_stride, | |
const float *__restrict__ const mx, | |
const float *__restrict__ const rx, | |
const float *__restrict__ const my, | |
const float *__restrict__ const ry, | |
float *__restrict__ const y) { | |
const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
if (k < M) { | |
float y_local = 0; | |
for (int j = j0; j < j1; ++j) { | |
y_local += x[j] * ( | |
(float(w[j * w_stride + k]) + 0.5f) | |
* rx[k] * ry[j] + mx[k] + my[j] | |
); | |
} | |
atomicAdd(&y[k], y_local); | |
} | |
} | |
template <typename F> | |
void cuda_mm8_one(int N, int M, | |
F *x, | |
uint8_t *w, int w_stride, | |
F *mx, F *rx, | |
F *my, F *ry, | |
float *y); | |
template <> | |
void cuda_mm8_one<float>(int N, int M, | |
float *x, | |
uint8_t *w, int w_stride, | |
float *mx, float *rx, | |
float *my, float *ry, | |
float *y) { | |
dim3 blockSize(1, MM8_ONE_TILE); | |
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); | |
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>( | |
N, M, x, w, w_stride, | |
mx, rx, my, ry, y); | |
} | |
__global__ void kernel_mm_one_fp16i8( | |
const int N, const int M, | |
const __half *__restrict__ const x, | |
const uint8_t *__restrict__ const w, const int w_stride, | |
const __half *__restrict__ const mx, | |
const __half *__restrict__ const rx, | |
const __half *__restrict__ const my, | |
const __half *__restrict__ const ry, | |
float *__restrict__ const y) { | |
const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
if (k < M) { | |
float y_local = 0; | |
for (int j = j0; j < j1; ++j) { | |
y_local += __half2float(x[j]) * ( | |
(float(w[j * w_stride + k]) + 0.5f) | |
* __half2float(rx[k]) * __half2float(ry[j]) | |
+ __half2float(mx[k]) + __half2float(my[j]) | |
); | |
} | |
atomicAdd(&y[k], y_local); | |
} | |
} | |
template <> | |
void cuda_mm8_one<fp16>(int N, int M, | |
fp16 *x, | |
uint8_t *w, int w_stride, | |
fp16 *mx, fp16 *rx, | |
fp16 *my, fp16 *ry, | |
float *y) { | |
dim3 blockSize(1, MM8_ONE_TILE); | |
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); | |
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>( | |
N, M, cast(x), w, w_stride, | |
cast(mx), cast(rx), cast(my), cast(ry), y); | |
} | |