bart-base-detox-svd / modules.py
not-found's picture
Add SVD-compressed model with rank 512
db45d00
# Copied from
# rut5compressed/nn/functional.py
# rut5compressed/nn/modules.py
# modules of original repository.
from typing import Optional, Sequence, Tuple
import torch as T
class SVDCompressedLinearFunc(T.autograd.Function):
@staticmethod
def forward(ctx, input: T.Tensor, lhs: T.Tensor,
rhs: T.Tensor, bias: Optional[T.Tensor] = None) -> T.Tensor:
# See PEP-0465 on matmul operator associativity.
# https://peps.python.org/pep-0465/#precedence-and-associativity
output = (input @ lhs) @ rhs
if bias is not None:
output += bias[None, :]
ctx.bias = bias is not None
ctx.save_for_backward(input, lhs, rhs)
return output
@staticmethod
def backward(ctx, grad_output: Sequence[T.Tensor]):
input, lhs, rhs = ctx.saved_tensors
# Flatten input and output gradients over the leading dimensions.
inp_size = lhs.shape[0]
out_size = rhs.shape[1]
input_shape = input.shape
input = input.reshape(-1, inp_size)
grad_output = grad_output.reshape(-1, out_size)
input_grad = None
if ctx.needs_input_grad[0]:
input_grad = (grad_output @ rhs.T) @ lhs.T
lhs_grad = None
if ctx.needs_input_grad[1]:
# On practice for large models embedding dimension is large than
# batch size.
lhs_grad = input.T @ (grad_output @ rhs.T)
rhs_grad = None
if ctx.needs_input_grad[2]:
# Again, batch size is usually lesser then embedding dimension.
rhs_grad = (input @ lhs).T @ grad_output
bias_grad = None
if ctx.needs_input_grad[3]:
bias_grad = grad_output.sum(axis=0)
# Restore shape of input gradients.
input_grad = input_grad.reshape(input_shape)
return input_grad, lhs_grad, rhs_grad, bias_grad
compressed_linear_svd = SVDCompressedLinearFunc.apply
class SVDCompressedLinear(T.nn.Module):
"""Class SVDCompressedLinear is a layer which represents a weight matrix of
lineaer layer in factorized view.
>>> linear_layer = T.nn.Linear(10, 20)
>>> svd_layer = SVDCompressedLinear.from_linear(linear_layer, rank=5)
"""
def __init__(self, factors: Tuple[T.Tensor, T.Tensor, T.Tensor],
bias: Optional[T.Tensor] = None):
super().__init__()
# We do not want track singular values so let's mix t into left and
# right vectors.
scale = T.sqrt(factors[1])
# Store factors of W^T but build factors for W.
self.lhs = T.nn.Parameter(factors[2].T * scale[None, :])
self.rhs = T.nn.Parameter(factors[0].T * scale[:, None])
self.bias = None
if bias is not None:
self.bias = T.nn.Parameter(bias)
self.in_features = self.lhs.shape[0]
self.out_features = self.rhs.shape[1]
@classmethod
def from_linear(cls, linear: T.nn.Linear, rank: Optional[int] = None,
tol: float = 1e-6):
with T.no_grad():
data = linear.weight.data
lhs, vals, rhs = T.linalg.svd(data)
if rank is None:
raise NotImplementedError
else:
lhs = lhs[:, :rank]
rhs = rhs[:rank, :]
vals = vals[:rank]
bias = None
if linear.bias is not None:
bias = T.clone(linear.bias.data)
return SVDCompressedLinear((lhs, vals, rhs), bias)
@classmethod
def from_random(cls, in_features: int, out_features: int, rank: int,
bias: bool = True):
lvecs = T.randn((out_features, rank))
svals = T.ones(rank)
rvecs = T.randn((rank, in_features))
bias_term = None
if bias:
bias_term = T.randn(out_features)
return SVDCompressedLinear((lvecs, svals, rvecs), bias_term)
def forward(self, input: T.Tensor) -> T.Tensor:
return compressed_linear_svd(input, self.lhs, self.rhs, self.bias)