Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class EMACodebook(nn.Module): | |
def __init__( | |
self, | |
n_codes, | |
embedding_dim, | |
no_random_restart=True, | |
restart_thres=1.0, | |
ema_decay=0.99, | |
): | |
super().__init__() | |
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) | |
self.register_buffer("N", torch.zeros(n_codes)) | |
self.register_buffer("z_avg", self.embeddings.data.clone()) | |
self.n_codes = n_codes | |
self.embedding_dim = embedding_dim | |
self._need_init = True | |
self.no_random_restart = no_random_restart | |
self.restart_thres = restart_thres | |
self.freeze_codebook = False | |
self.ema_decay = ema_decay | |
def reset_parameters(self): | |
# For meta init | |
pass | |
def _tile(self, x): | |
d, ew = x.shape | |
if d < self.n_codes: | |
n_repeats = (self.n_codes + d - 1) // d | |
std = 0.01 / np.sqrt(ew) | |
x = x.repeat(n_repeats, 1) | |
x = x + torch.randn_like(x) * std | |
return x | |
def _init_embeddings(self, z): | |
# z: [b, t, c] | |
self._need_init = False | |
flat_inputs = z.view(-1, self.embedding_dim) | |
y = self._tile(flat_inputs) | |
y.shape[0] | |
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] | |
if dist.is_initialized(): | |
dist.broadcast(_k_rand, 0) | |
self.embeddings.data.copy_(_k_rand) | |
self.z_avg.data.copy_(_k_rand) | |
self.N.data.copy_(torch.ones(self.n_codes)) | |
def forward(self, z): | |
# z: [b, t, c] | |
if self._need_init and self.training and not self.freeze_codebook: | |
self._init_embeddings(z) | |
# z is of shape [batch_size, sequence length, channels] | |
flat_inputs = z.view(-1, self.embedding_dim) | |
distances = ( | |
(flat_inputs**2).sum(dim=1, keepdim=True) | |
- 2 * flat_inputs @ self.embeddings.t() | |
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) | |
) # [bt, c] | |
encoding_indices = torch.argmin(distances, dim=1) | |
encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode] | |
embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c] | |
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) | |
# EMA codebook update | |
if self.training and not self.freeze_codebook: | |
assert False, "Not implemented" | |
embeddings_st = (embeddings - z).detach() + z | |
return embeddings_st, encoding_indices, commitment_loss | |
def dictionary_lookup(self, encodings): | |
embeddings = F.embedding(encodings, self.embeddings) | |
return embeddings | |
def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor: | |
return weights @ self.embeddings | |