bart-base-detox-svd / modeling_bart.py
not-found's picture
Add SVD-compressed model with rank 512
db45d00
"""This module uses parts of rut5compressed. It shares the same module
structure as model used in neural network compression experiments with
rut5compressed.
"""
from functools import partial
from typing import Optional
import torch as T
from transformers import BartForConditionalGeneration
from .configuration_bart import SVDCompressedBartConfig
from .modules import SVDCompressedLinear
from .util import compress_linear_svd, map_module
class SVDCompressedBartForConditionGeneration(BartForConditionalGeneration):
"""Class SVDCompressedBartForConditionGeneration defines a BART-based model
with compressed linear layers with SVD.
"""
LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'
config_class = SVDCompressedBartConfig
def __init__(self, config: SVDCompressedBartConfig,
rank: Optional[int] = None,
compress: bool = False):
super().__init__(config)
self.rank = rank or config.rank
compress_fn = partial(compress_linear_svd, rank=self.rank)
if not compress:
compress_fn = self.convert
self.model = map_module(self.model, compress_fn, self.LAYERS)
def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
if not isinstance(module, T.nn.Linear):
return module
return SVDCompressedLinear.from_random(module.in_features,
module.out_features, self.rank)
SVDCompressedBartForConditionGeneration \
.register_for_auto_class('AutoModelForSeq2SeqLM')