danieldk HF staff commited on
Commit
2dd62c9
·
1 Parent(s): 9d045c3

Import CUTLASS tests and add missing scaled mm with zp signature

Browse files
ext-torch/__init__.py CHANGED
@@ -42,6 +42,33 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # fp8
46
  def scaled_fp8_quant(
47
  input: torch.Tensor,
 
42
 
43
  return out
44
 
45
+ def cutlass_scaled_mm_azp(a: torch.Tensor,
46
+ b: torch.Tensor,
47
+ scale_a: torch.Tensor,
48
+ scale_b: torch.Tensor,
49
+ out_dtype: torch.dtype,
50
+ azp_adj: torch.Tensor,
51
+ azp: Optional[torch.Tensor] = None,
52
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
53
+ """
54
+ :param azp_adj: In the per-tensor case, this should include the azp.
55
+ Always per-channel.
56
+ :param azp: Only set in the per-token case. Per-token if set.
57
+ """
58
+ assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
59
+ assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
60
+ assert bias is None or bias.numel(
61
+ ) == b.shape[1] and bias.dtype == out_dtype
62
+ assert azp is None or azp.numel() == a.shape[0]
63
+
64
+ m = a.shape[0]
65
+ n = b.shape[1]
66
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
67
+
68
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
69
+ azp, bias)
70
+ return out
71
+
72
  # fp8
73
  def scaled_fp8_quant(
74
  input: torch.Tensor,
tests/__init__.py ADDED
File without changes
tests/kernels/__init__.py ADDED
File without changes
tests/kernels/test_cutlass.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for cutlass kernels
2
+
3
+ Run `pytest tests/kernels/test_cutlass.py`.
4
+ """
5
+ from typing import Optional, Type
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ from tests.kernels.utils import opcheck
11
+ import quantization as ops
12
+
13
+ MNK_FACTORS = [
14
+ (1, 256, 128),
15
+ (1, 16384, 1024),
16
+ (1, 24576, 496),
17
+ (16, 256, 496),
18
+ (16, 16384, 128),
19
+ (16, 24576, 4096),
20
+ (32, 8192, 4096),
21
+ (32, 16384, 4096),
22
+ (33, 1024, 1024),
23
+ (33, 8192, 128),
24
+ (64, 2048, 496),
25
+ (64, 16384, 1024),
26
+ (100, 8192, 496),
27
+ (128, 32768, 4096),
28
+ (256, 4096, 4096),
29
+ (512, 256, 1024),
30
+ (512, 8192, 4096),
31
+ (512, 16384, 128),
32
+ (512, 24576, 128),
33
+ ]
34
+
35
+ CUDA_DEVICES = [
36
+ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
37
+ ]
38
+
39
+ capability = torch.cuda.get_device_capability()
40
+ capability = capability[0] * 10 + capability[1]
41
+
42
+
43
+ def to_fp8(tensor: torch.Tensor):
44
+ finfo = torch.finfo(torch.float8_e4m3fn)
45
+ return torch.round(tensor.clamp(
46
+ min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
47
+
48
+
49
+ def to_int8(tensor: torch.Tensor):
50
+ return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
51
+
52
+
53
+ def rand_int8(shape: tuple, device: str = "cuda"):
54
+ return to_int8(torch.rand(shape, device=device) * 255 - 128)
55
+
56
+
57
+ def baseline_scaled_mm(a: torch.Tensor,
58
+ b: torch.Tensor,
59
+ scale_a: torch.Tensor,
60
+ scale_b: torch.Tensor,
61
+ out_dtype: Type[torch.dtype],
62
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
63
+ output = (scale_a * (scale_b * (torch.mm(
64
+ a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
65
+ if bias is not None:
66
+ output = output + bias
67
+
68
+ return output
69
+
70
+
71
+ def cutlass_fp8_gemm_helper(m: int,
72
+ n: int,
73
+ k: int,
74
+ per_token_act_quant: bool,
75
+ per_out_channel_weight_quant: bool,
76
+ use_bias: bool,
77
+ out_dtype: Type[torch.dtype] = torch.bfloat16,
78
+ device: str = "cuda"):
79
+ # Test for a cutlass kernel with per-token activation quantization
80
+ # and per-output channel weight quantization.
81
+ a = to_fp8(torch.randn((m, k), device=device))
82
+ b = to_fp8(torch.randn((n, k), device=device).t())
83
+
84
+ m_a_scales = m if per_token_act_quant else 1
85
+ n_b_scales = n if per_out_channel_weight_quant else 1
86
+
87
+ scale_a = (torch.randn((m_a_scales, 1), device=device,
88
+ dtype=torch.float32))
89
+ scale_b = (torch.randn((1, n_b_scales), device=device,
90
+ dtype=torch.float32))
91
+ if use_bias:
92
+ bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
93
+ else:
94
+ bias = None
95
+
96
+ out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
97
+ baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
98
+
99
+ torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
100
+
101
+ opcheck(ops.ops.cutlass_scaled_mm,
102
+ (out, a, b, scale_a, scale_b, bias))
103
+
104
+
105
+ def cutlass_int8_gemm_helper(m: int,
106
+ n: int,
107
+ k: int,
108
+ per_token_act_quant: bool,
109
+ per_out_channel_weight_quant: bool,
110
+ use_bias: bool,
111
+ out_dtype: Type[torch.dtype] = torch.bfloat16,
112
+ device: str = "cuda"):
113
+ # Test for a cutlass kernel with per-token activation quantization
114
+ # and per-output channel weight quantization.
115
+ a = to_int8(torch.randn((m, k), device=device) * 5)
116
+ b = to_int8(torch.randn((n, k), device=device).t() * 5)
117
+
118
+ m_a_scales = m if per_token_act_quant else 1
119
+ n_b_scales = n if per_out_channel_weight_quant else 1
120
+
121
+ scale_a = (torch.randn((m_a_scales, 1), device=device,
122
+ dtype=torch.float32))
123
+ scale_b = (torch.randn((1, n_b_scales), device=device,
124
+ dtype=torch.float32))
125
+
126
+ if use_bias:
127
+ bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
128
+ else:
129
+ bias = None
130
+
131
+ out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
132
+ baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
133
+
134
+ torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
135
+
136
+ opcheck(ops.ops.cutlass_scaled_mm,
137
+ (out, a, b, scale_a, scale_b, bias))
138
+
139
+
140
+ @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
141
+ @pytest.mark.parametrize("per_act_token", [True, False])
142
+ @pytest.mark.parametrize("per_out_ch", [True, False])
143
+ @pytest.mark.parametrize("use_bias", [True, False])
144
+ @pytest.mark.skipif(capability < 89,
145
+ reason="FP8 is not supported on this GPU type.")
146
+ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
147
+ per_out_ch: bool, use_bias: bool):
148
+ cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
149
+
150
+
151
+ @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
152
+ @pytest.mark.parametrize("per_act_token", [True, False])
153
+ @pytest.mark.parametrize("per_out_ch", [True, False])
154
+ @pytest.mark.parametrize("use_bias", [True, False])
155
+ def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
156
+ per_out_ch: bool, use_bias: bool):
157
+ cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
158
+
159
+
160
+ @pytest.mark.parametrize("per_act_token", [True, False])
161
+ @pytest.mark.parametrize("per_out_ch", [True, False])
162
+ @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
163
+ @pytest.mark.parametrize("use_bias", [True, False])
164
+ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
165
+ out_dtype: Type[torch.dtype],
166
+ use_bias: bool):
167
+ cutlass_int8_gemm_helper(512,
168
+ 512,
169
+ 512,
170
+ per_act_token,
171
+ per_out_ch,
172
+ use_bias,
173
+ out_dtype=out_dtype)
174
+
175
+
176
+ @pytest.mark.parametrize("per_act_token", [True, False])
177
+ @pytest.mark.parametrize("per_out_ch", [True, False])
178
+ @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
179
+ @pytest.mark.parametrize("use_bias", [True, False])
180
+ @pytest.mark.skipif(capability < 89,
181
+ reason="FP8 is not supported on this GPU type.")
182
+ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
183
+ out_dtype: Type[torch.dtype],
184
+ use_bias: bool):
185
+ cutlass_fp8_gemm_helper(512,
186
+ 512,
187
+ 512,
188
+ per_act_token,
189
+ per_out_ch,
190
+ use_bias,
191
+ out_dtype=out_dtype)
192
+
193
+
194
+ @pytest.mark.parametrize("per_act_token", [True, False])
195
+ @pytest.mark.parametrize("per_out_ch", [True, False])
196
+ @pytest.mark.parametrize("use_bias", [True, False])
197
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
198
+ @pytest.mark.skipif(capability < 89,
199
+ reason="FP8 is not supported on this GPU type.")
200
+ def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
201
+ use_bias: bool, device: str):
202
+ cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
203
+ torch.bfloat16, device)
204
+
205
+
206
+ @pytest.mark.parametrize("per_act_token", [True, False])
207
+ @pytest.mark.parametrize("per_out_ch", [True, False])
208
+ @pytest.mark.parametrize("use_bias", [True, False])
209
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
210
+ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
211
+ use_bias: bool, device: str):
212
+ cutlass_int8_gemm_helper(512,
213
+ 512,
214
+ 512,
215
+ per_act_token,
216
+ per_out_ch,
217
+ use_bias,
218
+ out_dtype=torch.bfloat16,
219
+ device=device)
220
+
221
+
222
+ # For the following two tests:
223
+ # N and K correspond to the size of the weight matrix and likely to be multiples
224
+ # of a large power of two. In any case, the kernel will have a naive fallback
225
+ # when N and K are not divisible by 16. But M is the number of tokens and the
226
+ # kernel must handle any M thrown at it.
227
+ @pytest.mark.parametrize("per_act_token", [True, False])
228
+ @pytest.mark.parametrize("per_out_ch", [True, False])
229
+ @pytest.mark.parametrize("use_bias", [True, False])
230
+ @pytest.mark.skipif(capability < 89,
231
+ reason="FP8 is not supported on this GPU type.")
232
+ def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
233
+ use_bias: bool):
234
+ for nk in range(32, 128, 32):
235
+ for m in range(1, 128):
236
+ cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
237
+ use_bias)
238
+
239
+
240
+ @pytest.mark.parametrize("per_act_token", [True, False])
241
+ @pytest.mark.parametrize("per_out_ch", [True, False])
242
+ @pytest.mark.parametrize("use_bias", [True, False])
243
+ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
244
+ use_bias: bool):
245
+ for nk in range(32, 128, 32):
246
+ for m in range(1, 128):
247
+ cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
248
+ use_bias)
249
+
250
+
251
+ @pytest.mark.parametrize("m", [32, 64, 128])
252
+ @pytest.mark.parametrize("n", [16, 32, 64])
253
+ @pytest.mark.parametrize("k", [64, 128, 256])
254
+ @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
255
+ @pytest.mark.skip
256
+ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
257
+ out_dtype: torch.dtype):
258
+ # Currently, the test is failing because folding azp into
259
+ # 16-bit bias loses too much precision
260
+ scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
261
+ scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
262
+
263
+ aq_i8 = rand_int8((m, k))
264
+ bq_i8 = rand_int8((n, k)).t()
265
+
266
+ aq_i32 = aq_i8.to(dtype=torch.int32)
267
+ bq_i32 = bq_i8.to(dtype=torch.int32)
268
+
269
+ aq_f32 = aq_i8.to(dtype=torch.float32)
270
+ bq_f32 = bq_i8.to(dtype=torch.float32)
271
+
272
+ b_dq = scale_b * bq_f32
273
+
274
+ azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
275
+ azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
276
+ azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
277
+
278
+ a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
279
+ torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
280
+
281
+ baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
282
+
283
+ J = torch.ones((1, k), device="cuda", dtype=torch.float32)
284
+ azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
285
+ assert azp_bias.shape == (1, n)
286
+ assert azp_bias[0, :].shape == (n, )
287
+
288
+ baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
289
+ (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
290
+ dtype=out_dtype, device='cuda')
291
+
292
+ out = ops.cutlass_scaled_mm(aq_i8,
293
+ bq_i8,
294
+ scale_a,
295
+ scale_b,
296
+ out_dtype=out_dtype,
297
+ bias=azp_bias[0, :])
298
+ torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
299
+ torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
300
+
301
+
302
+ @pytest.mark.parametrize("m", [32, 64, 128])
303
+ @pytest.mark.parametrize("n", [16, 32, 64])
304
+ @pytest.mark.parametrize("k", [64, 128, 256])
305
+ @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
306
+ @pytest.mark.parametrize("use_bias", [True, False])
307
+ @pytest.mark.parametrize("azp_per_token", [True, False])
308
+ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
309
+ use_bias: bool, azp_per_token: bool):
310
+ m_azp = m if azp_per_token else 1
311
+ scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
312
+ scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
313
+
314
+ aq_i8 = rand_int8((m, k))
315
+ aq_i32 = aq_i8.to(dtype=torch.int32)
316
+ aq_f32 = aq_i8.to(dtype=torch.float32)
317
+
318
+ bq_i8 = rand_int8((n, k)).t()
319
+ bq_i32 = bq_i8.to(dtype=torch.int32)
320
+ bq_f32 = bq_i8.to(dtype=torch.float32)
321
+ b_dq = scale_b * bq_f32
322
+
323
+ azp_a = torch.rand(
324
+ (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
325
+ azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
326
+ azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
327
+
328
+ a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
329
+ torch.testing.assert_close(a_dq,
330
+ scale_a * aq_f32 - azp_a,
331
+ rtol=1e-4,
332
+ atol=1e-3)
333
+
334
+ if use_bias:
335
+ bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
336
+ else:
337
+ bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
338
+
339
+ baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
340
+
341
+ # int32 mm not supported on CUDA
342
+ a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
343
+ cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
344
+ baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
345
+
346
+ # Hadamard is just the sum of the cols
347
+ azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
348
+ azp_i32 = azp_aq_i8.to(dtype=torch.int32)
349
+ func_bias = bias if use_bias else None
350
+
351
+ if azp_per_token:
352
+ out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
353
+ out_dtype, azp_adj_i32, azp_i32,
354
+ func_bias)
355
+ else:
356
+ azp_with_adj_i32 = azp_i32 * azp_adj_i32
357
+ out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
358
+ out_dtype, azp_with_adj_i32, None,
359
+ func_bias)
360
+
361
+ # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
362
+ # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
363
+ rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
364
+ atol = 1e-3
365
+ torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
366
+ torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
367
+
368
+ if azp_per_token:
369
+ opcheck(ops.ops.cutlass_scaled_mm_azp,
370
+ (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
371
+ func_bias))
372
+ else:
373
+ opcheck(ops.ops.cutlass_scaled_mm_azp,
374
+ (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
375
+ func_bias))
376
+
377
+
378
+ # Test working with a subset of A and B
379
+ def test_cutlass_subset():
380
+ big_m, big_n, big_k = 1024, 1024, 1024
381
+ m, n, k = 512, 512, 512
382
+
383
+ whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
384
+ whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
385
+ a = whole_a[0:m, 0:k]
386
+ b = whole_b[0:k, 0:n]
387
+
388
+ scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
389
+ scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
390
+
391
+ out = ops.cutlass_scaled_mm(a,
392
+ b,
393
+ scale_a,
394
+ scale_b,
395
+ out_dtype=torch.bfloat16)
396
+ baseline = baseline_scaled_mm(a,
397
+ b,
398
+ scale_a,
399
+ scale_b,
400
+ out_dtype=torch.bfloat16)
401
+
402
+ torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
403
+
404
+
405
+ # Test to make sure cuda graphs work
406
+ class CutlassLayer(torch.nn.Module):
407
+
408
+ def __init__(self, b, scale_a, scale_b, out_dtype):
409
+ super().__init__()
410
+ self.b = b
411
+ self.scale_a = scale_a
412
+ self.scale_b = scale_b
413
+ self.out_dtype = out_dtype
414
+
415
+ def forward(self, a):
416
+ return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
417
+ self.out_dtype)
418
+
419
+
420
+ @pytest.mark.parametrize("per_act_token", [True, False])
421
+ @pytest.mark.parametrize("per_out_ch", [True, False])
422
+ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
423
+ m, n, k = 512, 512, 512
424
+
425
+ a = to_int8(torch.randn((m, k), device="cuda"))
426
+ b = to_int8(torch.randn((n, k), device="cuda").t())
427
+
428
+ m_a_scales = m if per_act_token else 1
429
+ n_b_scales = n if per_out_ch else 1
430
+
431
+ scale_a = (torch.randn(
432
+ (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
433
+ scale_b = (torch.randn(
434
+ (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
435
+
436
+ # Construct a trivial model with a single layer that calls a CUTLASS kernel
437
+ model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
438
+
439
+ # Run the model with a cuda graph
440
+ stream = torch.cuda.Stream()
441
+ with torch.cuda.stream(stream):
442
+ g = torch.cuda.CUDAGraph()
443
+ with torch.cuda.graph(g):
444
+ out = model(a)
445
+ out.zero_()
446
+ g.replay()
447
+
448
+ baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
449
+ scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
450
+ torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
451
+
452
+
453
+ def test_cutlass_support_opcheck():
454
+ opcheck(ops.ops.cutlass_scaled_mm_supports_fp8, (capability, ))
tests/kernels/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kernel test utils"""
2
+
3
+ import itertools
4
+ import random
5
+ import unittest
6
+ from numbers import Number
7
+ from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
8
+ Union)
9
+
10
+ import pytest
11
+ import torch
12
+ from torch._prims_common import TensorLikeType
13
+
14
+ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
15
+ "test_schema",
16
+ "test_autograd_registration",
17
+ "test_faketensor",
18
+ "test_aot_dispatch_dynamic",
19
+ )
20
+
21
+ # Copied/modified from torch._refs.__init__.py
22
+ def fp8_allclose(
23
+ a: TensorLikeType,
24
+ b: TensorLikeType,
25
+ rtol: float = 1e-05,
26
+ atol: float = 1e-08,
27
+ equal_nan: bool = False,
28
+ ) -> bool:
29
+ """
30
+ Reference implementation of torch.allclose
31
+ """
32
+ torch._refs._check_close_args(name="torch.allclose",
33
+ a=a,
34
+ b=b,
35
+ rtol=rtol,
36
+ atol=atol)
37
+
38
+ return bool(
39
+ torch.all(
40
+ torch.isclose(a.double(),
41
+ b.double(),
42
+ rtol=rtol,
43
+ atol=atol,
44
+ equal_nan=equal_nan)).item())
45
+
46
+ # A special version of op check that has a restricted default set of test_utils
47
+ # and a patched version of allclose that supports fp8 types.
48
+ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
49
+ torch._library.custom_ops.CustomOpDef],
50
+ args: Tuple[Any, ...],
51
+ kwargs: Optional[Dict[str, Any]] = None,
52
+ *,
53
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
54
+ raise_exception: bool = True,
55
+ cond: bool = True) -> Dict[str, str]:
56
+ with unittest.mock.patch('torch.allclose', new=fp8_allclose):
57
+ return torch.library.opcheck(
58
+ op,
59
+ args,
60
+ kwargs,
61
+ test_utils=test_utils,
62
+ raise_exception=raise_exception) if cond else {}