JPEGArtifactRemover / network_fbcnn.py
Ajay Harikumar
Add initial versions of files without resolution constrain
feff774
raw
history blame
14.4 kB
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision.models as models
'''
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def sequential(*args):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Module
Returns:
nn.Sequential
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2):
L = []
for t in mode:
if t == 'C':
L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
elif t == 'T':
L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
elif t == 'B':
L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
elif t == 'I':
L.append(nn.InstanceNorm2d(out_channels, affine=True))
elif t == 'R':
L.append(nn.ReLU(inplace=True))
elif t == 'r':
L.append(nn.ReLU(inplace=False))
elif t == 'L':
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
elif t == 'l':
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
elif t == '2':
L.append(nn.PixelShuffle(upscale_factor=2))
elif t == '3':
L.append(nn.PixelShuffle(upscale_factor=3))
elif t == '4':
L.append(nn.PixelShuffle(upscale_factor=4))
elif t == 'U':
L.append(nn.Upsample(scale_factor=2, mode='nearest'))
elif t == 'u':
L.append(nn.Upsample(scale_factor=3, mode='nearest'))
elif t == 'v':
L.append(nn.Upsample(scale_factor=4, mode='nearest'))
elif t == 'M':
L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
elif t == 'A':
L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
else:
raise NotImplementedError('Undefined type: '.format(t))
return sequential(*L)
# --------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# --------------------------------------------
class ResBlock(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
super(ResBlock, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R', 'L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x):
res = self.res(x)
return x + res
# --------------------------------------------
# conv + subp (+ relu)
# --------------------------------------------
def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope)
return up1
# --------------------------------------------
# nearest_upsample + conv (+ R)
# --------------------------------------------
def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
if mode[0] == '2':
uc = 'UC'
elif mode[0] == '3':
uc = 'uC'
elif mode[0] == '4':
uc = 'vC'
mode = mode.replace(mode[0], uc)
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
return up1
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'T')
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return up1
'''
# --------------------------------------------
# Downsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# downsample_strideconv
# downsample_maxpool
# downsample_avgpool
# --------------------------------------------
'''
# --------------------------------------------
# strideconv (+ relu)
# --------------------------------------------
def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'C')
down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return down1
# --------------------------------------------
# maxpooling + conv (+ relu)
# --------------------------------------------
def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'MC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
return sequential(pool, pool_tail)
# --------------------------------------------
# averagepooling + conv (+ relu)
# --------------------------------------------
def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'AC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
return sequential(pool, pool_tail)
class QFAttention(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
super(QFAttention, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R', 'L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x, gamma, beta):
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
beta = beta.unsqueeze(-1).unsqueeze(-1)
res = (gamma)*self.res(x) + beta
return x + res
class FBCNN(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv',
upsample_mode='convtranspose'):
super(FBCNN, self).__init__()
self.m_head = conv(in_nc, nc[0], bias=True, mode='C')
self.nb = nb
self.nc = nc
# downsample
if downsample_mode == 'avgpool':
downsample_block = downsample_avgpool
elif downsample_mode == 'maxpool':
downsample_block = downsample_maxpool
elif downsample_mode == 'strideconv':
downsample_block = downsample_strideconv
else:
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
self.m_down1 = sequential(
*[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[0], nc[1], bias=True, mode='2'))
self.m_down2 = sequential(
*[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[1], nc[2], bias=True, mode='2'))
self.m_down3 = sequential(
*[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[2], nc[3], bias=True, mode='2'))
self.m_body_encoder = sequential(
*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
self.m_body_decoder = sequential(
*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
# upsample
if upsample_mode == 'upconv':
upsample_block = upsample_upconv
elif upsample_mode == 'pixelshuffle':
upsample_block = upsample_pixelshuffle
elif upsample_mode == 'convtranspose':
upsample_block = upsample_convtranspose
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
self.m_up3 = nn.ModuleList([upsample_block(nc[3], nc[2], bias=True, mode='2'),
*[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
self.m_up2 = nn.ModuleList([upsample_block(nc[2], nc[1], bias=True, mode='2'),
*[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
self.m_up1 = nn.ModuleList([upsample_block(nc[1], nc[0], bias=True, mode='2'),
*[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]])
self.m_tail = conv(nc[0], out_nc, bias=True, mode='C')
self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
torch.nn.AdaptiveAvgPool2d((1,1)),
torch.nn.Flatten(),
torch.nn.Linear(512, 512),
nn.ReLU(),
torch.nn.Linear(512, 512),
nn.ReLU(),
torch.nn.Linear(512, 1),
nn.Sigmoid()
)
self.qf_embed = sequential(torch.nn.Linear(1, 512),
nn.ReLU(),
torch.nn.Linear(512, 512),
nn.ReLU(),
torch.nn.Linear(512, 512),
nn.ReLU()
)
self.to_gamma_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Sigmoid())
self.to_beta_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Tanh())
self.to_gamma_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Sigmoid())
self.to_beta_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Tanh())
self.to_gamma_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Sigmoid())
self.to_beta_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Tanh())
def forward(self, x, qf_input=None):
h, w = x.size()[-2:]
paddingBottom = int(np.ceil(h / 8) * 8 - h)
paddingRight = int(np.ceil(w / 8) * 8 - w)
x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
x1 = self.m_head(x)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body_encoder(x4)
qf = self.qf_pred(x)
x = self.m_body_decoder(x)
qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf)
gamma_3 = self.to_gamma_3(qf_embedding)
beta_3 = self.to_beta_3(qf_embedding)
gamma_2 = self.to_gamma_2(qf_embedding)
beta_2 = self.to_beta_2(qf_embedding)
gamma_1 = self.to_gamma_1(qf_embedding)
beta_1 = self.to_beta_1(qf_embedding)
x = x + x4
x = self.m_up3[0](x)
for i in range(self.nb):
x = self.m_up3[i+1](x, gamma_3,beta_3)
x = x + x3
x = self.m_up2[0](x)
for i in range(self.nb):
x = self.m_up2[i+1](x, gamma_2, beta_2)
x = x + x2
x = self.m_up1[0](x)
for i in range(self.nb):
x = self.m_up1[i+1](x, gamma_1, beta_1)
x = x + x1
x = self.m_tail(x)
x = x[..., :h, :w]
return x, qf
if __name__ == "__main__":
x = torch.randn(1, 3, 96, 96)#.cuda()#.to(torch.device('cuda'))
fbar=FBAR()
y,qf = fbar(x)
print(y.shape,qf.shape)