|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from einops import rearrange |
|
from vector_quantize_pytorch import ResidualVQ, FSQ |
|
from .nn.quantize import ResidualVectorQuantize as DACResidualVQ |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
def __init__(self, is_discrete: bool = False): |
|
super().__init__() |
|
|
|
self.is_discrete = is_discrete |
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
raise NotImplementedError |
|
|
|
def decode(self, x): |
|
raise NotImplementedError |
|
|
|
|
|
class DiscreteBottleneck(Bottleneck): |
|
def __init__(self, num_quantizers, codebook_size, tokens_id): |
|
super().__init__(is_discrete=True) |
|
|
|
self.num_quantizers = num_quantizers |
|
self.codebook_size = codebook_size |
|
self.tokens_id = tokens_id |
|
|
|
def decode_tokens(self, codes, **kwargs): |
|
raise NotImplementedError |
|
|
|
|
|
class TanhBottleneck(Bottleneck): |
|
def __init__(self): |
|
super().__init__(is_discrete=False) |
|
self.tanh = nn.Tanh() |
|
|
|
def encode(self, x, return_info=False): |
|
info = {} |
|
|
|
x = torch.tanh(x) |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
|
|
@torch.jit.script |
|
def vae_sample_kl(mean, scale): |
|
stdev = nn.functional.softplus(scale) + 1e-4 |
|
var = stdev * stdev |
|
logvar = torch.log(var) |
|
latents = torch.randn_like(mean) * stdev + mean |
|
|
|
kl = (mean * mean + var - logvar - 1).sum(1).mean() |
|
|
|
return latents, kl |
|
|
|
|
|
@torch.jit.script |
|
def vae_sample(mean, scale): |
|
stdev = nn.functional.softplus(scale) + 1e-4 |
|
latents = torch.randn_like(mean) * stdev + mean |
|
return latents |
|
|
|
|
|
class VAEBottleneck(Bottleneck): |
|
def __init__(self): |
|
super().__init__(is_discrete=False) |
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
mean, scale = x.chunk(2, dim=1) |
|
|
|
if return_info: |
|
info = {} |
|
x, kl = vae_sample_kl(mean, scale) |
|
info["kl"] = kl |
|
return x, info |
|
else: |
|
x = vae_sample(mean, scale) |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
|
|
def compute_mean_kernel(x, y): |
|
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] |
|
return torch.exp(-kernel_input).mean() |
|
|
|
|
|
def compute_mmd(latents): |
|
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) |
|
noise = torch.randn_like(latents_reshaped) |
|
|
|
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) |
|
noise_kernel = compute_mean_kernel(noise, noise) |
|
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) |
|
|
|
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel |
|
return mmd.mean() |
|
|
|
|
|
class WassersteinBottleneck(Bottleneck): |
|
def __init__(self, noise_augment_dim: int = 0): |
|
super().__init__(is_discrete=False) |
|
|
|
self.noise_augment_dim = noise_augment_dim |
|
|
|
def encode(self, x, return_info=False): |
|
info = {} |
|
|
|
if self.training and return_info: |
|
mmd = compute_mmd(x) |
|
info["mmd"] = mmd |
|
|
|
if return_info: |
|
return x, info |
|
|
|
return x |
|
|
|
def decode(self, x): |
|
|
|
if self.noise_augment_dim > 0: |
|
noise = torch.randn(x.shape[0], self.noise_augment_dim, |
|
x.shape[-1]).type_as(x) |
|
x = torch.cat([x, noise], dim=1) |
|
|
|
return x |
|
|
|
|
|
class L2Bottleneck(Bottleneck): |
|
def __init__(self): |
|
super().__init__(is_discrete=False) |
|
|
|
def encode(self, x, return_info=False): |
|
info = {} |
|
|
|
x = F.normalize(x, dim=1) |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return F.normalize(x, dim=1) |
|
|
|
|
|
class RVQBottleneck(DiscreteBottleneck): |
|
def __init__(self, **quantizer_kwargs): |
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") |
|
self.quantizer = ResidualVQ(**quantizer_kwargs) |
|
self.num_quantizers = quantizer_kwargs["num_quantizers"] |
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
info = {} |
|
|
|
x = rearrange(x, "b c n -> b n c") |
|
x, indices, loss = self.quantizer(x) |
|
x = rearrange(x, "b n c -> b c n") |
|
|
|
info["quantizer_indices"] = indices |
|
info["quantizer_loss"] = loss.mean() |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
def decode_tokens(self, codes, **kwargs): |
|
latents = self.quantizer.get_outputs_from_indices(codes) |
|
|
|
return self.decode(latents, **kwargs) |
|
|
|
|
|
class RVQVAEBottleneck(DiscreteBottleneck): |
|
def __init__(self, **quantizer_kwargs): |
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") |
|
self.quantizer = ResidualVQ(**quantizer_kwargs) |
|
self.num_quantizers = quantizer_kwargs["num_quantizers"] |
|
|
|
def encode(self, x, return_info=False): |
|
info = {} |
|
|
|
x, kl = vae_sample(*x.chunk(2, dim=1)) |
|
|
|
info["kl"] = kl |
|
|
|
x = rearrange(x, "b c n -> b n c") |
|
x, indices, loss = self.quantizer(x) |
|
x = rearrange(x, "b n c -> b c n") |
|
|
|
info["quantizer_indices"] = indices |
|
info["quantizer_loss"] = loss.mean() |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
def decode_tokens(self, codes, **kwargs): |
|
latents = self.quantizer.get_outputs_from_indices(codes) |
|
|
|
return self.decode(latents, **kwargs) |
|
|
|
|
|
class DACRVQBottleneck(DiscreteBottleneck): |
|
def __init__(self, quantize_on_decode=False, **quantizer_kwargs): |
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") |
|
self.quantizer = DACResidualVQ(**quantizer_kwargs) |
|
self.num_quantizers = quantizer_kwargs["n_codebooks"] |
|
self.quantize_on_decode = quantize_on_decode |
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
info = {} |
|
|
|
info["pre_quantizer"] = x |
|
|
|
if self.quantize_on_decode: |
|
return x, info if return_info else x |
|
|
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) |
|
|
|
output = { |
|
"z": z, |
|
"codes": codes, |
|
"latents": latents, |
|
"vq/commitment_loss": commitment_loss, |
|
"vq/codebook_loss": codebook_loss, |
|
} |
|
|
|
output["vq/commitment_loss"] /= self.num_quantizers |
|
output["vq/codebook_loss"] /= self.num_quantizers |
|
|
|
info.update(output) |
|
|
|
if return_info: |
|
return output["z"], info |
|
|
|
return output["z"] |
|
|
|
def decode(self, x): |
|
|
|
if self.quantize_on_decode: |
|
x = self.quantizer(x)[0] |
|
|
|
return x |
|
|
|
def decode_tokens(self, codes, **kwargs): |
|
latents, _, _ = self.quantizer.from_codes(codes) |
|
|
|
return self.decode(latents, **kwargs) |
|
|
|
|
|
class DACRVQVAEBottleneck(DiscreteBottleneck): |
|
def __init__(self, quantize_on_decode=False, **quantizer_kwargs): |
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") |
|
self.quantizer = DACResidualVQ(**quantizer_kwargs) |
|
self.num_quantizers = quantizer_kwargs["n_codebooks"] |
|
self.quantize_on_decode = quantize_on_decode |
|
|
|
def encode(self, x, return_info=False, n_quantizers: int = None): |
|
info = {} |
|
|
|
mean, scale = x.chunk(2, dim=1) |
|
|
|
x, kl = vae_sample(mean, scale) |
|
|
|
info["pre_quantizer"] = x |
|
info["kl"] = kl |
|
|
|
if self.quantize_on_decode: |
|
return x, info if return_info else x |
|
|
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) |
|
|
|
output = { |
|
"z": z, |
|
"codes": codes, |
|
"latents": latents, |
|
"vq/commitment_loss": commitment_loss, |
|
"vq/codebook_loss": codebook_loss, |
|
} |
|
|
|
output["vq/commitment_loss"] /= self.num_quantizers |
|
output["vq/codebook_loss"] /= self.num_quantizers |
|
|
|
info.update(output) |
|
|
|
if return_info: |
|
return output["z"], info |
|
|
|
return output["z"] |
|
|
|
def decode(self, x): |
|
|
|
if self.quantize_on_decode: |
|
x = self.quantizer(x)[0] |
|
|
|
return x |
|
|
|
def decode_tokens(self, codes, **kwargs): |
|
latents, _, _ = self.quantizer.from_codes(codes) |
|
|
|
return self.decode(latents, **kwargs) |
|
|
|
|
|
class FSQBottleneck(DiscreteBottleneck): |
|
def __init__(self, dim, levels): |
|
super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices") |
|
self.quantizer = FSQ(levels=[levels] * dim) |
|
|
|
def encode(self, x, return_info=False): |
|
info = {} |
|
|
|
x = rearrange(x, "b c n -> b n c") |
|
x, indices = self.quantizer(x) |
|
x = rearrange(x, "b n c -> b c n") |
|
|
|
info["quantizer_indices"] = indices |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
def decode_tokens(self, tokens, **kwargs): |
|
latents = self.quantizer.indices_to_codes(tokens) |
|
|
|
return self.decode(latents, **kwargs) |