File size: 1,554 Bytes
db45d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""This module uses parts of rut5compressed. It shares the same module
structure as model used in neural network compression experiments with
rut5compressed.
"""

from functools import partial
from typing import Optional

import torch as T
from transformers import BartForConditionalGeneration

from .configuration_bart import SVDCompressedBartConfig
from .modules import SVDCompressedLinear
from .util import compress_linear_svd, map_module


class SVDCompressedBartForConditionGeneration(BartForConditionalGeneration):
    """Class SVDCompressedBartForConditionGeneration defines a BART-based model
    with compressed linear layers with SVD.
    """

    LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'

    config_class = SVDCompressedBartConfig

    def __init__(self, config: SVDCompressedBartConfig,
                 rank: Optional[int] = None,
                 compress: bool = False):
        super().__init__(config)
        self.rank = rank or config.rank

        compress_fn = partial(compress_linear_svd, rank=self.rank)
        if not compress:
            compress_fn = self.convert
        self.model = map_module(self.model, compress_fn, self.LAYERS)

    def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
        if not isinstance(module, T.nn.Linear):
            return module
        return SVDCompressedLinear.from_random(module.in_features,
                                               module.out_features, self.rank)


SVDCompressedBartForConditionGeneration \
    .register_for_auto_class('AutoModelForSeq2SeqLM')