Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch | |
import torch.nn as nn | |
import functools | |
from torch.autograd import Variable | |
import numpy as np | |
from torch.nn.utils import spectral_norm | |
# from util.util import SwitchNorm2d | |
import torch.nn.functional as F | |
############################################################################### | |
# Functions | |
############################################################################### | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find("BatchNorm2d") != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
def get_norm_layer(norm_type="instance"): | |
if norm_type == "batch": | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
elif norm_type == "instance": | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
elif norm_type == "spectral": | |
norm_layer = spectral_norm() | |
elif norm_type == "SwitchNorm": | |
norm_layer = SwitchNorm2d | |
else: | |
raise NotImplementedError("normalization layer [%s] is not found" % norm_type) | |
return norm_layer | |
def print_network(net): | |
if isinstance(net, list): | |
net = net[0] | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
print(net) | |
print("Total number of parameters: %d" % num_params) | |
def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, | |
n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): | |
norm_layer = get_norm_layer(norm_type=norm) | |
if netG == 'global': | |
# if opt.self_gen: | |
if opt.use_v2: | |
netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) | |
else: | |
netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt) | |
else: | |
raise('generator not implemented!') | |
print(netG) | |
if len(gpu_ids) > 0: | |
assert(torch.cuda.is_available()) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): | |
norm_layer = get_norm_layer(norm_type=norm) | |
netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) | |
print(netD) | |
if len(gpu_ids) > 0: | |
assert(torch.cuda.is_available()) | |
netD.cuda(gpu_ids[0]) | |
netD.apply(weights_init) | |
return netD | |
class GlobalGenerator_DCDCv2(nn.Module): | |
def __init__( | |
self, | |
input_nc, | |
output_nc, | |
ngf=64, | |
k_size=3, | |
n_downsampling=8, | |
norm_layer=nn.BatchNorm2d, | |
padding_type="reflect", | |
opt=None, | |
): | |
super(GlobalGenerator_DCDCv2, self).__init__() | |
activation = nn.ReLU(True) | |
model = [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), | |
norm_layer(ngf), | |
activation, | |
] | |
### downsample | |
for i in range(opt.start_r): | |
mult = 2 ** i | |
model += [ | |
nn.Conv2d( | |
min(ngf * mult, opt.mc), | |
min(ngf * mult * 2, opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
), | |
norm_layer(min(ngf * mult * 2, opt.mc)), | |
activation, | |
] | |
for i in range(opt.start_r, n_downsampling - 1): | |
mult = 2 ** i | |
model += [ | |
nn.Conv2d( | |
min(ngf * mult, opt.mc), | |
min(ngf * mult * 2, opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
), | |
norm_layer(min(ngf * mult * 2, opt.mc)), | |
activation, | |
] | |
model += [ | |
ResnetBlock( | |
min(ngf * mult * 2, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
model += [ | |
ResnetBlock( | |
min(ngf * mult * 2, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
mult = 2 ** (n_downsampling - 1) | |
if opt.spatio_size == 32: | |
model += [ | |
nn.Conv2d( | |
min(ngf * mult, opt.mc), | |
min(ngf * mult * 2, opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
), | |
norm_layer(min(ngf * mult * 2, opt.mc)), | |
activation, | |
] | |
if opt.spatio_size == 64: | |
model += [ | |
ResnetBlock( | |
min(ngf * mult * 2, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
model += [ | |
ResnetBlock( | |
min(ngf * mult * 2, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
# model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] | |
if opt.feat_dim > 0: | |
model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] | |
self.encoder = nn.Sequential(*model) | |
# decode | |
model = [] | |
if opt.feat_dim > 0: | |
model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] | |
# model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] | |
o_pad = 0 if k_size == 4 else 1 | |
mult = 2 ** n_downsampling | |
model += [ | |
ResnetBlock( | |
min(ngf * mult, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
if opt.spatio_size == 32: | |
model += [ | |
nn.ConvTranspose2d( | |
min(ngf * mult, opt.mc), | |
min(int(ngf * mult / 2), opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
output_padding=o_pad, | |
), | |
norm_layer(min(int(ngf * mult / 2), opt.mc)), | |
activation, | |
] | |
if opt.spatio_size == 64: | |
model += [ | |
ResnetBlock( | |
min(ngf * mult, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
for i in range(1, n_downsampling - opt.start_r): | |
mult = 2 ** (n_downsampling - i) | |
model += [ | |
ResnetBlock( | |
min(ngf * mult, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
model += [ | |
ResnetBlock( | |
min(ngf * mult, opt.mc), | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=opt, | |
) | |
] | |
model += [ | |
nn.ConvTranspose2d( | |
min(ngf * mult, opt.mc), | |
min(int(ngf * mult / 2), opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
output_padding=o_pad, | |
), | |
norm_layer(min(int(ngf * mult / 2), opt.mc)), | |
activation, | |
] | |
for i in range(n_downsampling - opt.start_r, n_downsampling): | |
mult = 2 ** (n_downsampling - i) | |
model += [ | |
nn.ConvTranspose2d( | |
min(ngf * mult, opt.mc), | |
min(int(ngf * mult / 2), opt.mc), | |
kernel_size=k_size, | |
stride=2, | |
padding=1, | |
output_padding=o_pad, | |
), | |
norm_layer(min(int(ngf * mult / 2), opt.mc)), | |
activation, | |
] | |
if opt.use_segmentation_model: | |
model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] | |
else: | |
model += [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), | |
nn.Tanh(), | |
] | |
self.decoder = nn.Sequential(*model) | |
def forward(self, input, flow="enc_dec"): | |
if flow == "enc": | |
return self.encoder(input) | |
elif flow == "dec": | |
return self.decoder(input) | |
elif flow == "enc_dec": | |
x = self.encoder(input) | |
x = self.decoder(x) | |
return x | |
# Define a resnet block | |
class ResnetBlock(nn.Module): | |
def __init__( | |
self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 | |
): | |
super(ResnetBlock, self).__init__() | |
self.opt = opt | |
self.dilation = dilation | |
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) | |
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): | |
conv_block = [] | |
p = 0 | |
if padding_type == "reflect": | |
conv_block += [nn.ReflectionPad2d(self.dilation)] | |
elif padding_type == "replicate": | |
conv_block += [nn.ReplicationPad2d(self.dilation)] | |
elif padding_type == "zero": | |
p = self.dilation | |
else: | |
raise NotImplementedError("padding [%s] is not implemented" % padding_type) | |
conv_block += [ | |
nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), | |
norm_layer(dim), | |
activation, | |
] | |
if use_dropout: | |
conv_block += [nn.Dropout(0.5)] | |
p = 0 | |
if padding_type == "reflect": | |
conv_block += [nn.ReflectionPad2d(1)] | |
elif padding_type == "replicate": | |
conv_block += [nn.ReplicationPad2d(1)] | |
elif padding_type == "zero": | |
p = 1 | |
else: | |
raise NotImplementedError("padding [%s] is not implemented" % padding_type) | |
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] | |
return nn.Sequential(*conv_block) | |
def forward(self, x): | |
out = x + self.conv_block(x) | |
return out | |
class Encoder(nn.Module): | |
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): | |
super(Encoder, self).__init__() | |
self.output_nc = output_nc | |
model = [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), | |
norm_layer(ngf), | |
nn.ReLU(True), | |
] | |
### downsample | |
for i in range(n_downsampling): | |
mult = 2 ** i | |
model += [ | |
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), | |
norm_layer(ngf * mult * 2), | |
nn.ReLU(True), | |
] | |
### upsample | |
for i in range(n_downsampling): | |
mult = 2 ** (n_downsampling - i) | |
model += [ | |
nn.ConvTranspose2d( | |
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 | |
), | |
norm_layer(int(ngf * mult / 2)), | |
nn.ReLU(True), | |
] | |
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] | |
self.model = nn.Sequential(*model) | |
def forward(self, input, inst): | |
outputs = self.model(input) | |
# instance-wise average pooling | |
outputs_mean = outputs.clone() | |
inst_list = np.unique(inst.cpu().numpy().astype(int)) | |
for i in inst_list: | |
for b in range(input.size()[0]): | |
indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 | |
for j in range(self.output_nc): | |
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] | |
mean_feat = torch.mean(output_ins).expand_as(output_ins) | |
outputs_mean[ | |
indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] | |
] = mean_feat | |
return outputs_mean | |
def SN(module, mode=True): | |
if mode: | |
return torch.nn.utils.spectral_norm(module) | |
return module | |
class NonLocalBlock2D_with_mask_Res(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
inter_channels, | |
mode="add", | |
re_norm=False, | |
temperature=1.0, | |
use_self=False, | |
cosin=False, | |
): | |
super(NonLocalBlock2D_with_mask_Res, self).__init__() | |
self.cosin = cosin | |
self.renorm = re_norm | |
self.in_channels = in_channels | |
self.inter_channels = inter_channels | |
self.g = nn.Conv2d( | |
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.W = nn.Conv2d( | |
in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
# for pytorch 0.3.1 | |
# nn.init.constant(self.W.weight, 0) | |
# nn.init.constant(self.W.bias, 0) | |
# for pytorch 0.4.0 | |
nn.init.constant_(self.W.weight, 0) | |
nn.init.constant_(self.W.bias, 0) | |
self.theta = nn.Conv2d( | |
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.phi = nn.Conv2d( | |
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.mode = mode | |
self.temperature = temperature | |
self.use_self = use_self | |
norm_layer = get_norm_layer(norm_type="instance") | |
activation = nn.ReLU(True) | |
model = [] | |
for i in range(3): | |
model += [ | |
ResnetBlock( | |
inter_channels, | |
padding_type="reflect", | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=None, | |
) | |
] | |
self.res_block = nn.Sequential(*model) | |
def forward(self, x, mask): ## The shape of mask is Batch*1*H*W | |
batch_size = x.size(0) | |
g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
g_x = g_x.permute(0, 2, 1) | |
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
theta_x = theta_x.permute(0, 2, 1) | |
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
if self.cosin: | |
theta_x = F.normalize(theta_x, dim=2) | |
phi_x = F.normalize(phi_x, dim=1) | |
f = torch.matmul(theta_x, phi_x) | |
f /= self.temperature | |
f_div_C = F.softmax(f, dim=2) | |
tmp = 1 - mask | |
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") | |
mask[mask > 0] = 1.0 | |
mask = 1 - mask | |
tmp = F.interpolate(tmp, (x.size(2), x.size(3))) | |
mask *= tmp | |
mask_expand = mask.view(batch_size, 1, -1) | |
mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) | |
# mask = 1 - mask | |
# mask=F.interpolate(mask,(x.size(2),x.size(3))) | |
# mask_expand=mask.view(batch_size,1,-1) | |
# mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) | |
if self.use_self: | |
mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 | |
# print(mask_expand.shape) | |
# print(f_div_C.shape) | |
f_div_C = mask_expand * f_div_C | |
if self.renorm: | |
f_div_C = F.normalize(f_div_C, p=1, dim=2) | |
########################### | |
y = torch.matmul(f_div_C, g_x) | |
y = y.permute(0, 2, 1).contiguous() | |
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
W_y = self.W(y) | |
W_y = self.res_block(W_y) | |
if self.mode == "combine": | |
full_mask = mask.repeat(1, self.inter_channels, 1, 1) | |
z = full_mask * x + (1 - full_mask) * W_y | |
return z | |
class MultiscaleDiscriminator(nn.Module): | |
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, | |
use_sigmoid=False, num_D=3, getIntermFeat=False): | |
super(MultiscaleDiscriminator, self).__init__() | |
self.num_D = num_D | |
self.n_layers = n_layers | |
self.getIntermFeat = getIntermFeat | |
for i in range(num_D): | |
netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) | |
if getIntermFeat: | |
for j in range(n_layers+2): | |
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) | |
else: | |
setattr(self, 'layer'+str(i), netD.model) | |
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
def singleD_forward(self, model, input): | |
if self.getIntermFeat: | |
result = [input] | |
for i in range(len(model)): | |
result.append(model[i](result[-1])) | |
return result[1:] | |
else: | |
return [model(input)] | |
def forward(self, input): | |
num_D = self.num_D | |
result = [] | |
input_downsampled = input | |
for i in range(num_D): | |
if self.getIntermFeat: | |
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] | |
else: | |
model = getattr(self, 'layer'+str(num_D-1-i)) | |
result.append(self.singleD_forward(model, input_downsampled)) | |
if i != (num_D-1): | |
input_downsampled = self.downsample(input_downsampled) | |
return result | |
# Defines the PatchGAN discriminator with the specified arguments. | |
class NLayerDiscriminator(nn.Module): | |
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): | |
super(NLayerDiscriminator, self).__init__() | |
self.getIntermFeat = getIntermFeat | |
self.n_layers = n_layers | |
kw = 4 | |
padw = int(np.ceil((kw-1.0)/2)) | |
sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]] | |
nf = ndf | |
for n in range(1, n_layers): | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[ | |
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), | |
norm_layer(nf), nn.LeakyReLU(0.2, True) | |
]] | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[ | |
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True) | |
]] | |
sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]] | |
if use_sigmoid: | |
sequence += [[nn.Sigmoid()]] | |
if getIntermFeat: | |
for n in range(len(sequence)): | |
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) | |
else: | |
sequence_stream = [] | |
for n in range(len(sequence)): | |
sequence_stream += sequence[n] | |
self.model = nn.Sequential(*sequence_stream) | |
def forward(self, input): | |
if self.getIntermFeat: | |
res = [input] | |
for n in range(self.n_layers+2): | |
model = getattr(self, 'model'+str(n)) | |
res.append(model(res[-1])) | |
return res[1:] | |
else: | |
return self.model(input) | |
class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask | |
def __init__(self, in_channels, inter_channels, patch_size): | |
super(Patch_Attention_4, self).__init__() | |
self.patch_size=patch_size | |
# self.g = nn.Conv2d( | |
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
# ) | |
# self.W = nn.Conv2d( | |
# in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 | |
# ) | |
# # for pytorch 0.3.1 | |
# # nn.init.constant(self.W.weight, 0) | |
# # nn.init.constant(self.W.bias, 0) | |
# # for pytorch 0.4.0 | |
# nn.init.constant_(self.W.weight, 0) | |
# nn.init.constant_(self.W.bias, 0) | |
# self.theta = nn.Conv2d( | |
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
# ) | |
# self.phi = nn.Conv2d( | |
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 | |
# ) | |
self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True) | |
norm_layer = get_norm_layer(norm_type="instance") | |
activation = nn.ReLU(True) | |
model = [] | |
for i in range(1): | |
model += [ | |
ResnetBlock( | |
inter_channels, | |
padding_type="reflect", | |
activation=activation, | |
norm_layer=norm_layer, | |
opt=None, | |
) | |
] | |
self.res_block = nn.Sequential(*model) | |
def Hard_Compose(self, input, dim, index): | |
# batch index select | |
# input: [B,C,HW] | |
# dim: scalar > 0 | |
# index: [B, HW] | |
views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))] | |
expanse = list(input.size()) | |
expanse[0] = -1 | |
expanse[dim] = -1 | |
index = index.view(views).expand(expanse) | |
return torch.gather(input, dim, index) | |
def forward(self, z, mask): ## The shape of mask is Batch*1*H*W | |
x=self.res_block(z) | |
b,c,h,w=x.shape | |
## mask resize + dilation | |
# tmp = 1 - mask | |
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") | |
mask[mask > 0] = 1.0 | |
# mask = 1 - mask | |
# tmp = F.interpolate(tmp, (x.size(2), x.size(3))) | |
# mask *= tmp | |
# mask=1-mask | |
## 1: mask position 0: non-mask | |
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) | |
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float() | |
all_patch_num=h*w/self.patch_size/self.patch_size | |
non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1) | |
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) | |
y_unfold=x_unfold.permute(0,2,1) | |
x_unfold_normalized=F.normalize(x_unfold,dim=1) | |
y_unfold_normalized=F.normalize(y_unfold,dim=2) | |
correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized) | |
correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9) | |
correlation_matrix=F.softmax(correlation_matrix,dim=2) | |
# print(correlation_matrix) | |
R, max_arg=torch.max(correlation_matrix,dim=2) | |
composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg) | |
composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) | |
concat_1=torch.cat((z,composed_fold,mask),dim=1) | |
concat_1=self.F_Combine(concat_1) | |
return concat_1 | |
def inference_forward(self,z,mask): ## Reduce the extra memory cost | |
x=self.res_block(z) | |
b,c,h,w=x.shape | |
## mask resize + dilation | |
# tmp = 1 - mask | |
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") | |
mask[mask > 0] = 1.0 | |
# mask = 1 - mask | |
# tmp = F.interpolate(tmp, (x.size(2), x.size(3))) | |
# mask *= tmp | |
# mask=1-mask | |
## 1: mask position 0: non-mask | |
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) | |
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num | |
all_patch_num=h*w/self.patch_size/self.patch_size | |
mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0] | |
if len(mask_index)==0: ## No mask patch is selected, no attention is needed | |
composed_fold=x | |
else: | |
unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0] | |
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) | |
Query_Patch=torch.index_select(x_unfold,2,mask_index) | |
Key_Patch=torch.index_select(x_unfold,2,unmask_index) | |
Query_Patch=Query_Patch.permute(0,2,1) | |
Query_Patch_normalized=F.normalize(Query_Patch,dim=2) | |
Key_Patch_normalized=F.normalize(Key_Patch,dim=1) | |
correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized) | |
correlation_matrix=F.softmax(correlation_matrix,dim=2) | |
R, max_arg=torch.max(correlation_matrix,dim=2) | |
composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg) | |
x_unfold[:,:,mask_index]=composed_unfold | |
composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) | |
concat_1=torch.cat((z,composed_fold,mask),dim=1) | |
concat_1=self.F_Combine(concat_1) | |
return concat_1 | |
############################################################################## | |
# Losses | |
############################################################################## | |
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, | |
tensor=torch.FloatTensor): | |
super(GANLoss, self).__init__() | |
self.real_label = target_real_label | |
self.fake_label = target_fake_label | |
self.real_label_var = None | |
self.fake_label_var = None | |
self.Tensor = tensor | |
if use_lsgan: | |
self.loss = nn.MSELoss() | |
else: | |
self.loss = nn.BCELoss() | |
def get_target_tensor(self, input, target_is_real): | |
target_tensor = None | |
if target_is_real: | |
create_label = ((self.real_label_var is None) or | |
(self.real_label_var.numel() != input.numel())) | |
if create_label: | |
real_tensor = self.Tensor(input.size()).fill_(self.real_label) | |
self.real_label_var = Variable(real_tensor, requires_grad=False) | |
target_tensor = self.real_label_var | |
else: | |
create_label = ((self.fake_label_var is None) or | |
(self.fake_label_var.numel() != input.numel())) | |
if create_label: | |
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) | |
self.fake_label_var = Variable(fake_tensor, requires_grad=False) | |
target_tensor = self.fake_label_var | |
return target_tensor | |
def __call__(self, input, target_is_real): | |
if isinstance(input[0], list): | |
loss = 0 | |
for input_i in input: | |
pred = input_i[-1] | |
target_tensor = self.get_target_tensor(pred, target_is_real) | |
loss += self.loss(pred, target_tensor) | |
return loss | |
else: | |
target_tensor = self.get_target_tensor(input[-1], target_is_real) | |
return self.loss(input[-1], target_tensor) | |
####################################### VGG Loss | |
from torchvision import models | |
class VGG19_torch(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super(VGG19_torch, self).__init__() | |
vgg_pretrained_features = models.vgg19(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
class VGGLoss_torch(nn.Module): | |
def __init__(self, gpu_ids): | |
super(VGGLoss_torch, self).__init__() | |
self.vgg = VGG19_torch().cuda() | |
self.criterion = nn.L1Loss() | |
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(len(x_vgg)): | |
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
return loss |