House_Layout_Generator / house_diffusion /transformer_models.py
Faran Fahandezh
Add application file
ddbbf37
import torch.nn as nn
from .transformer import TransformerModel, EncoderLayer
# class TransformerModels(nn.Module):
class TransformerModels:
def __init__(self, model, device):
self.model = model
self.device = device
""" ------------------------------- 1) Normalize ------------------------------- """
def replace_InstanceNorm1d_LayerNorm(self):
self.freeze_unfreeze(True)
for name, layer in self.model.named_modules():
if isinstance(layer, nn.InstanceNorm1d):
num_features = layer.num_features
new_layer = nn.LayerNorm(normalized_shape=num_features).to(self.device)
parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]]
setattr(parent_module, name.split('.')[-1], new_layer)
return self.model
def set_affine_true_for_instance_norm(self):
self.freeze_unfreeze(True)
for name, layer in self.model.named_modules():
if isinstance(layer, nn.InstanceNorm1d):
new_layer = nn.InstanceNorm1d(num_features=100, affine=True).to(self.device)
parent_module = dict(self.model.named_modules())[name.rsplit('.', 1)[0]]
setattr(parent_module, name.split('.')[-1], new_layer)
return self.model
""" ---------------------------------------------------------------------------- """
""" -------------------------- 2) Activation Function -------------------------- """
def replace_activation_function(self, activation):
self.freeze_unfreeze(True)
functions = {
"GELU": nn.GELU(),
"LeakyReLU": nn.LeakyReLU(),
"ELU": nn.ELU(),
"Mish": nn.Mish(),
# "ReLU": nn.ReLU(),
}
def replace_activation_in_module(module, activation_layer):
for name, child in module.named_children():
if isinstance(child, nn.ReLU):
setattr(module, name, activation_layer)
else:
replace_activation_in_module(child, activation_layer)
new_activation_layer = functions[activation].to(self.device)
replace_activation_in_module(self.model, new_activation_layer)
return self.model
""" ---------------------------------------------------------------------------- """
""" ---------------------------- 3) New Encoder Layers ------------------------- """
def add_encoder_layers(self, num_new_layers=2):
self.freeze_unfreeze(True)
new_encoder_layers = [EncoderLayer(512, 4, 0.1, nn.ReLU()).to(self.device) for _ in range(num_new_layers)]
for i, new_layer in enumerate(new_encoder_layers):
self.model.transformer_layers.insert(4 + i, new_layer.to(self.device))
return self.model
""" ---------------------------------------------------------------------------- """
""" -------------------------------- 4) Dropout -------------------------------- """
# def dropout_value_change(self, val=0.1):
# self.freeze_unfreeze(True)
# for layer in self.model.modules():
# if isinstance(layer, nn.Dropout):
# layer.p = val
#
# return self.model
def dropout_value_change(self, val=0.1):
self.freeze_unfreeze(True)
def replace_dropouts_in_module(module, rate):
for name, child in module.named_children():
if isinstance(child, nn.Dropout):
setattr(module, name, nn.Dropout(rate).to(self.device))
else:
replace_dropouts_in_module(child, rate)
replace_dropouts_in_module(self.model, val)
return self.model
""" ---------------------------------------------------------------------------- """
""" ------------------------- 5) Output linear layers -------------------------- """
def change_linear_output_layers(self):
output_layers_names = [
"output_linear1",
"output_linear2",
"output_linear3",
"output_linear_bin1",
"output_linear_bin2",
"output_linear_bin3",
]
for name, param in self.model.named_parameters():
param.requires_grad = False
if name.split(".")[0] in output_layers_names:
param.requires_grad = True
output_linear1 = self.model.output_linear1
output_linear2 = self.model.output_linear2
output_linear3 = self.model.output_linear3
output_linear_bin1 = self.model.output_linear_bin1
output_linear_bin2 = self.model.output_linear_bin2
output_linear_bin3 = self.model.output_linear_bin3
output_linear11 = nn.Linear(output_linear1.out_features,
output_linear1.out_features).to(self.device)
output_linear21 = nn.Linear(output_linear2.out_features,
output_linear2.out_features).to(self.device)
# self.model.output_layers = nn.Sequential(
# output_linear1,
# output_linear11,
# output_linear2,
# output_linear21,
# output_linear3,
# output_linear_bin1,
# output_linear_bin2,
# output_linear_bin3,
# )
self.model.insert(6, output_linear11)
self.model.insert(8, output_linear21)
return self.model
# def change_linear_output_layers(self):
# output_layers_names = [
# "output_linear1",
# "output_linear2",
# "output_linear3",
# "output_linear_bin1",
# "output_linear_bin2",
# "output_linear_bin3",
# ]
# for name, param in self.model.named_parameters():
# param.requires_grad = False
# if name.split(".")[0] in output_layers_names:
# param.requires_grad = True
#
# output_linear1 = self.model.output_linear1
# output_linear2 = self.model.output_linear2
# output_linear3 = self.model.output_linear3
# # output_linear_bin1 = self.model.output_linear_bin1
# # output_linear_bin2 = self.model.output_linear_bin2
# # output_linear_bin3 = self.model.output_linear_bin3
#
# output_linear11 = nn.Linear(output_linear1.out_features,
# output_linear1.out_features).to(self.device)
# output_linear21 = nn.Linear(output_linear2.out_features,
# output_linear2.out_features).to(self.device)
#
# self.model.output_linear1.append(output_linear11.to(self.device))
# self.model.output_linear2.append(output_linear21.to(self.device))
#
# # self.model.output_layers = nn.Sequential(
# # output_linear1,
# # output_linear11,
# # output_linear2,
# # output_linear21,
# # output_linear3,
# # output_linear_bin1,
# # output_linear_bin2,
# # output_linear_bin3,
# # )
#
# return self.model
""" ---------------------------------------------------------------------------- """
""" ---------------------------- 6) Cross-Attention ---------------------------- """
def add_cross_attention(self, embed_dim=512, num_heads=8, dropout=0.1):
self.freeze_unfreeze(True)
for idx, layer in enumerate(self.model.transformer_layers):
cross_attn_layer = CrossAttentionLayer(embed_dim, num_heads, dropout).to(self.device)
layer.gen_attn = nn.Sequential(layer.gen_attn, cross_attn_layer).to(self.device)
return self.model
""" ---------------------------------------------------------------------------- """
""" -------------------------- 7) Residual Connections? ------------------------- """
""" ---------------------------------------------------------------------------- """
""" ------------------------------- 8) Attention Heads? (check if works with same params) ------------------------------- """
""" ---------------------------------------------------------------------------- """
#Add LayerNorm Before/After Attention
# ADAM ?
# weight decay ?
# learning rate?
def freeze_unfreeze(self, flag):
for param in self.model.parameters():
param.requires_grad = flag
def count_parameters(self):
model = self.model
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
untrainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"Trainable parameters: {trainable_params}")
print(f"Untrainable parameters: {untrainable_params}")
return trainable_params, untrainable_params
class CrossAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(CrossAttentionLayer, self).__init__()
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key_value, attn_mask=None):
attn_output, _ = self.cross_attn(query, key_value, key_value, attn_mask=attn_mask)
return self.norm(self.dropout(attn_output) + query)