|
import os |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import torch.optim as optim |
|
from cvxopt import matrix |
|
from cvxopt import solvers |
|
from cvxopt import sparse |
|
from cvxopt import spmatrix |
|
from torch.autograd import grad as torch_grad |
|
from tqdm import tqdm |
|
|
|
|
|
class WassersteinGanQuadraticCost: |
|
|
|
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, |
|
data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0): |
|
self.G = generator |
|
self.G_opt = gen_optimizer |
|
self.D = discriminator |
|
self.D_opt = dis_optimizer |
|
self.losses = { |
|
'D' : [], |
|
'WD': [], |
|
'G' : [] |
|
} |
|
self.num_steps = 0 |
|
self.gen_steps = 0 |
|
self.epochs = epochs |
|
self.n_max_iterations = n_max_iterations |
|
|
|
self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2] |
|
self.batch_size = batch_size |
|
self.device = device |
|
self.criterion = criterion |
|
self.mone = torch.FloatTensor([-1]).to(device) |
|
self.tensorboard_counter = 0 |
|
|
|
if K <= 0: |
|
self.K = 1 / self.data_dim |
|
else: |
|
self.K = K |
|
self.Kr = np.sqrt(self.K) |
|
self.LAMBDA = 2 * self.Kr * gamma * 2 |
|
|
|
self.G = self.G.to(self.device) |
|
self.D = self.D.to(self.device) |
|
|
|
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal) |
|
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal) |
|
|
|
self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size) |
|
|
|
def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1): |
|
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1) |
|
return scheduler |
|
|
|
def _quadratic_wasserstein_distance_(self, real, generated): |
|
num_r = real.size(0) |
|
num_f = generated.size(0) |
|
real_flat = real.view(num_r, -1) |
|
fake_flat = generated.view(num_f, -1) |
|
|
|
real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim) |
|
fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim) |
|
|
|
dif = real3D - fake3D |
|
dist = 0.5 * dif.pow(2).sum(2).squeeze() |
|
|
|
return self.K * dist |
|
|
|
def _prepare_linear_programming_solver_(self, batch_size): |
|
A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size)) |
|
for i in range(1, batch_size): |
|
Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size)) |
|
A = sparse([A, Ai]) |
|
|
|
D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size)) |
|
DM = D |
|
for i in range(1, batch_size): |
|
DM = sparse([DM, D]) |
|
|
|
A = sparse([[A], [DM]]) |
|
|
|
cr = matrix([-1.0 / batch_size] * batch_size) |
|
cf = matrix([1.0 / batch_size] * batch_size) |
|
c = matrix([cr, cf]) |
|
|
|
pStart = {} |
|
pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)]) |
|
pStart['s'] = matrix([1.0] * (2 * batch_size)) |
|
|
|
return c, A, pStart |
|
|
|
def _linear_programming_(self, distance, batch_size): |
|
b = matrix(distance.cpu().double().detach().numpy().flatten()) |
|
sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk', |
|
options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}}) |
|
offset = 0.5 * (sum(sol['x'])) / batch_size |
|
sol['x'] = sol['x'] - offset |
|
self.pStart['x'] = sol['x'] |
|
self.pStart['s'] = sol['s'] |
|
|
|
return sol |
|
|
|
def _approx_OT_(self, sol): |
|
|
|
ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size)) |
|
mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device) |
|
|
|
return mapping |
|
|
|
def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff): |
|
output_fake_grad = torch.ones(output_fake.size()).to(self.device) |
|
gradients = torch_grad(outputs=output_fake, inputs=fake, |
|
grad_outputs=output_fake_grad, |
|
create_graph=True, retain_graph=True, only_inputs=True)[0] |
|
n = gradients.size(0) |
|
RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n, |
|
-1).norm( |
|
dim=1)).pow(2)).mean() |
|
fake.requires_grad = False |
|
|
|
return RegLoss |
|
|
|
def _critic_deep_regression_(self, images, opt_iterations=1): |
|
images = images.to(self.device) |
|
|
|
for p in self.D.parameters(): |
|
p.requires_grad = True |
|
|
|
self.G.train() |
|
self.D.train() |
|
|
|
|
|
generated_data = self.sample_generator(self.batch_size) |
|
|
|
|
|
distance = self._quadratic_wasserstein_distance_(images, generated_data) |
|
|
|
sol = self._linear_programming_(distance, self.batch_size) |
|
|
|
mapping = self._approx_OT_(sol) |
|
real_ordered = images[mapping] |
|
real_fake_diff = real_ordered - generated_data |
|
|
|
|
|
target = torch.from_numpy(np.array(sol['x'])).float() |
|
target = target.squeeze().to(self.device) |
|
|
|
for i in range(opt_iterations): |
|
self.D.zero_grad() |
|
self.D_opt.zero_grad() |
|
generated_data.requires_grad_() |
|
if generated_data.grad is not None: |
|
generated_data.grad.data.zero_() |
|
output_real = self.D(images) |
|
output_fake = self.D(generated_data) |
|
output_real, output_fake = output_real.squeeze(), output_fake.squeeze() |
|
output_R_mean = output_real.mean(0).view(1) |
|
output_F_mean = output_fake.mean(0).view(1) |
|
|
|
L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean()) |
|
L2LossD_fake = self.criterion(output_fake, target[self.batch_size:]) |
|
L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake |
|
|
|
reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff) |
|
|
|
total_loss = L2LossD + self.LAMBDA * reg_loss_D |
|
|
|
self.losses['D'].append(float(total_loss.data)) |
|
|
|
total_loss.backward() |
|
self.D_opt.step() |
|
|
|
|
|
wasserstein_distance = output_R_mean - output_F_mean |
|
self.losses['WD'].append(float(wasserstein_distance.data)) |
|
|
|
def _generator_train_iteration(self, batch_size): |
|
for p in self.D.parameters(): |
|
p.requires_grad = False |
|
|
|
self.G.zero_grad() |
|
self.G_opt.zero_grad() |
|
|
|
if isinstance(self.G, torch.nn.parallel.DataParallel): |
|
z = self.G.module.sample_latent(batch_size, self.G.module.z_dim) |
|
else: |
|
z = self.G.sample_latent(batch_size, self.G.z_dim) |
|
z.requires_grad = True |
|
|
|
fake = self.G(z) |
|
output_fake = self.D(fake) |
|
output_F_mean_after = output_fake.mean(0).view(1) |
|
|
|
self.losses['G'].append(float(output_F_mean_after.data)) |
|
|
|
output_F_mean_after.backward(self.mone) |
|
self.G_opt.step() |
|
|
|
self.schedulerD.step() |
|
self.schedulerG.step() |
|
|
|
def _train_epoch(self, data_loader, writer, experiment): |
|
for i, data in enumerate(tqdm(data_loader)): |
|
images = data[0] |
|
speaker_ids = data[1] |
|
self.num_steps += 1 |
|
|
|
if self.gen_steps >= self.n_max_iterations: |
|
return |
|
self._critic_deep_regression_(images) |
|
self._generator_train_iteration(images.size(0)) |
|
|
|
D_loss_avg = np.average(self.losses['D']) |
|
G_loss_avg = np.average(self.losses['G']) |
|
wd_avg = np.average(self.losses['WD']) |
|
|
|
def train(self, data_loader, writer, experiment=None): |
|
self.G.train() |
|
self.D.train() |
|
|
|
for epoch in range(self.epochs): |
|
if self.gen_steps >= self.n_max_iterations: |
|
return |
|
time_start_epoch = time.time() |
|
self._train_epoch(data_loader, writer, experiment) |
|
|
|
D_loss_avg = np.average(self.losses['D']) |
|
|
|
time_end_epoch = time.time() |
|
|
|
return self |
|
|
|
def sample_generator(self, num_samples, nograd=False, return_intermediate=False): |
|
self.G.eval() |
|
if isinstance(self.G, torch.nn.parallel.DataParallel): |
|
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0) |
|
else: |
|
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0) |
|
latent_samples = latent_samples.to(self.device) |
|
if nograd: |
|
with torch.no_grad(): |
|
generated_data = self.G(latent_samples, return_intermediate=return_intermediate) |
|
else: |
|
generated_data = self.G(latent_samples) |
|
self.G.train() |
|
if return_intermediate: |
|
return generated_data[0].detach(), generated_data[1], latent_samples |
|
return generated_data.detach() |
|
|
|
def sample(self, num_samples): |
|
generated_data = self.sample_generator(num_samples) |
|
|
|
return generated_data.data.cpu().numpy()[:, 0, :, :] |
|
|
|
def save_model_checkpoint(self, model_path, model_parameters, timestampStr): |
|
|
|
|
|
name = '%s_%s' % (timestampStr, 'wgan') |
|
model_filename = os.path.join(model_path, name) |
|
torch.save({ |
|
'generator_state_dict' : self.G.state_dict(), |
|
'critic_state_dict' : self.D.state_dict(), |
|
'gen_optimizer_state_dict' : self.G_opt.state_dict(), |
|
'critic_optimizer_state_dict': self.D_opt.state_dict(), |
|
'model_parameters' : model_parameters, |
|
'iterations' : self.num_steps |
|
}, model_filename) |
|
|