Spaces:
Sleeping
Sleeping
/****************************************************************************** | |
* Copyright (c) 2023, Tri Dao. | |
******************************************************************************/ | |
constexpr int TOTAL_DIM = 0; | |
constexpr int H_DIM = 1; | |
constexpr int D_DIM = 2; | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
struct Qkv_params { | |
using index_t = int64_t; | |
// The QKV matrices. | |
void *__restrict__ q_ptr; | |
void *__restrict__ k_ptr; | |
void *__restrict__ v_ptr; | |
// The stride between rows of the Q, K and V matrices. | |
index_t q_batch_stride; | |
index_t k_batch_stride; | |
index_t v_batch_stride; | |
index_t q_row_stride; | |
index_t k_row_stride; | |
index_t v_row_stride; | |
index_t q_head_stride; | |
index_t k_head_stride; | |
index_t v_head_stride; | |
// The number of heads. | |
int h, h_k; | |
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be | |
// different from nheads (query). | |
int h_h_k_ratio; // precompute h / h_k, | |
}; | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
struct Flash_fwd_params : public Qkv_params { | |
// The O matrix (output). | |
void * __restrict__ o_ptr; | |
void * __restrict__ oaccum_ptr; | |
// The stride between rows of O. | |
index_t o_batch_stride; | |
index_t o_row_stride; | |
index_t o_head_stride; | |
// The pointer to the P matrix. | |
void * __restrict__ p_ptr; | |
// The pointer to the softmax sum. | |
void * __restrict__ softmax_lse_ptr; | |
void * __restrict__ softmax_lseaccum_ptr; | |
// The dimensions. | |
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; | |
// The scaling factors for the kernel. | |
float scale_softmax; | |
float scale_softmax_log2; | |
// array of length b+1 holding starting offset of each sequence. | |
int * __restrict__ cu_seqlens_q; | |
int * __restrict__ cu_seqlens_k; | |
// If provided, the actual length of each k sequence. | |
int * __restrict__ seqused_k; | |
int *__restrict__ blockmask; | |
// The K_new and V_new matrices. | |
void * __restrict__ knew_ptr; | |
void * __restrict__ vnew_ptr; | |
// The stride between rows of the Q, K and V matrices. | |
index_t knew_batch_stride; | |
index_t vnew_batch_stride; | |
index_t knew_row_stride; | |
index_t vnew_row_stride; | |
index_t knew_head_stride; | |
index_t vnew_head_stride; | |
// The cos and sin matrices for rotary embedding. | |
void * __restrict__ rotary_cos_ptr; | |
void * __restrict__ rotary_sin_ptr; | |
// The indices to index into the KV cache. | |
int * __restrict__ cache_batch_idx; | |
// Paged KV cache | |
int * __restrict__ block_table; | |
index_t block_table_batch_stride; | |
int page_block_size; | |
// The dropout probability (probability of keeping an activation). | |
float p_dropout; | |
// uint32_t p_dropout_in_uint; | |
// uint16_t p_dropout_in_uint16_t; | |
uint8_t p_dropout_in_uint8_t; | |
// Scale factor of 1 / (1 - p_dropout). | |
float rp_dropout; | |
float scale_softmax_rp_dropout; | |
// Local window size | |
int window_size_left, window_size_right; | |
// Random state. | |
at::PhiloxCudaState philox_args; | |
// Pointer to the RNG seed (idx 0) and offset (idx 1). | |
uint64_t * rng_state; | |
bool is_bf16; | |
bool is_causal; | |
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. | |
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. | |
bool is_seqlens_k_cumulative; | |
bool is_rotary_interleaved; | |
int num_splits; // For split-KV version | |
void * __restrict__ alibi_slopes_ptr; | |
index_t alibi_slopes_batch_stride; | |
}; | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
struct Flash_bwd_params : public Flash_fwd_params { | |
// The dO and dQKV matrices. | |
void *__restrict__ do_ptr; | |
void *__restrict__ dq_ptr; | |
void *__restrict__ dk_ptr; | |
void *__restrict__ dv_ptr; | |
// To accumulate dQ | |
void *__restrict__ dq_accum_ptr; | |
void *__restrict__ dk_accum_ptr; | |
void *__restrict__ dv_accum_ptr; | |
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q | |
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__ | |
// dv_accum_ptr; | |
// The stride between rows of the dO, dQ, dK and dV matrices. | |
// TD [2022-04-16]: We're using 32-bit indexing to save registers. | |
// The code probably won't work for arrays larger than 2GB. | |
index_t do_batch_stride; | |
index_t do_row_stride; | |
index_t do_head_stride; | |
index_t dq_batch_stride; | |
index_t dk_batch_stride; | |
index_t dv_batch_stride; | |
index_t dq_row_stride; | |
index_t dk_row_stride; | |
index_t dv_row_stride; | |
index_t dq_head_stride; | |
index_t dk_head_stride; | |
index_t dv_head_stride; | |
// The pointer to the softmax d sum. | |
void *__restrict__ dsoftmax_sum; | |
bool deterministic; | |
index_t dq_accum_split_stride; | |
}; | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); | |
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); | |
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); | |