File size: 3,540 Bytes
a7165c8 876ac68 b0d3c12 876ac68 eda872e 876ac68 eda872e 876ac68 eda872e 876ac68 a7165c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
// TODO: Add all of the functions listed
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.doc() = "FlashAttention";
// m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
// m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
// m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
// m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
// m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
// }
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd("
"Tensor! q, "
"Tensor k, "
"Tensor v, "
"Tensor(out_!)? out_, "
"Tensor? alibi_slopes_, "
"float p_dropout, "
"float softmax_scale, "
"bool is_causal,"
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, &mha_fwd);
ops.def("varlen_fwd("
"Tensor! q, "
"Tensor k, "
"Tensor v, "
"Tensor? out_, "
"Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, "
"Tensor? seqused_k_, "
"Tensor? leftpad_k_, "
"Tensor? block_table_, "
"Tensor? alibi_slopes_, "
"int max_seqlen_q, "
"int max_seqlen_k, "
"float p_dropout, "
"float softmax_scale, "
"bool zero_tensors, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> Tensor[]");
ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
ops.def("bwd("
"Tensor! dout, "
"Tensor! q, "
"Tensor! k, "
"Tensor! v, "
"Tensor! out, "
"Tensor! "
"softmax_lse, "
"Tensor? dq_, "
"Tensor? dk_, "
"Tensor? dv_, "
"Tensor? alibi_slopes_, "
"float p_dropout, "
"float softmax_scale, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> Tensor[]");
ops.impl("bwd", torch::kCUDA, &mha_bwd);
ops.def("varlen_bwd("
"Tensor! dout, "
"Tensor! q, "
"Tensor! k, "
"Tensor! v, "
"Tensor! out, "
"Tensor! softmax_lse, "
"Tensor? dq_, "
"Tensor? dk_, "
"Tensor? dv_, "
"Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, "
"Tensor? alibi_slopes_, "
"int max_seqlen_q, "
"int max_seqlen_k, "
"float p_dropout, float softmax_scale, "
"bool zero_tensors, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> Tensor[]");
ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd);
ops.def("fwd_kvcache("
"Tensor! q, "
"Tensor! kcache, "
"Tensor! vcache, "
"Tensor? k_, "
"Tensor? v_, "
"Tensor? seqlens_k_, "
"Tensor? rotary_cos_, "
"Tensor? rotary_sin_, "
"Tensor? cache_batch_idx_, "
"Tensor? leftpad_k_, "
"Tensor? block_table_, "
"Tensor? alibi_slopes_, "
"Tensor? out_, "
"float softmax_scale, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool is_rotary_interleaved, "
"int num_splits) -> Tensor[]");
ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|