M3Site / esm /layers /codebook.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class EMACodebook(nn.Module):
def __init__(
self,
n_codes,
embedding_dim,
no_random_restart=True,
restart_thres=1.0,
ema_decay=0.99,
):
super().__init__()
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
self.register_buffer("N", torch.zeros(n_codes))
self.register_buffer("z_avg", self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
self.no_random_restart = no_random_restart
self.restart_thres = restart_thres
self.freeze_codebook = False
self.ema_decay = ema_decay
def reset_parameters(self):
# For meta init
pass
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, t, c]
self._need_init = False
flat_inputs = z.view(-1, self.embedding_dim)
y = self._tile(flat_inputs)
y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z):
# z: [b, t, c]
if self._need_init and self.training and not self.freeze_codebook:
self._init_embeddings(z)
# z is of shape [batch_size, sequence length, channels]
flat_inputs = z.view(-1, self.embedding_dim)
distances = (
(flat_inputs**2).sum(dim=1, keepdim=True)
- 2 * flat_inputs @ self.embeddings.t()
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
) # [bt, c]
encoding_indices = torch.argmin(distances, dim=1)
encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode]
embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c]
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training and not self.freeze_codebook:
assert False, "Not implemented"
embeddings_st = (embeddings - z).detach() + z
return embeddings_st, encoding_indices, commitment_loss
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings
def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor:
return weights @ self.embeddings