|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
try: |
|
from ._ops import ops |
|
except ImportError as e: |
|
|
|
try: |
|
import _quantization |
|
ops = torch.ops._quantization |
|
except ImportError: |
|
raise e |
|
|
|
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: |
|
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) |
|
|
|
def cutlass_scaled_mm(a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale_a: torch.Tensor, |
|
scale_b: torch.Tensor, |
|
out_dtype: torch.dtype, |
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) |
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) |
|
assert bias is None or bias.shape[0] == b.shape[ |
|
1] and bias.dtype == out_dtype |
|
|
|
m = a.shape[0] |
|
n = b.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device) |
|
|
|
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) |
|
|
|
return out |
|
|
|
|
|
def scaled_fp8_quant( |
|
input: torch.Tensor, |
|
scale: Optional[torch.Tensor] = None, |
|
num_token_padding: Optional[int] = None, |
|
scale_ub: Optional[torch.Tensor] = None, |
|
use_per_token_if_dynamic: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Quantize input tensor to FP8 and return quantized tensor and scale. |
|
|
|
This function supports both static and dynamic quantization: If you |
|
provide the scale, it will use static scaling and if you omit it, |
|
the scale will be determined dynamically. The function also allows |
|
optional padding of the output tensors for downstream kernels that |
|
will benefit from padding. |
|
|
|
Args: |
|
input: The input tensor to be quantized to FP8 |
|
scale: Optional scaling factor for the FP8 quantization |
|
scale_ub: Optional upper bound for scaling factor in dynamic |
|
per token case |
|
num_token_padding: If specified, pad the first dimension |
|
of the output to at least this value. |
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token |
|
in the dynamic quantization case. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and |
|
scaling factor. |
|
""" |
|
|
|
assert (input.ndim == 2) |
|
shape: Union[Tuple[int, int], torch.Size] = input.shape |
|
|
|
|
|
|
|
out_dtype = torch.float8_e4m3fn |
|
if num_token_padding: |
|
shape = (max(num_token_padding, input.shape[0]), shape[1]) |
|
output = torch.empty(shape, device=input.device, dtype=out_dtype) |
|
|
|
if scale is None: |
|
if use_per_token_if_dynamic: |
|
scale = torch.empty((shape[0], 1), |
|
device=input.device, |
|
dtype=torch.float32) |
|
ops.dynamic_per_token_scaled_fp8_quant( |
|
output, input, scale, scale_ub) |
|
else: |
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32) |
|
ops.dynamic_scaled_fp8_quant(output, input, scale) |
|
else: |
|
|
|
assert (scale.numel() == 1 or num_token_padding is None) |
|
ops.static_scaled_fp8_quant(output, input, scale) |
|
|
|
return output, scale |
|
|
|
|
|
def scaled_int8_quant( |
|
input: torch.Tensor, |
|
scale: Optional[torch.Tensor] = None, |
|
azp: Optional[torch.Tensor] = None, |
|
symmetric: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|
""" |
|
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. |
|
|
|
Args: |
|
input: The input tensor to be quantized to int8. |
|
scale: Optional scaling factor for the int8 quantization. |
|
When not provided, we invoke dynamic-per-token quantization. |
|
azp: Optional zero-point for the int8 quantization. |
|
Must be provided for asymmetric quantization if `scale` is provided. |
|
symmetric: Whether to use symmetric quantization (scale only, azp ignored). |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. |
|
""" |
|
output = torch.empty_like(input, dtype=torch.int8) |
|
if scale is not None: |
|
|
|
assert symmetric == ( |
|
azp is |
|
None), "azp must only be provided for asymmetric quantization." |
|
ops.static_scaled_int8_quant(output, input, scale, azp) |
|
return output, scale, azp |
|
|
|
|
|
input_scales = torch.empty((input.numel() // input.shape[-1], 1), |
|
device=input.device, |
|
dtype=torch.float32) |
|
input_azp = None if symmetric else torch.empty_like(input_scales, |
|
dtype=torch.int32) |
|
ops.dynamic_scaled_int8_quant(output, input, input_scales, |
|
input_azp) |
|
return output, input_scales, input_azp |
|
|
|
|
|
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, |
|
b_scales: torch.Tensor, workspace: torch.Tensor, |
|
num_bits: int, size_m: int, size_n: int, |
|
size_k: int) -> torch.Tensor: |
|
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, |
|
num_bits, size_m, size_n, size_k) |
|
|