import torch | |
from torch import nn | |
import torch.nn.functional as F | |
def modified_weight_quant(w): | |
""" Per−tensor quantization to 1.58 bits. No grouping is needed for quantization. | |
Args: | |
w: a weight tensor with shape [d, k] | |
Returns: | |
u: a quantized weight with shape [d, k] | |
""" | |
u = w.clamp(-1, 1).round() | |
return u | |
def normalize(w): | |
w = w / torch.norm(w, dim=1, keepdim=True) | |
return w | |
class QLinear(nn.Linear): | |
def __init__(self, | |
*kargs, | |
**kwargs | |
): | |
super(QLinear, self).__init__(*kargs, **kwargs) | |
""" | |
This is only for training, and kernel optimization is needed for efficiency. | |
""" | |
self.scales = nn.Parameter(torch.ones(self.out_features)) | |
self.quantizer = modified_weight_quant | |
def forward(self, x): | |
"""i | |
Args: | |
x: an input tensor with shape [n, d] | |
Returns: | |
y: an output tensor with shape [n, d] | |
""" | |
w_quant = self.weight | |
x = x.to(w_quant.device) | |
# STE weight quantization | |
w_quant = w_quant + (self.quantizer(w_quant) - w_quant).detach() | |
y = F.linear(x, w_quant) | |
# apply scales post matmul | |
y = y * self.scales | |
if self.bias is not None: | |
y = y + self.bias | |
return y |