danieldk HF Staff commited on
Commit
c516610
·
1 Parent(s): 229f047

Fix absolute imports

Browse files
flake.lock CHANGED
@@ -73,11 +73,11 @@
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
@@ -98,15 +98,16 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1751630801,
102
- "narHash": "sha256-wSVDQZejEifZQfcQF4R7tUd1QJo4wTwu2PMnF3lX0Z0=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "99306a90b243bf603f4d41906bd51af92266fcbd",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
 
110
  "repo": "kernel-builder",
111
  "type": "github"
112
  }
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1751968576,
77
+ "narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1751968677,
102
+ "narHash": "sha256-5gtVPN6uk+H3yq2gJRDjSTcaVSgGJZjbMALlO6TBcT8=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "54eea2ce49889202e7018792f407046e36f89bc5",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "ref": "get-kernel-check",
111
  "repo": "kernel-builder",
112
  "type": "github"
113
  }
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for quantization kernels";
3
 
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
 
2
  description = "Flake for quantization kernels";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/get-kernel-check";
6
  };
7
 
8
  outputs =
torch-ext/quantization/utils/marlin_utils.py CHANGED
@@ -6,8 +6,7 @@ from typing import Optional
6
  import numpy
7
  import torch
8
 
9
- import quantization as ops
10
- from quantization.scalar_type import ScalarType, scalar_types
11
 
12
  from .quant_utils import pack_cols, unpack_cols
13
 
@@ -383,7 +382,7 @@ def apply_gptq_marlin_linear(
383
  device=input.device,
384
  dtype=input.dtype)
385
 
