kernel
File size: 3,540 Bytes
a7165c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0d3c12
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eda872e
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eda872e
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eda872e
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7165c8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

// TODO: Add all of the functions listed
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//     m.doc() = "FlashAttention";
//     m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
//     m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
//     m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
//     m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
//     m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
// } 

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("fwd("
    "Tensor! q, "
    "Tensor k, "
    "Tensor v, "
    "Tensor(out_!)? out_, "
    "Tensor? alibi_slopes_, "
    "float p_dropout, "
    "float softmax_scale, "
    "bool is_causal,"
    "int window_size_left, "
    "int window_size_right, "
    "float softcap, "
    "bool return_softmax, "
    "Generator? gen_) -> Tensor[]");
  ops.impl("fwd", torch::kCUDA, &mha_fwd);

  ops.def("varlen_fwd("
    "Tensor! q, "
    "Tensor k, "
    "Tensor v, "
    "Tensor? out_, "
    "Tensor cu_seqlens_q, "
    "Tensor cu_seqlens_k, "
    "Tensor? seqused_k_, "
    "Tensor? leftpad_k_, "
    "Tensor? block_table_, "
    "Tensor? alibi_slopes_, "
    "int max_seqlen_q, "
    "int max_seqlen_k, "
    "float p_dropout, "
    "float softmax_scale, "
    "bool zero_tensors, "
    "bool is_causal, "
    "int window_size_left, "
    "int window_size_right, "
    "float softcap, "
    "bool return_softmax, "
    "Generator? gen_) -> Tensor[]");
  ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);

  ops.def("bwd("
    "Tensor! dout, "
    "Tensor! q, "
    "Tensor! k, "
    "Tensor! v, "
    "Tensor! out, "
    "Tensor! "
    "softmax_lse, "
    "Tensor? dq_, "
    "Tensor? dk_, "
    "Tensor? dv_, "
    "Tensor? alibi_slopes_, "
    "float p_dropout, "
    "float softmax_scale, "
    "bool is_causal, "
    "int window_size_left, "
    "int window_size_right, "
    "float softcap, "
    "bool deterministic, "
    "Generator? gen_, "
    "Tensor? rng_state) -> Tensor[]");
  ops.impl("bwd", torch::kCUDA, &mha_bwd);

  ops.def("varlen_bwd("
    "Tensor! dout, "
    "Tensor! q, "
    "Tensor! k, "
    "Tensor! v, "
    "Tensor! out, "
    "Tensor! softmax_lse, "
    "Tensor? dq_, "
    "Tensor? dk_, "
    "Tensor? dv_, "
    "Tensor cu_seqlens_q, "
    "Tensor cu_seqlens_k, "
    "Tensor? alibi_slopes_, "
    "int max_seqlen_q, "
    "int max_seqlen_k, "
    "float p_dropout, float softmax_scale, "
    "bool zero_tensors, "
    "bool is_causal, "
    "int window_size_left, "
    "int window_size_right, "
    "float softcap, "
    "bool deterministic, "
    "Generator? gen_, "
    "Tensor? rng_state) -> Tensor[]");
  ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd);

  ops.def("fwd_kvcache("
    "Tensor! q, "
    "Tensor! kcache, "
    "Tensor! vcache, "
    "Tensor? k_, "
    "Tensor? v_, "
    "Tensor? seqlens_k_, "
    "Tensor? rotary_cos_, "
    "Tensor? rotary_sin_, "
    "Tensor? cache_batch_idx_, "
    "Tensor? leftpad_k_, "
    "Tensor? block_table_, "
    "Tensor? alibi_slopes_, "
    "Tensor? out_, "
    "float softmax_scale, "
    "bool is_causal, "
    "int window_size_left, "
    "int window_size_right, "
    "float softcap, "
    "bool is_rotary_interleaved, "
    "int num_splits) -> Tensor[]");
  ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)