File size: 6,159 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""This file contains the definition of the look-free quantizer."""
from typing import Mapping, Text, Tuple
import torch
from einops import rearrange, reduce
from .quantizer_utils import entropy_loss_fn
class LookupFreeQuantizer(torch.nn.Module):
def __init__(
self,
token_bits: int = 10,
commitment_cost: float = 0.25,
entropy_loss_weight: float = 0.1,
entropy_loss_temperature: float = 0.01,
entropy_gamma: float = 1.0,
):
"""Initializes the lookup-free quantizer.
Args:
token_bits -> int: The number of bits per token.
commitment_cost -> float: The commitment cost.
entropy_loss_weight -> float: The weight of the entropy loss.
entropy_loss_temperature -> float: The temperature for the entropy loss.
entropy_gamma -> float: The gamma for the entropy loss.
"""
super().__init__()
self.token_size = token_bits
self.codebook_size = 2**token_bits
self.commitment_cost = commitment_cost
self.entropy_loss_weight = entropy_loss_weight
self.entropy_loss_temperature = entropy_loss_temperature
self.entropy_gamma = entropy_gamma
bits_to_indices = torch.pow(
2.0, torch.arange(0, self.token_size, dtype=torch.float32)
)
self.register_buffer("bits_to_indices", bits_to_indices.int())
all_codes = torch.arange(self.codebook_size)
bits = ((all_codes[..., None].int() & self.bits_to_indices) != 0).float()
self.register_buffer("codebook", bits * 2.0 - 1.0)
def forward(
self, z: torch.Tensor
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Forward pass of the quantizer.
Args:
z -> torch.Tensor: The input tensor.
Returns:
z_quantized -> torch.Tensor: The quantized latent representation.
result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results
and losses from the quantizer.
"""
z = rearrange(z, "b c h w -> b h w c").contiguous()
ones = torch.ones_like(z)
sign_mask = z > 0.0
z_quantized = torch.where(sign_mask, ones, -ones)
min_encoding_indices = self.convert_bits_to_indices(z_quantized)
# compute loss for embedding
commitment_loss = self.commitment_cost * torch.mean(
(z_quantized.detach() - z) ** 2
)
entropy_loss = torch.zeros((), device=z.device)
per_sample_entropy = torch.zeros((), device=z.device)
avg_entropy = torch.zeros((), device=z.device)
# Use entropy loss on the codebook
if self.entropy_loss_weight != 0.0 and self.training:
d = -2 * torch.einsum("b h w c, n c -> b h w n", z, self.codebook)
per_sample_entropy, avg_entropy = entropy_loss_fn(
-1 * d, self.entropy_loss_temperature, self.entropy_gamma
)
entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy)
loss = commitment_loss + entropy_loss
# preserve gradients
z_quantized = z + (z_quantized - z).detach()
# reshape back to match original input shape
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
entropy_loss=entropy_loss,
per_sample_entropy=per_sample_entropy,
avg_entropy=avg_entropy,
min_encoding_indices=min_encoding_indices,
)
return z_quantized, result_dict
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
"""Returns the `codebook entry` for the given indices.
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
Returns:
tokens -> torch.Tensor: The bit representation.
"""
indices = indices.long()
bits = ((indices[..., None].int() & self.bits_to_indices) != 0).float()
tokens = bits * 2.0 - 1.0 # scale to -1..1
return tokens
def convert_bits_to_indices(self, tokens: torch.Tensor) -> torch.Tensor:
"""Converts the given tokens to index numbers.
As the codebook exists only implicitly, this is mainly an integer conversion from a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
tokens -> torch.Tensor: The tokens.
Returns:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
"""
tokens = rearrange(tokens, "b h w c -> b h w c").contiguous()
sign_mask = tokens > 0.0
return reduce(sign_mask.int() * self.bits_to_indices, "b h w c -> b h w", "sum")
def convert_indices_to_bits(self, indices: torch.Tensor) -> torch.Tensor:
"""Converts the given indices to tokens.
As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation.
Note: The bits are represented by {-1, 1}.
Args:
indices -> torch.Tensor: The indices in range 0 to codebook size - 1.
Returns:
tokens -> torch.Tensor: The bit representation.
"""
indices = indices.long()
return self.get_codebook_entry(indices)
if __name__ == "__main__":
quantizer = LookupFreeQuantizer(
token_bits=10,
commitment_cost=0.25,
entropy_loss_weight=0.1,
entropy_loss_temperature=0.01,
entropy_gamma=1.0,
)
all_entries = torch.arange(1024).reshape(1, 1, 1024)
indices = quantizer.convert_bits_to_indices(
quantizer.convert_indices_to_bits(all_entries)
)
assert torch.equal(indices, all_entries)
assert torch.equal(
quantizer.convert_bits_to_indices(quantizer.codebook.reshape(1, 1, 1024, 10)),
all_entries,
)
|