Meta-Watermark-Remover / deepfillv2 /network_module.py
NeuralFalcon's picture
Upload 7 files
712b45c verified
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Parameter
from deepfillv2.network_utils import *
# -----------------------------------------------
# Normal ConvBlock
# -----------------------------------------------
class Conv2dLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
pad_type="zero",
activation="elu",
norm="none",
sn=False,
):
super(Conv2dLayer, self).__init__()
# Initialize the padding scheme
if pad_type == "reflect":
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == "replicate":
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == "zero":
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# Initialize the normalization type
if norm == "bn":
self.norm = nn.BatchNorm2d(out_channels)
elif norm == "in":
self.norm = nn.InstanceNorm2d(out_channels)
elif norm == "ln":
self.norm = LayerNorm(out_channels)
elif norm == "none":
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# Initialize the activation funtion
if activation == "relu":
self.activation = nn.ReLU(inplace=True)
elif activation == "lrelu":
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == "elu":
self.activation = nn.ELU(inplace=True)
elif activation == "selu":
self.activation = nn.SELU(inplace=True)
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "none":
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# Initialize the convolution layers
if sn:
self.conv2d = SpectralNorm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
)
else:
self.conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
def forward(self, x):
x = self.pad(x)
x = self.conv2d(x)
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class TransposeConv2dLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
pad_type="zero",
activation="lrelu",
norm="none",
sn=False,
scale_factor=2,
):
super(TransposeConv2dLayer, self).__init__()
# Initialize the conv scheme
self.scale_factor = scale_factor
self.conv2d = Conv2dLayer(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
pad_type,
activation,
norm,
sn,
)
def forward(self, x):
x = F.interpolate(
x,
scale_factor=self.scale_factor,
mode="nearest",
recompute_scale_factor=False,
)
x = self.conv2d(x)
return x
# -----------------------------------------------
# Gated ConvBlock
# -----------------------------------------------
class GatedConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
pad_type="reflect",
activation="elu",
norm="none",
sn=False,
):
super(GatedConv2d, self).__init__()
# Initialize the padding scheme
if pad_type == "reflect":
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == "replicate":
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == "zero":
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# Initialize the normalization type
if norm == "bn":
self.norm = nn.BatchNorm2d(out_channels)
elif norm == "in":
self.norm = nn.InstanceNorm2d(out_channels)
elif norm == "ln":
self.norm = LayerNorm(out_channels)
elif norm == "none":
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# Initialize the activation funtion
if activation == "relu":
self.activation = nn.ReLU(inplace=True)
elif activation == "lrelu":
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == "elu":
self.activation = nn.ELU()
elif activation == "selu":
self.activation = nn.SELU(inplace=True)
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "none":
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# Initialize the convolution layers
if sn:
self.conv2d = SpectralNorm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
)
self.mask_conv2d = SpectralNorm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
)
else:
self.conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
self.mask_conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.pad(x)
conv = self.conv2d(x)
mask = self.mask_conv2d(x)
gated_mask = self.sigmoid(mask)
if self.activation:
conv = self.activation(conv)
x = conv * gated_mask
return x
class TransposeGatedConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
pad_type="zero",
activation="lrelu",
norm="none",
sn=True,
scale_factor=2,
):
super(TransposeGatedConv2d, self).__init__()
# Initialize the conv scheme
self.scale_factor = scale_factor
self.gated_conv2d = GatedConv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
pad_type,
activation,
norm,
sn,
)
def forward(self, x):
x = F.interpolate(
x,
scale_factor=self.scale_factor,
mode="nearest",
recompute_scale_factor=False,
)
x = self.gated_conv2d(x)
return x
# ----------------------------------------
# Layer Norm
# ----------------------------------------
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-8, affine=True):
super(LayerNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = Parameter(torch.Tensor(num_features).uniform_())
self.beta = Parameter(torch.zeros(num_features))
def forward(self, x):
# layer norm
shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1]
if x.size(0) == 1:
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
mean = x.view(-1).mean().view(*shape)
std = x.view(-1).std().view(*shape)
else:
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
x = (x - mean) / (std + self.eps)
# if it is learnable
if self.affine:
shape = [1, -1] + [1] * (
x.dim() - 2
) # for 4d input: [1, -1, 1, 1]
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
# -----------------------------------------------
# SpectralNorm
# -----------------------------------------------
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name="weight", power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(
torch.mv(torch.t(w.view(height, -1).data), u.data)
)
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
class ContextualAttention(nn.Module):
def __init__(
self,
ksize=3,
stride=1,
rate=1,
fuse_k=3,
softmax_scale=10,
fuse=True,
use_cuda=True,
device_ids=None,
):
super(ContextualAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.rate = rate
self.fuse_k = fuse_k
self.softmax_scale = softmax_scale
self.fuse = fuse
self.use_cuda = use_cuda
self.device_ids = device_ids
def forward(self, f, b, mask=None):
"""Contextual attention layer implementation.
Contextual attention is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Args:
f: Input feature to match (foreground).
b: Input feature for match (background).
mask: Input mask for b, indicating patches not available.
ksize: Kernel size for contextual attention.
stride: Stride for extracting patches from b.
rate: Dilation for matching.
softmax_scale: Scaled softmax for attention.
Returns:
torch.tensor: output
"""
# get shapes
raw_int_fs = list(f.size()) # b*c*h*w
raw_int_bs = list(b.size()) # b*c*h*w
# extract patches from background with stride and rate
kernel = 2 * self.rate
# raw_w is extracted for reconstruction
raw_w = extract_image_patches(
b,
ksizes=[kernel, kernel],
strides=[self.rate * self.stride, self.rate * self.stride],
rates=[1, 1],
padding="same",
) # [N, C*k*k, L]
# raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024]
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_groups = torch.split(raw_w, 1, dim=0)
# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = F.interpolate(
f,
scale_factor=1.0 / self.rate,
mode="nearest",
recompute_scale_factor=False,
)
b = F.interpolate(
b,
scale_factor=1.0 / self.rate,
mode="nearest",
recompute_scale_factor=False,
)
int_fs = list(f.size()) # b*c*h*w
int_bs = list(b.size())
f_groups = torch.split(
f, 1, dim=0
) # split tensors along the batch dimension
# w shape: [N, C*k*k, L]
w = extract_image_patches(
b,
ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding="same",
)
# w shape: [N, C, k, k, L]
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_groups = torch.split(w, 1, dim=0)
# process mask
mask = F.interpolate(
mask,
scale_factor=1.0 / self.rate,
mode="nearest",
recompute_scale_factor=False,
)
int_ms = list(mask.size())
# m shape: [N, C*k*k, L]
m = extract_image_patches(
mask,
ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding="same",
)
# m shape: [N, C, k, k, L]
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
m = m[0] # m shape: [L, C, k, k]
# mm shape: [L, 1, 1, 1]
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to(
torch.float32
)
mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
y = []
offsets = []
k = self.fuse_k
scale = (
self.softmax_scale
) # to fit the PyTorch tensor image value range
fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
if self.use_cuda:
fuse_weight = fuse_weight.cuda()
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
"""
O => output channel as a conv filter
I => input channel as a conv filter
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
"""
# conv for compare
escape_NaN = torch.FloatTensor([1e-4])
if self.use_cuda:
escape_NaN = escape_NaN.cuda()
wi = wi[0] # [L, C, k, k]
max_wi = torch.sqrt(
reduce_sum(
torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True
)
)
wi_normed = wi / max_wi
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
xi = same_padding(
xi, [self.ksize, self.ksize], [1, 1], [1, 1]
) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
# conv implementation for fuse scores to encourage large patches
if self.fuse:
# make all of depth to spatial resolution
yi = yi.view(
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
) # (B=1, I=1, H=32*32, W=32*32)
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
yi = F.conv2d(
yi, fuse_weight, stride=1
) # (B=1, C=1, H=32*32, W=32*32)
yi = yi.contiguous().view(
1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]
) # (B=1, 32, 32, 32, 32)
yi = yi.permute(0, 2, 1, 4, 3)
yi = yi.contiguous().view(
1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]
)
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
yi = F.conv2d(yi, fuse_weight, stride=1)
yi = yi.contiguous().view(
1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]
)
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
yi = yi.view(
1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]
) # (B=1, C=32*32, H=32, W=32)
# softmax to match
yi = yi * mm
yi = F.softmax(yi * scale, dim=1)
yi = yi * mm # [1, L, H, W]
offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
if int_bs != int_fs:
# Normalize the offset value to match foreground dimension
times = float(int_fs[2] * int_fs[3]) / float(
int_bs[2] * int_bs[3]
)
offset = ((offset + 1).float() * times - 1).to(torch.int64)
offset = torch.cat(
[offset // int_fs[3], offset % int_fs[3]], dim=1
) # 1*2*H*W
# deconv for patch pasting
wi_center = raw_wi[0]
# yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
yi = (
F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1)
/ 4.0
) # (B=1, C=128, H=64, W=64)
y.append(yi)
offsets.append(offset)
y = torch.cat(y, dim=0) # back to the mini-batch
y.contiguous().view(raw_int_fs)
return y