Doa-doa's picture
Upload folder using huggingface_hub
72268ee
import torch.nn as nn
import awq_inference_engine
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantLlamaMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
self.out_features = down_proj.out_features
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
)
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
)
x = F.silu(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
return x