|
#include <torch/library.h> |
|
|
|
#include "registration.h" |
|
#include "torch_binding.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
ops.def("fwd(" |
|
"Tensor! q, " |
|
"Tensor k, " |
|
"Tensor v, " |
|
"Tensor(out_!)? out_, " |
|
"Tensor? alibi_slopes_, " |
|
"float p_dropout, " |
|
"float softmax_scale, " |
|
"bool is_causal," |
|
"int window_size_left, " |
|
"int window_size_right, " |
|
"float softcap, " |
|
"bool return_softmax, " |
|
"Generator? gen_) -> Tensor[]"); |
|
ops.impl("fwd", torch::kCUDA, &mha_fwd); |
|
|
|
ops.def("varlen_fwd(" |
|
"Tensor! q, " |
|
"Tensor k, " |
|
"Tensor v, " |
|
"Tensor? out_, " |
|
"Tensor cu_seqlens_q, " |
|
"Tensor cu_seqlens_k, " |
|
"Tensor? seqused_k_, " |
|
"Tensor? leftpad_k_, " |
|
"Tensor? block_table_, " |
|
"Tensor? alibi_slopes_, " |
|
"int max_seqlen_q, " |
|
"int max_seqlen_k, " |
|
"float p_dropout, " |
|
"float softmax_scale, " |
|
"bool zero_tensors, " |
|
"bool is_causal, " |
|
"int window_size_left, " |
|
"int window_size_right, " |
|
"float softcap, " |
|
"bool return_softmax, " |
|
"Generator? gen_) -> Tensor[]"); |
|
ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd); |
|
|
|
ops.def("bwd(" |
|
"Tensor! dout, " |
|
"Tensor! q, " |
|
"Tensor! k, " |
|
"Tensor! v, " |
|
"Tensor! out, " |
|
"Tensor! " |
|
"softmax_lse, " |
|
"Tensor? dq_, " |
|
"Tensor? dk_, " |
|
"Tensor? dv_, " |
|
"Tensor? alibi_slopes_, " |
|
"float p_dropout, " |
|
"float softmax_scale, " |
|
"bool is_causal, " |
|
"int window_size_left, " |
|
"int window_size_right, " |
|
"float softcap, " |
|
"bool deterministic, " |
|
"Generator? gen_, " |
|
"Tensor? rng_state) -> Tensor[]"); |
|
ops.impl("bwd", torch::kCUDA, &mha_bwd); |
|
|
|
ops.def("varlen_bwd(" |
|
"Tensor! dout, " |
|
"Tensor! q, " |
|
"Tensor! k, " |
|
"Tensor! v, " |
|
"Tensor! out, " |
|
"Tensor! softmax_lse, " |
|
"Tensor? dq_, " |
|
"Tensor? dk_, " |
|
"Tensor? dv_, " |
|
"Tensor cu_seqlens_q, " |
|
"Tensor cu_seqlens_k, " |
|
"Tensor? alibi_slopes_, " |
|
"int max_seqlen_q, " |
|
"int max_seqlen_k, " |
|
"float p_dropout, float softmax_scale, " |
|
"bool zero_tensors, " |
|
"bool is_causal, " |
|
"int window_size_left, " |
|
"int window_size_right, " |
|
"float softcap, " |
|
"bool deterministic, " |
|
"Generator? gen_, " |
|
"Tensor? rng_state) -> Tensor[]"); |
|
ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd); |
|
|
|
ops.def("fwd_kvcache(" |
|
"Tensor! q, " |
|
"Tensor! kcache, " |
|
"Tensor! vcache, " |
|
"Tensor? k_, " |
|
"Tensor? v_, " |
|
"Tensor? seqlens_k_, " |
|
"Tensor? rotary_cos_, " |
|
"Tensor? rotary_sin_, " |
|
"Tensor? cache_batch_idx_, " |
|
"Tensor? leftpad_k_, " |
|
"Tensor? block_table_, " |
|
"Tensor? alibi_slopes_, " |
|
"Tensor? out_, " |
|
"float softmax_scale, " |
|
"bool is_causal, " |
|
"int window_size_left, " |
|
"int window_size_right, " |
|
"float softcap, " |
|
"bool is_rotary_interleaved, " |
|
"int num_splits) -> Tensor[]"); |
|
ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|