File size: 1,829 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch.nn as nn
import awq_inference_engine
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantLlamaMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
self.out_features = down_proj.out_features
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
)
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
)
x = F.silu(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
return x |