DifFace / models /srcnn.py
Zongsheng
first upload
06f26d7
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-12 20:35:28
import math
from torch import nn
import torch.nn.functional as F
class SRCNN(nn.Module):
def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4):
super().__init__()
self.sf = sf
out_chns = in_chns if out_chns is None else out_chns
self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2)
body = []
for _ in range(depth-1):
body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
body.append(nn.LeakyReLU(0.2, inplace=True))
self.body = nn.Sequential(*body)
tail = []
for _ in range(int(math.log(sf, 2))):
tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1))
tail.append(nn.LeakyReLU(0.2, inplace=True))
tail.append(nn.PixelShuffle(2))
tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2))
self.tail = nn.Sequential(*tail)
def forward(self, x):
y = self.head(x)
y = self.body(y)
y = self.tail(y)
return y
class SRCNNFSR(nn.Module):
def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4):
super().__init__()
self.sf = sf
head = []
in_chns_shuffle = in_chns * 4
assert num_chns % 4 == 0
for ii in range(int(math.log(down_scale_factor, 2))):
head.append(nn.PixelUnshuffle(2))
head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1))
if ii + 1 < int(math.log(down_scale_factor, 2)):
head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2))
head.append(nn.LeakyReLU(0.2, inplace=True))
in_chns_shuffle = num_chns
self.head = nn.Sequential(*head)
body = []
for _ in range(depth-1):
body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
body.append(nn.LeakyReLU(0.2, inplace=True))
self.body = nn.Sequential(*body)
tail = []
for _ in range(int(math.log(down_scale_factor, 2))):
tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1))
tail.append(nn.LeakyReLU(0.2, inplace=True))
tail.append(nn.PixelShuffle(2))
num_chns //= 4
tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2))
self.tail = nn.Sequential(*tail)
def forward(self, x):
y = self.head(x)
y = self.body(y)
y = self.tail(y)
return y