|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Decode autoregressive/bidirectional masked transformers. |
|
|
|
|
|
Currently, we implement MaskGIT style temperature sampling: |
|
|
|
In each step: |
|
1. Get P = model(inputs), predicted GMMs |
|
2. Get samples = sample_from(P) |
|
3. Get probs = P[samples], ie, model evaluated at samples. |
|
We use this now as a confidence metric, but we scale the probs: |
|
4. probs = probs ^ 1/choice_temperature |
|
4. set probs[already_uncovered_points] = inf, ie, we will always keep |
|
uncovered points (no resampling!) |
|
5. Now pick top K points from probs to keep for the next steps, where |
|
K = some monotonically increasing ratio of points as we go along decoding |
|
""" |
|
|
|
import dataclasses |
|
from typing import Literal |
|
|
|
from absl import logging |
|
from big_vision.models.proj.givt import givt |
|
import distrax |
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
_CONFIDENCE_OF_KNOWN_TOKENS = jnp.inf |
|
|
|
|
|
@jax.vmap |
|
def _get_per_batch_mask(arr, k): |
|
(d,) = arr.shape |
|
indices = jnp.argsort(arr) |
|
valid_indices = jnp.arange(d) < k |
|
return jnp.zeros((d,), jnp.bool_).at[indices].set(valid_indices) |
|
|
|
|
|
def _get_bottom_k_mask(arr, k): |
|
*leading, d = arr.shape |
|
arr = arr.reshape((-1, d)) |
|
mask = _get_per_batch_mask(arr, k) |
|
return mask.reshape(*leading, -1) |
|
|
|
|
|
def mask_by_random_topk(rng, mask_len, probs, temperature=1.0): |
|
"""Create a mask. |
|
|
|
Adaption of jax.random.choice where probabilities are changed by scaling with |
|
`temperature` (probs = probs ^ (1/temperature)). |
|
|
|
Additionally, this function returns a mask of tokens to mask out, which |
|
are picked to be the low confidence ones. Thus, this function is roughly |
|
equivalent to (but not exactly at edge cases such as prob = inf..): |
|
|
|
keep = jax.random.choice( |
|
rng, seq_len, |
|
shape=(seq_len - mask_len,), |
|
# NOTE: probabilities are updated with `temperature`. |
|
p=jnp.power(probs, 1/temperature), |
|
replace=False |
|
) |
|
mask = jnp.ones((seq_len,), dtype=jnp.bool_) |
|
return mask.at[..., keep].set(False) |
|
|
|
Args: |
|
rng: a PRNG key used as the random key. |
|
mask_len: the number to mask. |
|
probs: the probabilities associated with each entry. |
|
temperature: when temperature = 1.0, it's identical to jax's implementation. |
|
The larger this value is, the more random the masking is picked. |
|
|
|
Returns: |
|
A binary masking map [batch_size, seq_len]. Contains True where we should |
|
mask (at mask_len locations), and False where we should keep. |
|
""" |
|
confidence = jnp.log(probs) + temperature * jax.random.gumbel( |
|
rng, probs.shape) |
|
return _get_bottom_k_mask(confidence, mask_len) |
|
|
|
|
|
@flax.struct.dataclass |
|
class DecodeState: |
|
"""Holds decoding state data.""" |
|
|
|
rng: jax.Array |
|
|
|
step: jax.Array |
|
|
|
|
|
|
|
|
|
|
|
all_inputs_q: jax.Array |
|
|
|
uncovered_per_step: jax.Array |
|
logits_per_step: jax.Array |
|
uncond_logits_per_step: jax.Array |
|
prob_per_step: jax.Array |
|
|
|
rejection_sampling_success_per_step: jax.Array |
|
|
|
@classmethod |
|
def make( |
|
cls, |
|
initial_rng: jax.Array, |
|
all_masked_input: jax.Array, |
|
num_logits: int, |
|
num_steps: int, |
|
) -> "DecodeState": |
|
"""Creates the initial state.""" |
|
b, seq_len, c = all_masked_input.shape |
|
all_inputs_q = jnp.broadcast_to( |
|
all_masked_input, |
|
(num_steps + 1, b, seq_len, c), |
|
) |
|
return cls( |
|
initial_rng, |
|
step=jnp.array(0), |
|
all_inputs_q=all_inputs_q, |
|
uncovered_per_step=jnp.full((num_steps, b, seq_len), False, jnp.bool_), |
|
logits_per_step=jnp.full( |
|
(num_steps, b, seq_len, num_logits), jnp.nan, jnp.float32 |
|
), |
|
uncond_logits_per_step=jnp.full( |
|
(num_steps, b, seq_len, num_logits), jnp.nan, jnp.float32 |
|
), |
|
prob_per_step=jnp.full((num_steps, b, seq_len), jnp.nan, jnp.float32), |
|
rejection_sampling_success_per_step=jnp.full( |
|
(num_steps,), jnp.nan, jnp.float32 |
|
), |
|
) |
|
|
|
@property |
|
def current_inputs_q(self) -> jax.Array: |
|
"""Returns the current quantized input.""" |
|
return self.all_inputs_q[self.step, ...] |
|
|
|
@property |
|
def num_steps(self) -> int: |
|
"""Returns number of decode steps.""" |
|
return self.uncovered_per_step.shape[0] |
|
|
|
def _steps_mask(self) -> jax.Array: |
|
return jnp.arange(self.num_steps) <= self.step |
|
|
|
@property |
|
def total_uncovered(self) -> jax.Array: |
|
"""Returns the total uncovered mask up to and including current step.""" |
|
return self.uncovered_per_step.sum( |
|
axis=0, where=self._steps_mask()[:, jnp.newaxis, jnp.newaxis] |
|
).astype(jnp.bool_) |
|
|
|
def split_rng(self) -> tuple["DecodeState", jax.Array]: |
|
"""Splits of RNG for the current step.""" |
|
rng, step_rng = jax.random.split(self.rng, 2) |
|
return self.replace(rng=rng), step_rng |
|
|
|
def set_next_input(self, next_input_q: jax.Array) -> "DecodeState": |
|
"""Sets the input for the next step.""" |
|
return self._set_row("all_inputs_q", self.step + 1, next_input_q) |
|
|
|
def set_uncover_at_current_step(self, uncovered: jax.Array) -> "DecodeState": |
|
"""Sets what was uncovered after the current step.""" |
|
return self._set_row("uncovered_per_step", self.step, uncovered) |
|
|
|
def set_logits_at_current_step(self, logits: jax.Array) -> "DecodeState": |
|
return self._set_row("logits_per_step", self.step, logits) |
|
|
|
def set_uncond_logits_at_current_step( |
|
self, logits: jax.Array |
|
) -> "DecodeState": |
|
return self._set_row("uncond_logits_per_step", self.step, logits) |
|
|
|
def set_rejection_sampling_success_at_current_step( |
|
self, success: jax.Array |
|
) -> "DecodeState": |
|
return self._set_row( |
|
"rejection_sampling_success_per_step", self.step, success |
|
) |
|
|
|
def set_prob_at_current_step(self, prob: jax.Array) -> "DecodeState": |
|
return self._set_row("prob_per_step", self.step, prob) |
|
|
|
def increment_step(self) -> "DecodeState": |
|
"""Increments step.""" |
|
return self.replace(step=self.step + 1) |
|
|
|
def _set_row(self, attr_name, row_index, row_value): |
|
"""Sets one row of the variables that have shape (num_steps, ...).""" |
|
current_value = getattr(self, attr_name) |
|
_, *expected_shape = current_value.shape |
|
if row_value.shape != tuple(expected_shape): |
|
raise ValueError(f"Expected {row_value.shape} == {expected_shape}!") |
|
if row_value.dtype != current_value.dtype: |
|
raise ValueError(f"Expected {row_value.dtype} == {current_value.dtype}") |
|
new_value = current_value.at[row_index, ...].set(row_value) |
|
return self.replace(**{attr_name: new_value}) |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class MaskedGenerationConfig: |
|
"""Config for masked generation. |
|
|
|
Attributes: |
|
num_steps: Number of sampling steps. |
|
should_anneal_temperature: If given, anneal choice temperature as we go |
|
through the sampling steps. |
|
choice_temperature: Temperature for picking points. |
|
ordering: How to order to select. Supports: |
|
maskgit: Maskgit style, use P[samples] |
|
schedule: Inference mask schedule. |
|
cfg_inference_weight: CFG Inference weight. |
|
""" |
|
num_steps: int = 16 |
|
should_anneal_temperature: bool = True |
|
choice_temperature: float = 1.0 |
|
ordering: Literal["maskgit"] = "maskgit" |
|
schedule: str = "cosine" |
|
cfg_inference_weight: float = 0.0 |
|
|
|
|
|
def _assert_single_component_get_loc_scale( |
|
pdf: distrax.Distribution, rng=None, mixture=None |
|
): |
|
"""Extracts loc and scale from a single mixture GMM.""" |
|
if not isinstance(pdf, distrax.MixtureSameFamily): |
|
raise ValueError(f"Expected mixture! Got {type(pdf)}") |
|
components_d = pdf.components_distribution |
|
if isinstance(components_d, distrax.MultivariateNormalDiag): |
|
loc, scale_diag = components_d.loc, components_d.scale_diag |
|
b, s, m, _ = loc.shape |
|
if mixture is None: |
|
assert rng is not None |
|
|
|
mixture = pdf.mixture_distribution.sample(seed=rng) |
|
mixture = jax.nn.one_hot(mixture, num_classes=m, axis=-1) |
|
assert mixture.shape == (b, s, m), (mixture.shape, loc.shape) |
|
loc = (loc * mixture[..., None]).sum(-2) |
|
scale_diag = (scale_diag * mixture[..., None]).sum(-2) |
|
return loc, scale_diag, mixture |
|
else: |
|
loc, scale = components_d.loc, components_d.scale |
|
if loc.shape[-1] != 1 or scale.shape[-1] != 1: |
|
raise ValueError(f"Expected one mixture! {loc.shape}/{scale.shape}") |
|
return loc[..., 0], scale[..., 0], None |
|
|
|
|
|
class CFGDensity: |
|
"""Helper to get probability and samples via CFG.""" |
|
|
|
pdf_c: distrax.Distribution |
|
pdf_u: distrax.Distribution |
|
w: float |
|
simple: distrax.Distribution |
|
fac: jax.Array |
|
|
|
def __init__( |
|
self, |
|
pdf_c: distrax.Distribution, |
|
pdf_u: distrax.Distribution, |
|
w: float, |
|
rng: jax.Array, |
|
) -> None: |
|
loc_c, scale_c, mixture = _assert_single_component_get_loc_scale(pdf_c, rng) |
|
|
|
loc_u, scale_u, _ = _assert_single_component_get_loc_scale( |
|
pdf_u, rng, mixture=mixture |
|
) |
|
|
|
|
|
|
|
loc_simple = loc_c |
|
scale_simple = jnp.stack([scale_c, scale_u], -1).max(-1) * 2 |
|
self.simple = distrax.Normal(loc_simple, scale_simple) |
|
|
|
self.pdf_c = distrax.Normal(loc_c, scale_c) |
|
self.pdf_u = distrax.Normal(loc_u, scale_u) |
|
self.w = w |
|
|
|
assert loc_c.ndim == 3, loc_c.shape |
|
points = loc_c[jnp.newaxis, ...] + jnp.linspace(-10, 10, 1001).reshape( |
|
-1, 1, 1, 1 |
|
) |
|
p_at_c, _ = self._unnormalized_p(points) |
|
|
|
self.fac = jnp.max(p_at_c / self.simple.prob(loc_c), axis=0) |
|
jax.debug.print("🎲 CFG {fac}", fac=self.fac.mean()) |
|
|
|
def _unnormalized_p(self, x): |
|
w = self.w |
|
logp_cfg = (1 + w) * self.pdf_c.log_prob(x) - w * self.pdf_u.log_prob(x) |
|
return jnp.exp(logp_cfg), logp_cfg |
|
|
|
def rejection_sample( |
|
self, |
|
seed: jax.Array, |
|
max_samples: int = 1_000, |
|
) -> tuple[jax.Array, jax.Array]: |
|
"""Rejection sampling, try `max_samples`, take first match.""" |
|
rng_sample, rng_uni = jax.random.split(seed, 2) |
|
|
|
xs = self.simple.sample(seed=rng_sample, sample_shape=(max_samples,)) |
|
facq = self.fac * self.simple.prob(xs) |
|
ys = jax.random.uniform(rng_uni, shape=facq.shape, minval=0.0, maxval=facq) |
|
|
|
|
|
p, _ = self._unnormalized_p(xs) |
|
mask = ys < p |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmask = jnp.cumsum(mask, axis=0).astype(jnp.bool_) |
|
shifted_cmask = jnp.pad( |
|
cmask, [(1, 0), (0, 0), (0, 0), (0, 0)], constant_values=False |
|
)[:-1] |
|
assert shifted_cmask.shape == mask.shape |
|
keep = jnp.logical_and(cmask, jnp.logical_not(shifted_cmask)) |
|
|
|
|
|
sample = jnp.where(keep, xs, 0).sum(0) |
|
|
|
|
|
ok = mask.sum(0) > 0 |
|
|
|
sample = jnp.where( |
|
ok, sample, self.pdf_c.sample(seed=rng_sample) |
|
) |
|
return sample, ok.mean() * 100 |
|
|
|
def sample( |
|
self, |
|
seed: jax.Array, |
|
max_samples: int = 1_000, |
|
) -> jax.Array: |
|
result, ok = self.rejection_sample(seed, max_samples) |
|
jax.debug.print("Debug ok={ok}%", ok=ok) |
|
return result |
|
|
|
|
|
def prob(self, xs: jax.Array) -> jax.Array: |
|
p, _ = self._unnormalized_p(xs) |
|
return p |
|
|
|
def log_prob(self, xs: jax.Array) -> jax.Array: |
|
_, lp = self._unnormalized_p(xs) |
|
return lp |
|
|
|
|
|
def decode_masked( |
|
rng: jax.Array, |
|
labels: jax.Array, |
|
seq_len: int, |
|
feature_dim: int, |
|
model: givt.Model, |
|
variables: flax.core.FrozenDict, |
|
config: MaskedGenerationConfig, |
|
) -> DecodeState: |
|
"""Implements an masked bidirectional sampling loop. |
|
|
|
This function implements the loop from the docstring. |
|
|
|
Args: |
|
rng: RNG, only required if sampling. |
|
labels: Shape (b,), labels per batch. Determines batch size. |
|
seq_len: How many tokens to sample per batch. |
|
feature_dim: Output dimension of the VAE, i.e., number of channels, `c`. |
|
model: GIVT model to sample from. |
|
variables: Variables of the model. |
|
config: Configures style. |
|
|
|
Returns: |
|
Final state. |
|
""" |
|
logging.info("Masked Generation Config:\n%s", config) |
|
|
|
if model.style != "masked": |
|
raise ValueError(f"Need masked model! Got `{model.style}`.") |
|
|
|
(b,) = labels.shape |
|
all_masked_input = jnp.zeros((b, seq_len, feature_dim)) |
|
init_state = DecodeState.make( |
|
rng, |
|
all_masked_input, |
|
num_logits=model.num_logits, |
|
num_steps=config.num_steps, |
|
) |
|
|
|
def loop_cond_fn(state: DecodeState): |
|
return state.step < state.num_steps |
|
|
|
def tokens_to_logits(tokens, input_mask, drop_labels=None): |
|
return model.apply( |
|
variables, |
|
tokens, |
|
labels=labels, |
|
|
|
input_mask=input_mask, |
|
drop_labels=drop_labels, |
|
method="decode", |
|
) |
|
|
|
def loop_body_fn(state: DecodeState) -> DecodeState: |
|
|
|
unknown = jnp.logical_not(state.total_uncovered) |
|
|
|
|
|
|
|
ratio = (state.step + 1) / config.num_steps |
|
|
|
|
|
mask_ratio = givt.apply_mask_schedule(ratio, method=config.schedule) |
|
mask_len = jnp.floor(seq_len * mask_ratio).reshape(1, 1) |
|
num_unknown = jnp.sum(unknown, axis=-1, keepdims=True) |
|
mask_len = jnp.maximum( |
|
0, |
|
|
|
|
|
|
|
jnp.minimum(num_unknown - 1, mask_len)) |
|
|
|
|
|
logits = tokens_to_logits(state.current_inputs_q, unknown) |
|
|
|
state = state.set_logits_at_current_step(logits) |
|
|
|
pdf = model.get_pdf(logits) |
|
state, sample_rng = state.split_rng() |
|
if config.cfg_inference_weight > 0: |
|
drop_all_labels = jnp.full((b,), True, jnp.bool_) |
|
logits_uncond = tokens_to_logits( |
|
state.current_inputs_q, unknown, drop_labels=drop_all_labels |
|
) |
|
state = state.set_uncond_logits_at_current_step(logits_uncond) |
|
pdf_uncond = model.get_pdf(logits_uncond) |
|
state, cfg_rng = state.split_rng() |
|
pdf = CFGDensity( |
|
pdf_c=pdf, |
|
pdf_u=pdf_uncond, |
|
w=config.cfg_inference_weight, |
|
rng=cfg_rng, |
|
) |
|
sample, rejection_sampling_success = pdf.rejection_sample(sample_rng) |
|
state = state.set_rejection_sampling_success_at_current_step( |
|
rejection_sampling_success |
|
) |
|
else: |
|
sample = pdf.sample(seed=sample_rng) |
|
|
|
|
|
sampled = jnp.where(unknown[:, :, None], sample, state.current_inputs_q) |
|
assert sampled.shape == (b, seq_len, feature_dim), ( |
|
sampled.shape, |
|
b, |
|
seq_len, |
|
feature_dim, |
|
) |
|
|
|
prob = pdf.prob(sampled) |
|
if model.multivariate: |
|
assert prob.ndim == 2 |
|
elif model.per_channel_mixtures or config.cfg_inference_weight > 0: |
|
|
|
|
|
|
|
|
|
prob = prob.prod(-1) |
|
state = state.set_prob_at_current_step(prob) |
|
|
|
if config.ordering == "maskgit": |
|
ordering = jnp.where(unknown, prob, _CONFIDENCE_OF_KNOWN_TOKENS) |
|
else: |
|
raise NotImplementedError(config.ordering) |
|
|
|
assert ordering.shape == (b, seq_len), (ordering.shape, b, seq_len) |
|
|
|
temp = config.choice_temperature |
|
if config.should_anneal_temperature: |
|
temp *= (1. - ratio) |
|
|
|
|
|
|
|
|
|
state, choice_rng = state.split_rng() |
|
masking = mask_by_random_topk(choice_rng, mask_len, ordering, temp) |
|
assert masking.shape == (b, seq_len) |
|
masking = jnp.where(mask_len == 0, jnp.zeros_like(masking), masking) |
|
|
|
|
|
|
|
sampled = jnp.where(masking[:, :, None], jnp.zeros_like(sampled), sampled) |
|
|
|
|
|
|
|
|
|
next_uncover = jnp.logical_and(unknown, jnp.logical_not(masking)) |
|
assert next_uncover.shape == (b, seq_len), (next_uncover.shape, b, seq_len) |
|
state = state.set_uncover_at_current_step(next_uncover) |
|
state = state.set_next_input(sampled) |
|
return state.increment_step() |
|
|
|
return jax.lax.while_loop(loop_cond_fn, loop_body_fn, init_state) |
|
|