pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Autorgregressive sampler for GIVT."""
import functools
from typing import Any, Optional
from big_vision.models.proj.givt import parallel_decode
import flax
from flax import linen as nn
import jax
from jax import lax
from jax import numpy as jnp
import ml_collections
def _sample_gmm(
gmm_pdf,
*,
rng,
cfg_inference_weight=None,
gmm_pdf_uncond=None,
):
"""Draw a single sample from a GMM."""
if cfg_inference_weight is not None:
assert gmm_pdf_uncond is not None
gmm_pdf = parallel_decode.CFGDensity(
gmm_pdf, gmm_pdf_uncond, w=cfg_inference_weight, rng=rng
)
samples = gmm_pdf.sample(seed=rng)
logprobs = gmm_pdf.log_prob(samples)
if logprobs.ndim == 2:
logprobs = logprobs[..., None]
return samples, logprobs
# Beam search reshaping utils
def _flatten_samples_dim(x):
"""Flattens samples dimension into batch dimension."""
if x.ndim == 0: # ignore scalars (e.g. cache index)
return x
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
def _unflatten_samples_dim(x, batch_size, num_samples):
"""Unflattens first dimension into batch and samples dimensions."""
if x.ndim == 0: # ignore scalars (e.g. cache index)
return x
assert batch_size * num_samples == x.shape[0]
return x.reshape((batch_size, num_samples) + x.shape[1:])
def _cache_map(fn, cache, scan=False):
"""Maps function over cache."""
if scan:
# Assuming the chache is scanned over the first dimension, we apply a map
# function over this dimension for scanned models
fn_mod = lambda x: jax.lax.map(fn, x) if x.ndim > 0 else fn(x)
else:
fn_mod = fn
frozen = isinstance(cache, flax.core.FrozenDict)
if frozen:
cache = flax.core.unfreeze(cache)
flat_cache = flax.traverse_util.flatten_dict(cache)
# Exclude cached relative position bias from beam expansion, etc.
keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"}
keyvals = jax.tree_map(fn_mod, keyvals)
flat_cache.update(keyvals)
new_cache = flax.traverse_util.unflatten_dict(flat_cache)
if frozen:
new_cache = flax.core.freeze(new_cache)
return new_cache
@flax.struct.dataclass
class LoopState:
"""Internal state of the sampling loop."""
# Terminology
# b: batch size
# nb: number of beams
# nf: number of fans
# s: seaquence length
# d: feature dimension
rng: jnp.ndarray # PRNGKey of the loop state.
cache: Any # Cache for fast auto-regressive decoding.
sequences: jnp.ndarray # (b * nb, s, d)
logprobs: jnp.ndarray # (b * nb, s, d)
cache_u: Any # Uncond cache if cfg, otherwise None
def _create_cache(
labels,
model,
init_sequence,
params,
encoded,
uncond=False,
):
"""Creates the cache and returns initial logits."""
if uncond:
assert labels is not None # Need labels for CFG!
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_)
else:
drop_labels = None
def init_cache(model):
return model.decode(
init_sequence, labels, encoded, decode=True, drop_labels=drop_labels
)
cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"]
def prefill_cache(model):
return model.prefill(
labels, init_sequence.shape[0], encoded, drop_labels=drop_labels
)
# prefill class label or BOS token
prefill_logits, aux = nn.apply(prefill_cache, model, mutable=True)(
{"params": params["params"], "cache": cache})
cache = aux["cache"]
return cache, prefill_logits
def generate(
params: Any,
seed: jax.Array,
*,
model: nn.Module,
seq_len: int,
feature_dim: int,
labels: Optional[jnp.ndarray] = None,
cond_image: Optional[jnp.ndarray] = None,
batch_size: Optional[int] = None,
config: Optional[ml_collections.ConfigDict] = None,
) -> tuple[jax.Array, jax.Array]:
"""Sampling loop for GIVT."""
if model.style != "ar": # pytype: disable=wrong-arg-types
raise ValueError(f"Invalid style: {model.style}")
if model.has_encoder != (cond_image is not None):
raise ValueError("Need cond_image if and only if the model has an encoder!")
assert labels is not None or batch_size, (
"Please provide either labels or batch_size.")
config = config or {}
config = dict(config) # copy
# For sampling, we support keep_gt (a bool mask), and gt (ground truth)
# tokens to use instead of samples.
keep_gt = config.pop("keep_gt", None)
gt = config.pop("gt", None)
if isinstance(seed, int):
seed = jax.random.PRNGKey(seed)
beam_size = config.pop("beam_size", 1)
fan_size = config.pop("fan_size", 1)
if labels is not None:
batch_size = labels.shape[0]
# fold beams into batch dimension
labels = labels.repeat(beam_size, axis=0)
# initialize sequence and logprobs (we track per feature dim logprobs)
init_sequence = jnp.zeros((batch_size * beam_size, seq_len, feature_dim))
init_logprobs = jnp.zeros_like(init_sequence)
if cond_image is not None:
# embed conditioning image if provided
def encode_cond_img(model, cond_img):
return model.encode(cond_img)
encoded = nn.apply(encode_cond_img, model)(params, cond_image)
encoded = jnp.repeat(encoded, beam_size, axis=0)
else:
encoded = None
cache, prefill_logits = _create_cache(
labels, model, init_sequence, params, encoded
)
cfg_inference_weight = config.pop("cfg_inference_weight", None)
if cfg_inference_weight == 0.0:
cfg_inference_weight = None
cfg = cfg_inference_weight is not None
get_pdf = functools.partial(
model.get_pdf,
temperature_scales=config.pop("temp", None),
temperature_probs=config.pop("temp_probs", None),
)
# setup sampling function
sample = functools.partial(
_sample_gmm, cfg_inference_weight=cfg_inference_weight
)
# draw first output token
pdf_first = get_pdf(prefill_logits)
rng_first, rng = jax.random.split(seed)
if cfg:
assert beam_size == 1 and fan_size == 1 # CFG + Beam not supported.
cache_u, prefill_logits_u = _create_cache(
labels, model, init_sequence, params, encoded, uncond=True
)
pdf_first_u = get_pdf(prefill_logits_u)
else:
cache_u = None
pdf_first_u = None
tokens_first, logprobs_first = sample(
pdf_first, rng=rng_first, gmm_pdf_uncond=pdf_first_u
)
init_sequence = init_sequence.at[:, 0].set(tokens_first.squeeze(axis=1))
init_logprobs = init_logprobs.at[:, 0].set(logprobs_first.squeeze(axis=1))
def tokens_to_logits(tokens, cache, uncond=False):
if uncond:
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_)
else:
drop_labels = None
def decode_step(model, tokens):
return model.decode(tokens, labels, encoded,
decode=True, drop_labels=drop_labels)
logits, aux = nn.apply(decode_step, model, mutable=True)(
{"params": params["params"], "cache": cache}, tokens)
return logits, aux["cache"]
init_state = LoopState(
cache=cache,
sequences=init_sequence, # (b * nb, s, d)
logprobs=init_logprobs, # (b * nb, s, d)
rng=rng,
cache_u=cache_u,
)
rand_top_k = config.pop("rand_top_k", False)
rand_top_k_temp = config.pop("rand_top_k_temp", 1.0)
assert not config, f"Sampling config is expected to be empty: {config}"
def sampling_iteration(i, state):
rng_sampling, rng_local = jax.random.split(state.rng)
cur_tokens = state.sequences[:, i][:, None]
# (b * nb, d)
cur_logits, cache = tokens_to_logits(cur_tokens, state.cache)
# (b, nb, d)
cur_logits = _unflatten_samples_dim(
cur_logits, batch_size, beam_size).squeeze(axis=2)
# (b, nb * nf, d)
cur_pdf = get_pdf(cur_logits.repeat(fan_size, axis=1))
if cfg:
cur_logits_u, cache_u = tokens_to_logits(
cur_tokens, state.cache_u, uncond=True
)
cur_logits_u = _unflatten_samples_dim(
cur_logits_u, batch_size, beam_size).squeeze(axis=2)
cur_pdf_u = get_pdf(cur_logits_u.repeat(fan_size, axis=1))
new_tokens, new_logprobs = sample(
cur_pdf, rng=rng_sampling, gmm_pdf_uncond=cur_pdf_u
)
else:
new_tokens, new_logprobs = sample(cur_pdf, rng=rng_sampling)
cache_u = None
if gt is not None:
assert keep_gt is not None
new_tokens = jnp.where(keep_gt[i], gt[:, i, :][:, None], new_tokens)
# Skip beam search if not needed
if beam_size == fan_size == 1:
sampled_tokens = new_tokens.squeeze(axis=1)
sequences = state.sequences.at[:, i + 1].set(sampled_tokens)
return LoopState(
cache=cache,
rng=rng_local,
sequences=sequences,
logprobs=state.logprobs,
cache_u=cache_u,
)
# (b, nb, s, d)
logprobs = _unflatten_samples_dim(state.logprobs, batch_size, beam_size)
cur_logprobs = logprobs[:, :, i] # (b, nb, d)
# (b, nb * nf, d)
new_logprobs = new_logprobs + cur_logprobs.repeat(fan_size, axis=1)
beam_logprobs = new_logprobs.sum(axis=-1) # (b, nb * nf)
if rand_top_k:
# randomize top-k sampling via sampling from a categorical distribution
def stoc_top_k(r, x, p):
return jax.random.choice(r, x, shape=(beam_size,), replace=False, p=p)
# construct index grid
index_grid = jnp.arange(beam_logprobs.shape[1], dtype=jnp.int32)
# (b, nb * nf)
index_grid = index_grid[None].repeat(beam_logprobs.shape[0], axis=0)
top_k_rng, rng_local = jax.random.split(rng_local)
top_k_rng = jax.random.split(top_k_rng, beam_logprobs.shape[0])
# vmap categorical sampling
top_beam_fan_indices = jax.vmap(stoc_top_k, in_axes=(0, 0, 0))(
top_k_rng,
index_grid,
nn.softmax(beam_logprobs / rand_top_k_temp, axis=-1))
else:
_, top_beam_fan_indices = lax.top_k(beam_logprobs, k=beam_size) # (b, nb)
top_beam_indices = top_beam_fan_indices // fan_size
def _gather_beams(x):
if x.ndim == 0:
return x
# checkify.check(jnp.all(top_beam_indices < x.shape[1]),
# f"`take_along_axis` out of bounds in `_gather_beams`: "
# f"{top_beam_indices.max()} vs. {x.shape[1]}")
# (b, nb, 1 ... 1)
expanded_indices = top_beam_indices.reshape(
top_beam_indices.shape + (1,) * (x.ndim - 2))
return jnp.take_along_axis(x, expanded_indices, axis=1)
def _gather_tokens(x):
# (b, nb * nf, d) -> (b, nb, d)
# checkify.check(jnp.all(top_beam_fan_indices < x.shape[1]),
# f"`take_along_axis` out of bounds in `_gather_tokens`: "
# f"{top_beam_fan_indices.max()} vs. {x.shape[1]}")
return jnp.take_along_axis(x, top_beam_fan_indices[..., None], axis=1)
# (b, nb, s, d)
sequences = _unflatten_samples_dim(state.sequences, batch_size, beam_size)
sequences = _gather_beams(sequences) # (b, nb, s, d)
sequences = sequences.at[:, :, i + 1].set(_gather_tokens(new_tokens))
# (b, nb, s, d)
sequences = _flatten_samples_dim(sequences)
logprobs = _gather_beams(logprobs)
logprobs = logprobs.at[:, :, i + 1].set(_gather_tokens(new_logprobs))
logprobs = _flatten_samples_dim(logprobs)
scanned_cache = getattr(model, "scan", False)
cache = _cache_map(
lambda x: _unflatten_samples_dim(x, batch_size, beam_size),
cache, scanned_cache)
cache = _cache_map(_gather_beams, cache, scanned_cache)
cache = _cache_map(_flatten_samples_dim, cache, scanned_cache)
if cfg:
assert cache_u is not None
cache_u = _cache_map(
lambda x: _unflatten_samples_dim(x, batch_size, beam_size),
cache_u, scanned_cache
)
cache_u = _cache_map(_gather_beams, cache_u, scanned_cache)
cache_u = _cache_map(_flatten_samples_dim, cache_u, scanned_cache)
else:
assert cache_u is None
return LoopState(
cache=cache,
rng=rng_local,
sequences=sequences,
logprobs=logprobs,
cache_u=cache_u,
)
final_state = lax.fori_loop(0, seq_len, sampling_iteration, init_state)
final_logprobs = final_state.logprobs[::beam_size][:, -1].sum(axis=-1)
# return top beams and corresponding log probs
return final_state.sequences[::beam_size], final_logprobs