manhkhanhUIT's picture
Add code
7fab858
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from models.networks.normalization import SPADE
# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
# ResNet block used in pix2pixHD
# We keep the same architecture as pix2pixHD.
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
super().__init__()
pw = (kernel_size - 1) // 2
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
activation,
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
)
def forward(self, x):
y = self.conv_block(x)
out = x + y
return out
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class SPADEResnetBlock_non_spade(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(x))
dx = self.conv_1(self.actvn(dx))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)