File size: 4,978 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""This file contains the definition of the VQ quantizer."""
from typing import Mapping, Text, Tuple
import torch
from einops import rearrange
from .quantizer_utils import entropy_loss_fn
class SimpleVectorizer(torch.nn.Module):
def __init__(
self,
codebook_size: int = 1024,
token_size: int = 256,
commitment_cost: float = 0.25,
entropy_loss_weight: float = 0.0,
entropy_loss_temperature: float = 0.01,
entropy_gamma: float = 1.0,
use_l2_normalisation: bool = False,
):
"""Initializes the quantizer.
Args:
codebook_size -> int: The size of the codebook.
token_size -> int: The feature dimensions of the tokens.
commitment_cost -> float: The commitment cost.
entropy_loss_weight -> float: The weight of the entropy loss.
entropy_loss_temperature -> float: The temperature of the entropy loss.
entropy_gamma -> float: The gamma of the entropy loss.
use_l2_normalisation -> bool: Whether to use L2 normalisation.
"""
super().__init__()
self.commitment_cost = commitment_cost
self.embedding = torch.nn.Embedding(codebook_size, token_size)
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
self.entropy_loss_weight = entropy_loss_weight
self.entropy_loss_temperature = entropy_loss_temperature
self.entropy_gamma = entropy_gamma
self.use_l2_normalisation = use_l2_normalisation
def forward(
self, z: torch.Tensor
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the quantization loss and returns the quantized latent representation.
Args:
z -> torch.Tensor: The latent representation.
Returns:
z_quantized -> torch.Tensor: The quantized latent representation.
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results
and losses from the quantizer.
"""
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
if self.use_l2_normalisation:
z = torch.nn.functional.normalize(z, dim=-1)
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
else:
embedding = self.embedding.weight
z_flattened = rearrange(z, "b h w c -> (b h w) c")
# distances from z to embeddings e_j d = (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(embedding**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
# compute loss for embedding
commitment_loss = self.commitment_cost * torch.mean(
(z_quantized.detach() - z) ** 2
)
codebook_loss = torch.mean((z_quantized - z.detach()) ** 2)
entropy_loss = torch.zeros((), device=z.device)
per_sample_entropy = torch.zeros((), device=z.device)
avg_entropy = torch.zeros((), device=z.device)
# Use entropy loss on the codebook
if self.entropy_loss_weight != 0.0 and self.training:
per_sample_entropy, avg_entropy = entropy_loss_fn(
-1 * d, self.entropy_loss_temperature, self.entropy_gamma
)
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy)
loss = commitment_loss + codebook_loss + entropy_loss
# preserve gradients
z_quantized = z + (z_quantized - z).detach()
# reshape back to match original input shape
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
entropy_loss=entropy_loss,
per_sample_entropy=per_sample_entropy,
avg_entropy=avg_entropy,
min_encoding_indices=min_encoding_indices.view(
z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]
),
)
return z_quantized, result_dict
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
"""Returns the codebook entry for the given indices.
Args:
indices -> torch.Tensor: The indices of the codebook entries.
Returns:
z_quantized -> torch.Tensor: The codebook entries.
"""
# get quantized latent vectors
z_quantized = self.embedding(indices.int())
if self.use_l2_normalisation:
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
return z_quantized
|