"""This module uses parts of rut5compressed. It shares the same module structure as model used in neural network compression experiments with rut5compressed. """ from functools import partial from typing import Optional import torch as T from transformers import BartForConditionalGeneration from .configuration_bart import SVDCompressedBartConfig from .modules import SVDCompressedLinear from .util import compress_linear_svd, map_module class SVDCompressedBartForConditionGeneration(BartForConditionalGeneration): """Class SVDCompressedBartForConditionGeneration defines a BART-based model with compressed linear layers with SVD. """ LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' config_class = SVDCompressedBartConfig def __init__(self, config: SVDCompressedBartConfig, rank: Optional[int] = None, compress: bool = False): super().__init__(config) self.rank = rank or config.rank compress_fn = partial(compress_linear_svd, rank=self.rank) if not compress: compress_fn = self.convert self.model = map_module(self.model, compress_fn, self.LAYERS) def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: if not isinstance(module, T.nn.Linear): return module return SVDCompressedLinear.from_random(module.in_features, module.out_features, self.rank) SVDCompressedBartForConditionGeneration \ .register_for_auto_class('AutoModelForSeq2SeqLM')