|
"""This module contains simple helper functions """
|
|
from __future__ import print_function
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
import os
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
|
|
def random_word(len_word, alphabet):
|
|
|
|
char = np.random.randint(low=0, high=len(alphabet), size=len_word)
|
|
word = [alphabet[c] for c in char]
|
|
return ''.join(word)
|
|
|
|
def load_network(net, save_dir, epoch):
|
|
"""Load all the networks from the disk.
|
|
|
|
Parameters:
|
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
|
"""
|
|
load_filename = '%s_net_%s.pth' % (epoch, net.name)
|
|
load_path = os.path.join(save_dir, load_filename)
|
|
|
|
|
|
state_dict = torch.load(load_path)
|
|
if hasattr(state_dict, '_metadata'):
|
|
del state_dict._metadata
|
|
net.load_state_dict(state_dict)
|
|
return net
|
|
|
|
def writeCache(env, cache):
|
|
with env.begin(write=True) as txn:
|
|
for k, v in cache.items():
|
|
if type(k) == str:
|
|
k = k.encode()
|
|
if type(v) == str:
|
|
v = v.encode()
|
|
txn.put(k, v)
|
|
|
|
def loadData(v, data):
|
|
with torch.no_grad():
|
|
v.resize_(data.size()).copy_(data)
|
|
|
|
def multiple_replace(string, rep_dict):
|
|
for key in rep_dict.keys():
|
|
string = string.replace(key, rep_dict[key])
|
|
return string
|
|
|
|
def get_curr_data(data, batch_size, counter):
|
|
curr_data = {}
|
|
for key in data:
|
|
curr_data[key] = data[key][batch_size*counter:batch_size*(counter+1)]
|
|
return curr_data
|
|
|
|
|
|
def seed_rng(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
def make_one_hot(labels, len_labels, n_classes):
|
|
one_hot = torch.zeros((labels.shape[0], labels.shape[1], n_classes),dtype=torch.float32)
|
|
for i in range(len(labels)):
|
|
one_hot[i,np.array(range(len_labels[i])), labels[i,:len_labels[i]]-1]=1
|
|
return one_hot
|
|
|
|
|
|
def loss_hinge_dis(dis_fake, dis_real, len_text_fake, len_text, mask_loss):
|
|
mask_real = torch.ones(dis_real.shape).to(dis_real.device)
|
|
mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device)
|
|
if mask_loss and len(dis_fake.shape)>2:
|
|
for i in range(len(len_text)):
|
|
mask_real[i, :, :, len_text[i]:] = 0
|
|
mask_fake[i, :, :, len_text_fake[i]:] = 0
|
|
loss_real = torch.sum(F.relu(1. - dis_real * mask_real))/torch.sum(mask_real)
|
|
loss_fake = torch.sum(F.relu(1. + dis_fake * mask_fake))/torch.sum(mask_fake)
|
|
return loss_real, loss_fake
|
|
|
|
|
|
def loss_hinge_gen(dis_fake, len_text_fake, mask_loss):
|
|
mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device)
|
|
if mask_loss and len(dis_fake.shape)>2:
|
|
for i in range(len(len_text_fake)):
|
|
mask_fake[i, :, :, len_text_fake[i]:] = 0
|
|
loss = -torch.sum(dis_fake*mask_fake)/torch.sum(mask_fake)
|
|
return loss
|
|
|
|
def loss_std(z, lengths, mask_loss):
|
|
loss_std = torch.zeros(1).to(z.device)
|
|
z_mean = torch.ones((z.shape[0], z.shape[1])).to(z.device)
|
|
for i in range(len(lengths)):
|
|
if mask_loss:
|
|
if lengths[i]>1:
|
|
loss_std += torch.mean(torch.std(z[i, :, :, :lengths[i]], 2))
|
|
z_mean[i,:] = torch.mean(z[i, :, :, :lengths[i]], 2).squeeze(1)
|
|
else:
|
|
z_mean[i, :] = z[i, :, :, 0].squeeze(1)
|
|
else:
|
|
loss_std += torch.mean(torch.std(z[i, :, :, :], 2))
|
|
z_mean[i,:] = torch.mean(z[i, :, :, :], 2).squeeze(1)
|
|
loss_std = loss_std/z.shape[0]
|
|
return loss_std, z_mean
|
|
|
|
|
|
def toggle_grad(model, on_or_off):
|
|
for param in model.parameters():
|
|
param.requires_grad = on_or_off
|
|
|
|
|
|
|
|
|
|
|
|
def ortho(model, strength=1e-4, blacklist=[]):
|
|
with torch.no_grad():
|
|
for param in model.parameters():
|
|
|
|
if len(param.shape) < 2 or any([param is item for item in blacklist]):
|
|
continue
|
|
w = param.view(param.shape[0], -1)
|
|
grad = (2 * torch.mm(torch.mm(w, w.t())
|
|
* (1. - torch.eye(w.shape[0], device=w.device)), w))
|
|
param.grad.data += strength * grad.view(param.shape)
|
|
|
|
|
|
|
|
|
|
|
|
def default_ortho(model, strength=1e-4, blacklist=[]):
|
|
with torch.no_grad():
|
|
for param in model.parameters():
|
|
|
|
if len(param.shape) < 2 or param in blacklist:
|
|
continue
|
|
w = param.view(param.shape[0], -1)
|
|
grad = (2 * torch.mm(torch.mm(w, w.t())
|
|
- torch.eye(w.shape[0], device=w.device), w))
|
|
param.grad.data += strength * grad.view(param.shape)
|
|
|
|
|
|
|
|
def toggle_grad(model, on_or_off):
|
|
for param in model.parameters():
|
|
param.requires_grad = on_or_off
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Distribution(torch.Tensor):
|
|
|
|
def init_distribution(self, dist_type, **kwargs):
|
|
seed_rng(kwargs['seed'])
|
|
self.dist_type = dist_type
|
|
self.dist_kwargs = kwargs
|
|
if self.dist_type == 'normal':
|
|
self.mean, self.var = kwargs['mean'], kwargs['var']
|
|
elif self.dist_type == 'categorical':
|
|
self.num_categories = kwargs['num_categories']
|
|
elif self.dist_type == 'poisson':
|
|
self.lam = kwargs['var']
|
|
elif self.dist_type == 'gamma':
|
|
self.scale = kwargs['var']
|
|
|
|
|
|
def sample_(self):
|
|
if self.dist_type == 'normal':
|
|
self.normal_(self.mean, self.var)
|
|
elif self.dist_type == 'categorical':
|
|
self.random_(0, self.num_categories)
|
|
elif self.dist_type == 'poisson':
|
|
type = self.type()
|
|
device = self.device
|
|
data = np.random.poisson(self.lam, self.size())
|
|
self.data = torch.from_numpy(data).type(type).to(device)
|
|
elif self.dist_type == 'gamma':
|
|
type = self.type()
|
|
device = self.device
|
|
data = np.random.gamma(shape=1, scale=self.scale, size=self.size())
|
|
self.data = torch.from_numpy(data).type(type).to(device)
|
|
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
new_obj = Distribution(self)
|
|
new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
|
|
new_obj.data = super().to(*args, **kwargs)
|
|
return new_obj
|
|
|
|
|
|
def to_device(net, gpu_ids):
|
|
if len(gpu_ids) > 0:
|
|
assert(torch.cuda.is_available())
|
|
net.to(gpu_ids[0])
|
|
|
|
if len(gpu_ids)>1:
|
|
net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
|
|
|
|
return net
|
|
|
|
|
|
|
|
def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda',
|
|
fp16=False, z_var=1.0, z_dist='normal', seed=0):
|
|
z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
|
|
z_.init_distribution(z_dist, mean=0, var=z_var, seed=seed)
|
|
z_ = z_.to(device, torch.float16 if fp16 else torch.float32)
|
|
|
|
if fp16:
|
|
z_ = z_.half()
|
|
|
|
y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
|
|
y_.init_distribution('categorical', num_categories=nclasses, seed=seed)
|
|
y_ = y_.to(device, torch.int64)
|
|
return z_, y_
|
|
|
|
|
|
def tensor2im(input_image, imtype=np.uint8):
|
|
""""Converts a Tensor array into a numpy image array.
|
|
|
|
Parameters:
|
|
input_image (tensor) -- the input image tensor array
|
|
imtype (type) -- the desired type of the converted numpy array
|
|
"""
|
|
if not isinstance(input_image, np.ndarray):
|
|
if isinstance(input_image, torch.Tensor):
|
|
image_tensor = input_image.data
|
|
else:
|
|
return input_image
|
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
|
if image_numpy.shape[0] == 1:
|
|
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
|
else:
|
|
image_numpy = input_image
|
|
return image_numpy.astype(imtype)
|
|
|
|
|
|
def diagnose_network(net, name='network'):
|
|
"""Calculate and print the mean of average absolute(gradients)
|
|
|
|
Parameters:
|
|
net (torch network) -- Torch network
|
|
name (str) -- the name of the network
|
|
"""
|
|
mean = 0.0
|
|
count = 0
|
|
for param in net.parameters():
|
|
if param.grad is not None:
|
|
mean += torch.mean(torch.abs(param.grad.data))
|
|
count += 1
|
|
if count > 0:
|
|
mean = mean / count
|
|
print(name)
|
|
print(mean)
|
|
|
|
|
|
def save_image(image_numpy, image_path):
|
|
"""Save a numpy image to the disk
|
|
|
|
Parameters:
|
|
image_numpy (numpy array) -- input numpy array
|
|
image_path (str) -- the path of the image
|
|
"""
|
|
image_pil = Image.fromarray(image_numpy)
|
|
image_pil.save(image_path)
|
|
|
|
|
|
def print_numpy(x, val=True, shp=False):
|
|
"""Print the mean, min, max, median, std, and size of a numpy array
|
|
|
|
Parameters:
|
|
val (bool) -- if print the values of the numpy array
|
|
shp (bool) -- if print the shape of the numpy array
|
|
"""
|
|
x = x.astype(np.float64)
|
|
if shp:
|
|
print('shape,', x.shape)
|
|
if val:
|
|
x = x.flatten()
|
|
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
|
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
|
|
|
|
|
def mkdirs(paths):
|
|
"""create empty directories if they don't exist
|
|
|
|
Parameters:
|
|
paths (str list) -- a list of directory paths
|
|
"""
|
|
if isinstance(paths, list) and not isinstance(paths, str):
|
|
for path in paths:
|
|
mkdir(path)
|
|
else:
|
|
mkdir(paths)
|
|
|
|
|
|
def mkdir(path):
|
|
"""create a single empty directory if it didn't exist
|
|
|
|
Parameters:
|
|
path (str) -- a single directory path
|
|
"""
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|