danieldk HF Staff commited on
Commit
3c8bb73
·
1 Parent(s): 6d36a16

Add support for ROCm

Browse files
Files changed (4) hide show
  1. build.toml +6 -0
  2. flake.lock +12 -12
  3. flake.nix +1 -1
  4. torch-ext/torch_binding.cpp +11 -0
build.toml CHANGED
@@ -46,8 +46,12 @@ include = [ "." ]
46
  depends = [ "cutlass_3_6", "torch" ]
47
 
48
  [kernel.fp8_common]
 
49
  cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
 
50
  src = [
 
 
51
  "fp8/common.cu",
52
  "fp8/common.cuh",
53
  "dispatch_utils.h",
@@ -66,7 +70,9 @@ src = [
66
  depends = [ "torch" ]
67
 
68
  [kernel.int8_common]
 
69
  cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
 
70
  src = [
71
  "compressed_tensors/int8_quant_kernels.cu",
72
  "dispatch_utils.h"
 
46
  depends = [ "cutlass_3_6", "torch" ]
47
 
48
  [kernel.fp8_common]
49
+ language = "cuda-hipify"
50
  cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
51
+ rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
52
  src = [
53
+ "fp8/amd/hip_float8.h",
54
+ "fp8/amd/hip_float8_impl.h",
55
  "fp8/common.cu",
56
  "fp8/common.cuh",
57
  "dispatch_utils.h",
 
70
  depends = [ "torch" ]
71
 
72
  [kernel.int8_common]
73
+ language = "cuda-hipify"
74
  cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
75
+ rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
76
  src = [
77
  "compressed_tensors/int8_quant_kernels.cu",
78
  "dispatch_utils.h"
flake.lock CHANGED
@@ -41,17 +41,17 @@
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1742905006,
45
- "narHash": "sha256-SCi1f5Lti4AM0kNPlAidcgN/5YM4HgJP4KwCsMrB0IE=",
46
- "ref": "refs/heads/main",
47
- "rev": "517a2bf2d0a3f1faf058ab995b6ca280b0999e7c",
48
- "revCount": 105,
49
- "type": "git",
50
- "url": "ssh://git@github.com/huggingface/kernel-builder"
51
  },
52
  "original": {
53
- "type": "git",
54
- "url": "ssh://git@github.com/huggingface/kernel-builder"
 
55
  }
56
  },
57
  "nixpkgs": {
@@ -78,11 +78,11 @@
78
  ]
79
  },
80
  "locked": {
81
- "lastModified": 1742285724,
82
- "narHash": "sha256-2QQn9fzmF/SKW082kXpSrEBgfmwKO2RNT5R91Fn/K4M=",
83
  "owner": "huggingface",
84
  "repo": "rocm-nix",
85
- "rev": "a90de1c2e5698b2f4fe984b5f0faf052f466be49",
86
  "type": "github"
87
  },
88
  "original": {
 
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
+ "lastModified": 1743416390,
45
+ "narHash": "sha256-Krrrq9asF2d5SVWGJQIhQA8UxVcTpiCor8hQU4G5J38=",
46
+ "owner": "huggingface",
47
+ "repo": "kernel-builder",
48
+ "rev": "e57cbde93f29032d32bbab8e32a1c86def6e9365",
49
+ "type": "github"
 
50
  },
51
  "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernel-builder",
54
+ "type": "github"
55
  }
56
  },
57
  "nixpkgs": {
 
78
  ]
79
  },
80
  "locked": {
81
+ "lastModified": 1743085847,
82
+ "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
83
  "owner": "huggingface",
84
  "repo": "rocm-nix",
85
+ "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
86
  "type": "github"
87
  },
88
  "original": {
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for quantization kernels";
3
 
4
  inputs = {
5
- kernel-builder.url = "git+ssh://git@github.com/huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
 
2
  description = "Flake for quantization kernels";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
torch-ext/torch_binding.cpp CHANGED
@@ -4,6 +4,8 @@
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 
 
7
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
8
  // quantization, as well as bias
9
  ops.def(
@@ -26,6 +28,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
26
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
27
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
28
 
 
 
29
  // Compute FP8 quantized tensor for given scaling factor.
30
  ops.def(
31
  "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
@@ -60,6 +64,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
60
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
61
  &dynamic_scaled_int8_quant);
62
 
 
 
63
  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
64
  ops.def(
65
  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
@@ -103,8 +109,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
103
  "Tensor s_tok, Tensor s_ch, Tensor s_group, "
104
  "Tensor! workspace, SymInt size_m, SymInt size_n, "
105
  "SymInt size_k) -> Tensor");
 
106
  }
107
 
 
 
108
  TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
109
  ops.impl("awq_marlin_repack", &awq_marlin_repack);
110
  ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
@@ -120,4 +129,6 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
120
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
121
  }
122
 
 
 
123
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ #ifndef USE_ROCM
8
+
9
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
10
  // quantization, as well as bias
11
  ops.def(
 
28
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
29
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
30
 
31
+ #endif
32
+
33
  // Compute FP8 quantized tensor for given scaling factor.
34
  ops.def(
35
  "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
 
64
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
65
  &dynamic_scaled_int8_quant);
66
 
67
+ #ifndef USE_ROCM
68
+
69
  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
70
  ops.def(
71
  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
 
109
  "Tensor s_tok, Tensor s_ch, Tensor s_group, "
110
  "Tensor! workspace, SymInt size_m, SymInt size_n, "
111
  "SymInt size_k) -> Tensor");
112
+ #endif
113
  }
114
 
115
+ #ifndef USE_ROCM
116
+
117
  TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
118
  ops.impl("awq_marlin_repack", &awq_marlin_repack);
119
  ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
 
129
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
130
  }
131
 
132
+ #endif
133
+
134
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)