MinhNH
Initial commit
48c5871
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Adam, SGD
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torch.nn.utils.weight_norm as weightNorm
from utils.util import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dim = 128
LAMBDA = 10 # Gradient penalty lambda hyperparameter
class TReLU(nn.Module):
def __init__(self):
super(TReLU, self).__init__()
self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.alpha.data.fill_(0)
def forward(self, x):
x = F.relu(x - self.alpha) + self.alpha
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2))
self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2))
self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2))
self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2))
self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2))
self.relu0 = TReLU()
self.relu1 = TReLU()
self.relu2 = TReLU()
self.relu3 = TReLU()
def forward(self, x):
x = self.conv0(x)
x = self.relu0(x)
x = self.conv1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.conv4(x)
x = F.avg_pool2d(x, 4)
x = x.view(-1, 1)
return x
netD = Discriminator()
target_netD = Discriminator()
netD = netD.to(device)
target_netD = target_netD.to(device)
hard_update(target_netD, netD)
optimizerD = Adam(netD.parameters(), lr=3e-4, betas=(0.5, 0.999))
def cal_gradient_penalty(netD, real_data, fake_data, batch_size):
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
alpha = alpha.view(batch_size, 6, dim, dim)
alpha = alpha.to(device)
fake_data = fake_data.view(batch_size, 6, dim, dim)
interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(disc_interpolates, interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty
def cal_reward(fake_data, real_data):
return target_netD(torch.cat([real_data, fake_data], 1))
def save_gan(path):
netD.cpu()
torch.save(netD.state_dict(),'{}/wgan.pkl'.format(path))
netD.to(device)
def load_gan(path):
netD.load_state_dict(torch.load('{}/wgan.pkl'.format(path)))
def update(fake_data, real_data):
fake_data = fake_data.detach()
real_data = real_data.detach()
fake = torch.cat([real_data, fake_data], 1)
real = torch.cat([real_data, real_data], 1)
D_real = netD(real)
D_fake = netD(fake)
gradient_penalty = cal_gradient_penalty(netD, real, fake, real.shape[0])
optimizerD.zero_grad()
D_cost = D_fake.mean() - D_real.mean() + gradient_penalty
D_cost.backward()
optimizerD.step()
soft_update(target_netD, netD, 0.001)
return D_fake.mean(), D_real.mean(), gradient_penalty