57894 / models /discriminator.py
Muhammad Naufal Rizqullah
Experiment 2
e61c431
import torch
import torch.nn as nn
from .base import WSConv2d, ConvBlock
from config.core import config
class Discriminator(nn.Module):
def __init__(self, num_classes=3, embed_size=128, image_size=128, features_discriminator=128, image_channel=3, label_channel=3):
super().__init__()
self.num_classes = num_classes
self.image_size = image_size
self.embed_size = embed_size
self.label_channel = label_channel
self.disc = nn.Sequential(
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
self._block_discriminator(features_discriminator * 4, features_discriminator *4 , kernel_size=4, stride=2, padding=1),
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
)
self.embed = nn.Embedding(num_classes, embed_size)
self.embed_linear = nn.Linear(embed_size, label_channel*image_size*image_size)
def forward(self, image, label):
embedding = self.embed(label)
linear_embedding = self.embed_linear(embedding)
embedding_layer = linear_embedding.view(
label.shape[0],
self.label_channel,
self.image_size,
self.image_size
)
data = torch.cat([image, embedding_layer], dim=1)
x = self.disc(data)
return x.view(len(x), -1)
def _block_discriminator(
self,
input_channels,
output_channels,
kernel_size=3,
stride=2,
padding=0,
final_layer=False
):
if not final_layer:
return nn.Sequential(
ConvBlock(input_channels, output_channels),
WSConv2d(output_channels, output_channels, kernel_size, stride, padding),
)
else:
return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
def test():
sample = torch.randn(1, 3, 128, 128)
label = torch.tensor([1])
model = Discriminator(
num_classes=config.NUM_CLASSES,
embed_size=config.EMBED_SIZE,
image_size=config.IMAGE_SIZE,
features_discriminator=config.FEATURES_DISCRIMINATOR,
image_channel=config.IMAGE_CHANNEL,
label_channel=config.LABEL_CHANNEL
)
preds = model(sample, label)
print(preds.shape)