Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
class ColorVidNet(nn.Module): | |
def __init__(self, ic): | |
super(ColorVidNet, self).__init__() | |
self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1)) | |
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) | |
self.conv1_2norm = nn.BatchNorm2d(64, affine=False) | |
self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64) | |
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1) | |
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) | |
self.conv2_2norm = nn.BatchNorm2d(128, affine=False) | |
self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128) | |
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) | |
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) | |
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) | |
self.conv3_3norm = nn.BatchNorm2d(256, affine=False) | |
self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256) | |
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) | |
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) | |
self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) | |
self.conv4_3norm = nn.BatchNorm2d(512, affine=False) | |
self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv5_3norm = nn.BatchNorm2d(512, affine=False) | |
self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) | |
self.conv6_3norm = nn.BatchNorm2d(512, affine=False) | |
self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1) | |
self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1) | |
self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1) | |
self.conv7_3norm = nn.BatchNorm2d(512, affine=False) | |
self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1) | |
self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1) | |
self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1) | |
self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1) | |
self.conv8_3norm = nn.BatchNorm2d(256, affine=False) | |
self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) | |
self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1) | |
self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1) | |
self.conv9_2norm = nn.BatchNorm2d(128, affine=False) | |
self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1) | |
self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1) | |
self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1) | |
self.conv10_ab = nn.Conv2d(128, 2, 1, 1) | |
# add self.relux_x | |
self.relu1_1 = nn.PReLU() | |
self.relu1_2 = nn.PReLU() | |
self.relu2_1 = nn.PReLU() | |
self.relu2_2 = nn.PReLU() | |
self.relu3_1 = nn.PReLU() | |
self.relu3_2 = nn.PReLU() | |
self.relu3_3 = nn.PReLU() | |
self.relu4_1 = nn.PReLU() | |
self.relu4_2 = nn.PReLU() | |
self.relu4_3 = nn.PReLU() | |
self.relu5_1 = nn.PReLU() | |
self.relu5_2 = nn.PReLU() | |
self.relu5_3 = nn.PReLU() | |
self.relu6_1 = nn.PReLU() | |
self.relu6_2 = nn.PReLU() | |
self.relu6_3 = nn.PReLU() | |
self.relu7_1 = nn.PReLU() | |
self.relu7_2 = nn.PReLU() | |
self.relu7_3 = nn.PReLU() | |
self.relu8_1_comb = nn.PReLU() | |
self.relu8_2 = nn.PReLU() | |
self.relu8_3 = nn.PReLU() | |
self.relu9_1_comb = nn.PReLU() | |
self.relu9_2 = nn.PReLU() | |
self.relu10_1_comb = nn.PReLU() | |
self.relu10_2 = nn.LeakyReLU(0.2, True) | |
self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1)) | |
self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1)) | |
self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1)) | |
self.conv1_2norm = nn.InstanceNorm2d(64) | |
self.conv2_2norm = nn.InstanceNorm2d(128) | |
self.conv3_3norm = nn.InstanceNorm2d(256) | |
self.conv4_3norm = nn.InstanceNorm2d(512) | |
self.conv5_3norm = nn.InstanceNorm2d(512) | |
self.conv6_3norm = nn.InstanceNorm2d(512) | |
self.conv7_3norm = nn.InstanceNorm2d(512) | |
self.conv8_3norm = nn.InstanceNorm2d(256) | |
self.conv9_2norm = nn.InstanceNorm2d(128) | |
def forward(self, x): | |
"""x: gray image (1 channel), ab(2 channel), ab_err, ba_err""" | |
conv1_1 = self.relu1_1(self.conv1_1(x)) | |
conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) | |
conv1_2norm = self.conv1_2norm(conv1_2) | |
conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm) | |
conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss)) | |
conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) | |
conv2_2norm = self.conv2_2norm(conv2_2) | |
conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm) | |
conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss)) | |
conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) | |
conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) | |
conv3_3norm = self.conv3_3norm(conv3_3) | |
conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm) | |
conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss)) | |
conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) | |
conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) | |
conv4_3norm = self.conv4_3norm(conv4_3) | |
conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm)) | |
conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) | |
conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) | |
conv5_3norm = self.conv5_3norm(conv5_3) | |
conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm)) | |
conv6_2 = self.relu6_2(self.conv6_2(conv6_1)) | |
conv6_3 = self.relu6_3(self.conv6_3(conv6_2)) | |
conv6_3norm = self.conv6_3norm(conv6_3) | |
conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm)) | |
conv7_2 = self.relu7_2(self.conv7_2(conv7_1)) | |
conv7_3 = self.relu7_3(self.conv7_3(conv7_2)) | |
conv7_3norm = self.conv7_3norm(conv7_3) | |
conv8_1 = self.conv8_1(conv7_3norm) | |
conv3_3_short = self.conv3_3_short(conv3_3norm) | |
conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short) | |
conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb)) | |
conv8_3 = self.relu8_3(self.conv8_3(conv8_2)) | |
conv8_3norm = self.conv8_3norm(conv8_3) | |
conv9_1 = self.conv9_1(conv8_3norm) | |
conv2_2_short = self.conv2_2_short(conv2_2norm) | |
conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short) | |
conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb)) | |
conv9_2norm = self.conv9_2norm(conv9_2) | |
conv10_1 = self.conv10_1(conv9_2norm) | |
conv1_2_short = self.conv1_2_short(conv1_2norm) | |
conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short) | |
conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb)) | |
conv10_ab = self.conv10_ab(conv10_2) | |
return torch.tanh(conv10_ab) * 128 | |