aprevtrue / quantized_linear.py
semran1's picture
Upload folder using huggingface_hub
58a855f verified
import copy
import functools
import math
import transformers
import torch
import torch.nn as nn
class TWNLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=False):
super().__init__(in_features, out_features, bias=bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.weight
x = TwnQuantizer().apply(x)
output = torch.nn.functional.linear(input, x.to(input.dtype))
return output
class TwnQuantizer(torch.autograd.Function):
"""Ternary Weight Networks (TWN)
Ref: https://arxiv.org/abs/1605.04711
"""
@staticmethod
def forward(ctx, input, max_scale=0.7, clip = None, group_size= -1, per_tensor = False, max_scale_dummy=0.7):
"""
:param input: tensor to be ternarized
:return: quantized tensor
"""
ctx.save_for_backward(input)
org_w_shape = input.shape
q_group_size = group_size
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
input = input.reshape(-1, q_group_size)
else:
input = input.reshape(-1, input.shape[-1])
if per_tensor: assert q_group_size == -1, "Conflict with Per Tensor and Per Group Quant!"
if clip != None:
if per_tensor:
m = input.norm(p=1).div(input.nelement())
# m = input.max()
clip_alpha = m * clip
else:
m = input.norm(p=1,dim=1).div(input[0].nelement())
m = m.expand(input.shape[1], -1).transpose(0,1)
clip_alpha = m * clip
input = torch.where(input <= clip_alpha, input, clip_alpha)
input = torch.where(input >= -1*clip_alpha, input, -1*clip_alpha)
if per_tensor:
# Per Tensor Quantizaiton
m = input.abs().mean()
thres = max_scale * m
pos = (input > thres).float()
neg = (input < -thres).float()
mask = (input.abs() > thres).float()
alpha = (mask * input).abs().sum() / mask.sum()
result = alpha * pos - alpha * neg
else:
# Per Channel/Group Quantization
n = input[0].nelement()
m = input.data.norm(p=1, dim=1).div(n)
thres = (max_scale * m).view(-1, 1).expand_as(input)
pos = (input > thres).float()
neg = (input < -thres).float()
mask = (input.abs() > thres).float()
alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
result = alpha * pos - alpha * neg
result = result.reshape(org_w_shape) # for per-group quantization
return result
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx: saved non-clipped full-precision tensor and clip_val
:param grad_output: gradient ert the quantized tensor
:return: estimated gradient wrt the full-precision tensor
"""
# input, clip_val = ctx.saved_tensors # unclipped input
input = ctx.saved_tensors # unclipped input
grad_input = grad_output.clone()
# grad_input[input.ge(clip_val[1])] = 0
# grad_input[input.le(clip_val[0])] = 0
return grad_input, None, None, None, None