File size: 3,298 Bytes
21a662b
 
 
eb42124
 
 
 
21a662b
 
eb42124
 
 
 
 
 
21a662b
 
eb42124
21a662b
eb42124
 
 
 
21a662b
eb42124
 
 
21a662b
eb42124
21a662b
eb42124
21a662b
 
eb42124
21a662b
 
eb42124
21a662b
eb42124
 
 
 
 
 
21a662b
 
eb42124
21a662b
 
eb42124
 
 
 
 
 
 
21a662b
 
eb42124
21a662b
 
eb42124
21a662b
eb42124
21a662b
 
eb42124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a662b
 
eb42124
21a662b
 
eb42124
21a662b
eb42124
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn


class WSConv2d(nn.Module):
    """
    A 2D convolutional layer with weight scaling.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolving kernel. Default is 3.
        stride (int, optional): Stride of the convolution. Default is 1.
        padding (int, optional): Zero-padding added to both sides of the input. Default is 1.
        gain (float, optional): Gain factor for weight scaling. Default is 2.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # Initialize Conv Layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        """
        Forward pass of the WSConv2d layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).

        Returns:
            torch.Tensor: Output tensor after applying convolution, weight scaling, and bias addition.
        """
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class PixelNorm(nn.Module):
    """
    Pixel normalization layer.

    Args:
        eps (float, optional): Small value to avoid division by zero. Default is 1e-8.
    """

    def __init__(self, eps=1e-8):
        super(PixelNorm, self).__init__()
        self.epsilon = eps

    def forward(self, x):
        """
        Forward pass of the PixelNorm layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).

        Returns:
            torch.Tensor: Normalized tensor.
        """
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    """
    A block of two convolutional layers, with optional pixel normalization and LeakyReLU activation.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        use_pixelnorm (bool, optional): Whether to apply pixel normalization. Default is True.
    """

    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        """
        Forward pass of the ConvBlock.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).

        Returns:
            torch.Tensor: Output tensor after two convolutional layers, optional pixel normalization, and LeakyReLU activation.
        """
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x

        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x