SongGeneration / codeclm /tokenizer /Flow1dVAE /libs /rvq /descript_quantize3_4layer_freezelayer1.py
hainazhu
Add application file
258fd02
# compared with `descript_quantize2`, we use rvq & random_dropout
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
import random
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
self.stale_tolerance = stale_tolerance
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
if(self.training):
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
stale_codes = (onehots.sum(0).sum(0) == 0).float()
self.stale_counter = self.stale_counter * stale_codes + stale_codes
# random replace codes that haven't been used for a while
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
if replace_code.sum(-1) > 0:
print("Replace {} codes".format(replace_code.sum(-1)))
random_input_idx = torch.randperm(encodings.shape[0])
random_input = encodings[random_input_idx].view(encodings.shape)
if random_input.shape[0] < self.codebook_size:
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
self.stale_counter = self.stale_counter * (1 - replace_code)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
stale_tolerance: int = 100,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
random_num = random.random()
if random_num<0.6:
n_quantizers = torch.ones((z.shape[0],)) * 2
else:
n_quantizers = torch.ones((z.shape[0],)) * 4
n_quantizers = n_quantizers.to(z.device)
else:
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
# if self.training is False and i >= n_quantizers:
# break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
# for n in range(encodings.shape[1]):
# print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
# (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
# ))
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
0
]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
if __name__ == "__main__":
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
x = torch.randn(16, 1024, 80)
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
print(quantized_prompt_embeds.shape)
print(codes.shape)
# w/o reconstruction
loss = commitment_loss * 0.25 + codebook_loss * 1.0
# w/ reconstruction
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()