Spaces:
Running
Running
File size: 4,705 Bytes
d960e2d 8dc35d7 d960e2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
try:
from .nafnet_utils.arch_util import LayerNorm2d
from .nafnet_utils.arch_model import SimpleGate
except:
from nafnet_utils.arch_util import LayerNorm2d
from nafnet_utils.arch_model import SimpleGate
'''
https://github.com/wangchx67/FourLLIE.git
'''
class FreNAFBlock(nn.Module):
def __init__(self, nc, expand = 2):
super(FreNAFBlock, self).__init__()
self.process1 = nn.Sequential(
nn.Conv2d(nc, expand * nc, 1, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(expand * nc, nc, 1, 1, 0))
def forward(self, x):
_, _, H, W = x.shape
x_freq = torch.fft.rfft2(x, norm='backward')
mag = torch.abs(x_freq)
pha = torch.angle(x_freq)
mag = self.process1(mag)
real = mag * torch.cos(pha)
imag = mag * torch.sin(pha)
x_out = torch.complex(real, imag)
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
return x_out
# ------------------------------------------------------------------------------------------------
class Branch(nn.Module):
'''
Branch that lasts lonly the dilated convolutions
'''
def __init__(self, c, DW_Expand, dilation = 1, extra_depth_wise = False):
super().__init__()
self.dw_channel = DW_Expand * c
self.branch = nn.Sequential(
nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity(), #optional extra dw
nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1),
nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
bias=True, dilation = dilation) # the dconv
)
def forward(self, input):
return self.branch(input)
class EBlock_freq(nn.Module):
'''
Change this block using Branch
'''
def __init__(self, c, DW_Expand=2, dilations = [1], extra_depth_wise = False):
super().__init__()
#we define the 2 branches
self.branches = nn.ModuleList()
for dilation in dilations:
self.branches.append(Branch(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
assert len(dilations) == len(self.branches)
self.dw_channel = DW_Expand * c
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
groups=1, bias=True, dilation = 1),
)
self.sg1 = SimpleGate()
self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
# second step
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.freq = FreNAFBlock(nc = c, expand=2)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def forward(self, inp):
y = inp
x = self.norm1(inp)
z = 0
for branch in self.branches:
z += branch(x)
z = self.sg1(z)
x = self.sca(z) * z
x = self.conv3(x)
y = inp + self.beta * x
#second step
x_step2 = self.norm2(y) # size [B, 2*C, H, W]
x_freq = self.freq(x_step2) # size [B, C, H, W]
x = y * x_freq
return y + x * self.gamma
#----------------------------------------------------------------------------------------------
if __name__ == '__main__':
img_channel = 128
width = 32
enc_blks = [1, 2, 3]
middle_blk_num = 3
dec_blks = [3, 1, 1]
dilations = [1, 4, 9]
extra_depth_wise = True
# net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
# enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
net = EBlock_freq(c = img_channel,
dilations = dilations,
extra_depth_wise=extra_depth_wise)
inp_shape = (128, 32, 32)
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
print(macs, params) |