|
import copy |
|
import functools |
|
import math |
|
|
|
import transformers |
|
import torch |
|
import torch.nn as nn |
|
|
|
class TWNLinear(nn.Linear): |
|
def __init__(self, in_features, out_features, bias=False): |
|
super().__init__(in_features, out_features, bias=bias) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
x = self.weight |
|
x = TwnQuantizer().apply(x) |
|
output = torch.nn.functional.linear(input, x.to(input.dtype)) |
|
return output |
|
|
|
|
|
class TwnQuantizer(torch.autograd.Function): |
|
"""Ternary Weight Networks (TWN) |
|
Ref: https://arxiv.org/abs/1605.04711 |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input, max_scale=0.7, clip = None, group_size= -1, per_tensor = False, max_scale_dummy=0.7): |
|
""" |
|
:param input: tensor to be ternarized |
|
:return: quantized tensor |
|
""" |
|
ctx.save_for_backward(input) |
|
|
|
org_w_shape = input.shape |
|
q_group_size = group_size |
|
|
|
if q_group_size > 0: |
|
assert org_w_shape[-1] % q_group_size == 0 |
|
input = input.reshape(-1, q_group_size) |
|
else: |
|
input = input.reshape(-1, input.shape[-1]) |
|
|
|
if per_tensor: assert q_group_size == -1, "Conflict with Per Tensor and Per Group Quant!" |
|
|
|
if clip != None: |
|
if per_tensor: |
|
m = input.norm(p=1).div(input.nelement()) |
|
|
|
clip_alpha = m * clip |
|
else: |
|
m = input.norm(p=1,dim=1).div(input[0].nelement()) |
|
m = m.expand(input.shape[1], -1).transpose(0,1) |
|
clip_alpha = m * clip |
|
input = torch.where(input <= clip_alpha, input, clip_alpha) |
|
input = torch.where(input >= -1*clip_alpha, input, -1*clip_alpha) |
|
|
|
if per_tensor: |
|
|
|
m = input.abs().mean() |
|
thres = max_scale * m |
|
pos = (input > thres).float() |
|
neg = (input < -thres).float() |
|
mask = (input.abs() > thres).float() |
|
alpha = (mask * input).abs().sum() / mask.sum() |
|
result = alpha * pos - alpha * neg |
|
else: |
|
|
|
n = input[0].nelement() |
|
m = input.data.norm(p=1, dim=1).div(n) |
|
thres = (max_scale * m).view(-1, 1).expand_as(input) |
|
pos = (input > thres).float() |
|
neg = (input < -thres).float() |
|
mask = (input.abs() > thres).float() |
|
alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1) |
|
result = alpha * pos - alpha * neg |
|
|
|
result = result.reshape(org_w_shape) |
|
|
|
return result |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
""" |
|
:param ctx: saved non-clipped full-precision tensor and clip_val |
|
:param grad_output: gradient ert the quantized tensor |
|
:return: estimated gradient wrt the full-precision tensor |
|
""" |
|
|
|
input = ctx.saved_tensors |
|
grad_input = grad_output.clone() |
|
|
|
|
|
return grad_input, None, None, None, None |
|
|