|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils import spectral_norm |
|
from .conv_blocks import DownConv |
|
from .conv_blocks import UpConv |
|
from .conv_blocks import SeparableConv2D |
|
from .conv_blocks import InvertedResBlock |
|
from .conv_blocks import ConvBlock |
|
from .layers import get_norm |
|
from utils.common import initialize_weights |
|
|
|
|
|
class GeneratorV1(nn.Module): |
|
def __init__(self, dataset=''): |
|
super(GeneratorV1, self).__init__() |
|
self.name = f'{self.__class__.__name__}_{dataset}' |
|
bias = False |
|
|
|
self.encode_blocks = nn.Sequential( |
|
ConvBlock(3, 64, bias=bias), |
|
ConvBlock(64, 128, bias=bias), |
|
DownConv(128, bias=bias), |
|
ConvBlock(128, 128, bias=bias), |
|
SeparableConv2D(128, 256, bias=bias), |
|
DownConv(256, bias=bias), |
|
ConvBlock(256, 256, bias=bias), |
|
) |
|
|
|
self.res_blocks = nn.Sequential( |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
InvertedResBlock(256, 256), |
|
) |
|
|
|
self.decode_blocks = nn.Sequential( |
|
ConvBlock(256, 128, bias=bias), |
|
UpConv(128, bias=bias), |
|
SeparableConv2D(128, 128, bias=bias), |
|
ConvBlock(128, 128, bias=bias), |
|
UpConv(128, bias=bias), |
|
ConvBlock(128, 64, bias=bias), |
|
ConvBlock(64, 64, bias=bias), |
|
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias), |
|
nn.Tanh(), |
|
) |
|
|
|
initialize_weights(self) |
|
|
|
def forward(self, x): |
|
out = self.encode_blocks(x) |
|
out = self.res_blocks(out) |
|
img = self.decode_blocks(out) |
|
|
|
return img |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__( |
|
self, |
|
dataset=None, |
|
num_layers=1, |
|
use_sn=False, |
|
norm_type="instance", |
|
): |
|
super(Discriminator, self).__init__() |
|
self.name = f'discriminator_{dataset}' |
|
self.bias = False |
|
channels = 32 |
|
|
|
layers = [ |
|
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
in_channels = channels |
|
for i in range(num_layers): |
|
layers += [ |
|
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias), |
|
get_norm(norm_type, channels * 4), |
|
nn.LeakyReLU(0.2, True), |
|
] |
|
in_channels = channels * 4 |
|
channels *= 2 |
|
|
|
channels *= 2 |
|
layers += [ |
|
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), |
|
get_norm(norm_type, channels), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias), |
|
] |
|
|
|
if use_sn: |
|
for i in range(len(layers)): |
|
if isinstance(layers[i], nn.Conv2d): |
|
layers[i] = spectral_norm(layers[i]) |
|
|
|
self.discriminate = nn.Sequential(*layers) |
|
|
|
initialize_weights(self) |
|
|
|
def forward(self, img): |
|
logits = self.discriminate(img) |
|
return logits |
|
|