|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def sub_mean(x): |
|
mean = x.mean(2, keepdim=True).mean(3, keepdim=True) |
|
x -= mean |
|
return x, mean |
|
|
|
def InOutPaddings(x): |
|
w, h = x.size(3), x.size(2) |
|
padding_width, padding_height = 0, 0 |
|
if w != ((w >> 7) << 7): |
|
padding_width = (((w >> 7) + 1) << 7) - w |
|
if h != ((h >> 7) << 7): |
|
padding_height = (((h >> 7) + 1) << 7) - h |
|
paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2, |
|
padding_height // 2, padding_height - padding_height // 2]) |
|
paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width, |
|
0 - padding_height // 2, padding_height // 2 - padding_height]) |
|
return paddingInput, paddingOutput |
|
|
|
|
|
class ConvNorm(nn.Module): |
|
def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False): |
|
super(ConvNorm, self).__init__() |
|
|
|
reflection_padding = kernel_size // 2 |
|
self.reflection_pad = nn.ReflectionPad2d(reflection_padding) |
|
self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) |
|
|
|
self.norm = norm |
|
if norm == 'IN': |
|
self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True) |
|
elif norm == 'BN': |
|
self.norm = nn.BatchNorm2d(out_feat) |
|
|
|
def forward(self, x): |
|
out = self.reflection_pad(x) |
|
out = self.conv(out) |
|
if self.norm: |
|
out = self.norm(out) |
|
return out |
|
|
|
|
|
class UpConvNorm(nn.Module): |
|
def __init__(self, in_channels, out_channels, mode='transpose', norm=False): |
|
super(UpConvNorm, self).__init__() |
|
|
|
if mode == 'transpose': |
|
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) |
|
elif mode == 'shuffle': |
|
self.upconv = nn.Sequential( |
|
ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm), |
|
PixelShuffle(2)) |
|
else: |
|
|
|
self.upconv = nn.Sequential( |
|
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), |
|
ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm)) |
|
|
|
def forward(self, x): |
|
out = self.upconv(x) |
|
return out |
|
|
|
|
|
|
|
class meanShift(nn.Module): |
|
def __init__(self, rgbRange, rgbMean, sign, nChannel=3): |
|
super(meanShift, self).__init__() |
|
if nChannel == 1: |
|
l = rgbMean[0] * rgbRange * float(sign) |
|
|
|
self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0) |
|
self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1) |
|
self.shifter.bias.data = torch.Tensor([l]) |
|
elif nChannel == 3: |
|
r = rgbMean[0] * rgbRange * float(sign) |
|
g = rgbMean[1] * rgbRange * float(sign) |
|
b = rgbMean[2] * rgbRange * float(sign) |
|
|
|
self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) |
|
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) |
|
self.shifter.bias.data = torch.Tensor([r, g, b]) |
|
else: |
|
r = rgbMean[0] * rgbRange * float(sign) |
|
g = rgbMean[1] * rgbRange * float(sign) |
|
b = rgbMean[2] * rgbRange * float(sign) |
|
self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0) |
|
self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1) |
|
self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b]) |
|
|
|
|
|
for params in self.shifter.parameters(): |
|
params.requires_grad = False |
|
|
|
def forward(self, x): |
|
x = self.shifter(x) |
|
|
|
return x |
|
|
|
|
|
""" CONV - (BN) - RELU - CONV - (BN) """ |
|
class ResBlock(nn.Module): |
|
def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, |
|
norm=False, act=nn.ReLU(True), downscale=False): |
|
super(ResBlock, self).__init__() |
|
|
|
self.body = nn.Sequential( |
|
ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1), |
|
act, |
|
ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1) |
|
) |
|
|
|
self.downscale = None |
|
if downscale: |
|
self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2) |
|
|
|
def forward(self, x): |
|
res = x |
|
out = self.body(x) |
|
if self.downscale is not None: |
|
res = self.downscale(res) |
|
out += res |
|
|
|
return out |
|
|
|
|
|
|
|
class CALayer(nn.Module): |
|
def __init__(self, channel, reduction=16): |
|
super(CALayer, self).__init__() |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
self.conv_du = nn.Sequential( |
|
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv_du(y) |
|
return x * y, y |
|
|
|
|
|
|
|
class RCAB(nn.Module): |
|
def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True, |
|
norm=False, act=nn.ReLU(True), downscale=False, return_ca=False): |
|
super(RCAB, self).__init__() |
|
|
|
self.body = nn.Sequential( |
|
ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm), |
|
act, |
|
ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm), |
|
CALayer(out_feat, reduction) |
|
) |
|
self.downscale = downscale |
|
if downscale: |
|
self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1) |
|
self.return_ca = return_ca |
|
|
|
def forward(self, x): |
|
res = x |
|
out, ca = self.body(x) |
|
if self.downscale: |
|
res = self.downConv(res) |
|
out += res |
|
|
|
if self.return_ca: |
|
return out, ca |
|
else: |
|
return out |
|
|
|
|
|
|
|
class ResidualGroup(nn.Module): |
|
def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False): |
|
super(ResidualGroup, self).__init__() |
|
|
|
modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act) |
|
for _ in range(n_resblocks)] |
|
modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm)) |
|
self.body = nn.Sequential(*modules_body) |
|
|
|
def forward(self, x): |
|
res = self.body(x) |
|
res += x |
|
return res |
|
|
|
|
|
def pixel_shuffle(input, scale_factor): |
|
batch_size, channels, in_height, in_width = input.size() |
|
|
|
out_channels = int(int(channels / scale_factor) / scale_factor) |
|
out_height = int(in_height * scale_factor) |
|
out_width = int(in_width * scale_factor) |
|
|
|
if scale_factor >= 1: |
|
input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) |
|
shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() |
|
else: |
|
block_size = int(1 / scale_factor) |
|
input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) |
|
shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() |
|
|
|
return shuffle_out.view(batch_size, out_channels, out_height, out_width) |
|
|
|
|
|
class PixelShuffle(nn.Module): |
|
def __init__(self, scale_factor): |
|
super(PixelShuffle, self).__init__() |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, x): |
|
return pixel_shuffle(x, self.scale_factor) |
|
def extra_repr(self): |
|
return 'scale_factor={}'.format(self.scale_factor) |
|
|
|
|
|
def conv(in_channels, out_channels, kernel_size, |
|
stride=1, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size//2, |
|
stride=1, |
|
bias=bias, |
|
groups=groups) |
|
|
|
|
|
def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=stride, |
|
bias=bias, |
|
groups=groups) |
|
|
|
def conv3x3(in_channels, out_channels, stride=1, |
|
padding=1, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
groups=groups) |
|
|
|
def conv5x5(in_channels, out_channels, stride=1, |
|
padding=2, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=5, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
groups=groups) |
|
|
|
def conv7x7(in_channels, out_channels, stride=1, |
|
padding=3, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=7, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
groups=groups) |
|
|
|
def upconv2x2(in_channels, out_channels, mode='shuffle'): |
|
if mode == 'transpose': |
|
return nn.ConvTranspose2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1) |
|
elif mode == 'shuffle': |
|
return nn.Sequential( |
|
conv3x3(in_channels, 4*out_channels), |
|
PixelShuffle(2)) |
|
else: |
|
|
|
return nn.Sequential( |
|
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), |
|
conv1x1(in_channels, out_channels)) |
|
|
|
|
|
|
|
class Interpolation(nn.Module): |
|
def __init__(self, n_resgroups, n_resblocks, n_feats, |
|
reduction=16, act=nn.LeakyReLU(0.2, True), norm=False): |
|
super(Interpolation, self).__init__() |
|
|
|
|
|
self.headConv = conv3x3(n_feats * 2, n_feats) |
|
|
|
modules_body = [ |
|
ResidualGroup( |
|
RCAB, |
|
n_resblocks=n_resblocks, |
|
n_feat=n_feats, |
|
kernel_size=3, |
|
reduction=reduction, |
|
act=act, |
|
norm=norm) |
|
for _ in range(n_resgroups)] |
|
self.body = nn.Sequential(*modules_body) |
|
|
|
self.tailConv = conv3x3(n_feats, n_feats) |
|
|
|
def forward(self, x0, x1): |
|
|
|
x = torch.cat([x0, x1], dim=1) |
|
x = self.headConv(x) |
|
|
|
res = self.body(x) |
|
res += x |
|
|
|
out = self.tailConv(res) |
|
return out |
|
|
|
|
|
class Interpolation_res(nn.Module): |
|
def __init__(self, n_resgroups, n_resblocks, n_feats, |
|
act=nn.LeakyReLU(0.2, True), norm=False): |
|
super(Interpolation_res, self).__init__() |
|
|
|
|
|
self.headConv = conv3x3(n_feats * 2, n_feats) |
|
|
|
modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3, |
|
reduction=0, act=act, norm=norm) |
|
for _ in range(n_resgroups)] |
|
self.body = nn.Sequential(*modules_body) |
|
|
|
self.tailConv = conv3x3(n_feats, n_feats) |
|
|
|
def forward(self, x0, x1): |
|
|
|
x = torch.cat([x0, x1], dim=1) |
|
x = self.headConv(x) |
|
|
|
res = x |
|
for m in self.body: |
|
res = m(res) |
|
res += x |
|
|
|
x = self.tailConv(res) |
|
|
|
return x |