File size: 4,062 Bytes
db45d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copied from
#   rut5compressed/nn/functional.py
#   rut5compressed/nn/modules.py
# modules of original repository.

from typing import Optional, Sequence, Tuple

import torch as T


class SVDCompressedLinearFunc(T.autograd.Function):

    @staticmethod
    def forward(ctx, input: T.Tensor, lhs: T.Tensor,
                rhs: T.Tensor, bias: Optional[T.Tensor] = None) -> T.Tensor:
        # See PEP-0465 on matmul operator associativity.
        # https://peps.python.org/pep-0465/#precedence-and-associativity
        output = (input @ lhs) @ rhs
        if bias is not None:
            output += bias[None, :]
        ctx.bias = bias is not None
        ctx.save_for_backward(input, lhs, rhs)
        return output

    @staticmethod
    def backward(ctx, grad_output: Sequence[T.Tensor]):
        input, lhs, rhs = ctx.saved_tensors

        # Flatten input and output gradients over the leading dimensions.
        inp_size = lhs.shape[0]
        out_size = rhs.shape[1]
        input_shape = input.shape
        input = input.reshape(-1, inp_size)
        grad_output = grad_output.reshape(-1, out_size)

        input_grad = None
        if ctx.needs_input_grad[0]:
            input_grad = (grad_output @ rhs.T) @ lhs.T

        lhs_grad = None
        if ctx.needs_input_grad[1]:
            # On practice for large models embedding dimension is large than
            # batch size.
            lhs_grad = input.T @ (grad_output @ rhs.T)

        rhs_grad = None
        if ctx.needs_input_grad[2]:
            # Again, batch size is usually lesser then embedding dimension.
            rhs_grad = (input @ lhs).T @ grad_output

        bias_grad = None
        if ctx.needs_input_grad[3]:
            bias_grad = grad_output.sum(axis=0)

        # Restore shape of input gradients.
        input_grad = input_grad.reshape(input_shape)
        return input_grad, lhs_grad, rhs_grad, bias_grad


compressed_linear_svd = SVDCompressedLinearFunc.apply


class SVDCompressedLinear(T.nn.Module):
    """Class SVDCompressedLinear is a layer which represents a weight matrix of
    lineaer layer in factorized view.

    >>> linear_layer = T.nn.Linear(10, 20)
    >>> svd_layer = SVDCompressedLinear.from_linear(linear_layer, rank=5)
    """

    def __init__(self, factors: Tuple[T.Tensor, T.Tensor, T.Tensor],
                 bias: Optional[T.Tensor] = None):
        super().__init__()

        # We do not want track singular values so let's mix t into left and
        # right vectors.
        scale = T.sqrt(factors[1])

        # Store factors of W^T but build factors for W.
        self.lhs = T.nn.Parameter(factors[2].T * scale[None, :])
        self.rhs = T.nn.Parameter(factors[0].T * scale[:, None])

        self.bias = None
        if bias is not None:
            self.bias = T.nn.Parameter(bias)

        self.in_features = self.lhs.shape[0]
        self.out_features = self.rhs.shape[1]

    @classmethod
    def from_linear(cls, linear: T.nn.Linear, rank: Optional[int] = None,
                    tol: float = 1e-6):
        with T.no_grad():
            data = linear.weight.data
            lhs, vals, rhs = T.linalg.svd(data)
            if rank is None:
                raise NotImplementedError
            else:
                lhs = lhs[:, :rank]
                rhs = rhs[:rank, :]
                vals = vals[:rank]

            bias = None
            if linear.bias is not None:
                bias = T.clone(linear.bias.data)

        return SVDCompressedLinear((lhs, vals, rhs), bias)

    @classmethod
    def from_random(cls, in_features: int, out_features: int, rank: int,
                    bias: bool = True):
        lvecs = T.randn((out_features, rank))
        svals = T.ones(rank)
        rvecs = T.randn((rank, in_features))
        bias_term = None
        if bias:
            bias_term = T.randn(out_features)
        return SVDCompressedLinear((lvecs, svals, rvecs), bias_term)

    def forward(self, input: T.Tensor) -> T.Tensor:
        return compressed_linear_svd(input, self.lhs, self.rhs, self.bias)