File size: 2,875 Bytes
165b25c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from typing import Optional
import torch
import quantization as ops
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
return capability >= 80
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, strategy: str = "tensor"
) -> None:
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device)
# WEIGHT
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=pack_fp8_to_int32(layer.weight),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (
byte_tensor[:, 0].to(torch.int32)
| (byte_tensor[:, 1].to(torch.int32) << 8)
| (byte_tensor[:, 2].to(torch.int32) << 16)
| (byte_tensor[:, 3].to(torch.int32) << 24)
)
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
|