Spaces:
Running
Running
File size: 2,909 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class EMACodebook(nn.Module):
def __init__(
self,
n_codes,
embedding_dim,
no_random_restart=True,
restart_thres=1.0,
ema_decay=0.99,
):
super().__init__()
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
self.register_buffer("N", torch.zeros(n_codes))
self.register_buffer("z_avg", self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
self.no_random_restart = no_random_restart
self.restart_thres = restart_thres
self.freeze_codebook = False
self.ema_decay = ema_decay
def reset_parameters(self):
# For meta init
pass
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, t, c]
self._need_init = False
flat_inputs = z.view(-1, self.embedding_dim)
y = self._tile(flat_inputs)
y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z):
# z: [b, t, c]
if self._need_init and self.training and not self.freeze_codebook:
self._init_embeddings(z)
# z is of shape [batch_size, sequence length, channels]
flat_inputs = z.view(-1, self.embedding_dim)
distances = (
(flat_inputs**2).sum(dim=1, keepdim=True)
- 2 * flat_inputs @ self.embeddings.t()
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
) # [bt, c]
encoding_indices = torch.argmin(distances, dim=1)
encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode]
embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c]
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training and not self.freeze_codebook:
assert False, "Not implemented"
embeddings_st = (embeddings - z).detach() + z
return embeddings_st, encoding_indices, commitment_loss
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings
def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor:
return weights @ self.embeddings
|