quantization / ext-torch /__init__.py
danieldk's picture
danieldk HF staff
Add `scaled_(int|fp8)_quant` and `fp8_marlin_gemm`
5c6fb68
raw
history blame
6.07 kB
from typing import Optional, Tuple
import torch
try:
from ._ops import ops
except ImportError as e:
# Fallback for local development.
try:
import _quantization
ops = torch.ops._quantization
except ImportError:
raise e
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
#if current_platform.is_rocm():
# triton_scaled_mm_module = importlib.import_module(
# "vllm.model_executor.layers.quantization.compressed_tensors."
# "triton_scaled_mm")
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
# if current_platform.is_rocm() else torch.float8_e4m3fn
out_dtype = torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
ops.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert (scale.numel() == 1 or num_token_padding is None)
ops.static_scaled_fp8_quant(output, input, scale)
return output, scale
# int8
def scaled_int8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
Args:
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
azp: Optional zero-point for the int8 quantization.
Must be provided for asymmetric quantization if `scale` is provided.
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
# static-per-tensor quantization.
assert symmetric == (
azp is
None), "azp must only be provided for asymmetric quantization."
ops.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, azp
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32)
ops.dynamic_scaled_int8_quant(output, input, input_scales,
input_azp)
return output, input_scales, input_azp
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)