Colorization / models.py
ChiKyi's picture
update file
4d92358
raw
history blame
6.92 kB
import torch
from torch import nn, optim
from loss import GANLoss
class UnetBlock(nn.Module):
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
innermost=False, outermost=False):
super().__init__()
self.outermost = outermost
if input_c is None: input_c = nf
downconv = nn.Conv2d(input_c, ni, kernel_size=4,
stride=2, padding=1, bias=False)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = nn.BatchNorm2d(ni)
uprelu = nn.ReLU(True)
upnorm = nn.BatchNorm2d(nf)
if outermost:
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
stride=2, padding=1, bias=False)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
stride=2, padding=1, bias=False)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if dropout: up += [nn.Dropout(0.5)]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
class Unet(nn.Module):
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
super().__init__()
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
for _ in range(n_down - 5):
unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
out_filters = num_filters * 8
for _ in range(3):
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
out_filters //= 2
self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
def forward(self, x):
return self.model(x)
class PatchDiscriminator(nn.Module):
def __init__(self, input_c, num_filters=64, n_down=3):
super().__init__()
model = [self.get_layers(input_c, num_filters, norm=False)]
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down - 1) else 2)
for i in range(n_down)] # the 'if' statement is taking care of not using
# stride of 2 for the last block in this loop
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False,
act=False)] # Make sure to not use normalization or
# activation for the last layer of the model
self.model = nn.Sequential(*model)
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True,
act=True): # when needing to make some repeatitive blocks of layers,
layers = [
nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
if norm: layers += [nn.BatchNorm2d(nf)]
if act: layers += [nn.LeakyReLU(0.2, True)]
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
def init_weights(net, init='norm', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
if init == 'norm':
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
elif init == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
print(f"model initialized with {init} initialization")
return net
def init_model(model, device):
model = model.to(device)
model = init_weights(model)
return model
class MainModel(nn.Module):
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
beta1=0.5, beta2=0.999, lambda_L1=100.):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lambda_L1 = lambda_L1
if net_G is None:
self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
else:
self.net_G = net_G.to(self.device)
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
self.L1criterion = nn.L1Loss()
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
def set_requires_grad(self, model, requires_grad=True):
for p in model.parameters():
p.requires_grad = requires_grad
def setup_input(self, data):
self.L = data['L'].to(self.device)
self.ab = data['ab'].to(self.device)
def forward(self):
self.fake_color = self.net_G(self.L)
def backward_D(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image.detach())
self.loss_D_fake = self.GANcriterion(fake_preds, False)
real_image = torch.cat([self.L, self.ab], dim=1)
real_preds = self.net_D(real_image)
self.loss_D_real = self.GANcriterion(real_preds, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image)
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize(self):
self.forward()
self.net_D.train()
self.set_requires_grad(self.net_D, True)
self.opt_D.zero_grad()
self.backward_D()
self.opt_D.step()
self.net_G.train()
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step()