386
- output = ops.gptq_marlin_gemm(reshaped_x,
387
  None,
388
  weight,
389
  weight_scale,
@@ -429,7 +428,7 @@ def apply_awq_marlin_linear(
429
  device=input.device,
430
  dtype=input.dtype)
431
 
432
- output = ops.gptq_marlin_gemm(reshaped_x,
433
  None,
434
  weight,
435
  weight_scale,
 
6
  import numpy
7
  import torch
8
 
9
+ from .. import ScalarType, gptq_marlin_gemm, scalar_types
 
10
 
11
  from .quant_utils import pack_cols, unpack_cols
12
 
 
382
  device=input.device,
383
  dtype=input.dtype)
384
 
385
+ output = gptq_marlin_gemm(reshaped_x,
386
  None,
387
  weight,
388
  weight_scale,
 
428
  device=input.device,
429
  dtype=input.dtype)
430
 
431
+ output = gptq_marlin_gemm(reshaped_x,
432
  None,
433
  weight,
434
  weight_scale,
torch-ext/quantization/utils/marlin_utils_fp4.py CHANGED
@@ -5,12 +5,11 @@ from typing import Optional
5
 
6
  import torch
7
 
8
- import quantization as ops
9
-
10
  from .marlin_utils import (
11
  USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
12
  should_use_atomic_add_reduce)
13
- from quantization.scalar_type import scalar_types
14
 
15
  FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
16
 
@@ -90,7 +89,7 @@ def apply_fp4_marlin_linear(
90
  device=input.device,
91
  dtype=input.dtype)
92
 
93
- output = ops.gptq_marlin_gemm(a=reshaped_x,
94
  c=None,
95
  b_q_weight=weight,
96
  b_scales=weight_scale,
@@ -135,7 +134,7 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
135
  perm = torch.empty(0, dtype=torch.int, device=device)
136
  qweight = layer.weight.view(torch.int32).T.contiguous()
137
 
138
- marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
139
  perm=perm,
140
  size_k=part_size_k,
141
  size_n=part_size_n,
@@ -192,7 +191,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
192
  for i in range(e):
193
  qweight = weight[i].view(torch.int32).T.contiguous()
194
 
195
- marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
196
  perm=perm,
197
  size_k=size_k,
198
  size_n=size_n,
@@ -263,7 +262,7 @@ def rand_marlin_weight_fp4_like(weight, group_size):
263
  weight_ref = weight_ref * global_scale.to(weight.dtype) * \
264
  scales.repeat_interleave(group_size, 1).to(weight.dtype)
265
 
266
- marlin_qweight = ops.gptq_marlin_repack(
267
  b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
268
  perm=torch.empty(0, dtype=torch.int, device=device),
269
  size_k=size_k,
 
5
 
6
  import torch
7
 
8
+ from .. import gptq_marlin_gemm, gptq_marlin_repack
 
9
  from .marlin_utils import (
10
  USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
11
  should_use_atomic_add_reduce)
12
+ from ..scalar_type import scalar_types
13
 
14
  FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
15
 
 
89
  device=input.device,
90
  dtype=input.dtype)
91
 
92
+ output = gptq_marlin_gemm(a=reshaped_x,
93
  c=None,
94
  b_q_weight=weight,
95
  b_scales=weight_scale,
 
134
  perm = torch.empty(0, dtype=torch.int, device=device)
135
  qweight = layer.weight.view(torch.int32).T.contiguous()
136
 
137
+ marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
138
  perm=perm,
139
  size_k=part_size_k,
140
  size_n=part_size_n,
 
191
  for i in range(e):
192
  qweight = weight[i].view(torch.int32).T.contiguous()
193
 
194
+ marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
195
  perm=perm,
196
  size_k=size_k,
197
  size_n=size_n,
 
262
  weight_ref = weight_ref * global_scale.to(weight.dtype) * \
263
  scales.repeat_interleave(group_size, 1).to(weight.dtype)
264
 
265
+ marlin_qweight = gptq_marlin_repack(
266
  b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
267
  perm=torch.empty(0, dtype=torch.int, device=device),
268
  size_k=size_k,
torch-ext/quantization/utils/marlin_utils_fp8.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
 
6
  import torch
7
 
8
- import quantization as ops
9
 
10
  from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
@@ -51,7 +51,7 @@ def apply_fp8_marlin_linear(
51
  device=input.device,
52
  dtype=input.dtype)
53
 
54
- output = ops.gptq_marlin_gemm(a=reshaped_x,
55
  c=None,
56
  b_q_weight=weight,
57
  b_scales=weight_scale,
@@ -104,7 +104,7 @@ def marlin_quant_fp8_torch(weight, group_size):
104
  weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
 
106
  packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
- marlin_qweight = ops.gptq_marlin_repack(
108
  b_q_weight=packed_weight,
109
  perm=torch.empty(0, dtype=torch.int, device=device),
110
  size_k=size_k,
 
5
 
6
  import torch
7
 
8
+ from .. import gptq_marlin_gemm, gptq_marlin_repack
9
 
10
  from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
 
51
  device=input.device,
52
  dtype=input.dtype)
53
 
54
+ output = gptq_marlin_gemm(a=reshaped_x,
55
  c=None,
56
  b_q_weight=weight,
57
  b_scales=weight_scale,
 
104
  weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
 
106
  packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = gptq_marlin_repack(
108
  b_q_weight=packed_weight,
109
  perm=torch.empty(0, dtype=torch.int, device=device),
110
  size_k=size_k,
torch-ext/quantization/utils/marlin_utils_test.py CHANGED
@@ -5,8 +5,7 @@ from typing import List, Optional
5
  import numpy as np
6
  import torch
7
 
8
- from quantization.scalar_type import ScalarType
9
-
10
  from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
  from .quant_utils import (
12
  get_pack_factor,
 
5
  import numpy as np
6
  import torch
7
 
8
+ from ..scalar_type import ScalarType
 
9
  from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
10
  from .quant_utils import (
11
  get_pack_factor,
torch-ext/quantization/utils/marlin_utils_test_24.py CHANGED
@@ -6,8 +6,7 @@ from typing import List
6
  import numpy
7
  import torch
8
 
9
- from quantization.scalar_type import ScalarType
10
-
11
  from .marlin_utils_test import marlin_weights
12
  from .quant_utils import gptq_quantize_weights
13
 
 
6
  import numpy
7
  import torch
8
 
9
+ from ..scalar_type import ScalarType
 
10
  from .marlin_utils_test import marlin_weights
11
  from .quant_utils import gptq_quantize_weights
12
 
torch-ext/quantization/utils/quant_utils.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
5
  import numpy
6
  import torch
7
 
8
- from quantization.scalar_type import ScalarType, scalar_types
9
 
10
  SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
  SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 
5
  import numpy
6
  import torch
7
 
8
+ from ..scalar_type import ScalarType, scalar_types
9
 
10
  SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
  SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]