# Copied from # rut5compressed/nn/functional.py # rut5compressed/nn/modules.py # modules of original repository. from typing import Optional, Sequence, Tuple import torch as T class SVDCompressedLinearFunc(T.autograd.Function): @staticmethod def forward(ctx, input: T.Tensor, lhs: T.Tensor, rhs: T.Tensor, bias: Optional[T.Tensor] = None) -> T.Tensor: # See PEP-0465 on matmul operator associativity. # https://peps.python.org/pep-0465/#precedence-and-associativity output = (input @ lhs) @ rhs if bias is not None: output += bias[None, :] ctx.bias = bias is not None ctx.save_for_backward(input, lhs, rhs) return output @staticmethod def backward(ctx, grad_output: Sequence[T.Tensor]): input, lhs, rhs = ctx.saved_tensors # Flatten input and output gradients over the leading dimensions. inp_size = lhs.shape[0] out_size = rhs.shape[1] input_shape = input.shape input = input.reshape(-1, inp_size) grad_output = grad_output.reshape(-1, out_size) input_grad = None if ctx.needs_input_grad[0]: input_grad = (grad_output @ rhs.T) @ lhs.T lhs_grad = None if ctx.needs_input_grad[1]: # On practice for large models embedding dimension is large than # batch size. lhs_grad = input.T @ (grad_output @ rhs.T) rhs_grad = None if ctx.needs_input_grad[2]: # Again, batch size is usually lesser then embedding dimension. rhs_grad = (input @ lhs).T @ grad_output bias_grad = None if ctx.needs_input_grad[3]: bias_grad = grad_output.sum(axis=0) # Restore shape of input gradients. input_grad = input_grad.reshape(input_shape) return input_grad, lhs_grad, rhs_grad, bias_grad compressed_linear_svd = SVDCompressedLinearFunc.apply class SVDCompressedLinear(T.nn.Module): """Class SVDCompressedLinear is a layer which represents a weight matrix of lineaer layer in factorized view. >>> linear_layer = T.nn.Linear(10, 20) >>> svd_layer = SVDCompressedLinear.from_linear(linear_layer, rank=5) """ def __init__(self, factors: Tuple[T.Tensor, T.Tensor, T.Tensor], bias: Optional[T.Tensor] = None): super().__init__() # We do not want track singular values so let's mix t into left and # right vectors. scale = T.sqrt(factors[1]) # Store factors of W^T but build factors for W. self.lhs = T.nn.Parameter(factors[2].T * scale[None, :]) self.rhs = T.nn.Parameter(factors[0].T * scale[:, None]) self.bias = None if bias is not None: self.bias = T.nn.Parameter(bias) self.in_features = self.lhs.shape[0] self.out_features = self.rhs.shape[1] @classmethod def from_linear(cls, linear: T.nn.Linear, rank: Optional[int] = None, tol: float = 1e-6): with T.no_grad(): data = linear.weight.data lhs, vals, rhs = T.linalg.svd(data) if rank is None: raise NotImplementedError else: lhs = lhs[:, :rank] rhs = rhs[:rank, :] vals = vals[:rank] bias = None if linear.bias is not None: bias = T.clone(linear.bias.data) return SVDCompressedLinear((lhs, vals, rhs), bias) @classmethod def from_random(cls, in_features: int, out_features: int, rank: int, bias: bool = True): lvecs = T.randn((out_features, rank)) svals = T.ones(rank) rvecs = T.randn((rank, in_features)) bias_term = None if bias: bias_term = T.randn(out_features) return SVDCompressedLinear((lvecs, svals, rvecs), bias_term) def forward(self, input: T.Tensor) -> T.Tensor: return compressed_linear_svd(input, self.lhs, self.rhs, self.bias)