manhkhanhUIT's picture
Add code
7fab858
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
parser.add_argument(
"--num_upsampling_layers",
choices=("normal", "more", "most"),
default="normal",
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh = self.compute_latent_vector_size(opt)
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
if opt.use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
if self.opt.no_parsing_map:
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
else:
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
else:
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
else:
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
else:
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
else:
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
final_nc = nf
if opt.num_upsampling_layers == "most":
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
final_nc = nf // 2
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def compute_latent_vector_size(self, opt):
if opt.num_upsampling_layers == "normal":
num_up_layers = 5
elif opt.num_upsampling_layers == "more":
num_up_layers = 6
elif opt.num_upsampling_layers == "most":
num_up_layers = 7
else:
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
sw = opt.load_size // (2 ** num_up_layers)
sh = round(sw / opt.aspect_ratio)
return sw, sh
def forward(self, input, degraded_image, z=None):
seg = input
if self.opt.use_vae:
# we sample z from unit normal and reshape the tensor
if z is None:
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
x = self.fc(z)
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
else:
# we downsample segmap and run convolution
if self.opt.no_parsing_map:
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
else:
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
x = self.fc(x)
x = self.head_0(x, seg, degraded_image)
x = self.up(x)
x = self.G_middle_0(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.G_middle_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_0(x, seg, degraded_image)
x = self.up(x)
x = self.up_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_2(x, seg, degraded_image)
x = self.up(x)
x = self.up_3(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.up_4(x, seg, degraded_image)
x = self.conv_img(F.leaky_relu(x, 2e-1))
x = F.tanh(x)
return x
class Pix2PixHDGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
)
parser.add_argument(
"--resnet_n_blocks",
type=int,
default=9,
help="number of residual blocks in the global generator network",
)
parser.add_argument(
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
)
parser.add_argument(
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
)
# parser.set_defaults(norm_G='instance')
return parser
def __init__(self, opt):
super().__init__()
input_nc = 3
# print("xxxxx")
# print(opt.norm_G)
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
activation = nn.ReLU(False)
model = []
# initial conv
model += [
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
activation,
]
# downsample
mult = 1
for i in range(opt.resnet_n_downsample):
model += [
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
activation,
]
mult *= 2
# resnet blocks
for i in range(opt.resnet_n_blocks):
model += [
ResnetBlock(
opt.ngf * mult,
norm_layer=norm_layer,
activation=activation,
kernel_size=opt.resnet_kernel_size,
)
]
# upsample
for i in range(opt.resnet_n_downsample):
nc_in = int(opt.ngf * mult)
nc_out = int((opt.ngf * mult) / 2)
model += [
norm_layer(
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
),
activation,
]
mult = mult // 2
# final output conv
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
nn.Tanh(),
]
self.model = nn.Sequential(*model)
def forward(self, input, degraded_image, z=None):
return self.model(degraded_image)