File size: 1,679 Bytes
d0eeb98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from transformers import LlamaForCausalLM
from .configuration_asvd_llama import ASVDLlamaConfig
import torch.nn as nn
class ASVDLinear(nn.Module):
def __init__(self, in_features, out_features, rank, bias=True):
super().__init__()
self.BLinear = nn.Linear(in_features, rank, bias=False)
self.ALinear = nn.Linear(rank, out_features, bias=bias)
def forward(self, input):
return self.ALinear(self.BLinear(input))
class ASVDLlamaForCausalLM(LlamaForCausalLM):
config_class = ASVDLlamaConfig
def __init__(self, config:ASVDLlamaConfig):
super().__init__(config)
self.truncation_ranks=config.truncation_ranks
full_name_dict = {module: name for name, module in self.named_modules()}
linear_info = {}
modules = [self]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)
for name,module in self.named_modules():
if name in self.truncation_ranks:
info=linear_info[module]
new_layer=ASVDLinear(module.in_features,module.out_features,self.truncation_ranks[name],bias=module.bias is not None)
setattr(info["father"], info["name"], new_layer)
|