Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class WSConv2d(nn.Module): | |
""" | |
A 2D convolutional layer with weight scaling. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int, optional): Size of the convolving kernel. Default is 3. | |
stride (int, optional): Stride of the convolution. Default is 1. | |
padding (int, optional): Zero-padding added to both sides of the input. Default is 1. | |
gain (float, optional): Gain factor for weight scaling. Default is 2. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) | |
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5 | |
self.bias = self.conv.bias | |
self.conv.bias = None | |
# Initialize Conv Layer | |
nn.init.normal_(self.conv.weight) | |
nn.init.zeros_(self.bias) | |
def forward(self, x): | |
""" | |
Forward pass of the WSConv2d layer. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width). | |
Returns: | |
torch.Tensor: Output tensor after applying convolution, weight scaling, and bias addition. | |
""" | |
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1) | |
class PixelNorm(nn.Module): | |
""" | |
Pixel normalization layer. | |
Args: | |
eps (float, optional): Small value to avoid division by zero. Default is 1e-8. | |
""" | |
def __init__(self, eps=1e-8): | |
super(PixelNorm, self).__init__() | |
self.epsilon = eps | |
def forward(self, x): | |
""" | |
Forward pass of the PixelNorm layer. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). | |
Returns: | |
torch.Tensor: Normalized tensor. | |
""" | |
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) | |
class ConvBlock(nn.Module): | |
""" | |
A block of two convolutional layers, with optional pixel normalization and LeakyReLU activation. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
use_pixelnorm (bool, optional): Whether to apply pixel normalization. Default is True. | |
""" | |
def __init__(self, in_channels, out_channels, use_pixelnorm=True): | |
super(ConvBlock, self).__init__() | |
self.use_pn = use_pixelnorm | |
self.conv1 = WSConv2d(in_channels, out_channels) | |
self.conv2 = WSConv2d(out_channels, out_channels) | |
self.leaky = nn.LeakyReLU(0.2) | |
self.pn = PixelNorm() | |
def forward(self, x): | |
""" | |
Forward pass of the ConvBlock. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width). | |
Returns: | |
torch.Tensor: Output tensor after two convolutional layers, optional pixel normalization, and LeakyReLU activation. | |
""" | |
x = self.leaky(self.conv1(x)) | |
x = self.pn(x) if self.use_pn else x | |
x = self.leaky(self.conv2(x)) | |
x = self.pn(x) if self.use_pn else x | |
return x | |