Spaces:
Sleeping
Sleeping
File size: 3,298 Bytes
21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 21a662b eb42124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import torch
import torch.nn as nn
class WSConv2d(nn.Module):
"""
A 2D convolutional layer with weight scaling.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int, optional): Size of the convolving kernel. Default is 3.
stride (int, optional): Stride of the convolution. Default is 1.
padding (int, optional): Zero-padding added to both sides of the input. Default is 1.
gain (float, optional): Gain factor for weight scaling. Default is 2.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
# Initialize Conv Layer
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
"""
Forward pass of the WSConv2d layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: Output tensor after applying convolution, weight scaling, and bias addition.
"""
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
class PixelNorm(nn.Module):
"""
Pixel normalization layer.
Args:
eps (float, optional): Small value to avoid division by zero. Default is 1e-8.
"""
def __init__(self, eps=1e-8):
super(PixelNorm, self).__init__()
self.epsilon = eps
def forward(self, x):
"""
Forward pass of the PixelNorm layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Normalized tensor.
"""
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
class ConvBlock(nn.Module):
"""
A block of two convolutional layers, with optional pixel normalization and LeakyReLU activation.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
use_pixelnorm (bool, optional): Whether to apply pixel normalization. Default is True.
"""
def __init__(self, in_channels, out_channels, use_pixelnorm=True):
super(ConvBlock, self).__init__()
self.use_pn = use_pixelnorm
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2)
self.pn = PixelNorm()
def forward(self, x):
"""
Forward pass of the ConvBlock.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
Returns:
torch.Tensor: Output tensor after two convolutional layers, optional pixel normalization, and LeakyReLU activation.
"""
x = self.leaky(self.conv1(x))
x = self.pn(x) if self.use_pn else x
x = self.leaky(self.conv2(x))
x = self.pn(x) if self.use_pn else x
return x
|