File size: 2,875 Bytes
165b25c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Optional

import torch

import quantization as ops

from .marlin_utils import marlin_make_workspace, marlin_permute_scales


def is_fp8_marlin_supported():
    capability = torch.cuda.get_device_capability()
    capability = capability[0] * 10 + capability[1]
    return capability >= 80


def apply_fp8_marlin_linear(
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    workspace: torch.Tensor,
    size_n: int,
    size_k: int,
    bias: Optional[torch.Tensor],
) -> torch.Tensor:
    # For GPUs that lack FP8 hardware support, we can leverage the
    # Marlin kernel for fast weight-only FP8 quantization

    reshaped_x = input.reshape(-1, input.shape[-1])
    out_shape = input.shape[:-1] + (size_n,)

    output = ops.fp8_marlin_gemm(
        a=reshaped_x,
        b_q_weight=weight,
        b_scales=weight_scale,
        workspace=workspace,
        num_bits=8,
        size_m=reshaped_x.shape[0],
        size_n=size_n,
        size_k=size_k,
    )

    if bias is not None:
        output.add_(bias)  # In-place add

    return output.reshape(out_shape)


def prepare_fp8_layer_for_marlin(
    layer: torch.nn.Module, strategy: str = "tensor"
) -> None:
    part_size_n = layer.output_size_per_partition
    part_size_k = layer.input_size_per_partition

    device = layer.weight.device

    # WORKSPACE
    layer.workspace = marlin_make_workspace(part_size_n, device)

    # WEIGHT
    # Repack weights to marlin format
    marlin_qweight = ops.gptq_marlin_repack(
        b_q_weight=pack_fp8_to_int32(layer.weight),
        perm=torch.empty(0, dtype=torch.int, device=device),
        size_k=part_size_k,
        size_n=part_size_n,
        num_bits=8,
    )
    layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)

    # WEIGHT SCALES
    scales = layer.weight_scale.to(layer.orig_dtype)
    # Permute scales
    marlin_scales = marlin_permute_scales(
        s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
    )
    layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)


def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
    """
    Repack FP8 weights to gptq format (packed int32 elements)
    """
    assert fp8_tensor.dtype == torch.float8_e4m3fn
    assert fp8_tensor.shape[0] % 4 == 0

    # Reshape to prepare for packing
    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])

    # Convert fp8 to uint8 (byte) representation
    byte_tensor = reshaped.view(torch.uint8)

    # Pack 4 uint8 values into one int32
    packed = (
        byte_tensor[:, 0].to(torch.int32)
        | (byte_tensor[:, 1].to(torch.int32) << 8)
        | (byte_tensor[:, 2].to(torch.int32) << 16)
        | (byte_tensor[:, 3].to(torch.int32) << 24)
    )

    return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()