File size: 2,905 Bytes
eb42124
 
 
 
e61c431
eb42124
 
 
e61c431
eb42124
 
 
e61c431
 
 
 
 
 
 
eb42124
e61c431
eb42124
 
e61c431
eb42124
e61c431
eb42124
 
e61c431
 
 
 
 
 
 
eb42124
 
e61c431
eb42124
e61c431
eb42124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e61c431
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn

from .base import WSConv2d, ConvBlock, PixelNorm
from config.core import config


class Generator(nn.Module):
    def __init__(self,  embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
        super().__init__()

        self.gen = nn.Sequential(
           self._block(input_dim + embed_size, features_generator*2, first_double_up=True),
           self._block(features_generator*2, features_generator*4, first_double_up=False, final_layer=False,),
           self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
           self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
           self._block(features_generator*4, features_generator*2, first_double_up=False, final_layer=False,),
           self._block(features_generator*2, features_generator, first_double_up=False, final_layer=False,),
           self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True,),
        )
        
        self.image_size = image_size
        self.embed_size = embed_size
        
        self.embed = nn.Embedding(num_classes, embed_size)
        self.embed_linear = nn.Linear(embed_size, embed_size)

    def forward(self, noise, labels):
        embedding_label = self.embed(labels)
        linear_embedding_label = self.embed_linear(embedding_label).unsqueeze(2).unsqueeze(3)
        
        noise = noise.view(noise.size(0), noise.size(1), 1, 1)
        
        x = torch.cat([noise, linear_embedding_label], dim=1)
        
        return self.gen(x)


    def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
           first_double_up=False, use_double=True, final_layer=False):
        layers = []

        if not final_layer:
            layers.append(ConvBlock(in_channels, out_channels))
        else:
            layers.append(WSConv2d(in_channels, out_channels, kernel_size, stride, padding))
            layers.append(nn.Tanh())

        if use_double:
            if first_double_up:
                layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 1, 0))
            else:
                layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))

            layers.append(PixelNorm())
            layers.append(nn.LeakyReLU(0.2))

        return nn.Sequential(*layers)

def test():
    sample = torch.randn(1, config.INPUT_Z_DIM, 1, 1)
    label = torch.tensor([1])

    model = Generator(
                embed_size=config.EMBED_SIZE,
                num_classes=config.NUM_CLASSES,
                image_size=config.IMAGE_SIZE,
                features_generator=config.FEATURES_GENERATOR,
                input_dim=config.INPUT_Z_DIM,
            )

    preds = model(sample, label)
    print(preds.shape)