Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
def adopt_weight(weight, global_step, threshold=0, value=0.): | |
if global_step < threshold: | |
weight = value | |
return weight | |
def hinge_d_loss(logits_real, logits_fake): | |
loss_real = torch.mean(F.relu(1. - logits_real)) | |
loss_fake = torch.mean(F.relu(1. + logits_fake)) | |
d_loss = 0.5 * (loss_real + loss_fake) | |
return d_loss | |
def vanilla_d_loss(logits_real, logits_fake): | |
d_loss = 0.5 * ( | |
torch.mean(torch.nn.functional.softplus(-logits_real)) + | |
torch.mean(torch.nn.functional.softplus(logits_fake))) | |
return d_loss | |
def measure_perplexity(predicted_indices, n_embed): | |
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py | |
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally | |
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) | |
avg_probs = encodings.mean(0) | |
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() | |
cluster_use = torch.sum(avg_probs > 0) | |
return perplexity, cluster_use | |
def l1(x, y): | |
return torch.abs(x - y) | |
def l2(x, y): | |
return torch.pow((x - y), 2) | |
def square_dist_loss(x, y): | |
return torch.sum((x - y) ** 2, dim=1, keepdim=True) | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) |