57894 / models /generator.py
Muhammad Naufal Rizqullah
Experiment 2
e61c431
import torch
import torch.nn as nn
from .base import WSConv2d, ConvBlock, PixelNorm
from config.core import config
class Generator(nn.Module):
def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
super().__init__()
self.gen = nn.Sequential(
self._block(input_dim + embed_size, features_generator*2, first_double_up=True),
self._block(features_generator*2, features_generator*4, first_double_up=False, final_layer=False,),
self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
self._block(features_generator*4, features_generator*2, first_double_up=False, final_layer=False,),
self._block(features_generator*2, features_generator, first_double_up=False, final_layer=False,),
self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True,),
)
self.image_size = image_size
self.embed_size = embed_size
self.embed = nn.Embedding(num_classes, embed_size)
self.embed_linear = nn.Linear(embed_size, embed_size)
def forward(self, noise, labels):
embedding_label = self.embed(labels)
linear_embedding_label = self.embed_linear(embedding_label).unsqueeze(2).unsqueeze(3)
noise = noise.view(noise.size(0), noise.size(1), 1, 1)
x = torch.cat([noise, linear_embedding_label], dim=1)
return self.gen(x)
def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
first_double_up=False, use_double=True, final_layer=False):
layers = []
if not final_layer:
layers.append(ConvBlock(in_channels, out_channels))
else:
layers.append(WSConv2d(in_channels, out_channels, kernel_size, stride, padding))
layers.append(nn.Tanh())
if use_double:
if first_double_up:
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 1, 0))
else:
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))
layers.append(PixelNorm())
layers.append(nn.LeakyReLU(0.2))
return nn.Sequential(*layers)
def test():
sample = torch.randn(1, config.INPUT_Z_DIM, 1, 1)
label = torch.tensor([1])
model = Generator(
embed_size=config.EMBED_SIZE,
num_classes=config.NUM_CLASSES,
image_size=config.IMAGE_SIZE,
features_generator=config.FEATURES_GENERATOR,
input_dim=config.INPUT_Z_DIM,
)
preds = model(sample, label)
print(preds.shape)