import torch import torch.nn as nn import numpy as np from torch.optim import Adam, SGD from torch import autograd from torch.autograd import Variable import torch.nn.functional as F from torch.autograd import grad as torch_grad import torch.nn.utils.weight_norm as weightNorm from utils.util import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dim = 128 LAMBDA = 10 # Gradient penalty lambda hyperparameter class TReLU(nn.Module): def __init__(self): super(TReLU, self).__init__() self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) self.alpha.data.fill_(0) def forward(self, x): x = F.relu(x - self.alpha) + self.alpha return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2)) self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2)) self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2)) self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2)) self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2)) self.relu0 = TReLU() self.relu1 = TReLU() self.relu2 = TReLU() self.relu3 = TReLU() def forward(self, x): x = self.conv0(x) x = self.relu0(x) x = self.conv1(x) x = self.relu1(x) x = self.conv2(x) x = self.relu2(x) x = self.conv3(x) x = self.relu3(x) x = self.conv4(x) x = F.avg_pool2d(x, 4) x = x.view(-1, 1) return x netD = Discriminator() target_netD = Discriminator() netD = netD.to(device) target_netD = target_netD.to(device) hard_update(target_netD, netD) optimizerD = Adam(netD.parameters(), lr=3e-4, betas=(0.5, 0.999)) def cal_gradient_penalty(netD, real_data, fake_data, batch_size): alpha = torch.rand(batch_size, 1) alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous() alpha = alpha.view(batch_size, 6, dim, dim) alpha = alpha.to(device) fake_data = fake_data.view(batch_size, 6, dim, dim) interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad(disc_interpolates, interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA return gradient_penalty def cal_reward(fake_data, real_data): return target_netD(torch.cat([real_data, fake_data], 1)) def save_gan(path): netD.cpu() torch.save(netD.state_dict(),'{}/wgan.pkl'.format(path)) netD.to(device) def load_gan(path): netD.load_state_dict(torch.load('{}/wgan.pkl'.format(path))) def update(fake_data, real_data): fake_data = fake_data.detach() real_data = real_data.detach() fake = torch.cat([real_data, fake_data], 1) real = torch.cat([real_data, real_data], 1) D_real = netD(real) D_fake = netD(fake) gradient_penalty = cal_gradient_penalty(netD, real, fake, real.shape[0]) optimizerD.zero_grad() D_cost = D_fake.mean() - D_real.mean() + gradient_penalty D_cost.backward() optimizerD.step() soft_update(target_netD, netD, 0.001) return D_fake.mean(), D_real.mean(), gradient_penalty