Add support for ROCm
Browse files- build.toml +6 -0
- flake.lock +12 -12
- flake.nix +1 -1
- torch-ext/torch_binding.cpp +11 -0
build.toml
CHANGED
@@ -46,8 +46,12 @@ include = [ "." ]
|
|
46 |
depends = [ "cutlass_3_6", "torch" ]
|
47 |
|
48 |
[kernel.fp8_common]
|
|
|
49 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
|
|
50 |
src = [
|
|
|
|
|
51 |
"fp8/common.cu",
|
52 |
"fp8/common.cuh",
|
53 |
"dispatch_utils.h",
|
@@ -66,7 +70,9 @@ src = [
|
|
66 |
depends = [ "torch" ]
|
67 |
|
68 |
[kernel.int8_common]
|
|
|
69 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
|
|
70 |
src = [
|
71 |
"compressed_tensors/int8_quant_kernels.cu",
|
72 |
"dispatch_utils.h"
|
|
|
46 |
depends = [ "cutlass_3_6", "torch" ]
|
47 |
|
48 |
[kernel.fp8_common]
|
49 |
+
language = "cuda-hipify"
|
50 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
51 |
+
rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
|
52 |
src = [
|
53 |
+
"fp8/amd/hip_float8.h",
|
54 |
+
"fp8/amd/hip_float8_impl.h",
|
55 |
"fp8/common.cu",
|
56 |
"fp8/common.cuh",
|
57 |
"dispatch_utils.h",
|
|
|
70 |
depends = [ "torch" ]
|
71 |
|
72 |
[kernel.int8_common]
|
73 |
+
language = "cuda-hipify"
|
74 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
75 |
+
rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
|
76 |
src = [
|
77 |
"compressed_tensors/int8_quant_kernels.cu",
|
78 |
"dispatch_utils.h"
|
flake.lock
CHANGED
@@ -41,17 +41,17 @@
|
|
41 |
"rocm-nix": "rocm-nix"
|
42 |
},
|
43 |
"locked": {
|
44 |
-
"lastModified":
|
45 |
-
"narHash": "sha256-
|
46 |
-
"
|
47 |
-
"
|
48 |
-
"
|
49 |
-
"type": "
|
50 |
-
"url": "ssh://git@github.com/huggingface/kernel-builder"
|
51 |
},
|
52 |
"original": {
|
53 |
-
"
|
54 |
-
"
|
|
|
55 |
}
|
56 |
},
|
57 |
"nixpkgs": {
|
@@ -78,11 +78,11 @@
|
|
78 |
]
|
79 |
},
|
80 |
"locked": {
|
81 |
-
"lastModified":
|
82 |
-
"narHash": "sha256-
|
83 |
"owner": "huggingface",
|
84 |
"repo": "rocm-nix",
|
85 |
-
"rev": "
|
86 |
"type": "github"
|
87 |
},
|
88 |
"original": {
|
|
|
41 |
"rocm-nix": "rocm-nix"
|
42 |
},
|
43 |
"locked": {
|
44 |
+
"lastModified": 1743416390,
|
45 |
+
"narHash": "sha256-Krrrq9asF2d5SVWGJQIhQA8UxVcTpiCor8hQU4G5J38=",
|
46 |
+
"owner": "huggingface",
|
47 |
+
"repo": "kernel-builder",
|
48 |
+
"rev": "e57cbde93f29032d32bbab8e32a1c86def6e9365",
|
49 |
+
"type": "github"
|
|
|
50 |
},
|
51 |
"original": {
|
52 |
+
"owner": "huggingface",
|
53 |
+
"repo": "kernel-builder",
|
54 |
+
"type": "github"
|
55 |
}
|
56 |
},
|
57 |
"nixpkgs": {
|
|
|
78 |
]
|
79 |
},
|
80 |
"locked": {
|
81 |
+
"lastModified": 1743085847,
|
82 |
+
"narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
|
83 |
"owner": "huggingface",
|
84 |
"repo": "rocm-nix",
|
85 |
+
"rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
|
86 |
"type": "github"
|
87 |
},
|
88 |
"original": {
|
flake.nix
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
description = "Flake for quantization kernels";
|
3 |
|
4 |
inputs = {
|
5 |
-
kernel-builder.url = "
|
6 |
};
|
7 |
|
8 |
outputs =
|
|
|
2 |
description = "Flake for quantization kernels";
|
3 |
|
4 |
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
};
|
7 |
|
8 |
outputs =
|
torch-ext/torch_binding.cpp
CHANGED
@@ -4,6 +4,8 @@
|
|
4 |
#include "torch_binding.h"
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
|
|
7 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
8 |
// quantization, as well as bias
|
9 |
ops.def(
|
@@ -26,6 +28,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
26 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
27 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
28 |
|
|
|
|
|
29 |
// Compute FP8 quantized tensor for given scaling factor.
|
30 |
ops.def(
|
31 |
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
@@ -60,6 +64,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
60 |
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
61 |
&dynamic_scaled_int8_quant);
|
62 |
|
|
|
|
|
63 |
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
64 |
ops.def(
|
65 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
@@ -103,8 +109,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
103 |
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
104 |
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
105 |
"SymInt size_k) -> Tensor");
|
|
|
106 |
}
|
107 |
|
|
|
|
|
108 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
109 |
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
110 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
@@ -120,4 +129,6 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
|
120 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
121 |
}
|
122 |
|
|
|
|
|
123 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
4 |
#include "torch_binding.h"
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
|
9 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
10 |
// quantization, as well as bias
|
11 |
ops.def(
|
|
|
28 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
29 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
30 |
|
31 |
+
#endif
|
32 |
+
|
33 |
// Compute FP8 quantized tensor for given scaling factor.
|
34 |
ops.def(
|
35 |
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
|
|
64 |
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
65 |
&dynamic_scaled_int8_quant);
|
66 |
|
67 |
+
#ifndef USE_ROCM
|
68 |
+
|
69 |
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
70 |
ops.def(
|
71 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
109 |
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
110 |
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
111 |
"SymInt size_k) -> Tensor");
|
112 |
+
#endif
|
113 |
}
|
114 |
|
115 |
+
#ifndef USE_ROCM
|
116 |
+
|
117 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
118 |
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
119 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
|
129 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
130 |
}
|
131 |
|
132 |
+
#endif
|
133 |
+
|
134 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|