quantization / ext-torch /utils /marlin_utils_fp8.py
danieldk's picture
danieldk HF staff
Add full Marlin support and tests for Marlin/CUTLASS
165b25c
raw
history blame
2.88 kB
from typing import Optional
import torch
import quantization as ops
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
return capability >= 80
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, strategy: str = "tensor"
) -> None:
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device)
# WEIGHT
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=pack_fp8_to_int32(layer.weight),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (
byte_tensor[:, 0].to(torch.int32)
| (byte_tensor[:, 1].to(torch.int32) << 8)
| (byte_tensor[:, 2].to(torch.int32) << 16)
| (byte_tensor[:, 3].to(torch.int32) << 24)
)
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()