|
from typing import Optional |
|
|
|
import torch |
|
|
|
import quantization as ops |
|
|
|
from .marlin_utils import marlin_make_workspace, marlin_permute_scales |
|
|
|
|
|
def is_fp8_marlin_supported(): |
|
capability = torch.cuda.get_device_capability() |
|
capability = capability[0] * 10 + capability[1] |
|
return capability >= 80 |
|
|
|
|
|
def apply_fp8_marlin_linear( |
|
input: torch.Tensor, |
|
weight: torch.Tensor, |
|
weight_scale: torch.Tensor, |
|
workspace: torch.Tensor, |
|
size_n: int, |
|
size_k: int, |
|
bias: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
reshaped_x = input.reshape(-1, input.shape[-1]) |
|
out_shape = input.shape[:-1] + (size_n,) |
|
|
|
output = ops.fp8_marlin_gemm( |
|
a=reshaped_x, |
|
b_q_weight=weight, |
|
b_scales=weight_scale, |
|
workspace=workspace, |
|
num_bits=8, |
|
size_m=reshaped_x.shape[0], |
|
size_n=size_n, |
|
size_k=size_k, |
|
) |
|
|
|
if bias is not None: |
|
output.add_(bias) |
|
|
|
return output.reshape(out_shape) |
|
|
|
|
|
def prepare_fp8_layer_for_marlin( |
|
layer: torch.nn.Module, strategy: str = "tensor" |
|
) -> None: |
|
part_size_n = layer.output_size_per_partition |
|
part_size_k = layer.input_size_per_partition |
|
|
|
device = layer.weight.device |
|
|
|
|
|
layer.workspace = marlin_make_workspace(part_size_n, device) |
|
|
|
|
|
|
|
marlin_qweight = ops.gptq_marlin_repack( |
|
b_q_weight=pack_fp8_to_int32(layer.weight), |
|
perm=torch.empty(0, dtype=torch.int, device=device), |
|
size_k=part_size_k, |
|
size_n=part_size_n, |
|
num_bits=8, |
|
) |
|
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) |
|
|
|
|
|
scales = layer.weight_scale.to(layer.orig_dtype) |
|
|
|
marlin_scales = marlin_permute_scales( |
|
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 |
|
) |
|
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) |
|
|
|
|
|
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Repack FP8 weights to gptq format (packed int32 elements) |
|
""" |
|
assert fp8_tensor.dtype == torch.float8_e4m3fn |
|
assert fp8_tensor.shape[0] % 4 == 0 |
|
|
|
|
|
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) |
|
|
|
|
|
byte_tensor = reshaped.view(torch.uint8) |
|
|
|
|
|
packed = ( |
|
byte_tensor[:, 0].to(torch.int32) |
|
| (byte_tensor[:, 1].to(torch.int32) << 8) |
|
| (byte_tensor[:, 2].to(torch.int32) << 16) |
|
| (byte_tensor[:, 3].to(torch.int32) << 24) |
|
) |
|
|
|
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() |
|
|