File size: 1,328 Bytes
8798b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch import nn
import torch.nn.functional as F

def modified_weight_quant(w):
    """ Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
    Args:
    w: a weight tensor with shape [d, k]
    Returns:
    u: a quantized weight with shape [d, k]
    """
    u = w.clamp(-1, 1).round() 
    return u

def normalize(w):
    w = w / torch.norm(w, dim=1, keepdim=True) 
    return w

class QLinear(nn.Linear):
    def __init__(self,
            *kargs,
            **kwargs
        ):
        super(QLinear, self).__init__(*kargs, **kwargs)
        """
        This is only for training, and kernel optimization is needed for efficiency.
        """
        self.scales = nn.Parameter(torch.ones(self.out_features))
        self.quantizer = modified_weight_quant


    def forward(self, x):
        """i
        Args:
        x: an input tensor with shape [n, d]
        Returns:
        y: an output tensor with shape [n, d]
        """
        w_quant = self.weight
        x = x.to(w_quant.device)
        # STE weight quantization
        w_quant = w_quant + (self.quantizer(w_quant) - w_quant).detach()
        y = F.linear(x, w_quant) 
        # apply scales post matmul
        y = y * self.scales
        if self.bias is not None:
            y = y + self.bias
        return y