import torch.nn as nn from .transformer import TransformerModel, EncoderLayer # class TransformerModels(nn.Module): class TransformerModels: def __init__(self, model, device): self.model = model self.device = device """ ------------------------------- 1) Normalize ------------------------------- """ def replace_InstanceNorm1d_LayerNorm(self): self.freeze_unfreeze(True) for name, layer in self.model.named_modules(): if isinstance(layer, nn.InstanceNorm1d): num_features = layer.num_features new_layer = nn.LayerNorm(normalized_shape=num_features).to(self.device) parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]] setattr(parent_module, name.split('.')[-1], new_layer) return self.model def set_affine_true_for_instance_norm(self): self.freeze_unfreeze(True) for name, layer in self.model.named_modules(): if isinstance(layer, nn.InstanceNorm1d): new_layer = nn.InstanceNorm1d(num_features=100, affine=True).to(self.device) parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]] setattr(parent_module, name.split('.')[-1], new_layer) return self.model """ ---------------------------------------------------------------------------- """ """ -------------------------- 2) Activation Function -------------------------- """ def replace_activation_function(self, activation): self.freeze_unfreeze(True) functions = { "GELU": nn.GELU(), "LeakyReLU": nn.LeakyReLU(), "ELU": nn.ELU(), "Mish": nn.Mish(), # "ReLU": nn.ReLU(), } def replace_activation_in_module(module, activation_layer): for name, child in module.named_children(): if isinstance(child, nn.ReLU): setattr(module, name, activation_layer) else: replace_activation_in_module(child, activation_layer) new_activation_layer = functions[activation].to(self.device) replace_activation_in_module(self.model, new_activation_layer) return self.model """ ---------------------------------------------------------------------------- """ """ ---------------------------- 3) New Encoder Layers ------------------------- """ def add_encoder_layers(self, num_new_layers=2): self.freeze_unfreeze(True) new_encoder_layers = [EncoderLayer(512, 4, 0.1, nn.ReLU()).to(self.device) for _ in range(num_new_layers)] for i, new_layer in enumerate(new_encoder_layers): self.model.transformer_layers.insert(4 + i, new_layer.to(self.device)) return self.model """ ---------------------------------------------------------------------------- """ """ -------------------------------- 4) Dropout -------------------------------- """ # def dropout_value_change(self, val=0.1): # self.freeze_unfreeze(True) # for layer in self.model.modules(): # if isinstance(layer, nn.Dropout): # layer.p = val # # return self.model def dropout_value_change(self, val=0.1): self.freeze_unfreeze(True) def replace_dropouts_in_module(module, rate): for name, child in module.named_children(): if isinstance(child, nn.Dropout): setattr(module, name, nn.Dropout(rate).to(self.device)) else: replace_dropouts_in_module(child, rate) replace_dropouts_in_module(self.model, val) return self.model """ ---------------------------------------------------------------------------- """ """ ------------------------- 5) Output linear layers -------------------------- """ def change_linear_output_layers(self): output_layers_names = [ "output_linear1", "output_linear2", "output_linear3", "output_linear_bin1", "output_linear_bin2", "output_linear_bin3", ] for name, param in self.model.named_parameters(): param.requires_grad = False if name.split(".")[0] in output_layers_names: param.requires_grad = True output_linear1 = self.model.output_linear1 output_linear2 = self.model.output_linear2 output_linear3 = self.model.output_linear3 output_linear_bin1 = self.model.output_linear_bin1 output_linear_bin2 = self.model.output_linear_bin2 output_linear_bin3 = self.model.output_linear_bin3 output_linear11 = nn.Linear(output_linear1.out_features, output_linear1.out_features).to(self.device) output_linear21 = nn.Linear(output_linear2.out_features, output_linear2.out_features).to(self.device) # self.model.output_layers = nn.Sequential( # output_linear1, # output_linear11, # output_linear2, # output_linear21, # output_linear3, # output_linear_bin1, # output_linear_bin2, # output_linear_bin3, # ) self.model.insert(6, output_linear11) self.model.insert(8, output_linear21) return self.model # def change_linear_output_layers(self): # output_layers_names = [ # "output_linear1", # "output_linear2", # "output_linear3", # "output_linear_bin1", # "output_linear_bin2", # "output_linear_bin3", # ] # for name, param in self.model.named_parameters(): # param.requires_grad = False # if name.split(".")[0] in output_layers_names: # param.requires_grad = True # # output_linear1 = self.model.output_linear1 # output_linear2 = self.model.output_linear2 # output_linear3 = self.model.output_linear3 # # output_linear_bin1 = self.model.output_linear_bin1 # # output_linear_bin2 = self.model.output_linear_bin2 # # output_linear_bin3 = self.model.output_linear_bin3 # # output_linear11 = nn.Linear(output_linear1.out_features, # output_linear1.out_features).to(self.device) # output_linear21 = nn.Linear(output_linear2.out_features, # output_linear2.out_features).to(self.device) # # self.model.output_linear1.append(output_linear11.to(self.device)) # self.model.output_linear2.append(output_linear21.to(self.device)) # # # self.model.output_layers = nn.Sequential( # # output_linear1, # # output_linear11, # # output_linear2, # # output_linear21, # # output_linear3, # # output_linear_bin1, # # output_linear_bin2, # # output_linear_bin3, # # ) # # return self.model """ ---------------------------------------------------------------------------- """ """ ---------------------------- 6) Cross-Attention ---------------------------- """ def add_cross_attention(self, embed_dim=512, num_heads=8, dropout=0.1): self.freeze_unfreeze(True) for idx, layer in enumerate(self.model.transformer_layers): cross_attn_layer = CrossAttentionLayer(embed_dim, num_heads, dropout).to(self.device) layer.gen_attn = nn.Sequential(layer.gen_attn, cross_attn_layer).to(self.device) return self.model """ ---------------------------------------------------------------------------- """ """ -------------------------- 7) Residual Connections? ------------------------- """ """ ---------------------------------------------------------------------------- """ """ ------------------------------- 8) Attention Heads? (check if works with same params) ------------------------------- """ """ ---------------------------------------------------------------------------- """ #Add LayerNorm Before/After Attention # ADAM ? # weight decay ? # learning rate? def freeze_unfreeze(self, flag): for param in self.model.parameters(): param.requires_grad = flag def count_parameters(self): model = self.model trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) untrainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) print(f"Trainable parameters: {trainable_params}") print(f"Untrainable parameters: {untrainable_params}") return trainable_params, untrainable_params class CrossAttentionLayer(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super(CrossAttentionLayer, self).__init__() self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, query, key_value, attn_mask=None): attn_output, _ = self.cross_attn(query, key_value, key_value, attn_mask=attn_mask) return self.norm(self.dropout(attn_output) + query)