Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Lookup Free Quantization | |
Proposed in https://arxiv.org/abs/2310.05737 | |
In the simplest setup, each dimension is quantized into {-1, 1}. | |
An entropy penalty is used to encourage utilization. | |
Refer to | |
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py | |
https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py | |
""" | |
""" | |
Modified Open-MAGVIT2 code to use VQConfig. | |
""" | |
from math import log2, ceil | |
from collections import namedtuple | |
import torch | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
from torch.nn import Module | |
from einops import rearrange, reduce, pack, unpack | |
from magvit2.config import VQConfig | |
# constants | |
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs']) | |
# helper functions | |
def exists(v): | |
return v is not None | |
def default(*args): | |
for arg in args: | |
if exists(arg): | |
return arg() if callable(arg) else arg | |
return None | |
def pack_one(t, pattern): | |
return pack([t], pattern) | |
def unpack_one(t, ps, pattern): | |
return unpack(t, ps, pattern)[0] | |
# entropy | |
def entropy(prob): | |
return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) | |
# class | |
def mult_along_first_dims(x, y): | |
""" | |
returns x * y elementwise along the leading dimensions of y | |
""" | |
ndim_to_expand = x.ndim - y.ndim | |
for _ in range(ndim_to_expand): | |
y = y.unsqueeze(-1) | |
return x * y | |
def masked_mean(x, m): | |
""" | |
takes the mean of the elements of x that are not masked | |
the mean is taken along the shared leading dims of m | |
equivalent to: x[m].mean(tuple(range(m.ndim))) | |
The benefit of using masked_mean rather than using | |
tensor indexing is that masked_mean is much faster | |
for torch-compile on batches. | |
The drawback is larger floating point errors | |
""" | |
x = mult_along_first_dims(x, m) | |
x = x / m.sum() | |
return x.sum(tuple(range(m.ndim))) | |
def entropy_loss( | |
logits, | |
mask=None, | |
temperature=0.01, | |
sample_minimization_weight=1.0, | |
batch_maximization_weight=1.0, | |
eps=1e-5, | |
): | |
""" | |
Entropy loss of unnormalized logits | |
logits: Affinities are over the last dimension | |
https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 | |
LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) | |
""" | |
probs = F.softmax(logits / temperature, -1) | |
log_probs = F.log_softmax(logits / temperature + eps, -1) | |
if mask is not None: | |
avg_probs = masked_mean(probs, mask) | |
else: | |
avg_probs = reduce(probs, "... D -> D", "mean") | |
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) | |
sample_entropy = -torch.sum(probs * log_probs, -1) | |
if mask is not None: | |
sample_entropy = masked_mean(sample_entropy, mask).mean() | |
else: | |
sample_entropy = torch.mean(sample_entropy) | |
loss = (sample_minimization_weight * sample_entropy) - ( | |
batch_maximization_weight * avg_entropy | |
) | |
return sample_entropy, avg_entropy, loss | |
class LFQ(Module): | |
def __init__(self, config: VQConfig): | |
super().__init__() | |
# some assert validations | |
assert exists(config.z_channels) or exists(config.codebook_size), \ | |
"either dim or codebook_size must be specified for LFQ" | |
assert not exists(config.codebook_size) or log2(config.codebook_size).is_integer(), \ | |
f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(config.codebook_size))})" | |
self.codebook_size = default(config.codebook_size, lambda: 2 ** dim) | |
self.codebook_dim = int(log2(config.codebook_size)) | |
codebook_dims = self.codebook_dim * config.num_codebooks | |
dim = default(config.z_channels, codebook_dims) | |
has_projections = dim != codebook_dims | |
self.has_projections = has_projections | |
self.dim = dim | |
self.codebook_dim = self.codebook_dim | |
self.num_codebooks = config.num_codebooks | |
# for entropy loss | |
self.sample_minimization_weight = config.sample_minimization_weight | |
self.batch_maximization_weight = config.batch_maximization_weight | |
# for no auxiliary loss, during inference | |
self.token_factorization = config.token_factorization # only utilized in second stage | |
if not self.token_factorization: # for first stage model | |
self.register_buffer('mask', 2 ** torch.arange(self.codebook_dim - 1, -1, -1), persistent=False) | |
else: | |
k = self.codebook_dim // 2 | |
self.register_buffer("mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False) | |
self.register_buffer('zero', torch.tensor(0.), persistent=False) | |
# codes | |
all_codes = torch.arange(config.codebook_size) | |
bits = self.indices_to_bits(all_codes) | |
codebook = bits * 2.0 - 1.0 | |
self.register_buffer('codebook', codebook, persistent=False) | |
def dtype(self): | |
return self.codebook.dtype | |
def indices_to_bits(self, x): | |
""" | |
x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments. | |
returns big endian bits | |
""" | |
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) | |
# x is now big endian bits, the last dimension being the bits | |
x = (x.unsqueeze(-1) & mask) != 0 | |
return x | |
def get_codebook_entry(self, x, bhwc): | |
if self.token_factorization: | |
k = self.codebook_dim // 2 | |
mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long) | |
else: | |
mask = 2 ** torch.arange(self.codebook_dim-1, -1, -1, device=x.device, dtype=torch.long) | |
x = (x.unsqueeze(-1) & mask) != 0 # find its bit representation | |
x = x * 2.0 - 1.0 #back to the float | |
## scale back to the desired shape | |
b, h, w, c = bhwc | |
x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) | |
x = rearrange(x, "b h w c -> b c h w") | |
return x | |
def bits_to_indices(self, bits): | |
""" | |
bits: bool tensor of big endian bits, where the last dimension is the bit dimension | |
returns indices, which are long integers from 0 to self.codebook_size | |
""" | |
assert bits.shape[-1] == self.codebook_dim | |
indices = 2 ** torch.arange( | |
0, | |
self.codebook_dim, | |
1, | |
dtype=torch.long, | |
device=bits.device, | |
) | |
return (bits * indices).sum(-1) | |
def decode(self, x): | |
""" | |
x: ... NH | |
where NH is number of codebook heads | |
A longtensor of codebook indices, containing values from | |
0 to self.codebook_size | |
""" | |
x = self.indices_to_bits(x) | |
# to some sort of float | |
x = x.to(self.dtype) | |
# -1 or 1 | |
x = x * 2 - 1 | |
x = rearrange(x, "... NC Z-> ... (NC Z)") | |
return x | |
def forward( | |
self, | |
x, | |
return_loss_breakdown=False, | |
mask=None, | |
return_loss=True, | |
flip=False, | |
): | |
""" | |
einstein notation | |
b - batch | |
n - sequence (or flattened spatial dimensions) | |
d - feature dimension, which is also log2(codebook size) | |
c - number of codebook dim | |
""" | |
x = rearrange(x, 'b d ... -> b ... d') | |
x, ps = pack_one(x, 'b * d') | |
# split out number of codebooks | |
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks) | |
codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) | |
quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled | |
# calculate indices | |
if self.token_factorization: | |
k = self.codebook_dim // 2 | |
indices_pre = reduce((quantized[..., :k] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") | |
indices_post = reduce((quantized[..., k:] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") | |
# indices_post = 2**k + indices_post #shifter to the 1024 | |
else: | |
if not flip: | |
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') | |
else: | |
# not sure why this is necessary | |
indices = reduce((quantized > 0).flip(-1).int() * self.mask.int(), 'b n c d -> b n c', 'sum') | |
# entropy aux loss | |
if self.training and return_loss: | |
logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) | |
# the same as Euclidean distance up to a constant | |
per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( | |
logits=logits, | |
sample_minimization_weight=self.sample_minimization_weight, | |
batch_maximization_weight=self.batch_maximization_weight | |
) | |
avg_probs = self.zero | |
else: | |
## calculate the codebook_entropy needed for one batch evaluation | |
#------------------------------------------------------------------ | |
# logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) | |
# probs = F.softmax(logits / 0.01, -1) | |
# avg_probs = reduce(probs, "b n c d -> b d", "mean") | |
# avg_probs = torch.sum(avg_probs, 0) #batch dimension | |
#------------------------------------------------------------------- | |
# if not training, just return dummy 0 | |
per_sample_entropy = codebook_entropy = self.zero | |
entropy_aux_loss = self.zero | |
avg_probs = self.zero | |
# commit loss | |
if self.training: | |
commit_loss = F.mse_loss(x, quantized.detach(), reduction='none') | |
if exists(mask): | |
commit_loss = commit_loss[mask] | |
commit_loss = commit_loss.mean() | |
else: | |
commit_loss = self.zero | |
# use straight-through gradients (optionally with custom activation fn) if training | |
quantized = x + (quantized - x).detach() # transfer to quantized | |
# merge back codebook dim | |
quantized = rearrange(quantized, 'b n c d -> b n (c d)') | |
# reconstitute image or video dimensions | |
quantized = unpack_one(quantized, ps, 'b * d') | |
quantized = rearrange(quantized, 'b ... d -> b d ...') | |
if self.token_factorization: | |
indices_pre = unpack_one(indices_pre, ps, "b * c") | |
indices_post = unpack_one(indices_post, ps, "b * c") | |
indices_pre = indices_pre.flatten() | |
indices_post = indices_post.flatten() | |
indices = (indices_pre, indices_post) | |
else: | |
indices = unpack_one(indices, ps, 'b * c') | |
indices = indices.flatten() | |
ret = (quantized, entropy_aux_loss, indices) | |
if not return_loss_breakdown: | |
return ret | |
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs) | |