cartoonize / models /anime_gan.py
YANGYYYY's picture
Upload 13 files
7e4e601 verified
raw
history blame
3.66 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from .conv_blocks import DownConv
from .conv_blocks import UpConv
from .conv_blocks import SeparableConv2D
from .conv_blocks import InvertedResBlock
from .conv_blocks import ConvBlock
from .layers import get_norm
from utils.common import initialize_weights
class GeneratorV1(nn.Module):
def __init__(self, dataset=''):
super(GeneratorV1, self).__init__()
self.name = f'{self.__class__.__name__}_{dataset}'
bias = False
self.encode_blocks = nn.Sequential(
ConvBlock(3, 64, bias=bias),
ConvBlock(64, 128, bias=bias),
DownConv(128, bias=bias),
ConvBlock(128, 128, bias=bias),
SeparableConv2D(128, 256, bias=bias),
DownConv(256, bias=bias),
ConvBlock(256, 256, bias=bias),
)
self.res_blocks = nn.Sequential(
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
)
self.decode_blocks = nn.Sequential(
ConvBlock(256, 128, bias=bias),
UpConv(128, bias=bias),
SeparableConv2D(128, 128, bias=bias),
ConvBlock(128, 128, bias=bias),
UpConv(128, bias=bias),
ConvBlock(128, 64, bias=bias),
ConvBlock(64, 64, bias=bias),
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
nn.Tanh(),
)
initialize_weights(self)
def forward(self, x):
out = self.encode_blocks(x)
out = self.res_blocks(out)
img = self.decode_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(
self,
dataset=None,
num_layers=1,
use_sn=False,
norm_type="instance",
):
super(Discriminator, self).__init__()
self.name = f'discriminator_{dataset}'
self.bias = False
channels = 32
layers = [
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
nn.LeakyReLU(0.2, True)
]
in_channels = channels
for i in range(num_layers):
layers += [
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
nn.LeakyReLU(0.2, True),
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
get_norm(norm_type)(channels * 4),
nn.LeakyReLU(0.2, True),
]
in_channels = channels * 4
channels *= 2
channels *= 2
layers += [
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
get_norm(norm_type)(channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
]
if use_sn:
for i in range(len(layers)):
if isinstance(layers[i], nn.Conv2d):
layers[i] = spectral_norm(layers[i])
self.discriminate = nn.Sequential(*layers)
initialize_weights(self)
def forward(self, img):
logits = self.discriminate(img)
return logits