import copy import functools import math import transformers import torch import torch.nn as nn class TWNLinear(nn.Linear): def __init__(self, in_features, out_features, bias=False): super().__init__(in_features, out_features, bias=bias) def forward(self, input: torch.Tensor) -> torch.Tensor: x = self.weight x = TwnQuantizer().apply(x) output = torch.nn.functional.linear(input, x.to(input.dtype)) return output class TwnQuantizer(torch.autograd.Function): """Ternary Weight Networks (TWN) Ref: https://arxiv.org/abs/1605.04711 """ @staticmethod def forward(ctx, input, max_scale=0.7, clip = None, group_size= -1, per_tensor = False, max_scale_dummy=0.7): """ :param input: tensor to be ternarized :return: quantized tensor """ ctx.save_for_backward(input) org_w_shape = input.shape q_group_size = group_size if q_group_size > 0: assert org_w_shape[-1] % q_group_size == 0 input = input.reshape(-1, q_group_size) else: input = input.reshape(-1, input.shape[-1]) if per_tensor: assert q_group_size == -1, "Conflict with Per Tensor and Per Group Quant!" if clip != None: if per_tensor: m = input.norm(p=1).div(input.nelement()) # m = input.max() clip_alpha = m * clip else: m = input.norm(p=1,dim=1).div(input[0].nelement()) m = m.expand(input.shape[1], -1).transpose(0,1) clip_alpha = m * clip input = torch.where(input <= clip_alpha, input, clip_alpha) input = torch.where(input >= -1*clip_alpha, input, -1*clip_alpha) if per_tensor: # Per Tensor Quantizaiton m = input.abs().mean() thres = max_scale * m pos = (input > thres).float() neg = (input < -thres).float() mask = (input.abs() > thres).float() alpha = (mask * input).abs().sum() / mask.sum() result = alpha * pos - alpha * neg else: # Per Channel/Group Quantization n = input[0].nelement() m = input.data.norm(p=1, dim=1).div(n) thres = (max_scale * m).view(-1, 1).expand_as(input) pos = (input > thres).float() neg = (input < -thres).float() mask = (input.abs() > thres).float() alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1) result = alpha * pos - alpha * neg result = result.reshape(org_w_shape) # for per-group quantization return result @staticmethod def backward(ctx, grad_output): """ :param ctx: saved non-clipped full-precision tensor and clip_val :param grad_output: gradient ert the quantized tensor :return: estimated gradient wrt the full-precision tensor """ # input, clip_val = ctx.saved_tensors # unclipped input input = ctx.saved_tensors # unclipped input grad_input = grad_output.clone() # grad_input[input.ge(clip_val[1])] = 0 # grad_input[input.le(clip_val[0])] = 0 return grad_input, None, None, None, None