Llama-2-7b-hf-asvd95 / modeling_asvd_llama.py
hahnyuan's picture
Upload ASVDLlamaForCausalLM
d0eeb98
raw
history blame
1.68 kB
from transformers import LlamaForCausalLM
from .configuration_asvd_llama import ASVDLlamaConfig
import torch.nn as nn
class ASVDLinear(nn.Module):
def __init__(self, in_features, out_features, rank, bias=True):
super().__init__()
self.BLinear = nn.Linear(in_features, rank, bias=False)
self.ALinear = nn.Linear(rank, out_features, bias=bias)
def forward(self, input):
return self.ALinear(self.BLinear(input))
class ASVDLlamaForCausalLM(LlamaForCausalLM):
config_class = ASVDLlamaConfig
def __init__(self, config:ASVDLlamaConfig):
super().__init__(config)
self.truncation_ranks=config.truncation_ranks
full_name_dict = {module: name for name, module in self.named_modules()}
linear_info = {}
modules = [self]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)
for name,module in self.named_modules():
if name in self.truncation_ranks:
info=linear_info[module]
new_layer=ASVDLinear(module.in_features,module.out_features,self.truncation_ranks[name],bias=module.bias is not None)
setattr(info["father"], info["name"], new_layer)