pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VQ-VAE autoencoder with ViT backbone."""
import functools
from typing import Mapping, Optional, Sequence, Union
from big_vision import utils
from big_vision.models import common
from big_vision.models import vit
import einops
import flax.linen as nn
import flax.training.checkpoints
import jax
import jax.numpy as jnp
import numpy as np
partial = functools.partial
# Multiplicative perturbation applied to codewords when doing the split.
# Note, the multiplicative pertubation is not perfectly symmetric and rep.
# applications can shrink the embedding. However, in practice it does not matter
# for the value we use.
PERTURB = 0.001
# The function below takes a vector `x` and a dictioniary of vectors `e` as an
# input. It then returns a "quantized" version of x (namely the closest to `x`
# vector from `e`) and its index in `e` as well.
# On top of this, it has two extra features:
# 1. Double `vmap` vectorizes this function to operate on many `x` vectors.
# More concretely, we add two extra dimensions (batch and space) to `x`.
# Also note we compute euclidian distance in a decomposed way, because it
# makes it more efficient for vmapping.
# 2. `quantize` is a "discrete" operation, so it does not have a gradient for
# `x`. So we implement a so-called "straight-through" gradient estimator
# using `stop_gradient` magic. It does not affect forward pass, but changes
# the gradient.
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0))
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0))
def quantize(x, e):
dist = jnp.sum(x * x)[None] - 2 * x.dot(e.T) + jnp.sum(e * e, axis=1)
idx = jnp.argmin(dist)
x_q = jax.lax.stop_gradient(e[idx] - x) + x # just `e[idx]` for the fwd pass.
return x_q, idx
def split_the_most_frequent_embedding(state):
"""Splits most frequent embedding into two and eliminates least frequent.
Args:
state: a dict. that contains current jax rng, embeddings and their counts.
Returns:
New dict. with the updated jax rng, embeddings and counts.
"""
rng, e, c = state["rng"], state["dictionary"], state["counts"]
rng, rng_local = jax.random.split(rng)
i_max = jnp.argmax(c)
i_min = jnp.argmin(c)
e = e.at[i_min].set(
e[i_max] * jax.random.uniform(rng_local, (e.shape[1],), jnp.float32,
1.0-PERTURB, 1.0+PERTURB))
c = c.at[i_min].set(c[i_max] / 2.0)
c = c.at[i_max].set(c[i_max] / 2.0)
e = e.at[i_min].set(e[i_min] / 2.0)
e = e.at[i_max].set(e[i_max] / 2.0)
return {"rng": rng, "dictionary": e, "counts": c}
class Model(nn.Module):
"""ViT model."""
inputs: Mapping[str, Sequence[int]]
outputs: Mapping[str, Sequence[int]]
input_size: Sequence[int] = (256, 256)
patch_size: Sequence[int] = (8, 8)
code_len: int = 256
width: int = 768
enc_depth: int = 6
dec_depth: int = 6
mlp_dim: Optional[int] = None
num_heads: int = 12
posemb: str = "learn" # Can also be "sincos2d"
rep_size: Union[int, bool] = False
dropout: float = 0.0
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
dict_size: int = 512 # Number of words in dict.
codeword_dim: Optional[int] = None
dict_momentum: float = 0.995 # Exp. moving average coeff. for dict. learning.
quantize: bool = True
# Useful to set to None when running without pmap, e.g. testing.
statistics_axis_name: str = "batch"
# Threshold for the discounted count after which the codeword will be
# considered unused. For the `dict_momentum` param of 0.995 the codeword
# should not be present in ~500 batches in a row.
min_count: float = 0.1 # ~= 0.995 ** 500
with_encoder_ctx: bool = False
with_decoder_ctx: bool = False
code_dropout: str = "none"
bottleneck_resize: bool = False
zero_decoder_seq: bool = False
def setup(self):
self.grid_size = np.array(self.input_size) // np.array(self.patch_size)
self.embeddings = {
k: nn.DenseGeneral(features=(self.width,), axis=range(-len(shape), 0),
name=f"embedding_{k}")
for k, shape in self.inputs.items()
}
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
self.heads = {
k: nn.DenseGeneral(features=shape, name=f"head_{k}", **kw)
for k, shape in self.outputs.items()
}
if self.with_encoder_ctx:
self.stem_conv_ctx_enc = nn.Conv(
self.width, self.patch_size, strides=self.patch_size,
padding="VALID", name="ctx_enc_embedding")
if self.with_decoder_ctx:
self.stem_conv_ctx_dec = nn.Conv(
self.width, self.patch_size, strides=self.patch_size,
padding="VALID", name="ctx_dec_embedding")
self.pos_embedding_encoder = vit.get_posemb(
self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder")
self.encoder = vit.Encoder(
depth=self.enc_depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
name="encoder")
if not self.bottleneck_resize:
self.bottleneck_downsample = self.param(
"bottleneck_downsample",
nn.initializers.xavier_uniform(),
(np.prod(self.grid_size), self.code_len))
norm_init = nn.initializers.normal(stddev=1.0 / np.sqrt(self.dict_size))
self.dictionary = self.variable(
"state", "dictionary",
lambda shape: norm_init(self.make_rng("state"), shape),
(self.dict_size, self.codeword_dim or self.width))
self.counts = self.variable("state", "counts", jnp.ones, (self.dict_size,))
if not self.bottleneck_resize:
self.bottleneck_upsample = self.param(
"bottleneck_upsample",
nn.initializers.xavier_uniform(),
(self.code_len, np.prod(self.grid_size)))
self.pos_embedding_decoder = vit.get_posemb(
self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder")
self.decoder = vit.Encoder(
depth=self.dec_depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
name="decoder")
self.encoder_head = nn.Dense(self.codeword_dim or self.width)
self.decoder_stem = nn.Dense(self.width)
def get_codewords(self):
e = self.dictionary.value / self.counts.value[:, None]
e = e / jnp.linalg.norm(e, axis=-1, keepdims=True)
return e
def encode(self, x, *, ctx=None, train=False, update_dict=True):
out = {}
out["stem"] = {}
for key, embed in self.embeddings.items():
out["stem"][key] = embed(x[key])
x = sum(out["stem"].values())
if self.with_encoder_ctx:
ctx_tokens = self.stem_conv_ctx_enc(ctx)
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c")
x = x + ctx_tokens
x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train)
if self.bottleneck_resize:
x = einops.rearrange(x, "b (h w) c -> b h w c",
h=self.grid_size[0], w=self.grid_size[1])
l = int(np.round(self.code_len ** 0.5))
x = jax.image.resize(
x, (x.shape[0], l, l, x.shape[3]),
method="linear")
x = einops.rearrange(x, "b h w c -> b (h w) c")
else:
x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample)
x = self.encoder_head(x)
x = jax.nn.standardize(x, axis=-1)
x_pre_q = out["bottleneck"] = x
e = self.get_codewords()
x, idx = quantize(x, e)
out["bottleneck_q"] = x
out["code"] = idx
# Implements explicit dictionary learning algo outlined in the VQ-VAE paper.
# We slightly deviate from the papers formulation, as we find it confusing,
# especially in the multi-host scenario. What is implemented below can be
# seen as computing discounted counts and sums of all embeddings.
if train:
# Compute counts and sum(x) of code in the global batch.
counts = jnp.zeros(self.dict_size, dtype=jnp.int32)
counts = counts.at[idx].add(1)
# Below we introduce redundant stop_gradient, because jax' dead code
# elimination for our program's gradient fails to infer that the code
# below does not require gradient computation.
# Relevant github issue: https://github.com/google/jax/issues/9042.
# TODO: remove stop_gradient when the bug is fixed.
x_sum = jnp.zeros_like(self.dictionary.value)
x_sum = x_sum.at[idx].add(jax.lax.stop_gradient(x_pre_q))
if self.statistics_axis_name:
counts = jax.lax.psum(counts, axis_name=self.statistics_axis_name)
x_sum = jax.lax.psum(x_sum, axis_name=self.statistics_axis_name)
out["codebook_max_ratio"] = jnp.max(counts) / jnp.sum(counts)
out["codebook_zeros_ratio"] = jnp.sum(counts == 0) / len(counts)
if update_dict:
self.counts.value = self.counts.value * self.dict_momentum + counts
self.dictionary.value = (self.dictionary.value * self.dict_momentum +
x_sum)
state = {"dictionary": self.dictionary.value,
"counts": self.counts.value,
"rng": self.make_rng("vqvae")}
new_state = jax.lax.while_loop(
lambda state: jnp.any(state["counts"] < self.min_count),
split_the_most_frequent_embedding,
state)
self.counts.value = new_state["counts"]
self.dictionary.value = new_state["dictionary"]
if not self.quantize:
x = x_pre_q
out["bottleneck_q"] = x
return x, out
def decode(self, x, ctx=None, discrete_input=False, train=False):
out = {}
if discrete_input:
e = self.get_codewords()
x = e[x]
if self.zero_decoder_seq:
x = jnp.zeros_like(x)
if train and self.code_dropout != "none":
importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1]
thr = jax.random.uniform(self.make_rng("dropout"), x.shape[:1])
mask = importance[None, :] > thr[:, None]
if self.code_dropout == "random":
mask = jax.random.permutation(
self.make_rng("dropout"), mask, axis=-1, independent=True)
x = x * mask[:, :, None]
x = self.decoder_stem(x)
if self.bottleneck_resize:
l = int(np.round(self.code_len ** 0.5))
x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l)
x = jax.image.resize(
x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]),
method="linear")
x = einops.rearrange(x, "b h w c -> b (h w) c")
else:
x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample)
if self.with_decoder_ctx:
ctx_tokens = self.stem_conv_ctx_dec(ctx)
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c")
x = x + ctx_tokens
x, _ = self.decoder(x + self.pos_embedding_decoder)
out["logits"] = {}
for key, head in self.heads.items():
out["logits"][key] = head(x)
return out["logits"], out
def __call__(self, x, *, ctx=None, train=False, update_dict=True):
x, out_enc = self.encode(x, ctx=ctx, train=train, update_dict=update_dict)
x, out_dec = self.decode(x, ctx=ctx, train=train)
return x, {**out_enc, **out_dec}
def load(init_params, init_file, model_params=None, dont_load=()):
"""Loads params from init checkpoint and merges into init_params."""
del model_params
ckpt = flax.core.unfreeze(utils.load_checkpoint(None, init_file))
params = {"params": ckpt["params"], "state": ckpt["state"]}
params = flax.training.checkpoints.convert_pre_linen(params)
# Fix old-style param name.
if "Encoder" in params["params"]:
p = params["params"]
p["encoder"] = p.pop("Encoder")
p["decoder"] = p.pop("Decoder")
params["params"] = p
if init_params is not None:
params = common.merge_params(params, init_params, dont_load)
return params["params"], params["state"]