Import CUTLASS tests and add missing scaled mm with zp signature
Browse files- ext-torch/__init__.py +27 -0
- tests/__init__.py +0 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/test_cutlass.py +454 -0
- tests/kernels/utils.py +62 -0
ext-torch/__init__.py
CHANGED
@@ -42,6 +42,33 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|
42 |
|
43 |
return out
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# fp8
|
46 |
def scaled_fp8_quant(
|
47 |
input: torch.Tensor,
|
|
|
42 |
|
43 |
return out
|
44 |
|
45 |
+
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
46 |
+
b: torch.Tensor,
|
47 |
+
scale_a: torch.Tensor,
|
48 |
+
scale_b: torch.Tensor,
|
49 |
+
out_dtype: torch.dtype,
|
50 |
+
azp_adj: torch.Tensor,
|
51 |
+
azp: Optional[torch.Tensor] = None,
|
52 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
53 |
+
"""
|
54 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
55 |
+
Always per-channel.
|
56 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
57 |
+
"""
|
58 |
+
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
59 |
+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
60 |
+
assert bias is None or bias.numel(
|
61 |
+
) == b.shape[1] and bias.dtype == out_dtype
|
62 |
+
assert azp is None or azp.numel() == a.shape[0]
|
63 |
+
|
64 |
+
m = a.shape[0]
|
65 |
+
n = b.shape[1]
|
66 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
67 |
+
|
68 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
69 |
+
azp, bias)
|
70 |
+
return out
|
71 |
+
|
72 |
# fp8
|
73 |
def scaled_fp8_quant(
|
74 |
input: torch.Tensor,
|
tests/__init__.py
ADDED
File without changes
|
tests/kernels/__init__.py
ADDED
File without changes
|
tests/kernels/test_cutlass.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for cutlass kernels
|
2 |
+
|
3 |
+
Run `pytest tests/kernels/test_cutlass.py`.
|
4 |
+
"""
|
5 |
+
from typing import Optional, Type
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from tests.kernels.utils import opcheck
|
11 |
+
import quantization as ops
|
12 |
+
|
13 |
+
MNK_FACTORS = [
|
14 |
+
(1, 256, 128),
|
15 |
+
(1, 16384, 1024),
|
16 |
+
(1, 24576, 496),
|
17 |
+
(16, 256, 496),
|
18 |
+
(16, 16384, 128),
|
19 |
+
(16, 24576, 4096),
|
20 |
+
(32, 8192, 4096),
|
21 |
+
(32, 16384, 4096),
|
22 |
+
(33, 1024, 1024),
|
23 |
+
(33, 8192, 128),
|
24 |
+
(64, 2048, 496),
|
25 |
+
(64, 16384, 1024),
|
26 |
+
(100, 8192, 496),
|
27 |
+
(128, 32768, 4096),
|
28 |
+
(256, 4096, 4096),
|
29 |
+
(512, 256, 1024),
|
30 |
+
(512, 8192, 4096),
|
31 |
+
(512, 16384, 128),
|
32 |
+
(512, 24576, 128),
|
33 |
+
]
|
34 |
+
|
35 |
+
CUDA_DEVICES = [
|
36 |
+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
37 |
+
]
|
38 |
+
|
39 |
+
capability = torch.cuda.get_device_capability()
|
40 |
+
capability = capability[0] * 10 + capability[1]
|
41 |
+
|
42 |
+
|
43 |
+
def to_fp8(tensor: torch.Tensor):
|
44 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
45 |
+
return torch.round(tensor.clamp(
|
46 |
+
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
47 |
+
|
48 |
+
|
49 |
+
def to_int8(tensor: torch.Tensor):
|
50 |
+
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
51 |
+
|
52 |
+
|
53 |
+
def rand_int8(shape: tuple, device: str = "cuda"):
|
54 |
+
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
55 |
+
|
56 |
+
|
57 |
+
def baseline_scaled_mm(a: torch.Tensor,
|
58 |
+
b: torch.Tensor,
|
59 |
+
scale_a: torch.Tensor,
|
60 |
+
scale_b: torch.Tensor,
|
61 |
+
out_dtype: Type[torch.dtype],
|
62 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
63 |
+
output = (scale_a * (scale_b * (torch.mm(
|
64 |
+
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
65 |
+
if bias is not None:
|
66 |
+
output = output + bias
|
67 |
+
|
68 |
+
return output
|
69 |
+
|
70 |
+
|
71 |
+
def cutlass_fp8_gemm_helper(m: int,
|
72 |
+
n: int,
|
73 |
+
k: int,
|
74 |
+
per_token_act_quant: bool,
|
75 |
+
per_out_channel_weight_quant: bool,
|
76 |
+
use_bias: bool,
|
77 |
+
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
78 |
+
device: str = "cuda"):
|
79 |
+
# Test for a cutlass kernel with per-token activation quantization
|
80 |
+
# and per-output channel weight quantization.
|
81 |
+
a = to_fp8(torch.randn((m, k), device=device))
|
82 |
+
b = to_fp8(torch.randn((n, k), device=device).t())
|
83 |
+
|
84 |
+
m_a_scales = m if per_token_act_quant else 1
|
85 |
+
n_b_scales = n if per_out_channel_weight_quant else 1
|
86 |
+
|
87 |
+
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
88 |
+
dtype=torch.float32))
|
89 |
+
scale_b = (torch.randn((1, n_b_scales), device=device,
|
90 |
+
dtype=torch.float32))
|
91 |
+
if use_bias:
|
92 |
+
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
93 |
+
else:
|
94 |
+
bias = None
|
95 |
+
|
96 |
+
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
97 |
+
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
98 |
+
|
99 |
+
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
100 |
+
|
101 |
+
opcheck(ops.ops.cutlass_scaled_mm,
|
102 |
+
(out, a, b, scale_a, scale_b, bias))
|
103 |
+
|
104 |
+
|
105 |
+
def cutlass_int8_gemm_helper(m: int,
|
106 |
+
n: int,
|
107 |
+
k: int,
|
108 |
+
per_token_act_quant: bool,
|
109 |
+
per_out_channel_weight_quant: bool,
|
110 |
+
use_bias: bool,
|
111 |
+
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
112 |
+
device: str = "cuda"):
|
113 |
+
# Test for a cutlass kernel with per-token activation quantization
|
114 |
+
# and per-output channel weight quantization.
|
115 |
+
a = to_int8(torch.randn((m, k), device=device) * 5)
|
116 |
+
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
117 |
+
|
118 |
+
m_a_scales = m if per_token_act_quant else 1
|
119 |
+
n_b_scales = n if per_out_channel_weight_quant else 1
|
120 |
+
|
121 |
+
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
122 |
+
dtype=torch.float32))
|
123 |
+
scale_b = (torch.randn((1, n_b_scales), device=device,
|
124 |
+
dtype=torch.float32))
|
125 |
+
|
126 |
+
if use_bias:
|
127 |
+
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
128 |
+
else:
|
129 |
+
bias = None
|
130 |
+
|
131 |
+
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
132 |
+
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
133 |
+
|
134 |
+
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
135 |
+
|
136 |
+
opcheck(ops.ops.cutlass_scaled_mm,
|
137 |
+
(out, a, b, scale_a, scale_b, bias))
|
138 |
+
|
139 |
+
|
140 |
+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
141 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
142 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
143 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
144 |
+
@pytest.mark.skipif(capability < 89,
|
145 |
+
reason="FP8 is not supported on this GPU type.")
|
146 |
+
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
147 |
+
per_out_ch: bool, use_bias: bool):
|
148 |
+
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
149 |
+
|
150 |
+
|
151 |
+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
152 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
153 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
154 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
155 |
+
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
156 |
+
per_out_ch: bool, use_bias: bool):
|
157 |
+
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
158 |
+
|
159 |
+
|
160 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
161 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
162 |
+
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
163 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
164 |
+
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
165 |
+
out_dtype: Type[torch.dtype],
|
166 |
+
use_bias: bool):
|
167 |
+
cutlass_int8_gemm_helper(512,
|
168 |
+
512,
|
169 |
+
512,
|
170 |
+
per_act_token,
|
171 |
+
per_out_ch,
|
172 |
+
use_bias,
|
173 |
+
out_dtype=out_dtype)
|
174 |
+
|
175 |
+
|
176 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
177 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
178 |
+
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
179 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
180 |
+
@pytest.mark.skipif(capability < 89,
|
181 |
+
reason="FP8 is not supported on this GPU type.")
|
182 |
+
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
183 |
+
out_dtype: Type[torch.dtype],
|
184 |
+
use_bias: bool):
|
185 |
+
cutlass_fp8_gemm_helper(512,
|
186 |
+
512,
|
187 |
+
512,
|
188 |
+
per_act_token,
|
189 |
+
per_out_ch,
|
190 |
+
use_bias,
|
191 |
+
out_dtype=out_dtype)
|
192 |
+
|
193 |
+
|
194 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
195 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
196 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
197 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
198 |
+
@pytest.mark.skipif(capability < 89,
|
199 |
+
reason="FP8 is not supported on this GPU type.")
|
200 |
+
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
201 |
+
use_bias: bool, device: str):
|
202 |
+
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
|
203 |
+
torch.bfloat16, device)
|
204 |
+
|
205 |
+
|
206 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
207 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
208 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
209 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
210 |
+
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
211 |
+
use_bias: bool, device: str):
|
212 |
+
cutlass_int8_gemm_helper(512,
|
213 |
+
512,
|
214 |
+
512,
|
215 |
+
per_act_token,
|
216 |
+
per_out_ch,
|
217 |
+
use_bias,
|
218 |
+
out_dtype=torch.bfloat16,
|
219 |
+
device=device)
|
220 |
+
|
221 |
+
|
222 |
+
# For the following two tests:
|
223 |
+
# N and K correspond to the size of the weight matrix and likely to be multiples
|
224 |
+
# of a large power of two. In any case, the kernel will have a naive fallback
|
225 |
+
# when N and K are not divisible by 16. But M is the number of tokens and the
|
226 |
+
# kernel must handle any M thrown at it.
|
227 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
228 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
229 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
230 |
+
@pytest.mark.skipif(capability < 89,
|
231 |
+
reason="FP8 is not supported on this GPU type.")
|
232 |
+
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
233 |
+
use_bias: bool):
|
234 |
+
for nk in range(32, 128, 32):
|
235 |
+
for m in range(1, 128):
|
236 |
+
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
237 |
+
use_bias)
|
238 |
+
|
239 |
+
|
240 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
241 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
242 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
243 |
+
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
244 |
+
use_bias: bool):
|
245 |
+
for nk in range(32, 128, 32):
|
246 |
+
for m in range(1, 128):
|
247 |
+
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
248 |
+
use_bias)
|
249 |
+
|
250 |
+
|
251 |
+
@pytest.mark.parametrize("m", [32, 64, 128])
|
252 |
+
@pytest.mark.parametrize("n", [16, 32, 64])
|
253 |
+
@pytest.mark.parametrize("k", [64, 128, 256])
|
254 |
+
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
255 |
+
@pytest.mark.skip
|
256 |
+
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
257 |
+
out_dtype: torch.dtype):
|
258 |
+
# Currently, the test is failing because folding azp into
|
259 |
+
# 16-bit bias loses too much precision
|
260 |
+
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
261 |
+
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
262 |
+
|
263 |
+
aq_i8 = rand_int8((m, k))
|
264 |
+
bq_i8 = rand_int8((n, k)).t()
|
265 |
+
|
266 |
+
aq_i32 = aq_i8.to(dtype=torch.int32)
|
267 |
+
bq_i32 = bq_i8.to(dtype=torch.int32)
|
268 |
+
|
269 |
+
aq_f32 = aq_i8.to(dtype=torch.float32)
|
270 |
+
bq_f32 = bq_i8.to(dtype=torch.float32)
|
271 |
+
|
272 |
+
b_dq = scale_b * bq_f32
|
273 |
+
|
274 |
+
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
275 |
+
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
276 |
+
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
277 |
+
|
278 |
+
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
279 |
+
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
280 |
+
|
281 |
+
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
282 |
+
|
283 |
+
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
284 |
+
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
285 |
+
assert azp_bias.shape == (1, n)
|
286 |
+
assert azp_bias[0, :].shape == (n, )
|
287 |
+
|
288 |
+
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
289 |
+
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
290 |
+
dtype=out_dtype, device='cuda')
|
291 |
+
|
292 |
+
out = ops.cutlass_scaled_mm(aq_i8,
|
293 |
+
bq_i8,
|
294 |
+
scale_a,
|
295 |
+
scale_b,
|
296 |
+
out_dtype=out_dtype,
|
297 |
+
bias=azp_bias[0, :])
|
298 |
+
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
299 |
+
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
300 |
+
|
301 |
+
|
302 |
+
@pytest.mark.parametrize("m", [32, 64, 128])
|
303 |
+
@pytest.mark.parametrize("n", [16, 32, 64])
|
304 |
+
@pytest.mark.parametrize("k", [64, 128, 256])
|
305 |
+
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
306 |
+
@pytest.mark.parametrize("use_bias", [True, False])
|
307 |
+
@pytest.mark.parametrize("azp_per_token", [True, False])
|
308 |
+
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
309 |
+
use_bias: bool, azp_per_token: bool):
|
310 |
+
m_azp = m if azp_per_token else 1
|
311 |
+
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
312 |
+
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
313 |
+
|
314 |
+
aq_i8 = rand_int8((m, k))
|
315 |
+
aq_i32 = aq_i8.to(dtype=torch.int32)
|
316 |
+
aq_f32 = aq_i8.to(dtype=torch.float32)
|
317 |
+
|
318 |
+
bq_i8 = rand_int8((n, k)).t()
|
319 |
+
bq_i32 = bq_i8.to(dtype=torch.int32)
|
320 |
+
bq_f32 = bq_i8.to(dtype=torch.float32)
|
321 |
+
b_dq = scale_b * bq_f32
|
322 |
+
|
323 |
+
azp_a = torch.rand(
|
324 |
+
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
325 |
+
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
326 |
+
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
327 |
+
|
328 |
+
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
329 |
+
torch.testing.assert_close(a_dq,
|
330 |
+
scale_a * aq_f32 - azp_a,
|
331 |
+
rtol=1e-4,
|
332 |
+
atol=1e-3)
|
333 |
+
|
334 |
+
if use_bias:
|
335 |
+
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
336 |
+
else:
|
337 |
+
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
338 |
+
|
339 |
+
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
340 |
+
|
341 |
+
# int32 mm not supported on CUDA
|
342 |
+
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
343 |
+
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
344 |
+
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
345 |
+
|
346 |
+
# Hadamard is just the sum of the cols
|
347 |
+
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
348 |
+
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
349 |
+
func_bias = bias if use_bias else None
|
350 |
+
|
351 |
+
if azp_per_token:
|
352 |
+
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
353 |
+
out_dtype, azp_adj_i32, azp_i32,
|
354 |
+
func_bias)
|
355 |
+
else:
|
356 |
+
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
357 |
+
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
358 |
+
out_dtype, azp_with_adj_i32, None,
|
359 |
+
func_bias)
|
360 |
+
|
361 |
+
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
362 |
+
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
363 |
+
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
364 |
+
atol = 1e-3
|
365 |
+
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
366 |
+
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
367 |
+
|
368 |
+
if azp_per_token:
|
369 |
+
opcheck(ops.ops.cutlass_scaled_mm_azp,
|
370 |
+
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
|
371 |
+
func_bias))
|
372 |
+
else:
|
373 |
+
opcheck(ops.ops.cutlass_scaled_mm_azp,
|
374 |
+
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
|
375 |
+
func_bias))
|
376 |
+
|
377 |
+
|
378 |
+
# Test working with a subset of A and B
|
379 |
+
def test_cutlass_subset():
|
380 |
+
big_m, big_n, big_k = 1024, 1024, 1024
|
381 |
+
m, n, k = 512, 512, 512
|
382 |
+
|
383 |
+
whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
|
384 |
+
whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
|
385 |
+
a = whole_a[0:m, 0:k]
|
386 |
+
b = whole_b[0:k, 0:n]
|
387 |
+
|
388 |
+
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
389 |
+
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
390 |
+
|
391 |
+
out = ops.cutlass_scaled_mm(a,
|
392 |
+
b,
|
393 |
+
scale_a,
|
394 |
+
scale_b,
|
395 |
+
out_dtype=torch.bfloat16)
|
396 |
+
baseline = baseline_scaled_mm(a,
|
397 |
+
b,
|
398 |
+
scale_a,
|
399 |
+
scale_b,
|
400 |
+
out_dtype=torch.bfloat16)
|
401 |
+
|
402 |
+
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
403 |
+
|
404 |
+
|
405 |
+
# Test to make sure cuda graphs work
|
406 |
+
class CutlassLayer(torch.nn.Module):
|
407 |
+
|
408 |
+
def __init__(self, b, scale_a, scale_b, out_dtype):
|
409 |
+
super().__init__()
|
410 |
+
self.b = b
|
411 |
+
self.scale_a = scale_a
|
412 |
+
self.scale_b = scale_b
|
413 |
+
self.out_dtype = out_dtype
|
414 |
+
|
415 |
+
def forward(self, a):
|
416 |
+
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
417 |
+
self.out_dtype)
|
418 |
+
|
419 |
+
|
420 |
+
@pytest.mark.parametrize("per_act_token", [True, False])
|
421 |
+
@pytest.mark.parametrize("per_out_ch", [True, False])
|
422 |
+
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
423 |
+
m, n, k = 512, 512, 512
|
424 |
+
|
425 |
+
a = to_int8(torch.randn((m, k), device="cuda"))
|
426 |
+
b = to_int8(torch.randn((n, k), device="cuda").t())
|
427 |
+
|
428 |
+
m_a_scales = m if per_act_token else 1
|
429 |
+
n_b_scales = n if per_out_ch else 1
|
430 |
+
|
431 |
+
scale_a = (torch.randn(
|
432 |
+
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
|
433 |
+
scale_b = (torch.randn(
|
434 |
+
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
|
435 |
+
|
436 |
+
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
437 |
+
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
438 |
+
|
439 |
+
# Run the model with a cuda graph
|
440 |
+
stream = torch.cuda.Stream()
|
441 |
+
with torch.cuda.stream(stream):
|
442 |
+
g = torch.cuda.CUDAGraph()
|
443 |
+
with torch.cuda.graph(g):
|
444 |
+
out = model(a)
|
445 |
+
out.zero_()
|
446 |
+
g.replay()
|
447 |
+
|
448 |
+
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
449 |
+
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
450 |
+
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
451 |
+
|
452 |
+
|
453 |
+
def test_cutlass_support_opcheck():
|
454 |
+
opcheck(ops.ops.cutlass_scaled_mm_supports_fp8, (capability, ))
|
tests/kernels/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kernel test utils"""
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import random
|
5 |
+
import unittest
|
6 |
+
from numbers import Number
|
7 |
+
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
8 |
+
Union)
|
9 |
+
|
10 |
+
import pytest
|
11 |
+
import torch
|
12 |
+
from torch._prims_common import TensorLikeType
|
13 |
+
|
14 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
15 |
+
"test_schema",
|
16 |
+
"test_autograd_registration",
|
17 |
+
"test_faketensor",
|
18 |
+
"test_aot_dispatch_dynamic",
|
19 |
+
)
|
20 |
+
|
21 |
+
# Copied/modified from torch._refs.__init__.py
|
22 |
+
def fp8_allclose(
|
23 |
+
a: TensorLikeType,
|
24 |
+
b: TensorLikeType,
|
25 |
+
rtol: float = 1e-05,
|
26 |
+
atol: float = 1e-08,
|
27 |
+
equal_nan: bool = False,
|
28 |
+
) -> bool:
|
29 |
+
"""
|
30 |
+
Reference implementation of torch.allclose
|
31 |
+
"""
|
32 |
+
torch._refs._check_close_args(name="torch.allclose",
|
33 |
+
a=a,
|
34 |
+
b=b,
|
35 |
+
rtol=rtol,
|
36 |
+
atol=atol)
|
37 |
+
|
38 |
+
return bool(
|
39 |
+
torch.all(
|
40 |
+
torch.isclose(a.double(),
|
41 |
+
b.double(),
|
42 |
+
rtol=rtol,
|
43 |
+
atol=atol,
|
44 |
+
equal_nan=equal_nan)).item())
|
45 |
+
|
46 |
+
# A special version of op check that has a restricted default set of test_utils
|
47 |
+
# and a patched version of allclose that supports fp8 types.
|
48 |
+
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
49 |
+
torch._library.custom_ops.CustomOpDef],
|
50 |
+
args: Tuple[Any, ...],
|
51 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
52 |
+
*,
|
53 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
54 |
+
raise_exception: bool = True,
|
55 |
+
cond: bool = True) -> Dict[str, str]:
|
56 |
+
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
57 |
+
return torch.library.opcheck(
|
58 |
+
op,
|
59 |
+
args,
|
60 |
+
kwargs,
|
61 |
+
test_utils=test_utils,
|
62 |
+
raise_exception=raise_exception) if cond else {}
|