Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from .base import WSConv2d, ConvBlock | |
from config.core import config | |
class Discriminator(nn.Module): | |
def __init__(self, num_classes=3, embed_size=128, image_size=128, features_discriminator=128, image_channel=3, label_channel=3): | |
super().__init__() | |
self.num_classes = num_classes | |
self.image_size = image_size | |
self.embed_size = embed_size | |
self.label_channel = label_channel | |
self.disc = nn.Sequential( | |
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1), | |
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1), | |
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1), | |
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1), | |
self._block_discriminator(features_discriminator * 4, features_discriminator *4 , kernel_size=4, stride=2, padding=1), | |
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True) | |
) | |
self.embed = nn.Embedding(num_classes, embed_size) | |
self.embed_linear = nn.Linear(embed_size, label_channel*image_size*image_size) | |
def forward(self, image, label): | |
embedding = self.embed(label) | |
linear_embedding = self.embed_linear(embedding) | |
embedding_layer = linear_embedding.view( | |
label.shape[0], | |
self.label_channel, | |
self.image_size, | |
self.image_size | |
) | |
data = torch.cat([image, embedding_layer], dim=1) | |
x = self.disc(data) | |
return x.view(len(x), -1) | |
def _block_discriminator( | |
self, | |
input_channels, | |
output_channels, | |
kernel_size=3, | |
stride=2, | |
padding=0, | |
final_layer=False | |
): | |
if not final_layer: | |
return nn.Sequential( | |
ConvBlock(input_channels, output_channels), | |
WSConv2d(output_channels, output_channels, kernel_size, stride, padding), | |
) | |
else: | |
return WSConv2d(input_channels, output_channels, kernel_size, stride, padding) | |
def test(): | |
sample = torch.randn(1, 3, 128, 128) | |
label = torch.tensor([1]) | |
model = Discriminator( | |
num_classes=config.NUM_CLASSES, | |
embed_size=config.EMBED_SIZE, | |
image_size=config.IMAGE_SIZE, | |
features_discriminator=config.FEATURES_DISCRIMINATOR, | |
image_channel=config.IMAGE_CHANNEL, | |
label_channel=config.LABEL_CHANNEL | |
) | |
preds = model(sample, label) | |
print(preds.shape) |