Fix absolute imports
Browse files- flake.lock +7 -6
- flake.nix +1 -1
- torch-ext/quantization/utils/marlin_utils.py +3 -4
- torch-ext/quantization/utils/marlin_utils_fp4.py +6 -7
- torch-ext/quantization/utils/marlin_utils_fp8.py +3 -3
- torch-ext/quantization/utils/marlin_utils_test.py +1 -2
- torch-ext/quantization/utils/marlin_utils_test_24.py +1 -2
- torch-ext/quantization/utils/quant_utils.py +1 -1
flake.lock
CHANGED
@@ -73,11 +73,11 @@
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
-
"lastModified":
|
77 |
-
"narHash": "sha256-
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
-
"rev": "
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
@@ -98,15 +98,16 @@
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
-
"lastModified":
|
102 |
-
"narHash": "sha256-
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
-
"rev": "
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
|
|
110 |
"repo": "kernel-builder",
|
111 |
"type": "github"
|
112 |
}
|
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
+
"lastModified": 1751968576,
|
77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1751968677,
|
102 |
+
"narHash": "sha256-5gtVPN6uk+H3yq2gJRDjSTcaVSgGJZjbMALlO6TBcT8=",
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
+
"rev": "54eea2ce49889202e7018792f407046e36f89bc5",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
110 |
+
"ref": "get-kernel-check",
|
111 |
"repo": "kernel-builder",
|
112 |
"type": "github"
|
113 |
}
|
flake.nix
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
description = "Flake for quantization kernels";
|
3 |
|
4 |
inputs = {
|
5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
};
|
7 |
|
8 |
outputs =
|
|
|
2 |
description = "Flake for quantization kernels";
|
3 |
|
4 |
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder/get-kernel-check";
|
6 |
};
|
7 |
|
8 |
outputs =
|
torch-ext/quantization/utils/marlin_utils.py
CHANGED
@@ -6,8 +6,7 @@ from typing import Optional
|
|
6 |
import numpy
|
7 |
import torch
|
8 |
|
9 |
-
import
|
10 |
-
from quantization.scalar_type import ScalarType, scalar_types
|
11 |
|
12 |
from .quant_utils import pack_cols, unpack_cols
|
13 |
|
@@ -383,7 +382,7 @@ def apply_gptq_marlin_linear(
|
|
383 |
device=input.device,
|
384 |
dtype=input.dtype)
|
385 |
|
386 |
-
output =
|
387 |
None,
|
388 |
weight,
|
389 |
weight_scale,
|
@@ -429,7 +428,7 @@ def apply_awq_marlin_linear(
|
|
429 |
device=input.device,
|
430 |
dtype=input.dtype)
|
431 |
|
432 |
-
output =
|
433 |
None,
|
434 |
weight,
|
435 |
weight_scale,
|
|
|
6 |
import numpy
|
7 |
import torch
|
8 |
|
9 |
+
from .. import ScalarType, gptq_marlin_gemm, scalar_types
|
|
|
10 |
|
11 |
from .quant_utils import pack_cols, unpack_cols
|
12 |
|
|
|
382 |
device=input.device,
|
383 |
dtype=input.dtype)
|
384 |
|
385 |
+
output = gptq_marlin_gemm(reshaped_x,
|
386 |
None,
|
387 |
weight,
|
388 |
weight_scale,
|
|
|
428 |
device=input.device,
|
429 |
dtype=input.dtype)
|
430 |
|
431 |
+
output = gptq_marlin_gemm(reshaped_x,
|
432 |
None,
|
433 |
weight,
|
434 |
weight_scale,
|
torch-ext/quantization/utils/marlin_utils_fp4.py
CHANGED
@@ -5,12 +5,11 @@ from typing import Optional
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
-
import
|
9 |
-
|
10 |
from .marlin_utils import (
|
11 |
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
12 |
should_use_atomic_add_reduce)
|
13 |
-
from
|
14 |
|
15 |
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
16 |
|
@@ -90,7 +89,7 @@ def apply_fp4_marlin_linear(
|
|
90 |
device=input.device,
|
91 |
dtype=input.dtype)
|
92 |
|
93 |
-
output =
|
94 |
c=None,
|
95 |
b_q_weight=weight,
|
96 |
b_scales=weight_scale,
|
@@ -135,7 +134,7 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
|
135 |
perm = torch.empty(0, dtype=torch.int, device=device)
|
136 |
qweight = layer.weight.view(torch.int32).T.contiguous()
|
137 |
|
138 |
-
marlin_qweight =
|
139 |
perm=perm,
|
140 |
size_k=part_size_k,
|
141 |
size_n=part_size_n,
|
@@ -192,7 +191,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
|
192 |
for i in range(e):
|
193 |
qweight = weight[i].view(torch.int32).T.contiguous()
|
194 |
|
195 |
-
marlin_qweight =
|
196 |
perm=perm,
|
197 |
size_k=size_k,
|
198 |
size_n=size_n,
|
@@ -263,7 +262,7 @@ def rand_marlin_weight_fp4_like(weight, group_size):
|
|
263 |
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
264 |
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
265 |
|
266 |
-
marlin_qweight =
|
267 |
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
268 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
269 |
size_k=size_k,
|
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
+
from .. import gptq_marlin_gemm, gptq_marlin_repack
|
|
|
9 |
from .marlin_utils import (
|
10 |
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
11 |
should_use_atomic_add_reduce)
|
12 |
+
from ..scalar_type import scalar_types
|
13 |
|
14 |
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
15 |
|
|
|
89 |
device=input.device,
|
90 |
dtype=input.dtype)
|
91 |
|
92 |
+
output = gptq_marlin_gemm(a=reshaped_x,
|
93 |
c=None,
|
94 |
b_q_weight=weight,
|
95 |
b_scales=weight_scale,
|
|
|
134 |
perm = torch.empty(0, dtype=torch.int, device=device)
|
135 |
qweight = layer.weight.view(torch.int32).T.contiguous()
|
136 |
|
137 |
+
marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
|
138 |
perm=perm,
|
139 |
size_k=part_size_k,
|
140 |
size_n=part_size_n,
|
|
|
191 |
for i in range(e):
|
192 |
qweight = weight[i].view(torch.int32).T.contiguous()
|
193 |
|
194 |
+
marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
|
195 |
perm=perm,
|
196 |
size_k=size_k,
|
197 |
size_n=size_n,
|
|
|
262 |
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
263 |
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
264 |
|
265 |
+
marlin_qweight = gptq_marlin_repack(
|
266 |
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
267 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
268 |
size_k=size_k,
|
torch-ext/quantization/utils/marlin_utils_fp8.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
-
import
|
9 |
|
10 |
from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
|
11 |
|
@@ -51,7 +51,7 @@ def apply_fp8_marlin_linear(
|
|
51 |
device=input.device,
|
52 |
dtype=input.dtype)
|
53 |
|
54 |
-
output =
|
55 |
c=None,
|
56 |
b_q_weight=weight,
|
57 |
b_scales=weight_scale,
|
@@ -104,7 +104,7 @@ def marlin_quant_fp8_torch(weight, group_size):
|
|
104 |
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
105 |
|
106 |
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
107 |
-
marlin_qweight =
|
108 |
b_q_weight=packed_weight,
|
109 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
110 |
size_k=size_k,
|
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
+
from .. import gptq_marlin_gemm, gptq_marlin_repack
|
9 |
|
10 |
from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
|
11 |
|
|
|
51 |
device=input.device,
|
52 |
dtype=input.dtype)
|
53 |
|
54 |
+
output = gptq_marlin_gemm(a=reshaped_x,
|
55 |
c=None,
|
56 |
b_q_weight=weight,
|
57 |
b_scales=weight_scale,
|
|
|
104 |
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
105 |
|
106 |
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
107 |
+
marlin_qweight = gptq_marlin_repack(
|
108 |
b_q_weight=packed_weight,
|
109 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
110 |
size_k=size_k,
|
torch-ext/quantization/utils/marlin_utils_test.py
CHANGED
@@ -5,8 +5,7 @@ from typing import List, Optional
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
-
from
|
9 |
-
|
10 |
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
from .quant_utils import (
|
12 |
get_pack_factor,
|
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
+
from ..scalar_type import ScalarType
|
|
|
9 |
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
10 |
from .quant_utils import (
|
11 |
get_pack_factor,
|
torch-ext/quantization/utils/marlin_utils_test_24.py
CHANGED
@@ -6,8 +6,7 @@ from typing import List
|
|
6 |
import numpy
|
7 |
import torch
|
8 |
|
9 |
-
from
|
10 |
-
|
11 |
from .marlin_utils_test import marlin_weights
|
12 |
from .quant_utils import gptq_quantize_weights
|
13 |
|
|
|
6 |
import numpy
|
7 |
import torch
|
8 |
|
9 |
+
from ..scalar_type import ScalarType
|
|
|
10 |
from .marlin_utils_test import marlin_weights
|
11 |
from .quant_utils import gptq_quantize_weights
|
12 |
|
torch-ext/quantization/utils/quant_utils.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
|
|
5 |
import numpy
|
6 |
import torch
|
7 |
|
8 |
-
from
|
9 |
|
10 |
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
|
5 |
import numpy
|
6 |
import torch
|
7 |
|
8 |
+
from ..scalar_type import ScalarType, scalar_types
|
9 |
|
10 |
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|