|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.nn import Parameter
|
|
|
|
from deepfillv2.network_utils import *
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2dLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
pad_type="zero",
|
|
activation="elu",
|
|
norm="none",
|
|
sn=False,
|
|
):
|
|
super(Conv2dLayer, self).__init__()
|
|
|
|
if pad_type == "reflect":
|
|
self.pad = nn.ReflectionPad2d(padding)
|
|
elif pad_type == "replicate":
|
|
self.pad = nn.ReplicationPad2d(padding)
|
|
elif pad_type == "zero":
|
|
self.pad = nn.ZeroPad2d(padding)
|
|
else:
|
|
assert 0, "Unsupported padding type: {}".format(pad_type)
|
|
|
|
|
|
if norm == "bn":
|
|
self.norm = nn.BatchNorm2d(out_channels)
|
|
elif norm == "in":
|
|
self.norm = nn.InstanceNorm2d(out_channels)
|
|
elif norm == "ln":
|
|
self.norm = LayerNorm(out_channels)
|
|
elif norm == "none":
|
|
self.norm = None
|
|
else:
|
|
assert 0, "Unsupported normalization: {}".format(norm)
|
|
|
|
|
|
if activation == "relu":
|
|
self.activation = nn.ReLU(inplace=True)
|
|
elif activation == "lrelu":
|
|
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
|
elif activation == "elu":
|
|
self.activation = nn.ELU(inplace=True)
|
|
elif activation == "selu":
|
|
self.activation = nn.SELU(inplace=True)
|
|
elif activation == "tanh":
|
|
self.activation = nn.Tanh()
|
|
elif activation == "sigmoid":
|
|
self.activation = nn.Sigmoid()
|
|
elif activation == "none":
|
|
self.activation = None
|
|
else:
|
|
assert 0, "Unsupported activation: {}".format(activation)
|
|
|
|
|
|
if sn:
|
|
self.conv2d = SpectralNorm(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
)
|
|
else:
|
|
self.conv2d = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.pad(x)
|
|
x = self.conv2d(x)
|
|
if self.norm:
|
|
x = self.norm(x)
|
|
if self.activation:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
class TransposeConv2dLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
pad_type="zero",
|
|
activation="lrelu",
|
|
norm="none",
|
|
sn=False,
|
|
scale_factor=2,
|
|
):
|
|
super(TransposeConv2dLayer, self).__init__()
|
|
|
|
self.scale_factor = scale_factor
|
|
self.conv2d = Conv2dLayer(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
pad_type,
|
|
activation,
|
|
norm,
|
|
sn,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(
|
|
x,
|
|
scale_factor=self.scale_factor,
|
|
mode="nearest",
|
|
recompute_scale_factor=False,
|
|
)
|
|
x = self.conv2d(x)
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class GatedConv2d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
pad_type="reflect",
|
|
activation="elu",
|
|
norm="none",
|
|
sn=False,
|
|
):
|
|
super(GatedConv2d, self).__init__()
|
|
|
|
if pad_type == "reflect":
|
|
self.pad = nn.ReflectionPad2d(padding)
|
|
elif pad_type == "replicate":
|
|
self.pad = nn.ReplicationPad2d(padding)
|
|
elif pad_type == "zero":
|
|
self.pad = nn.ZeroPad2d(padding)
|
|
else:
|
|
assert 0, "Unsupported padding type: {}".format(pad_type)
|
|
|
|
|
|
if norm == "bn":
|
|
self.norm = nn.BatchNorm2d(out_channels)
|
|
elif norm == "in":
|
|
self.norm = nn.InstanceNorm2d(out_channels)
|
|
elif norm == "ln":
|
|
self.norm = LayerNorm(out_channels)
|
|
elif norm == "none":
|
|
self.norm = None
|
|
else:
|
|
assert 0, "Unsupported normalization: {}".format(norm)
|
|
|
|
|
|
if activation == "relu":
|
|
self.activation = nn.ReLU(inplace=True)
|
|
elif activation == "lrelu":
|
|
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
|
elif activation == "elu":
|
|
self.activation = nn.ELU()
|
|
elif activation == "selu":
|
|
self.activation = nn.SELU(inplace=True)
|
|
elif activation == "tanh":
|
|
self.activation = nn.Tanh()
|
|
elif activation == "sigmoid":
|
|
self.activation = nn.Sigmoid()
|
|
elif activation == "none":
|
|
self.activation = None
|
|
else:
|
|
assert 0, "Unsupported activation: {}".format(activation)
|
|
|
|
|
|
if sn:
|
|
self.conv2d = SpectralNorm(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
)
|
|
self.mask_conv2d = SpectralNorm(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
)
|
|
else:
|
|
self.conv2d = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
self.mask_conv2d = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.pad(x)
|
|
conv = self.conv2d(x)
|
|
mask = self.mask_conv2d(x)
|
|
gated_mask = self.sigmoid(mask)
|
|
if self.activation:
|
|
conv = self.activation(conv)
|
|
x = conv * gated_mask
|
|
return x
|
|
|
|
|
|
class TransposeGatedConv2d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
pad_type="zero",
|
|
activation="lrelu",
|
|
norm="none",
|
|
sn=True,
|
|
scale_factor=2,
|
|
):
|
|
super(TransposeGatedConv2d, self).__init__()
|
|
|
|
self.scale_factor = scale_factor
|
|
self.gated_conv2d = GatedConv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
pad_type,
|
|
activation,
|
|
norm,
|
|
sn,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(
|
|
x,
|
|
scale_factor=self.scale_factor,
|
|
mode="nearest",
|
|
recompute_scale_factor=False,
|
|
)
|
|
x = self.gated_conv2d(x)
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, num_features, eps=1e-8, affine=True):
|
|
super(LayerNorm, self).__init__()
|
|
self.num_features = num_features
|
|
self.affine = affine
|
|
self.eps = eps
|
|
|
|
if self.affine:
|
|
self.gamma = Parameter(torch.Tensor(num_features).uniform_())
|
|
self.beta = Parameter(torch.zeros(num_features))
|
|
|
|
def forward(self, x):
|
|
|
|
shape = [-1] + [1] * (x.dim() - 1)
|
|
if x.size(0) == 1:
|
|
|
|
mean = x.view(-1).mean().view(*shape)
|
|
std = x.view(-1).std().view(*shape)
|
|
else:
|
|
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
|
std = x.view(x.size(0), -1).std(1).view(*shape)
|
|
x = (x - mean) / (std + self.eps)
|
|
|
|
if self.affine:
|
|
shape = [1, -1] + [1] * (
|
|
x.dim() - 2
|
|
)
|
|
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def l2normalize(v, eps=1e-12):
|
|
return v / (v.norm() + eps)
|
|
|
|
|
|
class SpectralNorm(nn.Module):
|
|
def __init__(self, module, name="weight", power_iterations=1):
|
|
super(SpectralNorm, self).__init__()
|
|
self.module = module
|
|
self.name = name
|
|
self.power_iterations = power_iterations
|
|
if not self._made_params():
|
|
self._make_params()
|
|
|
|
def _update_u_v(self):
|
|
u = getattr(self.module, self.name + "_u")
|
|
v = getattr(self.module, self.name + "_v")
|
|
w = getattr(self.module, self.name + "_bar")
|
|
|
|
height = w.data.shape[0]
|
|
for _ in range(self.power_iterations):
|
|
v.data = l2normalize(
|
|
torch.mv(torch.t(w.view(height, -1).data), u.data)
|
|
)
|
|
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
|
|
|
|
|
sigma = u.dot(w.view(height, -1).mv(v))
|
|
setattr(self.module, self.name, w / sigma.expand_as(w))
|
|
|
|
def _made_params(self):
|
|
try:
|
|
u = getattr(self.module, self.name + "_u")
|
|
v = getattr(self.module, self.name + "_v")
|
|
w = getattr(self.module, self.name + "_bar")
|
|
return True
|
|
except AttributeError:
|
|
return False
|
|
|
|
def _make_params(self):
|
|
w = getattr(self.module, self.name)
|
|
|
|
height = w.data.shape[0]
|
|
width = w.view(height, -1).data.shape[1]
|
|
|
|
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
|
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
|
u.data = l2normalize(u.data)
|
|
v.data = l2normalize(v.data)
|
|
w_bar = Parameter(w.data)
|
|
|
|
del self.module._parameters[self.name]
|
|
|
|
self.module.register_parameter(self.name + "_u", u)
|
|
self.module.register_parameter(self.name + "_v", v)
|
|
self.module.register_parameter(self.name + "_bar", w_bar)
|
|
|
|
def forward(self, *args):
|
|
self._update_u_v()
|
|
return self.module.forward(*args)
|
|
|
|
|
|
class ContextualAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
ksize=3,
|
|
stride=1,
|
|
rate=1,
|
|
fuse_k=3,
|
|
softmax_scale=10,
|
|
fuse=True,
|
|
use_cuda=True,
|
|
device_ids=None,
|
|
):
|
|
super(ContextualAttention, self).__init__()
|
|
self.ksize = ksize
|
|
self.stride = stride
|
|
self.rate = rate
|
|
self.fuse_k = fuse_k
|
|
self.softmax_scale = softmax_scale
|
|
self.fuse = fuse
|
|
self.use_cuda = use_cuda
|
|
self.device_ids = device_ids
|
|
|
|
def forward(self, f, b, mask=None):
|
|
"""Contextual attention layer implementation.
|
|
Contextual attention is first introduced in publication:
|
|
Generative Image Inpainting with Contextual Attention, Yu et al.
|
|
Args:
|
|
f: Input feature to match (foreground).
|
|
b: Input feature for match (background).
|
|
mask: Input mask for b, indicating patches not available.
|
|
ksize: Kernel size for contextual attention.
|
|
stride: Stride for extracting patches from b.
|
|
rate: Dilation for matching.
|
|
softmax_scale: Scaled softmax for attention.
|
|
Returns:
|
|
torch.tensor: output
|
|
"""
|
|
|
|
raw_int_fs = list(f.size())
|
|
raw_int_bs = list(b.size())
|
|
|
|
|
|
kernel = 2 * self.rate
|
|
|
|
raw_w = extract_image_patches(
|
|
b,
|
|
ksizes=[kernel, kernel],
|
|
strides=[self.rate * self.stride, self.rate * self.stride],
|
|
rates=[1, 1],
|
|
padding="same",
|
|
)
|
|
|
|
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
|
|
raw_w = raw_w.permute(0, 4, 1, 2, 3)
|
|
raw_w_groups = torch.split(raw_w, 1, dim=0)
|
|
|
|
|
|
|
|
f = F.interpolate(
|
|
f,
|
|
scale_factor=1.0 / self.rate,
|
|
mode="nearest",
|
|
recompute_scale_factor=False,
|
|
)
|
|
b = F.interpolate(
|
|
b,
|
|
scale_factor=1.0 / self.rate,
|
|
mode="nearest",
|
|
recompute_scale_factor=False,
|
|
)
|
|
int_fs = list(f.size())
|
|
int_bs = list(b.size())
|
|
f_groups = torch.split(
|
|
f, 1, dim=0
|
|
)
|
|
|
|
w = extract_image_patches(
|
|
b,
|
|
ksizes=[self.ksize, self.ksize],
|
|
strides=[self.stride, self.stride],
|
|
rates=[1, 1],
|
|
padding="same",
|
|
)
|
|
|
|
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
|
|
w = w.permute(0, 4, 1, 2, 3)
|
|
w_groups = torch.split(w, 1, dim=0)
|
|
|
|
|
|
mask = F.interpolate(
|
|
mask,
|
|
scale_factor=1.0 / self.rate,
|
|
mode="nearest",
|
|
recompute_scale_factor=False,
|
|
)
|
|
int_ms = list(mask.size())
|
|
|
|
m = extract_image_patches(
|
|
mask,
|
|
ksizes=[self.ksize, self.ksize],
|
|
strides=[self.stride, self.stride],
|
|
rates=[1, 1],
|
|
padding="same",
|
|
)
|
|
|
|
|
|
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
|
|
m = m.permute(0, 4, 1, 2, 3)
|
|
m = m[0]
|
|
|
|
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to(
|
|
torch.float32
|
|
)
|
|
mm = mm.permute(1, 0, 2, 3)
|
|
|
|
y = []
|
|
offsets = []
|
|
k = self.fuse_k
|
|
scale = (
|
|
self.softmax_scale
|
|
)
|
|
fuse_weight = torch.eye(k).view(1, 1, k, k)
|
|
if self.use_cuda:
|
|
fuse_weight = fuse_weight.cuda()
|
|
|
|
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
|
|
"""
|
|
O => output channel as a conv filter
|
|
I => input channel as a conv filter
|
|
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
|
|
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
|
|
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
|
|
"""
|
|
|
|
escape_NaN = torch.FloatTensor([1e-4])
|
|
if self.use_cuda:
|
|
escape_NaN = escape_NaN.cuda()
|
|
wi = wi[0]
|
|
max_wi = torch.sqrt(
|
|
reduce_sum(
|
|
torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True
|
|
)
|
|
)
|
|
wi_normed = wi / max_wi
|
|
|
|
xi = same_padding(
|
|
xi, [self.ksize, self.ksize], [1, 1], [1, 1]
|
|
)
|
|
yi = F.conv2d(xi, wi_normed, stride=1)
|
|
|
|
if self.fuse:
|
|
|
|
yi = yi.view(
|
|
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
|
)
|
|
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
|
yi = F.conv2d(
|
|
yi, fuse_weight, stride=1
|
|
)
|
|
yi = yi.contiguous().view(
|
|
1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]
|
|
)
|
|
yi = yi.permute(0, 2, 1, 4, 3)
|
|
yi = yi.contiguous().view(
|
|
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
|
|
)
|
|
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
|
yi = F.conv2d(yi, fuse_weight, stride=1)
|
|
yi = yi.contiguous().view(
|
|
1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]
|
|
)
|
|
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
|
|
yi = yi.view(
|
|
1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]
|
|
)
|
|
|
|
yi = yi * mm
|
|
yi = F.softmax(yi * scale, dim=1)
|
|
yi = yi * mm
|
|
|
|
offset = torch.argmax(yi, dim=1, keepdim=True)
|
|
|
|
if int_bs != int_fs:
|
|
|
|
times = float(int_fs[2] * int_fs[3]) / float(
|
|
int_bs[2] * int_bs[3]
|
|
)
|
|
offset = ((offset + 1).float() * times - 1).to(torch.int64)
|
|
offset = torch.cat(
|
|
[offset // int_fs[3], offset % int_fs[3]], dim=1
|
|
)
|
|
|
|
|
|
wi_center = raw_wi[0]
|
|
|
|
yi = (
|
|
F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1)
|
|
/ 4.0
|
|
)
|
|
y.append(yi)
|
|
offsets.append(offset)
|
|
|
|
y = torch.cat(y, dim=0)
|
|
y.contiguous().view(raw_int_fs)
|
|
|
|
return y
|
|
|