danieldk's picture
danieldk HF staff
Add full Marlin support and tests for Marlin/CUTLASS
165b25c
raw
history blame
6.47 kB
from typing import TYPE_CHECKING
import torch
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
try:
from ._ops import ops, add_op_namespace_prefix
except ImportError as e:
# Fallback for local development.
try:
import _quantization
ops = torch.ops._quantization
def add_op_namespace_prefix(op_name: str):
return f"_quantization::{op_name}"
except ImportError:
raise e
from .scalar_type import ScalarType
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return ops.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)
# gptq_marlin
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return ops.gptq_marlin_gemm(
a,
b_q_weight,
b_scales,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
has_zp,
use_fp32_reduce,
is_zp_float,
)
# gptq_marlin
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
# gptq_marlin
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
# marlin
def marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return ops.marlin_gemm(
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
)
# marlin_24
def gptq_marlin_24_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_meta: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return ops.gptq_marlin_24_gemm(
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
)
# qqq ops
def marlin_qqq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
s_tok: torch.Tensor,
s_ch: torch.Tensor,
s_group: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return ops.marlin_qqq_gemm(
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
)
# Fake ops
if hasattr(ops, "gptq_marlin_24_gemm"):
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@register_fake(add_op_namespace_prefix("marlin_gemm"))
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)