Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru | |
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto | |
# and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from functools import partial | |
from math import ceil | |
from random import randrange | |
from typing import Callable | |
import torch | |
import torch.distributed as distributed | |
import torch.nn.functional as F # noqa: N812 | |
from einops import pack, rearrange, reduce, repeat, unpack | |
from torch import einsum, nn | |
from torch.cuda.amp import autocast | |
from torch.optim import Optimizer | |
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig | |
# ruff: noqa: N806 | |
""" | |
This file is part of a VQ-BeT that utilizes code from the following repositories: | |
- Vector Quantize PyTorch code is licensed under the MIT License: | |
Original source: https://github.com/lucidrains/vector-quantize-pytorch | |
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. | |
Original source: https://github.com/karpathy/nanoGPT | |
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below. | |
""" | |
""" | |
This is a part for nanoGPT that utilizes code from the following repository: | |
- Andrej Karpathy's nanoGPT implementation in PyTorch. | |
Original source: https://github.com/karpathy/nanoGPT | |
- The nanoGPT code is licensed under the MIT License: | |
MIT License | |
Copyright (c) 2022 Andrej Karpathy | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
- We've made some changes to the original code to adapt it to our needs. | |
Changed variable names: | |
- n_head -> gpt_n_head | |
- n_embd -> gpt_hidden_dim | |
- block_size -> gpt_block_size | |
- n_layer -> gpt_n_layer | |
class GPT(nn.Module): | |
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained` | |
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop. | |
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads). | |
""" | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
assert config.gpt_hidden_dim % config.gpt_n_head == 0 | |
# key, query, value projections for all heads, but in a batch | |
self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim) | |
# output projection | |
self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim) | |
# regularization | |
self.attn_dropout = nn.Dropout(config.dropout) | |
self.resid_dropout = nn.Dropout(config.dropout) | |
# causal mask to ensure that attention is only applied to the left in the input sequence | |
self.register_buffer( | |
"bias", | |
torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view( | |
1, 1, config.gpt_block_size, config.gpt_block_size | |
), | |
) | |
self.gpt_n_head = config.gpt_n_head | |
self.gpt_hidden_dim = config.gpt_hidden_dim | |
def forward(self, x): | |
( | |
B, | |
T, | |
C, | |
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_hidden_dim) | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) | |
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) | |
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) | |
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) | |
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) | |
att = F.softmax(att, dim=-1) | |
att = self.attn_dropout(att) | |
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
# output projection | |
y = self.resid_dropout(self.c_proj(y)) | |
return y | |
class Block(nn.Module): | |
# causual self-attention block for GPT | |
def __init__(self, config): | |
super().__init__() | |
self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim) | |
self.attn = CausalSelfAttention(config) | |
self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim) | |
self.mlp = nn.Sequential( | |
nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim), | |
nn.GELU(), | |
nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim), | |
nn.Dropout(config.dropout), | |
) | |
def forward(self, x): | |
x = x + self.attn(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
class GPT(nn.Module): | |
""" | |
Original comments: | |
Full definition of a GPT Language Model, all of it in this single file. | |
References: | |
1) the official GPT-2 TensorFlow implementation released by OpenAI: | |
https://github.com/openai/gpt-2/blob/master/src/model.py | |
2) huggingface/transformers PyTorch implementation: | |
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py | |
""" | |
def __init__(self, config: VQBeTConfig): | |
""" | |
GPT model gets hyperparameters from a config object. Please refer configuration_vqbet.py for more details. | |
""" | |
super().__init__() | |
assert config.gpt_output_dim is not None | |
assert config.gpt_block_size is not None | |
self.config = config | |
self.transformer = nn.ModuleDict( | |
{ | |
"wte": nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim), | |
"wpe": nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim), | |
"drop": nn.Dropout(config.dropout), | |
"h": nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]), | |
"ln_f": nn.LayerNorm(config.gpt_hidden_dim), | |
} | |
) | |
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) | |
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper | |
self.apply(self._init_weights) | |
for pn, p in self.named_parameters(): | |
if pn.endswith("c_proj.weight"): | |
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) | |
# report number of parameters | |
n_params = sum(p.numel() for p in self.parameters()) | |
print("number of parameters: {:.2f}M".format(n_params / 1e6)) | |
def forward(self, input, targets=None): | |
device = input.device | |
b, t, d = input.size() | |
assert t <= self.config.gpt_block_size, ( | |
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" | |
) | |
# positional encodings that are added to the input embeddings | |
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) | |
# forward the GPT model itself | |
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) | |
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) | |
x = self.transformer.drop(tok_emb + pos_emb) | |
for block in self.transformer.h: | |
x = block(x) | |
x = self.transformer.ln_f(x) | |
logits = self.lm_head(x) | |
return logits | |
def _init_weights(self, module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
if module.bias is not None: | |
torch.nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Embedding): | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
elif isinstance(module, nn.LayerNorm): | |
torch.nn.init.zeros_(module.bias) | |
torch.nn.init.ones_(module.weight) | |
def crop_block_size(self, gpt_block_size): | |
# model surgery to decrease the block size if necessary | |
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) | |
# but want to use a smaller block size for some smaller, simpler model | |
assert gpt_block_size <= self.config.gpt_block_size | |
self.config.gpt_block_size = gpt_block_size | |
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) | |
for block in self.transformer.h: | |
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] | |
def configure_parameters(self): | |
""" | |
This long function is unfortunately doing something very simple and is being very defensive: | |
We are separating out all parameters of the model into two buckets: those that will experience | |
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). | |
""" | |
# separate out all parameters to those that will and won't experience regularizing weight decay | |
decay = set() | |
no_decay = set() | |
whitelist_weight_modules = (torch.nn.Linear,) | |
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
for mn, m in self.named_modules(): | |
for pn, _p in m.named_parameters(): | |
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name | |
if pn.endswith("bias"): | |
# all biases will not be decayed | |
no_decay.add(fpn) | |
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): | |
# weights of whitelist modules will be weight decayed | |
decay.add(fpn) | |
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): | |
# weights of blacklist modules will NOT be weight decayed | |
no_decay.add(fpn) | |
# validate that we considered every parameter | |
param_dict = dict(self.named_parameters()) | |
inter_params = decay & no_decay | |
union_params = decay | no_decay | |
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( | |
str(inter_params) | |
) | |
assert len(param_dict.keys() - union_params) == 0, ( | |
"parameters {} were not separated into either decay/no_decay set!".format( | |
str(param_dict.keys() - union_params), | |
) | |
) | |
decay = [param_dict[pn] for pn in sorted(decay)] | |
no_decay = [param_dict[pn] for pn in sorted(no_decay)] | |
# return the parameters that require weight decay, and the parameters that don't separately. | |
return decay, no_decay | |
""" | |
This file is a part for Residual Vector Quantization that utilizes code from the following repository: | |
- Phil Wang's vector-quantize-pytorch implementation in PyTorch. | |
Original source: https://github.com/lucidrains/vector-quantize-pytorch | |
- The vector-quantize-pytorch code is licensed under the MIT License: | |
MIT License | |
Copyright (c) 2020 Phil Wang | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
- We've made some changes to the original code to adapt it to our needs. | |
class ResidualVQ(nn.Module): | |
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method: | |
This enables the user to save an indicator whether the codebook is frozen or not. | |
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: | |
This is to make the function name more descriptive. | |
class VectorQuantize(nn.Module): | |
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method: | |
These parameters are not used in the code. | |
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: | |
This is to make the function name more descriptive. | |
""" | |
class ResidualVQ(nn.Module): | |
""" | |
Residual VQ is composed of multiple VectorQuantize layers. | |
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf | |
"Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is | |
passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional | |
Nq -1 vector quantizers, as described in Algorithm 1." | |
self.project_in: function for projecting input to codebook dimension | |
self.project_out: function for projecting codebook dimension to output dimension | |
self.layers: nn.ModuleList of VectorQuantize layers that contains Nq layers of VQ as described in the paper. | |
self.freeze_codebook: buffer to save an indicator whether the codebook is frozen or not. VQ-BeT will check this to determine whether to update the codebook or not. | |
""" | |
def __init__( | |
self, | |
*, | |
dim, | |
num_quantizers, | |
codebook_dim=None, | |
shared_codebook=False, | |
heads=1, | |
quantize_dropout=False, | |
quantize_dropout_cutoff_index=0, | |
quantize_dropout_multiple_of=1, | |
accept_image_fmap=False, | |
**kwargs, | |
): | |
super().__init__() | |
assert heads == 1, "residual vq is not compatible with multi-headed codes" | |
codebook_dim = codebook_dim if (codebook_dim is not None) else dim | |
codebook_input_dim = codebook_dim * heads | |
requires_projection = codebook_input_dim != dim | |
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() | |
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() | |
self.num_quantizers = num_quantizers | |
self.accept_image_fmap = accept_image_fmap | |
self.layers = nn.ModuleList( | |
[ | |
VectorQuantize( | |
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs | |
) | |
for _ in range(num_quantizers) | |
] | |
) | |
self.quantize_dropout = quantize_dropout and num_quantizers > 1 | |
assert quantize_dropout_cutoff_index >= 0 | |
self.register_buffer("freeze_codebook", torch.tensor(False)) | |
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index | |
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 | |
if not shared_codebook: | |
return | |
first_vq, *rest_vq = self.layers | |
codebook = first_vq._codebook | |
for vq in rest_vq: | |
vq._codebook = codebook | |
def codebooks(self): | |
codebooks = [layer._codebook.embed for layer in self.layers] | |
codebooks = torch.stack(codebooks, dim=0) | |
codebooks = rearrange(codebooks, "q 1 c d -> q c d") | |
return codebooks | |
def get_codebook_vector_from_indices(self, indices): | |
# this function will return the codes from all codebooks across layers corresponding to the indices | |
batch, quantize_dim = indices.shape[0], indices.shape[-1] | |
# may also receive indices in the shape of 'b h w q' (accept_image_fmap) | |
indices, ps = pack([indices], "b * q") | |
# because of quantize dropout, one can pass in indices that are coarse | |
# and the network should be able to reconstruct | |
if quantize_dim < self.num_quantizers: | |
assert self.quantize_dropout > 0.0, ( | |
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" | |
) | |
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) | |
# get ready for gathering | |
codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch) | |
gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1]) | |
# take care of quantizer dropout | |
mask = gather_indices == -1.0 | |
gather_indices = gather_indices.masked_fill( | |
mask, 0 | |
) # have it fetch a dummy code to be masked out later | |
all_codes = codebooks.gather(2, gather_indices) # gather all codes | |
# mask out any codes that were dropout-ed | |
all_codes = all_codes.masked_fill(mask, 0.0) | |
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) | |
(all_codes,) = unpack(all_codes, ps, "q b * d") | |
return all_codes | |
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): | |
""" | |
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. | |
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. | |
The residual value of each layer is fed to the next layer. | |
""" | |
num_quant, quant_dropout_multiple_of, return_loss, device = ( | |
self.num_quantizers, | |
self.quantize_dropout_multiple_of, | |
(indices is not None), | |
x.device, | |
) | |
x = self.project_in(x) | |
assert not (self.accept_image_fmap and (indices is not None)) | |
quantized_out = 0.0 | |
residual = x | |
all_losses = [] | |
all_indices = [] | |
if return_loss: | |
assert not torch.any(indices == -1), ( | |
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" | |
) | |
ce_losses = [] | |
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss | |
# sample a layer index at which to dropout further residual quantization | |
# also prepare null indices and loss | |
if should_quantize_dropout: | |
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) | |
if quant_dropout_multiple_of != 1: | |
rand_quantize_dropout_index = ( | |
ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of) | |
* quant_dropout_multiple_of | |
- 1 | |
) | |
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) | |
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) | |
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) | |
# go through the layers | |
for quantizer_index, layer in enumerate(self.layers): | |
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: | |
all_indices.append(null_indices) | |
all_losses.append(null_loss) | |
continue | |
layer_indices = None | |
if return_loss: | |
layer_indices = indices[..., quantizer_index] | |
quantized, *rest = layer( | |
residual, | |
indices=layer_indices, | |
sample_codebook_temp=sample_codebook_temp, | |
freeze_codebook=self.freeze_codebook, | |
) | |
residual = residual - quantized.detach() | |
quantized_out = quantized_out + quantized | |
if return_loss: | |
ce_loss = rest[0] | |
ce_losses.append(ce_loss) | |
continue | |
embed_indices, loss = rest | |
all_indices.append(embed_indices) | |
all_losses.append(loss) | |
# project out, if needed | |
quantized_out = self.project_out(quantized_out) | |
# whether to early return the cross entropy loss | |
if return_loss: | |
return quantized_out, sum(ce_losses) | |
# stack all losses and indices | |
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) | |
ret = (quantized_out, all_indices, all_losses) | |
if return_all_codes: | |
# whether to return all codes from all codebooks across layers | |
all_codes = self.get_codebook_vector_from_indices(all_indices) | |
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension) | |
ret = (*ret, all_codes) | |
return ret | |
class VectorQuantize(nn.Module): | |
def __init__( | |
self, | |
dim, | |
codebook_size, | |
codebook_dim=None, | |
heads=1, | |
separate_codebook_per_head=False, | |
decay=0.8, | |
eps=1e-5, | |
kmeans_init=False, | |
kmeans_iters=10, | |
sync_kmeans=True, | |
threshold_ema_dead_code=0, | |
channel_last=True, | |
accept_image_fmap=False, | |
commitment_weight=1.0, | |
commitment_use_cross_entropy_loss=False, | |
orthogonal_reg_weight=0.0, | |
orthogonal_reg_active_codes_only=False, | |
orthogonal_reg_max_codes=None, | |
stochastic_sample_codes=False, | |
sample_codebook_temp=1.0, | |
straight_through=False, | |
reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all | |
sync_codebook=None, | |
sync_affine_param=False, | |
ema_update=True, | |
learnable_codebook=False, | |
in_place_codebook_optimizer: Callable[ | |
..., Optimizer | |
] = None, # Optimizer used to update the codebook embedding if using learnable_codebook | |
affine_param=False, | |
affine_param_batch_decay=0.99, | |
affine_param_codebook_decay=0.9, | |
sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf | |
): | |
super().__init__() | |
self.dim = dim | |
self.heads = heads | |
self.separate_codebook_per_head = separate_codebook_per_head | |
codebook_dim = codebook_dim if (codebook_dim is not None) else dim | |
codebook_input_dim = codebook_dim * heads | |
requires_projection = codebook_input_dim != dim | |
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() | |
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() | |
self.eps = eps | |
self.commitment_weight = commitment_weight | |
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss | |
self.learnable_codebook = learnable_codebook | |
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 | |
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss | |
self.orthogonal_reg_weight = orthogonal_reg_weight | |
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only | |
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes | |
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" | |
assert 0 <= sync_update_v <= 1.0 | |
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" | |
self.sync_update_v = sync_update_v | |
gumbel_sample_fn = partial( | |
gumbel_sample, | |
stochastic=stochastic_sample_codes, | |
reinmax=reinmax, | |
straight_through=straight_through, | |
) | |
if sync_codebook is None: | |
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 | |
codebook_kwargs = { | |
"dim": codebook_dim, | |
"num_codebooks": heads if separate_codebook_per_head else 1, | |
"codebook_size": codebook_size, | |
"kmeans_init": kmeans_init, | |
"kmeans_iters": kmeans_iters, | |
"sync_kmeans": sync_kmeans, | |
"decay": decay, | |
"eps": eps, | |
"threshold_ema_dead_code": threshold_ema_dead_code, | |
"use_ddp": sync_codebook, | |
"learnable_codebook": has_codebook_orthogonal_loss or learnable_codebook, | |
"sample_codebook_temp": sample_codebook_temp, | |
"gumbel_sample": gumbel_sample_fn, | |
"ema_update": ema_update, | |
} | |
if affine_param: | |
codebook_kwargs = dict( | |
**codebook_kwargs, | |
affine_param=True, | |
sync_affine_param=sync_affine_param, | |
affine_param_batch_decay=affine_param_batch_decay, | |
affine_param_codebook_decay=affine_param_codebook_decay, | |
) | |
self._codebook = EuclideanCodebook(**codebook_kwargs) | |
self.in_place_codebook_optimizer = ( | |
in_place_codebook_optimizer(self._codebook.parameters()) | |
if (in_place_codebook_optimizer is not None) | |
else None | |
) | |
self.codebook_size = codebook_size | |
self.accept_image_fmap = accept_image_fmap | |
self.channel_last = channel_last | |
def codebook(self): | |
codebook = self._codebook.embed | |
if self.separate_codebook_per_head: | |
return codebook | |
return rearrange(codebook, "1 ... -> ...") | |
def codebook(self, codes): | |
if not self.separate_codebook_per_head: | |
codes = rearrange(codes, "... -> 1 ...") | |
self._codebook.embed.copy_(codes) | |
def get_codebook_vector_from_indices(self, indices): | |
codebook = self.codebook | |
is_multiheaded = codebook.ndim > 2 | |
if not is_multiheaded: | |
codes = codebook[indices] | |
return rearrange(codes, "... h d -> ... (h d)") | |
indices, ps = pack_one(indices, "b * h") | |
indices = rearrange(indices, "b n h -> b h n") | |
indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1]) | |
codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0]) | |
codes = codebook.gather(2, indices) | |
codes = rearrange(codes, "b h n d -> b n (h d)") | |
codes = unpack_one(codes, ps, "b * d") | |
return codes | |
def forward( | |
self, | |
x, | |
indices=None, | |
mask=None, | |
sample_codebook_temp=None, | |
freeze_codebook=False, | |
): | |
orig_input = x | |
only_one = x.ndim == 2 | |
if only_one: | |
assert mask is None | |
x = rearrange(x, "b d -> b 1 d") | |
shape, device, heads, is_multiheaded, _codebook_size, return_loss = ( | |
x.shape, | |
x.device, | |
self.heads, | |
self.heads > 1, | |
self.codebook_size, | |
(indices is not None), | |
) | |
need_transpose = not self.channel_last and not self.accept_image_fmap | |
should_inplace_optimize = self.in_place_codebook_optimizer is not None | |
# rearrange inputs | |
if self.accept_image_fmap: | |
height, width = x.shape[-2:] | |
x = rearrange(x, "b c h w -> b (h w) c") | |
if need_transpose: | |
x = rearrange(x, "b d n -> b n d") | |
# project input | |
x = self.project_in(x) | |
# handle multi-headed separate codebooks | |
if is_multiheaded: | |
ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d" | |
x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads) | |
# l2norm for cosine sim, otherwise identity | |
x = self._codebook.transform_input(x) | |
# codebook forward kwargs | |
codebook_forward_kwargs = { | |
"sample_codebook_temp": sample_codebook_temp, | |
"mask": mask, | |
"freeze_codebook": freeze_codebook, | |
} | |
# quantize | |
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) | |
# one step in-place update | |
if should_inplace_optimize and self.training and not freeze_codebook: | |
if mask is not None: | |
loss = F.mse_loss(quantize, x.detach(), reduction="none") | |
loss_mask = mask | |
if is_multiheaded: | |
loss_mask = repeat( | |
mask, | |
"b n -> c (b h) n", | |
c=loss.shape[0], | |
h=loss.shape[1] // mask.shape[0], | |
) | |
loss = loss[loss_mask].mean() | |
else: | |
loss = F.mse_loss(quantize, x.detach()) | |
loss.backward() | |
self.in_place_codebook_optimizer.step() | |
self.in_place_codebook_optimizer.zero_grad() | |
# quantize again | |
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) | |
if self.training: | |
# determine code to use for commitment loss | |
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity | |
commit_quantize = maybe_detach(quantize) | |
# straight through | |
quantize = x + (quantize - x).detach() | |
if self.sync_update_v > 0.0: | |
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf | |
quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) | |
# function for calculating cross entropy loss to distance matrix | |
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss | |
def calculate_ce_loss(codes): | |
if not is_multiheaded: | |
dist_einops_eq = "1 b n l -> b l n" | |
elif self.separate_codebook_per_head: | |
dist_einops_eq = "c b n l -> b l n c" | |
else: | |
dist_einops_eq = "1 (b h) n l -> b l n h" | |
ce_loss = F.cross_entropy( | |
rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1 | |
) | |
return ce_loss | |
# if returning cross entropy loss on codes that were passed in | |
if return_loss: | |
return quantize, calculate_ce_loss(indices) | |
# transform embedding indices | |
if is_multiheaded: | |
if self.separate_codebook_per_head: | |
embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads) | |
else: | |
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) | |
if self.accept_image_fmap: | |
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) | |
if only_one: | |
embed_ind = rearrange(embed_ind, "b 1 -> b") | |
# aggregate loss | |
loss = torch.tensor([0.0], device=device, requires_grad=self.training) | |
if self.training: | |
if self.commitment_weight > 0: | |
if self.commitment_use_cross_entropy_loss: | |
if mask is not None: | |
ce_loss_mask = mask | |
if is_multiheaded: | |
ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads) | |
embed_ind.masked_fill_(~ce_loss_mask, -1) | |
commit_loss = calculate_ce_loss(embed_ind) | |
else: | |
if mask is not None: | |
# with variable lengthed sequences | |
commit_loss = F.mse_loss(commit_quantize, x, reduction="none") | |
loss_mask = mask | |
if is_multiheaded: | |
loss_mask = repeat( | |
loss_mask, | |
"b n -> c (b h) n", | |
c=commit_loss.shape[0], | |
h=commit_loss.shape[1] // mask.shape[0], | |
) | |
commit_loss = commit_loss[loss_mask].mean() | |
else: | |
commit_loss = F.mse_loss(commit_quantize, x) | |
loss = loss + commit_loss * self.commitment_weight | |
if self.has_codebook_orthogonal_loss: | |
codebook = self._codebook.embed | |
# only calculate orthogonal loss for the activated codes for this batch | |
if self.orthogonal_reg_active_codes_only: | |
assert not (is_multiheaded and self.separate_codebook_per_head), ( | |
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" | |
) | |
unique_code_ids = torch.unique(embed_ind) | |
codebook = codebook[:, unique_code_ids] | |
num_codes = codebook.shape[-2] | |
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: | |
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] | |
codebook = codebook[:, rand_ids] | |
orthogonal_reg_loss = orthogonal_loss_fn(codebook) | |
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight | |
# handle multi-headed quantized embeddings | |
if is_multiheaded: | |
if self.separate_codebook_per_head: | |
quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads) | |
else: | |
quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads) | |
# project out | |
quantize = self.project_out(quantize) | |
# rearrange quantized embeddings | |
if need_transpose: | |
quantize = rearrange(quantize, "b n d -> b d n") | |
if self.accept_image_fmap: | |
quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width) | |
if only_one: | |
quantize = rearrange(quantize, "b 1 d -> b d") | |
# if masking, only return quantized for where mask has True | |
if mask is not None: | |
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) | |
return quantize, embed_ind, loss | |
def noop(*args, **kwargs): | |
pass | |
def identity(t): | |
return t | |
def cdist(x, y): | |
x2 = reduce(x**2, "b n d -> b n", "sum") | |
y2 = reduce(y**2, "b n d -> b n", "sum") | |
xy = einsum("b i d, b j d -> b i j", x, y) * -2 | |
return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt() | |
def log(t, eps=1e-20): | |
return torch.log(t.clamp(min=eps)) | |
def ema_inplace(old, new, decay): | |
is_mps = str(old.device).startswith("mps:") | |
if not is_mps: | |
old.lerp_(new, 1 - decay) | |
else: | |
old.mul_(decay).add_(new * (1 - decay)) | |
def pack_one(t, pattern): | |
return pack([t], pattern) | |
def unpack_one(t, ps, pattern): | |
return unpack(t, ps, pattern)[0] | |
def uniform_init(*shape): | |
t = torch.empty(shape) | |
nn.init.kaiming_uniform_(t) | |
return t | |
def gumbel_noise(t): | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
def gumbel_sample( | |
logits, | |
temperature=1.0, | |
stochastic=False, | |
straight_through=False, | |
reinmax=False, | |
dim=-1, | |
training=True, | |
): | |
dtype, size = logits.dtype, logits.shape[dim] | |
if training and stochastic and temperature > 0: | |
sampling_logits = (logits / temperature) + gumbel_noise(logits) | |
else: | |
sampling_logits = logits | |
ind = sampling_logits.argmax(dim=dim) | |
one_hot = F.one_hot(ind, size).type(dtype) | |
assert not (reinmax and not straight_through), ( | |
"reinmax can only be turned on if using straight through gumbel softmax" | |
) | |
if not straight_through or temperature <= 0.0 or not training: | |
return ind, one_hot | |
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 | |
# algorithm 2 | |
if reinmax: | |
π0 = logits.softmax(dim=dim) | |
π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2 | |
π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1) | |
π2 = 2 * π1 - 0.5 * π0 | |
one_hot = π2 - π2.detach() + one_hot | |
else: | |
π1 = (logits / temperature).softmax(dim=dim) | |
one_hot = one_hot + π1 - π1.detach() | |
return ind, one_hot | |
def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1): | |
denom = x.sum(dim=dim, keepdim=True) | |
return (x + eps) / (denom + n_categories * eps) | |
def sample_vectors(samples, num): | |
num_samples, device = samples.shape[0], samples.device | |
if num_samples >= num: | |
indices = torch.randperm(num_samples, device=device)[:num] | |
else: | |
indices = torch.randint(0, num_samples, (num,), device=device) | |
return samples[indices] | |
def batched_sample_vectors(samples, num): | |
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) | |
def pad_shape(shape, size, dim=0): | |
return [size if i == dim else s for i, s in enumerate(shape)] | |
def sample_multinomial(total_count, probs): | |
device = probs.device | |
probs = probs.cpu() | |
total_count = probs.new_full((), total_count) | |
remainder = probs.new_ones(()) | |
sample = torch.empty_like(probs, dtype=torch.long) | |
for i, p in enumerate(probs): | |
s = torch.binomial(total_count, p / remainder) | |
sample[i] = s | |
total_count -= s | |
remainder -= p | |
return sample.to(device) | |
def all_gather_sizes(x, dim): | |
size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device) | |
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] | |
distributed.all_gather(all_sizes, size) | |
return torch.stack(all_sizes) | |
def all_gather_variably_sized(x, sizes, dim=0): | |
rank = distributed.get_rank() | |
all_x = [] | |
for i, size in enumerate(sizes): | |
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) | |
distributed.broadcast(t, src=i, async_op=True) | |
all_x.append(t) | |
distributed.barrier() | |
return all_x | |
def sample_vectors_distributed(local_samples, num): | |
local_samples = rearrange(local_samples, "1 ... -> ...") | |
rank = distributed.get_rank() | |
all_num_samples = all_gather_sizes(local_samples, dim=0) | |
if rank == 0: | |
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) | |
else: | |
samples_per_rank = torch.empty_like(all_num_samples) | |
distributed.broadcast(samples_per_rank, src=0) | |
samples_per_rank = samples_per_rank.tolist() | |
local_samples = sample_vectors(local_samples, samples_per_rank[rank]) | |
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0) | |
out = torch.cat(all_samples, dim=0) | |
return rearrange(out, "... -> 1 ...") | |
def batched_bincount(x, *, minlength): | |
batch, dtype, device = x.shape[0], x.dtype, x.device | |
target = torch.zeros(batch, minlength, dtype=dtype, device=device) | |
values = torch.ones_like(x) | |
target.scatter_add_(-1, x, values) | |
return target | |
def kmeans( | |
samples, | |
num_clusters, | |
num_iters=10, | |
sample_fn=batched_sample_vectors, | |
all_reduce_fn=noop, | |
): | |
num_codebooks, dim, dtype, _device = ( | |
samples.shape[0], | |
samples.shape[-1], | |
samples.dtype, | |
samples.device, | |
) | |
means = sample_fn(samples, num_clusters) | |
for _ in range(num_iters): | |
dists = -torch.cdist(samples, means, p=2) | |
buckets = torch.argmax(dists, dim=-1) | |
bins = batched_bincount(buckets, minlength=num_clusters) | |
all_reduce_fn(bins) | |
zero_mask = bins == 0 | |
bins_min_clamped = bins.masked_fill(zero_mask, 1) | |
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype) | |
new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples) | |
new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1") | |
all_reduce_fn(new_means) | |
means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means) | |
return means, bins | |
def batched_embedding(indices, embeds): | |
batch, dim = indices.shape[1], embeds.shape[-1] | |
indices = repeat(indices, "h b n -> h b n d", d=dim) | |
embeds = repeat(embeds, "h c d -> h b c d", b=batch) | |
return embeds.gather(2, indices) | |
def orthogonal_loss_fn(t): | |
# eq (2) from https://arxiv.org/abs/2112.00384 | |
h, n = t.shape[:2] | |
normed_codes = F.normalize(t, p=2, dim=-1) | |
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) | |
return (cosine_sim**2).sum() / (h * n**2) - (1 / n) | |
class EuclideanCodebook(nn.Module): | |
def __init__( | |
self, | |
dim, | |
codebook_size, | |
num_codebooks=1, | |
kmeans_init=False, | |
kmeans_iters=10, | |
sync_kmeans=True, | |
decay=0.8, | |
eps=1e-5, | |
threshold_ema_dead_code=2, | |
reset_cluster_size=None, | |
use_ddp=False, | |
learnable_codebook=False, | |
gumbel_sample=gumbel_sample, | |
sample_codebook_temp=1.0, | |
ema_update=True, | |
affine_param=False, | |
sync_affine_param=False, | |
affine_param_batch_decay=0.99, | |
affine_param_codebook_decay=0.9, | |
): | |
super().__init__() | |
self.transform_input = identity | |
self.decay = decay | |
self.ema_update = ema_update | |
init_fn = uniform_init if not kmeans_init else torch.zeros | |
embed = init_fn(num_codebooks, codebook_size, dim) | |
self.codebook_size = codebook_size | |
self.num_codebooks = num_codebooks | |
self.kmeans_iters = kmeans_iters | |
self.eps = eps | |
self.threshold_ema_dead_code = threshold_ema_dead_code | |
self.reset_cluster_size = ( | |
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code | |
) | |
assert callable(gumbel_sample) | |
self.gumbel_sample = gumbel_sample | |
self.sample_codebook_temp = sample_codebook_temp | |
assert not (use_ddp and num_codebooks > 1 and kmeans_init), ( | |
"kmeans init is not compatible with multiple codebooks in distributed environment for now" | |
) | |
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors | |
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop | |
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop | |
self.register_buffer("initted", torch.Tensor([not kmeans_init])) | |
self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) | |
self.register_buffer("embed_avg", embed.clone()) | |
self.learnable_codebook = learnable_codebook | |
if learnable_codebook: | |
self.embed = nn.Parameter(embed) | |
else: | |
self.register_buffer("embed", embed) | |
# affine related params | |
self.affine_param = affine_param | |
self.sync_affine_param = sync_affine_param | |
if not affine_param: | |
return | |
self.affine_param_batch_decay = affine_param_batch_decay | |
self.affine_param_codebook_decay = affine_param_codebook_decay | |
self.register_buffer("batch_mean", None) | |
self.register_buffer("batch_variance", None) | |
self.register_buffer("codebook_mean_needs_init", torch.Tensor([True])) | |
self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim)) | |
self.register_buffer("codebook_variance_needs_init", torch.Tensor([True])) | |
self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim)) | |
def init_embed_(self, data, mask=None): | |
if self.initted: | |
return | |
if mask is not None: | |
c = data.shape[0] | |
data = rearrange(data[mask], "(c n) d -> c n d", c=c) | |
embed, cluster_size = kmeans( | |
data, | |
self.codebook_size, | |
self.kmeans_iters, | |
sample_fn=self.sample_fn, | |
all_reduce_fn=self.kmeans_all_reduce_fn, | |
) | |
embed_sum = embed * rearrange(cluster_size, "... -> ... 1") | |
self.embed.data.copy_(embed) | |
self.embed_avg.data.copy_(embed_sum) | |
self.cluster_size.data.copy_(cluster_size) | |
self.initted.data.copy_(torch.Tensor([True])) | |
def update_with_decay(self, buffer_name, new_value, decay): | |
old_value = getattr(self, buffer_name) | |
needs_init = getattr(self, buffer_name + "_needs_init", False) | |
if needs_init: | |
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) | |
if not (old_value is not None) or needs_init: | |
self.register_buffer(buffer_name, new_value.detach()) | |
return | |
value = old_value * decay + new_value.detach() * (1 - decay) | |
self.register_buffer(buffer_name, value) | |
def update_affine(self, data, embed, mask=None): | |
assert self.affine_param | |
var_fn = partial(torch.var, unbiased=False) | |
# calculate codebook mean and variance | |
embed = rearrange(embed, "h ... d -> h (...) d") | |
if self.training: | |
self.update_with_decay( | |
"codebook_mean", | |
reduce(embed, "h n d -> h 1 d", "mean"), | |
self.affine_param_codebook_decay, | |
) | |
self.update_with_decay( | |
"codebook_variance", | |
reduce(embed, "h n d -> h 1 d", var_fn), | |
self.affine_param_codebook_decay, | |
) | |
# prepare batch data, which depends on whether it has masking | |
data = rearrange(data, "h ... d -> h (...) d") | |
if mask is not None: | |
c = data.shape[0] | |
data = rearrange(data[mask], "(c n) d -> c n d", c=c) | |
# calculate batch mean and variance | |
if not self.sync_affine_param: | |
self.update_with_decay( | |
"batch_mean", | |
reduce(data, "h n d -> h 1 d", "mean"), | |
self.affine_param_batch_decay, | |
) | |
self.update_with_decay( | |
"batch_variance", | |
reduce(data, "h n d -> h 1 d", var_fn), | |
self.affine_param_batch_decay, | |
) | |
return | |
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype | |
# number of vectors, for denominator | |
num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype) | |
distributed.all_reduce(num_vectors) | |
# calculate distributed mean | |
batch_sum = reduce(data, "h n d -> h 1 d", "sum") | |
distributed.all_reduce(batch_sum) | |
batch_mean = batch_sum / num_vectors | |
self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay) | |
# calculate distributed variance | |
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum") | |
distributed.all_reduce(variance_number) | |
batch_variance = variance_number / num_vectors | |
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) | |
def replace(self, batch_samples, batch_mask): | |
for ind, (samples, mask) in enumerate( | |
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0), strict=False) | |
): | |
if not torch.any(mask): | |
continue | |
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) | |
sampled = rearrange(sampled, "1 ... -> ...") | |
self.embed.data[ind][mask] = sampled | |
self.cluster_size.data[ind][mask] = self.reset_cluster_size | |
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size | |
def expire_codes_(self, batch_samples): | |
if self.threshold_ema_dead_code == 0: | |
return | |
expired_codes = self.cluster_size < self.threshold_ema_dead_code | |
if not torch.any(expired_codes): | |
return | |
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") | |
self.replace(batch_samples, batch_mask=expired_codes) | |
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): | |
needs_codebook_dim = x.ndim < 4 | |
sample_codebook_temp = ( | |
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp | |
) | |
x = x.float() | |
if needs_codebook_dim: | |
x = rearrange(x, "... -> 1 ...") | |
flatten, ps = pack_one(x, "h * d") | |
if mask is not None: | |
mask = repeat( | |
mask, | |
"b n -> c (b h n)", | |
c=flatten.shape[0], | |
h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), | |
) | |
self.init_embed_(flatten, mask=mask) | |
if self.affine_param: | |
self.update_affine(flatten, self.embed, mask=mask) | |
embed = self.embed if self.learnable_codebook else self.embed.detach() | |
if self.affine_param: | |
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() | |
batch_std = self.batch_variance.clamp(min=1e-5).sqrt() | |
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean | |
dist = -cdist(flatten, embed) | |
embed_ind, embed_onehot = self.gumbel_sample( | |
dist, dim=-1, temperature=sample_codebook_temp, training=self.training | |
) | |
embed_ind = unpack_one(embed_ind, ps, "h *") | |
if self.training: | |
unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") | |
quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) | |
else: | |
quantize = batched_embedding(embed_ind, embed) | |
if self.training and self.ema_update and not freeze_codebook: | |
if self.affine_param: | |
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean | |
if mask is not None: | |
embed_onehot[~mask] = 0.0 | |
cluster_size = embed_onehot.sum(dim=1) | |
self.all_reduce_fn(cluster_size) | |
ema_inplace(self.cluster_size.data, cluster_size, self.decay) | |
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) | |
self.all_reduce_fn(embed_sum.contiguous()) | |
ema_inplace(self.embed_avg.data, embed_sum, self.decay) | |
cluster_size = laplace_smoothing( | |
self.cluster_size, self.codebook_size, self.eps | |
) * self.cluster_size.sum(dim=-1, keepdim=True) | |
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") | |
self.embed.data.copy_(embed_normalized) | |
self.expire_codes_(x) | |
if needs_codebook_dim: | |
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) | |
dist = unpack_one(dist, ps, "h * d") | |
return quantize, embed_ind, dist | |