|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Quantizer(nn.Module):
|
|
def __init__(self, n_e, e_dim, beta):
|
|
super(Quantizer, self).__init__()
|
|
|
|
self.e_dim = e_dim
|
|
self.n_e = n_e
|
|
self.beta = beta
|
|
|
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
|
|
|
def forward(self, z):
|
|
"""
|
|
Inputs the output of the encoder network z and maps it to a discrete
|
|
one-hot vectort that is the index of the closest embedding vector e_j
|
|
z (continuous) -> z_q (discrete)
|
|
:param z (B, seq_len, channel):
|
|
:return z_q:
|
|
"""
|
|
assert z.shape[-1] == self.e_dim
|
|
z_flattened = z.contiguous().view(-1, self.e_dim)
|
|
|
|
|
|
d = (
|
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
|
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
)
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1)
|
|
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
|
|
|
|
|
loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean((z_q.detach() - z) ** 2)
|
|
|
|
|
|
z_q = z + (z_q - z).detach()
|
|
|
|
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
|
|
e_mean = torch.mean(min_encodings, dim=0)
|
|
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
return loss, z_q, min_encoding_indices, perplexity
|
|
|
|
def map2index(self, z):
|
|
"""
|
|
Inputs the output of the encoder network z and maps it to a discrete
|
|
one-hot vectort that is the index of the closest embedding vector e_j
|
|
z (continuous) -> z_q (discrete)
|
|
:param z (B, seq_len, channel):
|
|
:return z_q:
|
|
"""
|
|
assert z.shape[-1] == self.e_dim
|
|
|
|
z_flattened = z.contiguous().view(-1, self.e_dim)
|
|
|
|
|
|
|
|
d = (
|
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
|
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
)
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1)
|
|
return min_encoding_indices.reshape(z.shape[0], -1)
|
|
|
|
def get_codebook_entry(self, indices):
|
|
"""
|
|
|
|
:param indices(B, seq_len):
|
|
:return z_q(B, seq_len, e_dim):
|
|
"""
|
|
index_flattened = indices.view(-1)
|
|
z_q = self.embedding(index_flattened)
|
|
z_q = z_q.view(indices.shape + (self.e_dim,)).contiguous()
|
|
return z_q
|
|
|
|
|
|
class EmbeddingEMA(nn.Module):
|
|
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
|
super(EmbeddingEMA, self).__init__()
|
|
self.decay = decay
|
|
self.eps = eps
|
|
weight = torch.randn(num_tokens, codebook_dim)
|
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
|
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
|
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
|
self.update = True
|
|
|
|
def forward(self, embed_id):
|
|
return F.embedding(embed_id, self.weight)
|
|
|
|
def cluster_size_ema_update(self, new_cluster_size):
|
|
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
|
|
|
def embed_avg_ema_update(self, new_emb_avg):
|
|
self.embed_avg.data.mul_(self.decay).add(new_emb_avg, alpha=1 - self.decay)
|
|
|
|
def weight_update(self, num_tokens):
|
|
n = self.cluster_size.sum()
|
|
smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
|
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
|
self.weight.data.copy_(embed_normalized)
|
|
|
|
|
|
class EMAVectorQuantizer(nn.Module):
|
|
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5):
|
|
super(EMAVectorQuantizer, self).__init__()
|
|
self.codebook_dim = embedding_dim
|
|
self.num_tokens = n_embed
|
|
self.beta = beta
|
|
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
|
|
|
def forward(self, z):
|
|
z_flattened = z.view(-1, self.codebook_dim)
|
|
|
|
d = (
|
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
|
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
)
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1)
|
|
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
|
|
|
min_encodings = F.one_hot(min_encoding_indices, self.num_tokens).type(z.dtype)
|
|
e_mean = torch.mean(min_encodings, dim=0)
|
|
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
|
|
if self.training and self.embedding.update:
|
|
encoding_sum = min_encodings.sum(0)
|
|
embed_sum = min_encodings.transpose(0, 1) @ z_flattened
|
|
|
|
self.embedding.cluster_size_ema_update(encoding_sum)
|
|
self.embedding.embed_avg_ema_update(embed_sum)
|
|
self.embedding.weight_update(self.num_tokens)
|
|
|
|
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
|
|
|
z_q = z + (z_q - z).detach()
|
|
return loss, z_q, min_encoding_indices, perplexity
|
|
|