|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def normalize(in_channels):
|
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
def swish(x):
|
|
return x*torch.sigmoid(x)
|
|
|
|
class VectorQuantizer(nn.Module):
|
|
"""
|
|
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
|
____________________________________________
|
|
Discretization bottleneck part of the VQ-VAE.
|
|
Inputs:
|
|
- n_e : number of embeddings
|
|
- e_dim : dimension of embedding
|
|
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
_____________________________________________
|
|
"""
|
|
|
|
def __init__(self, n_e, e_dim, beta):
|
|
super(VectorQuantizer, self).__init__()
|
|
self.n_e = n_e
|
|
self.e_dim = e_dim
|
|
self.beta = beta
|
|
|
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
|
|
|
def forward(self, z):
|
|
z_flattened = z.view(-1, self.e_dim)
|
|
|
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
|
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
|
torch.matmul(z_flattened, self.embedding.weight.t())
|
|
d1 = torch.sum(z_flattened ** 2, dim=1, keepdim=True)
|
|
d2 = torch.sum(self.embedding.weight**2, dim=1)
|
|
d3 = torch.matmul(z_flattened, self.embedding.weight.t())
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
|
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
|
|
|
|
|
|
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
|
|
|
|
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
|
torch.mean((z_q - z.detach()) ** 2)
|
|
|
|
|
|
z_q = z + (z_q - z).detach()
|
|
|
|
|
|
e_mean = torch.mean(min_encodings, dim=0)
|
|
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
|
|
|
|
z_q = z_q.permute(0, 2, 1).contiguous()
|
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
|
|
|
def get_distance(self, z):
|
|
z = z.permute(0, 2, 1).contiguous()
|
|
z_flattened = z.view(-1, self.e_dim)
|
|
|
|
|
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
|
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
|
torch.matmul(z_flattened, self.embedding.weight.t())
|
|
d = torch.reshape(d, (z.shape[0], -1, z.shape[2])).permute(0,2,1).contiguous()
|
|
return d
|
|
|
|
def get_codebook_entry(self, indices, shape):
|
|
|
|
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
|
min_encodings.scatter_(1, indices[:,None], 1)
|
|
|
|
|
|
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
|
|
if shape is not None:
|
|
z_q = z_q.view(shape)
|
|
|
|
return z_q
|
|
|