Spaces:
Running
Running
import torch.nn as nn | |
from .transformer import TransformerModel, EncoderLayer | |
# class TransformerModels(nn.Module): | |
class TransformerModels: | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
""" ------------------------------- 1) Normalize ------------------------------- """ | |
def replace_InstanceNorm1d_LayerNorm(self): | |
self.freeze_unfreeze(True) | |
for name, layer in self.model.named_modules(): | |
if isinstance(layer, nn.InstanceNorm1d): | |
num_features = layer.num_features | |
new_layer = nn.LayerNorm(normalized_shape=num_features).to(self.device) | |
parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]] | |
setattr(parent_module, name.split('.')[-1], new_layer) | |
return self.model | |
def set_affine_true_for_instance_norm(self): | |
self.freeze_unfreeze(True) | |
for name, layer in self.model.named_modules(): | |
if isinstance(layer, nn.InstanceNorm1d): | |
new_layer = nn.InstanceNorm1d(num_features=100, affine=True).to(self.device) | |
parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]] | |
setattr(parent_module, name.split('.')[-1], new_layer) | |
return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" -------------------------- 2) Activation Function -------------------------- """ | |
def replace_activation_function(self, activation): | |
self.freeze_unfreeze(True) | |
functions = { | |
"GELU": nn.GELU(), | |
"LeakyReLU": nn.LeakyReLU(), | |
"ELU": nn.ELU(), | |
"Mish": nn.Mish(), | |
# "ReLU": nn.ReLU(), | |
} | |
def replace_activation_in_module(module, activation_layer): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.ReLU): | |
setattr(module, name, activation_layer) | |
else: | |
replace_activation_in_module(child, activation_layer) | |
new_activation_layer = functions[activation].to(self.device) | |
replace_activation_in_module(self.model, new_activation_layer) | |
return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" ---------------------------- 3) New Encoder Layers ------------------------- """ | |
def add_encoder_layers(self, num_new_layers=2): | |
self.freeze_unfreeze(True) | |
new_encoder_layers = [EncoderLayer(512, 4, 0.1, nn.ReLU()).to(self.device) for _ in range(num_new_layers)] | |
for i, new_layer in enumerate(new_encoder_layers): | |
self.model.transformer_layers.insert(4 + i, new_layer.to(self.device)) | |
return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" -------------------------------- 4) Dropout -------------------------------- """ | |
# def dropout_value_change(self, val=0.1): | |
# self.freeze_unfreeze(True) | |
# for layer in self.model.modules(): | |
# if isinstance(layer, nn.Dropout): | |
# layer.p = val | |
# | |
# return self.model | |
def dropout_value_change(self, val=0.1): | |
self.freeze_unfreeze(True) | |
def replace_dropouts_in_module(module, rate): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Dropout): | |
setattr(module, name, nn.Dropout(rate).to(self.device)) | |
else: | |
replace_dropouts_in_module(child, rate) | |
replace_dropouts_in_module(self.model, val) | |
return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" ------------------------- 5) Output linear layers -------------------------- """ | |
def change_linear_output_layers(self): | |
output_layers_names = [ | |
"output_linear1", | |
"output_linear2", | |
"output_linear3", | |
"output_linear_bin1", | |
"output_linear_bin2", | |
"output_linear_bin3", | |
] | |
for name, param in self.model.named_parameters(): | |
param.requires_grad = False | |
if name.split(".")[0] in output_layers_names: | |
param.requires_grad = True | |
output_linear1 = self.model.output_linear1 | |
output_linear2 = self.model.output_linear2 | |
output_linear3 = self.model.output_linear3 | |
output_linear_bin1 = self.model.output_linear_bin1 | |
output_linear_bin2 = self.model.output_linear_bin2 | |
output_linear_bin3 = self.model.output_linear_bin3 | |
output_linear11 = nn.Linear(output_linear1.out_features, | |
output_linear1.out_features).to(self.device) | |
output_linear21 = nn.Linear(output_linear2.out_features, | |
output_linear2.out_features).to(self.device) | |
# self.model.output_layers = nn.Sequential( | |
# output_linear1, | |
# output_linear11, | |
# output_linear2, | |
# output_linear21, | |
# output_linear3, | |
# output_linear_bin1, | |
# output_linear_bin2, | |
# output_linear_bin3, | |
# ) | |
self.model.insert(6, output_linear11) | |
self.model.insert(8, output_linear21) | |
return self.model | |
# def change_linear_output_layers(self): | |
# output_layers_names = [ | |
# "output_linear1", | |
# "output_linear2", | |
# "output_linear3", | |
# "output_linear_bin1", | |
# "output_linear_bin2", | |
# "output_linear_bin3", | |
# ] | |
# for name, param in self.model.named_parameters(): | |
# param.requires_grad = False | |
# if name.split(".")[0] in output_layers_names: | |
# param.requires_grad = True | |
# | |
# output_linear1 = self.model.output_linear1 | |
# output_linear2 = self.model.output_linear2 | |
# output_linear3 = self.model.output_linear3 | |
# # output_linear_bin1 = self.model.output_linear_bin1 | |
# # output_linear_bin2 = self.model.output_linear_bin2 | |
# # output_linear_bin3 = self.model.output_linear_bin3 | |
# | |
# output_linear11 = nn.Linear(output_linear1.out_features, | |
# output_linear1.out_features).to(self.device) | |
# output_linear21 = nn.Linear(output_linear2.out_features, | |
# output_linear2.out_features).to(self.device) | |
# | |
# self.model.output_linear1.append(output_linear11.to(self.device)) | |
# self.model.output_linear2.append(output_linear21.to(self.device)) | |
# | |
# # self.model.output_layers = nn.Sequential( | |
# # output_linear1, | |
# # output_linear11, | |
# # output_linear2, | |
# # output_linear21, | |
# # output_linear3, | |
# # output_linear_bin1, | |
# # output_linear_bin2, | |
# # output_linear_bin3, | |
# # ) | |
# | |
# return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" ---------------------------- 6) Cross-Attention ---------------------------- """ | |
def add_cross_attention(self, embed_dim=512, num_heads=8, dropout=0.1): | |
self.freeze_unfreeze(True) | |
for idx, layer in enumerate(self.model.transformer_layers): | |
cross_attn_layer = CrossAttentionLayer(embed_dim, num_heads, dropout).to(self.device) | |
layer.gen_attn = nn.Sequential(layer.gen_attn, cross_attn_layer).to(self.device) | |
return self.model | |
""" ---------------------------------------------------------------------------- """ | |
""" -------------------------- 7) Residual Connections? ------------------------- """ | |
""" ---------------------------------------------------------------------------- """ | |
""" ------------------------------- 8) Attention Heads? (check if works with same params) ------------------------------- """ | |
""" ---------------------------------------------------------------------------- """ | |
#Add LayerNorm Before/After Attention | |
# ADAM ? | |
# weight decay ? | |
# learning rate? | |
def freeze_unfreeze(self, flag): | |
for param in self.model.parameters(): | |
param.requires_grad = flag | |
def count_parameters(self): | |
model = self.model | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
untrainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) | |
print(f"Trainable parameters: {trainable_params}") | |
print(f"Untrainable parameters: {untrainable_params}") | |
return trainable_params, untrainable_params | |
class CrossAttentionLayer(nn.Module): | |
def __init__(self, embed_dim, num_heads, dropout=0.1): | |
super(CrossAttentionLayer, self).__init__() | |
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) | |
self.norm = nn.LayerNorm(embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, query, key_value, attn_mask=None): | |
attn_output, _ = self.cross_attn(query, key_value, key_value, attn_mask=attn_mask) | |
return self.norm(self.dropout(attn_output) + query) | |