|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Autorgregressive sampler for GIVT.""" |
|
|
|
import functools |
|
from typing import Any, Optional |
|
|
|
from big_vision.models.proj.givt import parallel_decode |
|
import flax |
|
from flax import linen as nn |
|
import jax |
|
from jax import lax |
|
from jax import numpy as jnp |
|
import ml_collections |
|
|
|
|
|
def _sample_gmm( |
|
gmm_pdf, |
|
*, |
|
rng, |
|
cfg_inference_weight=None, |
|
gmm_pdf_uncond=None, |
|
): |
|
"""Draw a single sample from a GMM.""" |
|
if cfg_inference_weight is not None: |
|
assert gmm_pdf_uncond is not None |
|
gmm_pdf = parallel_decode.CFGDensity( |
|
gmm_pdf, gmm_pdf_uncond, w=cfg_inference_weight, rng=rng |
|
) |
|
samples = gmm_pdf.sample(seed=rng) |
|
logprobs = gmm_pdf.log_prob(samples) |
|
if logprobs.ndim == 2: |
|
logprobs = logprobs[..., None] |
|
return samples, logprobs |
|
|
|
|
|
|
|
def _flatten_samples_dim(x): |
|
"""Flattens samples dimension into batch dimension.""" |
|
if x.ndim == 0: |
|
return x |
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) |
|
|
|
|
|
def _unflatten_samples_dim(x, batch_size, num_samples): |
|
"""Unflattens first dimension into batch and samples dimensions.""" |
|
if x.ndim == 0: |
|
return x |
|
assert batch_size * num_samples == x.shape[0] |
|
return x.reshape((batch_size, num_samples) + x.shape[1:]) |
|
|
|
|
|
def _cache_map(fn, cache, scan=False): |
|
"""Maps function over cache.""" |
|
if scan: |
|
|
|
|
|
fn_mod = lambda x: jax.lax.map(fn, x) if x.ndim > 0 else fn(x) |
|
else: |
|
fn_mod = fn |
|
|
|
frozen = isinstance(cache, flax.core.FrozenDict) |
|
if frozen: |
|
cache = flax.core.unfreeze(cache) |
|
flat_cache = flax.traverse_util.flatten_dict(cache) |
|
|
|
keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"} |
|
keyvals = jax.tree_map(fn_mod, keyvals) |
|
flat_cache.update(keyvals) |
|
new_cache = flax.traverse_util.unflatten_dict(flat_cache) |
|
if frozen: |
|
new_cache = flax.core.freeze(new_cache) |
|
return new_cache |
|
|
|
|
|
@flax.struct.dataclass |
|
class LoopState: |
|
"""Internal state of the sampling loop.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
rng: jnp.ndarray |
|
cache: Any |
|
sequences: jnp.ndarray |
|
logprobs: jnp.ndarray |
|
cache_u: Any |
|
|
|
|
|
def _create_cache( |
|
labels, |
|
model, |
|
init_sequence, |
|
params, |
|
encoded, |
|
uncond=False, |
|
): |
|
"""Creates the cache and returns initial logits.""" |
|
if uncond: |
|
assert labels is not None |
|
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_) |
|
else: |
|
drop_labels = None |
|
|
|
def init_cache(model): |
|
return model.decode( |
|
init_sequence, labels, encoded, decode=True, drop_labels=drop_labels |
|
) |
|
|
|
cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"] |
|
|
|
def prefill_cache(model): |
|
return model.prefill( |
|
labels, init_sequence.shape[0], encoded, drop_labels=drop_labels |
|
) |
|
|
|
|
|
prefill_logits, aux = nn.apply(prefill_cache, model, mutable=True)( |
|
{"params": params["params"], "cache": cache}) |
|
cache = aux["cache"] |
|
return cache, prefill_logits |
|
|
|
|
|
def generate( |
|
params: Any, |
|
seed: jax.Array, |
|
*, |
|
model: nn.Module, |
|
seq_len: int, |
|
feature_dim: int, |
|
labels: Optional[jnp.ndarray] = None, |
|
cond_image: Optional[jnp.ndarray] = None, |
|
batch_size: Optional[int] = None, |
|
config: Optional[ml_collections.ConfigDict] = None, |
|
) -> tuple[jax.Array, jax.Array]: |
|
"""Sampling loop for GIVT.""" |
|
if model.style != "ar": |
|
raise ValueError(f"Invalid style: {model.style}") |
|
if model.has_encoder != (cond_image is not None): |
|
raise ValueError("Need cond_image if and only if the model has an encoder!") |
|
|
|
assert labels is not None or batch_size, ( |
|
"Please provide either labels or batch_size.") |
|
|
|
config = config or {} |
|
config = dict(config) |
|
|
|
|
|
|
|
keep_gt = config.pop("keep_gt", None) |
|
gt = config.pop("gt", None) |
|
|
|
if isinstance(seed, int): |
|
seed = jax.random.PRNGKey(seed) |
|
|
|
beam_size = config.pop("beam_size", 1) |
|
fan_size = config.pop("fan_size", 1) |
|
|
|
if labels is not None: |
|
batch_size = labels.shape[0] |
|
|
|
labels = labels.repeat(beam_size, axis=0) |
|
|
|
|
|
init_sequence = jnp.zeros((batch_size * beam_size, seq_len, feature_dim)) |
|
init_logprobs = jnp.zeros_like(init_sequence) |
|
|
|
if cond_image is not None: |
|
|
|
def encode_cond_img(model, cond_img): |
|
return model.encode(cond_img) |
|
encoded = nn.apply(encode_cond_img, model)(params, cond_image) |
|
encoded = jnp.repeat(encoded, beam_size, axis=0) |
|
else: |
|
encoded = None |
|
|
|
cache, prefill_logits = _create_cache( |
|
labels, model, init_sequence, params, encoded |
|
) |
|
|
|
cfg_inference_weight = config.pop("cfg_inference_weight", None) |
|
if cfg_inference_weight == 0.0: |
|
cfg_inference_weight = None |
|
cfg = cfg_inference_weight is not None |
|
|
|
get_pdf = functools.partial( |
|
model.get_pdf, |
|
temperature_scales=config.pop("temp", None), |
|
temperature_probs=config.pop("temp_probs", None), |
|
) |
|
|
|
|
|
sample = functools.partial( |
|
_sample_gmm, cfg_inference_weight=cfg_inference_weight |
|
) |
|
|
|
|
|
pdf_first = get_pdf(prefill_logits) |
|
rng_first, rng = jax.random.split(seed) |
|
|
|
if cfg: |
|
assert beam_size == 1 and fan_size == 1 |
|
cache_u, prefill_logits_u = _create_cache( |
|
labels, model, init_sequence, params, encoded, uncond=True |
|
) |
|
pdf_first_u = get_pdf(prefill_logits_u) |
|
else: |
|
cache_u = None |
|
pdf_first_u = None |
|
|
|
tokens_first, logprobs_first = sample( |
|
pdf_first, rng=rng_first, gmm_pdf_uncond=pdf_first_u |
|
) |
|
init_sequence = init_sequence.at[:, 0].set(tokens_first.squeeze(axis=1)) |
|
init_logprobs = init_logprobs.at[:, 0].set(logprobs_first.squeeze(axis=1)) |
|
|
|
def tokens_to_logits(tokens, cache, uncond=False): |
|
if uncond: |
|
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_) |
|
else: |
|
drop_labels = None |
|
|
|
def decode_step(model, tokens): |
|
return model.decode(tokens, labels, encoded, |
|
decode=True, drop_labels=drop_labels) |
|
|
|
logits, aux = nn.apply(decode_step, model, mutable=True)( |
|
{"params": params["params"], "cache": cache}, tokens) |
|
return logits, aux["cache"] |
|
|
|
init_state = LoopState( |
|
cache=cache, |
|
sequences=init_sequence, |
|
logprobs=init_logprobs, |
|
rng=rng, |
|
cache_u=cache_u, |
|
) |
|
|
|
rand_top_k = config.pop("rand_top_k", False) |
|
rand_top_k_temp = config.pop("rand_top_k_temp", 1.0) |
|
|
|
assert not config, f"Sampling config is expected to be empty: {config}" |
|
|
|
def sampling_iteration(i, state): |
|
rng_sampling, rng_local = jax.random.split(state.rng) |
|
cur_tokens = state.sequences[:, i][:, None] |
|
|
|
cur_logits, cache = tokens_to_logits(cur_tokens, state.cache) |
|
|
|
|
|
cur_logits = _unflatten_samples_dim( |
|
cur_logits, batch_size, beam_size).squeeze(axis=2) |
|
|
|
|
|
cur_pdf = get_pdf(cur_logits.repeat(fan_size, axis=1)) |
|
|
|
if cfg: |
|
cur_logits_u, cache_u = tokens_to_logits( |
|
cur_tokens, state.cache_u, uncond=True |
|
) |
|
cur_logits_u = _unflatten_samples_dim( |
|
cur_logits_u, batch_size, beam_size).squeeze(axis=2) |
|
cur_pdf_u = get_pdf(cur_logits_u.repeat(fan_size, axis=1)) |
|
new_tokens, new_logprobs = sample( |
|
cur_pdf, rng=rng_sampling, gmm_pdf_uncond=cur_pdf_u |
|
) |
|
else: |
|
new_tokens, new_logprobs = sample(cur_pdf, rng=rng_sampling) |
|
cache_u = None |
|
|
|
if gt is not None: |
|
assert keep_gt is not None |
|
new_tokens = jnp.where(keep_gt[i], gt[:, i, :][:, None], new_tokens) |
|
|
|
|
|
if beam_size == fan_size == 1: |
|
sampled_tokens = new_tokens.squeeze(axis=1) |
|
sequences = state.sequences.at[:, i + 1].set(sampled_tokens) |
|
return LoopState( |
|
cache=cache, |
|
rng=rng_local, |
|
sequences=sequences, |
|
logprobs=state.logprobs, |
|
cache_u=cache_u, |
|
) |
|
|
|
|
|
logprobs = _unflatten_samples_dim(state.logprobs, batch_size, beam_size) |
|
cur_logprobs = logprobs[:, :, i] |
|
|
|
new_logprobs = new_logprobs + cur_logprobs.repeat(fan_size, axis=1) |
|
beam_logprobs = new_logprobs.sum(axis=-1) |
|
|
|
if rand_top_k: |
|
|
|
def stoc_top_k(r, x, p): |
|
return jax.random.choice(r, x, shape=(beam_size,), replace=False, p=p) |
|
|
|
index_grid = jnp.arange(beam_logprobs.shape[1], dtype=jnp.int32) |
|
|
|
index_grid = index_grid[None].repeat(beam_logprobs.shape[0], axis=0) |
|
top_k_rng, rng_local = jax.random.split(rng_local) |
|
top_k_rng = jax.random.split(top_k_rng, beam_logprobs.shape[0]) |
|
|
|
top_beam_fan_indices = jax.vmap(stoc_top_k, in_axes=(0, 0, 0))( |
|
top_k_rng, |
|
index_grid, |
|
nn.softmax(beam_logprobs / rand_top_k_temp, axis=-1)) |
|
else: |
|
_, top_beam_fan_indices = lax.top_k(beam_logprobs, k=beam_size) |
|
|
|
top_beam_indices = top_beam_fan_indices // fan_size |
|
|
|
def _gather_beams(x): |
|
if x.ndim == 0: |
|
return x |
|
|
|
|
|
|
|
|
|
expanded_indices = top_beam_indices.reshape( |
|
top_beam_indices.shape + (1,) * (x.ndim - 2)) |
|
return jnp.take_along_axis(x, expanded_indices, axis=1) |
|
|
|
def _gather_tokens(x): |
|
|
|
|
|
|
|
|
|
return jnp.take_along_axis(x, top_beam_fan_indices[..., None], axis=1) |
|
|
|
sequences = _unflatten_samples_dim(state.sequences, batch_size, beam_size) |
|
sequences = _gather_beams(sequences) |
|
sequences = sequences.at[:, :, i + 1].set(_gather_tokens(new_tokens)) |
|
|
|
sequences = _flatten_samples_dim(sequences) |
|
|
|
logprobs = _gather_beams(logprobs) |
|
logprobs = logprobs.at[:, :, i + 1].set(_gather_tokens(new_logprobs)) |
|
logprobs = _flatten_samples_dim(logprobs) |
|
|
|
scanned_cache = getattr(model, "scan", False) |
|
cache = _cache_map( |
|
lambda x: _unflatten_samples_dim(x, batch_size, beam_size), |
|
cache, scanned_cache) |
|
cache = _cache_map(_gather_beams, cache, scanned_cache) |
|
cache = _cache_map(_flatten_samples_dim, cache, scanned_cache) |
|
|
|
if cfg: |
|
assert cache_u is not None |
|
cache_u = _cache_map( |
|
lambda x: _unflatten_samples_dim(x, batch_size, beam_size), |
|
cache_u, scanned_cache |
|
) |
|
cache_u = _cache_map(_gather_beams, cache_u, scanned_cache) |
|
cache_u = _cache_map(_flatten_samples_dim, cache_u, scanned_cache) |
|
else: |
|
assert cache_u is None |
|
|
|
return LoopState( |
|
cache=cache, |
|
rng=rng_local, |
|
sequences=sequences, |
|
logprobs=logprobs, |
|
cache_u=cache_u, |
|
) |
|
|
|
final_state = lax.fori_loop(0, seq_len, sampling_iteration, init_state) |
|
final_logprobs = final_state.logprobs[::beam_size][:, -1].sum(axis=-1) |
|
|
|
|
|
return final_state.sequences[::beam_size], final_logprobs |
|
|