# compared with `descript_quantize2`, we use rvq & random_dropout from typing import Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) class VectorQuantize(nn.Module): """ Implementation of VQ similar to Karpathy's repo: https://github.com/karpathy/deep-vector-quantization Additionally uses following tricks from Improved VQGAN (https://arxiv.org/pdf/2110.04627.pdf): 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space for improved codebook usage 2. l2-normalized codes: Converts euclidean distance to cosine similarity which improves training stability """ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) self.codebook = nn.Embedding(codebook_size, codebook_dim) self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) self.stale_tolerance = stale_tolerance def forward(self, z): """Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors Parameters ---------- z : Tensor[B x D x T] Returns ------- Tensor[B x D x T] Quantized continuous representation of input Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries Tensor[1] Codebook loss to update the codebook Tensor[B x T] Codebook indices (quantized discrete representation of input) Tensor[B x D x T] Projected latents (continuous representation of input before quantization) """ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space z_e = self.in_proj(z) # z_e : (B x D x T) z_q, indices = self.decode_latents(z_e) commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) z_q = ( z_e + (z_q - z_e).detach() ) # noop in forward pass, straight-through gradient estimator in backward pass z_q = self.out_proj(z_q) return z_q, commitment_loss, codebook_loss, indices, z_e def embed_code(self, embed_id): return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id): return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents): encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight # codebook: (N x D) # L2 normalize encodings and codebook (ViT-VQGAN) encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance with codebook dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) z_q = self.decode_code(indices) if(self.training): onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size stale_codes = (onehots.sum(0).sum(0) == 0).float() self.stale_counter = self.stale_counter * stale_codes + stale_codes # random replace codes that haven't been used for a while replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size if replace_code.sum(-1) > 0: print("Replace {} codes".format(replace_code.sum(-1))) random_input_idx = torch.randperm(encodings.shape[0]) random_input = encodings[random_input_idx].view(encodings.shape) if random_input.shape[0] < self.codebook_size: random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) self.stale_counter = self.stale_counter * (1 - replace_code) return z_q, indices class ResidualVectorQuantize(nn.Module): """ Introduced in SoundStream: An end2end neural audio codec https://arxiv.org/abs/2107.03312 """ def __init__( self, input_dim: int = 512, n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: Union[int, list] = 8, quantizer_dropout: float = 0.0, stale_tolerance: int = 100, ): super().__init__() if isinstance(codebook_dim, int): codebook_dim = [codebook_dim for _ in range(n_codebooks)] self.n_codebooks = n_codebooks self.codebook_dim = codebook_dim self.codebook_size = codebook_size self.quantizers = nn.ModuleList( [ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) for i in range(n_codebooks) ] ) self.quantizer_dropout = quantizer_dropout def forward(self, z, n_quantizers: int = None): """Quantized the input tensor using a fixed set of `n` codebooks and returns the corresponding codebook vectors Parameters ---------- z : Tensor[B x D x T] n_quantizers : int, optional No. of quantizers to use (n_quantizers < self.n_codebooks ex: for quantizer dropout) Note: if `self.quantizer_dropout` is True, this argument is ignored when in training mode, and a random number of quantizers is used. Returns ------- dict A dictionary with the following keys: "z" : Tensor[B x D x T] Quantized continuous representation of input "codes" : Tensor[B x N x T] Codebook indices for each codebook (quantized discrete representation of input) "latents" : Tensor[B x N*D x T] Projected latents (continuous representation of input before quantization) "vq/commitment_loss" : Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries "vq/codebook_loss" : Tensor[1] Codebook loss to update the codebook """ z_q = 0 residual = z commitment_loss = 0 codebook_loss = 0 codebook_indices = [] latents = [] if n_quantizers is None: n_quantizers = self.n_codebooks if self.training: n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) n_dropout = int(z.shape[0] * self.quantizer_dropout) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(z.device) else: n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 n_quantizers = n_quantizers.to(z.device) for i, quantizer in enumerate(self.quantizers): # if self.training is False and i >= n_quantizers: # break z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( residual ) # Create mask to apply quantizer dropout mask = ( torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers ) z_q = z_q + z_q_i * mask[:, None, None] residual = residual - z_q_i # Sum losses commitment_loss += (commitment_loss_i * mask).mean() codebook_loss += (codebook_loss_i * mask).mean() codebook_indices.append(indices_i) latents.append(z_e_i) codes = torch.stack(codebook_indices, dim=1) latents = torch.cat(latents, dim=1) encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 # for n in range(encodings.shape[1]): # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. # )) return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 def from_codes(self, codes: torch.Tensor): """Given the quantized codes, reconstruct the continuous representation Parameters ---------- codes : Tensor[B x N x T] Quantized discrete representation of input Returns ------- Tensor[B x D x T] Quantized continuous representation of input """ z_q = 0.0 z_p = [] n_codebooks = codes.shape[1] for i in range(n_codebooks): z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) z_p.append(z_p_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), codes def from_latents(self, latents: torch.Tensor): """Given the unquantized latents, reconstruct the continuous representation after quantization. Parameters ---------- latents : Tensor[B x N x T] Continuous representation of input after projection Returns ------- Tensor[B x D x T] Quantized representation of full-projected space Tensor[B x D x T] Quantized representation of latent space """ z_q = 0 z_p = [] codes = [] dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 0 ] for i in range(n_codebooks): j, k = dims[i], dims[i + 1] z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) z_p.append(z_p_i) codes.append(codes_i) z_q_i = self.quantizers[i].out_proj(z_p_i) z_q = z_q + z_q_i return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) if __name__ == "__main__": rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) x = torch.randn(16, 1024, 80) quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) print(quantized_prompt_embeds.shape) print(codes.shape) # w/o reconstruction loss = commitment_loss * 0.25 + codebook_loss * 1.0 # w/ reconstruction loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()