|
"""This module uses parts of rut5compressed. It shares the same module |
|
structure as model used in neural network compression experiments with |
|
rut5compressed. |
|
""" |
|
|
|
from functools import partial |
|
from typing import Optional |
|
|
|
import torch as T |
|
from transformers import BartForConditionalGeneration |
|
|
|
from .configuration_bart import SVDCompressedBartConfig |
|
from .modules import SVDCompressedLinear |
|
from .util import compress_linear_svd, map_module |
|
|
|
|
|
class SVDCompressedBartForConditionGeneration(BartForConditionalGeneration): |
|
"""Class SVDCompressedBartForConditionGeneration defines a BART-based model |
|
with compressed linear layers with SVD. |
|
""" |
|
|
|
LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' |
|
|
|
config_class = SVDCompressedBartConfig |
|
|
|
def __init__(self, config: SVDCompressedBartConfig, |
|
rank: Optional[int] = None, |
|
compress: bool = False): |
|
super().__init__(config) |
|
self.rank = rank or config.rank |
|
|
|
compress_fn = partial(compress_linear_svd, rank=self.rank) |
|
if not compress: |
|
compress_fn = self.convert |
|
self.model = map_module(self.model, compress_fn, self.LAYERS) |
|
|
|
def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: |
|
if not isinstance(module, T.nn.Linear): |
|
return module |
|
return SVDCompressedLinear.from_random(module.in_features, |
|
module.out_features, self.rank) |
|
|
|
|
|
SVDCompressedBartForConditionGeneration \ |
|
.register_for_auto_class('AutoModelForSeq2SeqLM') |
|
|