SwinTExCo / src /models /CNN /ColorVidNet.py
duongttr's picture
Update new app
3d85088
import torch
import torch.nn as nn
import torch.nn.parallel
class ColorVidNet(nn.Module):
def __init__(self, ic):
super(ColorVidNet, self).__init__()
self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
self.conv1_2norm = nn.BatchNorm2d(64, affine=False)
self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
self.conv2_2norm = nn.BatchNorm2d(128, affine=False)
self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv3_3norm = nn.BatchNorm2d(256, affine=False)
self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv4_3norm = nn.BatchNorm2d(512, affine=False)
self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv5_3norm = nn.BatchNorm2d(512, affine=False)
self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
self.conv6_3norm = nn.BatchNorm2d(512, affine=False)
self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv7_3norm = nn.BatchNorm2d(512, affine=False)
self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv8_3norm = nn.BatchNorm2d(256, affine=False)
self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
self.conv9_2norm = nn.BatchNorm2d(128, affine=False)
self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
self.conv10_ab = nn.Conv2d(128, 2, 1, 1)
# add self.relux_x
self.relu1_1 = nn.PReLU()
self.relu1_2 = nn.PReLU()
self.relu2_1 = nn.PReLU()
self.relu2_2 = nn.PReLU()
self.relu3_1 = nn.PReLU()
self.relu3_2 = nn.PReLU()
self.relu3_3 = nn.PReLU()
self.relu4_1 = nn.PReLU()
self.relu4_2 = nn.PReLU()
self.relu4_3 = nn.PReLU()
self.relu5_1 = nn.PReLU()
self.relu5_2 = nn.PReLU()
self.relu5_3 = nn.PReLU()
self.relu6_1 = nn.PReLU()
self.relu6_2 = nn.PReLU()
self.relu6_3 = nn.PReLU()
self.relu7_1 = nn.PReLU()
self.relu7_2 = nn.PReLU()
self.relu7_3 = nn.PReLU()
self.relu8_1_comb = nn.PReLU()
self.relu8_2 = nn.PReLU()
self.relu8_3 = nn.PReLU()
self.relu9_1_comb = nn.PReLU()
self.relu9_2 = nn.PReLU()
self.relu10_1_comb = nn.PReLU()
self.relu10_2 = nn.LeakyReLU(0.2, True)
self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
self.conv1_2norm = nn.InstanceNorm2d(64)
self.conv2_2norm = nn.InstanceNorm2d(128)
self.conv3_3norm = nn.InstanceNorm2d(256)
self.conv4_3norm = nn.InstanceNorm2d(512)
self.conv5_3norm = nn.InstanceNorm2d(512)
self.conv6_3norm = nn.InstanceNorm2d(512)
self.conv7_3norm = nn.InstanceNorm2d(512)
self.conv8_3norm = nn.InstanceNorm2d(256)
self.conv9_2norm = nn.InstanceNorm2d(128)
def forward(self, x):
"""x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
conv1_1 = self.relu1_1(self.conv1_1(x))
conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
conv1_2norm = self.conv1_2norm(conv1_2)
conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
conv2_2norm = self.conv2_2norm(conv2_2)
conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
conv3_3norm = self.conv3_3norm(conv3_3)
conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
conv4_3norm = self.conv4_3norm(conv4_3)
conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
conv5_3norm = self.conv5_3norm(conv5_3)
conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
conv6_3norm = self.conv6_3norm(conv6_3)
conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
conv7_3norm = self.conv7_3norm(conv7_3)
conv8_1 = self.conv8_1(conv7_3norm)
conv3_3_short = self.conv3_3_short(conv3_3norm)
conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
conv8_3norm = self.conv8_3norm(conv8_3)
conv9_1 = self.conv9_1(conv8_3norm)
conv2_2_short = self.conv2_2_short(conv2_2norm)
conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
conv9_2norm = self.conv9_2norm(conv9_2)
conv10_1 = self.conv10_1(conv9_2norm)
conv1_2_short = self.conv1_2_short(conv1_2norm)
conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
conv10_ab = self.conv10_ab(conv10_2)
return torch.tanh(conv10_ab) * 128