|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""VQ-VAE autoencoder with ViT backbone.""" |
|
|
|
import functools |
|
from typing import Mapping, Optional, Sequence, Union |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models import vit |
|
|
|
import einops |
|
import flax.linen as nn |
|
import flax.training.checkpoints |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
partial = functools.partial |
|
|
|
|
|
|
|
|
|
|
|
PERTURB = 0.001 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) |
|
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) |
|
def quantize(x, e): |
|
dist = jnp.sum(x * x)[None] - 2 * x.dot(e.T) + jnp.sum(e * e, axis=1) |
|
idx = jnp.argmin(dist) |
|
x_q = jax.lax.stop_gradient(e[idx] - x) + x |
|
return x_q, idx |
|
|
|
|
|
def split_the_most_frequent_embedding(state): |
|
"""Splits most frequent embedding into two and eliminates least frequent. |
|
|
|
Args: |
|
state: a dict. that contains current jax rng, embeddings and their counts. |
|
|
|
Returns: |
|
New dict. with the updated jax rng, embeddings and counts. |
|
""" |
|
rng, e, c = state["rng"], state["dictionary"], state["counts"] |
|
rng, rng_local = jax.random.split(rng) |
|
|
|
i_max = jnp.argmax(c) |
|
i_min = jnp.argmin(c) |
|
|
|
e = e.at[i_min].set( |
|
e[i_max] * jax.random.uniform(rng_local, (e.shape[1],), jnp.float32, |
|
1.0-PERTURB, 1.0+PERTURB)) |
|
|
|
c = c.at[i_min].set(c[i_max] / 2.0) |
|
c = c.at[i_max].set(c[i_max] / 2.0) |
|
|
|
e = e.at[i_min].set(e[i_min] / 2.0) |
|
e = e.at[i_max].set(e[i_max] / 2.0) |
|
|
|
return {"rng": rng, "dictionary": e, "counts": c} |
|
|
|
|
|
class Model(nn.Module): |
|
"""ViT model.""" |
|
|
|
inputs: Mapping[str, Sequence[int]] |
|
outputs: Mapping[str, Sequence[int]] |
|
input_size: Sequence[int] = (256, 256) |
|
patch_size: Sequence[int] = (8, 8) |
|
code_len: int = 256 |
|
width: int = 768 |
|
enc_depth: int = 6 |
|
dec_depth: int = 6 |
|
mlp_dim: Optional[int] = None |
|
num_heads: int = 12 |
|
posemb: str = "learn" |
|
rep_size: Union[int, bool] = False |
|
dropout: float = 0.0 |
|
reinit: Optional[Sequence[str]] = None |
|
head_zeroinit: bool = True |
|
dict_size: int = 512 |
|
codeword_dim: Optional[int] = None |
|
dict_momentum: float = 0.995 |
|
quantize: bool = True |
|
|
|
statistics_axis_name: str = "batch" |
|
|
|
|
|
|
|
min_count: float = 0.1 |
|
with_encoder_ctx: bool = False |
|
with_decoder_ctx: bool = False |
|
code_dropout: str = "none" |
|
bottleneck_resize: bool = False |
|
zero_decoder_seq: bool = False |
|
|
|
def setup(self): |
|
|
|
self.grid_size = np.array(self.input_size) // np.array(self.patch_size) |
|
|
|
self.embeddings = { |
|
k: nn.DenseGeneral(features=(self.width,), axis=range(-len(shape), 0), |
|
name=f"embedding_{k}") |
|
for k, shape in self.inputs.items() |
|
} |
|
|
|
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} |
|
self.heads = { |
|
k: nn.DenseGeneral(features=shape, name=f"head_{k}", **kw) |
|
for k, shape in self.outputs.items() |
|
} |
|
|
|
if self.with_encoder_ctx: |
|
self.stem_conv_ctx_enc = nn.Conv( |
|
self.width, self.patch_size, strides=self.patch_size, |
|
padding="VALID", name="ctx_enc_embedding") |
|
|
|
if self.with_decoder_ctx: |
|
self.stem_conv_ctx_dec = nn.Conv( |
|
self.width, self.patch_size, strides=self.patch_size, |
|
padding="VALID", name="ctx_dec_embedding") |
|
|
|
self.pos_embedding_encoder = vit.get_posemb( |
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder") |
|
self.encoder = vit.Encoder( |
|
depth=self.enc_depth, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.dropout, |
|
name="encoder") |
|
|
|
if not self.bottleneck_resize: |
|
self.bottleneck_downsample = self.param( |
|
"bottleneck_downsample", |
|
nn.initializers.xavier_uniform(), |
|
(np.prod(self.grid_size), self.code_len)) |
|
|
|
norm_init = nn.initializers.normal(stddev=1.0 / np.sqrt(self.dict_size)) |
|
self.dictionary = self.variable( |
|
"state", "dictionary", |
|
lambda shape: norm_init(self.make_rng("state"), shape), |
|
(self.dict_size, self.codeword_dim or self.width)) |
|
self.counts = self.variable("state", "counts", jnp.ones, (self.dict_size,)) |
|
|
|
if not self.bottleneck_resize: |
|
self.bottleneck_upsample = self.param( |
|
"bottleneck_upsample", |
|
nn.initializers.xavier_uniform(), |
|
(self.code_len, np.prod(self.grid_size))) |
|
|
|
self.pos_embedding_decoder = vit.get_posemb( |
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder") |
|
self.decoder = vit.Encoder( |
|
depth=self.dec_depth, |
|
mlp_dim=self.mlp_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.dropout, |
|
name="decoder") |
|
|
|
self.encoder_head = nn.Dense(self.codeword_dim or self.width) |
|
self.decoder_stem = nn.Dense(self.width) |
|
|
|
def get_codewords(self): |
|
e = self.dictionary.value / self.counts.value[:, None] |
|
e = e / jnp.linalg.norm(e, axis=-1, keepdims=True) |
|
return e |
|
|
|
def encode(self, x, *, ctx=None, train=False, update_dict=True): |
|
out = {} |
|
|
|
out["stem"] = {} |
|
for key, embed in self.embeddings.items(): |
|
out["stem"][key] = embed(x[key]) |
|
x = sum(out["stem"].values()) |
|
|
|
if self.with_encoder_ctx: |
|
ctx_tokens = self.stem_conv_ctx_enc(ctx) |
|
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c") |
|
x = x + ctx_tokens |
|
|
|
x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train) |
|
|
|
if self.bottleneck_resize: |
|
x = einops.rearrange(x, "b (h w) c -> b h w c", |
|
h=self.grid_size[0], w=self.grid_size[1]) |
|
l = int(np.round(self.code_len ** 0.5)) |
|
x = jax.image.resize( |
|
x, (x.shape[0], l, l, x.shape[3]), |
|
method="linear") |
|
x = einops.rearrange(x, "b h w c -> b (h w) c") |
|
else: |
|
x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample) |
|
|
|
x = self.encoder_head(x) |
|
|
|
x = jax.nn.standardize(x, axis=-1) |
|
x_pre_q = out["bottleneck"] = x |
|
e = self.get_codewords() |
|
x, idx = quantize(x, e) |
|
out["bottleneck_q"] = x |
|
out["code"] = idx |
|
|
|
|
|
|
|
|
|
|
|
if train: |
|
|
|
counts = jnp.zeros(self.dict_size, dtype=jnp.int32) |
|
counts = counts.at[idx].add(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_sum = jnp.zeros_like(self.dictionary.value) |
|
x_sum = x_sum.at[idx].add(jax.lax.stop_gradient(x_pre_q)) |
|
|
|
if self.statistics_axis_name: |
|
counts = jax.lax.psum(counts, axis_name=self.statistics_axis_name) |
|
x_sum = jax.lax.psum(x_sum, axis_name=self.statistics_axis_name) |
|
|
|
out["codebook_max_ratio"] = jnp.max(counts) / jnp.sum(counts) |
|
out["codebook_zeros_ratio"] = jnp.sum(counts == 0) / len(counts) |
|
|
|
if update_dict: |
|
self.counts.value = self.counts.value * self.dict_momentum + counts |
|
self.dictionary.value = (self.dictionary.value * self.dict_momentum + |
|
x_sum) |
|
|
|
state = {"dictionary": self.dictionary.value, |
|
"counts": self.counts.value, |
|
"rng": self.make_rng("vqvae")} |
|
new_state = jax.lax.while_loop( |
|
lambda state: jnp.any(state["counts"] < self.min_count), |
|
split_the_most_frequent_embedding, |
|
state) |
|
self.counts.value = new_state["counts"] |
|
self.dictionary.value = new_state["dictionary"] |
|
|
|
if not self.quantize: |
|
x = x_pre_q |
|
out["bottleneck_q"] = x |
|
return x, out |
|
|
|
def decode(self, x, ctx=None, discrete_input=False, train=False): |
|
out = {} |
|
|
|
if discrete_input: |
|
e = self.get_codewords() |
|
x = e[x] |
|
|
|
if self.zero_decoder_seq: |
|
x = jnp.zeros_like(x) |
|
|
|
if train and self.code_dropout != "none": |
|
importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1] |
|
thr = jax.random.uniform(self.make_rng("dropout"), x.shape[:1]) |
|
mask = importance[None, :] > thr[:, None] |
|
if self.code_dropout == "random": |
|
mask = jax.random.permutation( |
|
self.make_rng("dropout"), mask, axis=-1, independent=True) |
|
x = x * mask[:, :, None] |
|
|
|
x = self.decoder_stem(x) |
|
|
|
if self.bottleneck_resize: |
|
l = int(np.round(self.code_len ** 0.5)) |
|
x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l) |
|
x = jax.image.resize( |
|
x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]), |
|
method="linear") |
|
x = einops.rearrange(x, "b h w c -> b (h w) c") |
|
else: |
|
x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample) |
|
|
|
if self.with_decoder_ctx: |
|
ctx_tokens = self.stem_conv_ctx_dec(ctx) |
|
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c") |
|
x = x + ctx_tokens |
|
|
|
x, _ = self.decoder(x + self.pos_embedding_decoder) |
|
|
|
out["logits"] = {} |
|
for key, head in self.heads.items(): |
|
out["logits"][key] = head(x) |
|
|
|
return out["logits"], out |
|
|
|
def __call__(self, x, *, ctx=None, train=False, update_dict=True): |
|
x, out_enc = self.encode(x, ctx=ctx, train=train, update_dict=update_dict) |
|
x, out_dec = self.decode(x, ctx=ctx, train=train) |
|
return x, {**out_enc, **out_dec} |
|
|
|
|
|
def load(init_params, init_file, model_params=None, dont_load=()): |
|
"""Loads params from init checkpoint and merges into init_params.""" |
|
del model_params |
|
ckpt = flax.core.unfreeze(utils.load_checkpoint(None, init_file)) |
|
params = {"params": ckpt["params"], "state": ckpt["state"]} |
|
params = flax.training.checkpoints.convert_pre_linen(params) |
|
|
|
if "Encoder" in params["params"]: |
|
p = params["params"] |
|
p["encoder"] = p.pop("Encoder") |
|
p["decoder"] = p.pop("Decoder") |
|
params["params"] = p |
|
if init_params is not None: |
|
params = common.merge_params(params, init_params, dont_load) |
|
return params["params"], params["state"] |
|
|