File size: 2,898 Bytes
eb42124
 
 
 
e61c431
eb42124
 
 
e61c431
eb42124
e61c431
eb42124
 
e61c431
 
 
 
eb42124
e61c431
 
 
 
 
 
eb42124
e61c431
 
 
eb42124
 
 
e61c431
 
 
 
eb42124
e61c431
eb42124
 
 
e61c431
 
 
eb42124
e61c431
eb42124
 
 
e61c431
 
 
 
 
 
 
eb42124
e61c431
eb42124
 
e61c431
eb42124
 
 
e61c431
 
 
 
eb42124
e61c431
 
 
 
 
 
 
 
eb42124
e61c431
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn

from .base import WSConv2d, ConvBlock
from config.core import config


class Discriminator(nn.Module):
    def __init__(self, num_classes=3, embed_size=128, image_size=128, features_discriminator=128, image_channel=3, label_channel=3):
        super().__init__()
        
        self.num_classes = num_classes
        self.image_size = image_size
        
        self.embed_size = embed_size
        self.label_channel = label_channel
        
        self.disc = nn.Sequential(
            self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
            self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
            self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
            self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
            self._block_discriminator(features_discriminator * 4, features_discriminator  *4 , kernel_size=4, stride=2, padding=1),
            self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
        )
        
        self.embed = nn.Embedding(num_classes, embed_size)
        self.embed_linear = nn.Linear(embed_size, label_channel*image_size*image_size)

    def forward(self, image, label):
        embedding = self.embed(label)
        
        linear_embedding = self.embed_linear(embedding)
        
        embedding_layer = linear_embedding.view(
            label.shape[0],
            self.label_channel,
            self.image_size,
            self.image_size
        )
        
        data = torch.cat([image, embedding_layer], dim=1)
        
        x = self.disc(data)
        
        return x.view(len(x), -1)

    def _block_discriminator(
        self,
        input_channels,
        output_channels,
        kernel_size=3,
        stride=2,
        padding=0,
        final_layer=False
    ):
        if not final_layer:            
            return nn.Sequential(
                ConvBlock(input_channels, output_channels),
                WSConv2d(output_channels, output_channels, kernel_size, stride, padding),
            )
        else:
            return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
        
def test():
    sample = torch.randn(1, 3, 128, 128)
    label = torch.tensor([1])

    model = Discriminator(
        num_classes=config.NUM_CLASSES,
        embed_size=config.EMBED_SIZE,
        image_size=config.IMAGE_SIZE,
        features_discriminator=config.FEATURES_DISCRIMINATOR,
        image_channel=config.IMAGE_CHANNEL,
        label_channel=config.LABEL_CHANNEL
    )

    preds = model(sample, label)
    print(preds.shape)