|
""" |
|
Lookup Free Quantization |
|
Proposed in https://arxiv.org/abs/2310.05737 |
|
|
|
basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss |
|
https://arxiv.org/abs/2309.15505 |
|
""" |
|
|
|
import torch |
|
from einops import rearrange |
|
from torch.nn import Module |
|
|
|
|
|
|
|
|
|
def binary_entropy(prob): |
|
return -prob * log(prob) - (1 - prob) * log(1 - prob) |
|
|
|
|
|
|
|
|
|
def log(t, eps=1e-20): |
|
return t.clamp(min=eps).log() |
|
|
|
|
|
|
|
|
|
def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor: |
|
|
|
mask = 2 ** torch.arange(bits).to(x) |
|
bits = ((x.unsqueeze(-1) & mask) != 0).float() |
|
return bits * 2 - 1 |
|
|
|
|
|
def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor: |
|
|
|
x = (x > 0).long() |
|
mask = 2 ** torch.arange(x.size(-1)).to(x) |
|
dec = (x * mask).sum(-1) |
|
return dec |
|
|
|
|
|
|
|
|
|
class LFQY(Module): |
|
def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0): |
|
super().__init__() |
|
self.dim = dim |
|
self.diversity_gamma = diversity_gamma |
|
self.entropy_loss_weight = entropy_loss_weight |
|
|
|
def indices_to_codes(self, indices): |
|
codes = decimal_to_bits(indices, self.dim) |
|
|
|
return codes |
|
|
|
def forward(self, x, mask=None, inv_temperature=1.): |
|
""" |
|
einstein notation |
|
b - batch |
|
n - sequence (or flattened spatial dimensions) |
|
d - feature dimension, which is also log2(codebook size) |
|
""" |
|
|
|
|
|
assert x.shape[-1] == self.dim |
|
z = torch.tanh(x / inv_temperature) |
|
|
|
|
|
quantized = torch.sign(x) |
|
z = z + (quantized - z).detach() |
|
|
|
|
|
indices = bits_to_decimal(z) |
|
|
|
|
|
if self.training: |
|
prob = torch.sigmoid(x / inv_temperature) |
|
|
|
bit_entropy = binary_entropy(prob).sum(-1).mean() |
|
|
|
|
|
avg_prob = prob.flatten(0, -2).mean(0) |
|
codebook_entropy = binary_entropy(avg_prob).sum() |
|
|
|
|
|
""" |
|
1. entropy will be nudged to be low for each bit, |
|
so each scalar commits to one latent binary bit or the other. |
|
2. codebook entropy will be nudged to be high, |
|
to encourage all codes to be uniformly used. |
|
""" |
|
|
|
entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy |
|
else: |
|
|
|
entropy_aux_loss = torch.zeros(1).to(z) |
|
|
|
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
return z, entropy_aux_loss, indices |
|
|
|
def get_codebook_entry(self, encoding_indices): |
|
return self.indices_to_codes(encoding_indices) |
|
|