# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Prediction functions for PaliGemma.""" import collections import functools from big_vision.pp import registry import big_vision.utils as u import einops import jax import jax.numpy as jnp import numpy as np P = jax.sharding.PartitionSpec # pylint: disable=missing-function-docstring def get_all(model): """Returns `predict_fns` for evaluators.""" fns = { "logits": _logits, "image_avg_repr": _image_avg_repr, "decode": _decode, "decode_with_logp": _decode_with_logp, "beam_decode": _beam_decode, } return {name: functools.partial(fn, model=model) for name, fn in fns.items()} def _logits(train_state, batch, *, model): images, text, mask = batch["image"], batch["text"], batch["mask_ar"] text_logits, out = model.apply( {"params": train_state["params"]}, images, text[:, :-1], mask[:, :-1], ) return text_logits, out def _image_avg_repr(train_state, batch, *, model, key="img/pre_logits"): zimg, out = model.apply( {"params": train_state["params"]}, image=batch["image"], method=model.embed_image, ) if key: zimg = u.tree_get(out, key) # At this point, zimg is a (batch of) sequence of image tokens, because we # assume the model is a vit with "none" head. This predict-fn is for fewshot # evaluator, so we need to turn it into reasonably-sized vector -> avg. zimg = jnp.mean(zimg, axis=range(1, zimg.ndim - 1)) return zimg, out def _decode_with_logp( train_state, batch, *, model, devices, max_decode_len, eos_token, best_of_n=1, sampler="greedy", replicate_out=False, eos_look_behind=0): """Sample token continuations to the input sequences.""" mesh = jax.sharding.Mesh(devices, ("devices",)) replicate_sharding = jax.sharding.NamedSharding(mesh, P()) out_sharding = jax.sharding.NamedSharding( mesh, P() if replicate_out else P("devices") ) # Prefill the model cache and generate logits for first token. logits, cache = jax.jit( _prefill_cache, out_shardings=out_sharding, static_argnames=("model", "max_decode_len"), )( train_state["params"], { "image": batch["image"], "text": batch["text"], "mask_input": batch["mask_input"], "mask_ar": batch["mask_ar"], }, model=model, max_decode_len=max_decode_len, ) # Mask indicating real examples. False if example is used to pad the batch. mask = batch["_mask"] # Repeat example in case we are picking the best of n. logits, cache, mask = jax.jit( _bon_repeat, static_argnames=("n",) )((logits, cache, mask), n=best_of_n) decode_sample_output = jax.jit( _decode_sample_output, static_argnames=("max_decode_len", "sampler"), ) decode_early_stop = jax.jit( _decode_early_stop, out_shardings=replicate_sharding, static_argnames=("eos_token",), ) extend_cache = jax.jit( _extend_cache, donate_argnums=1, static_argnames=("model",), ) # Keep sampling tokens from last logits until EOS or max_decode_len. state = None # Setting `eos_look_behind>0` removes blocking transfer with small batches. stops = collections.deque(maxlen=1 + eos_look_behind) for idx in range(max_decode_len): tokens, state = decode_sample_output( state, logits, max_decode_len=max_decode_len, sampler=sampler ) if idx + 1 >= max_decode_len: break stops.append(decode_early_stop(state, mask, eos_token=eos_token)) if len(stops) == stops.maxlen and jax.device_get(stops[0]): break # Compute logits for next token logits, cache = extend_cache( train_state["params"], cache, tokens, model=model ) # Select the best of n sample for each example. _, tokens, logp = jax.jit( _bon_select, out_shardings=out_sharding, static_argnames=("n", "eos_token"), )(state, n=best_of_n, eos_token=eos_token) return tokens, logp def _decode(train_state, batch, **kwargs): tokens, _ = _decode_with_logp(train_state, batch, **kwargs) return tokens def _bon_repeat(tree, *, n): return jax.tree.map(lambda x: jnp.repeat(x, n, axis=0), tree) def _compute_score(tokens, logp, eos_token): """Compute log-probability of each sequence up to first eos (including it).""" seqlen = jnp.sum(jnp.cumsum(tokens == eos_token, axis=-1) == 0, axis=-1) + 1 token_mask = jnp.arange(tokens.shape[-1]) < seqlen[..., None] scores = jnp.sum(logp * token_mask, axis=-1) return scores def _bon_select(state, *, n, eos_token): """Pick the sampled sequence with the highest likelihood for each example.""" (_, tokens, logp) = state # Filter state to only keep the best of each example. scores = _compute_score(tokens, logp, eos_token) scores = einops.rearrange(scores, "(b n) -> b n", n=n) state = jax.tree.map( lambda x: einops.rearrange(x, "(b n) l -> b n l", n=n), state) best_indices = jnp.argmax(scores, -1) # [b] state = jax.tree.map( lambda x: jnp.take_along_axis(x, best_indices[:, None, None], axis=1), state) state = jax.tree.map(lambda x: x[:, 0], state) return state def _decode_sample_output(state, logits, *, max_decode_len, sampler): if state is None: # Decode state keeps track of sampled tokens and their logp. bs = logits.shape[0] seqlen = jnp.zeros((bs, 1), dtype=jnp.int32) tokens = jnp.zeros((bs, max_decode_len), dtype=jnp.int32) logp = jnp.zeros((bs, max_decode_len), dtype=logits.dtype) else: (seqlen, tokens, logp) = state # Sample tokens. sampled_tokens, sampled_logp = _sample_logits(logits, sampler=sampler) # Update state with sampled outputs. new_len = seqlen + 1 new_tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens) new_logp = _put_along_last_axis(logp, seqlen, sampled_logp) new_state = (new_len, new_tokens, new_logp) return sampled_tokens, new_state def _decode_early_stop(state, mask, *, eos_token): (seqlen, tokens, unused_logp) = state token_mask = jnp.arange(tokens.shape[-1])[None, :] < seqlen has_eos = jnp.any(jnp.logical_and(tokens == eos_token, token_mask), axis=-1) done = jnp.logical_or(has_eos, jnp.logical_not(mask)) return jnp.all(done) def _put_along_last_axis(arr, indices, values): """Like np.put_along_axis(..., axis=-1), since jax is missing it.""" assert arr.ndim == indices.ndim == values.ndim, ( arr.ndim, indices.ndim, values.ndim) onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype) put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot) put_values = jnp.einsum("...i,...in->...n", values, onehot) return jnp.where(put_mask, put_values, arr) def _prefill_cache(params, batch, *, model, max_decode_len): """Initialize the model cache for decoding with the prompts.""" variables = {"params": params} (x, input_mask, mask_ar), _ = model.apply( variables, batch["image"], batch["text"], input_mask=batch["mask_input"], mask_ar=batch["mask_ar"], method=model.embed_image_and_text) last_logits, variables = model.apply( variables, x, input_mask, mask_ar, cache_size=x.shape[1] + max_decode_len, method=model.prefill_cache, mutable=("cache",)) return last_logits, variables["cache"] def _extend_cache(params, cache, tokens, *, model): """Extend the model cache for decoding with one token per sequence.""" variables = {"params": params, "cache": cache} x, _ = model.apply(variables, tokens, method=model.embed_text) last_logits, variables = model.apply( variables, x, method=model.extend_cache, mutable=("cache",)) return last_logits, variables["cache"] def _sample_logits(logits, sampler): """Returns a sampled token and its logp from logits.""" # Note: Consider making it possible for evaluators to pass rng seed to # decode functions. For now generate it from jax.lax and avoid evaluators # having to deal with it. rng = jax.random.PRNGKey( jax.lax.rng_uniform(0, np.iinfo(np.int32).max, tuple())) # Use Registry to support specifying things like: # "greedy", "nucleus(0.2)", "temperature(t=1.0)" sampled_tokens = registry.Registry.lookup("paligemma_sampler." + sampler)( logits=logits, rng=rng) # Find the log probability (normalized logits) of selected tokens. sampled_logp = jnp.take_along_axis( jax.nn.log_softmax(logits, axis=-1), sampled_tokens[..., None], -1)[..., 0] return sampled_tokens, sampled_logp @registry.Registry.register("paligemma_sampler.greedy") def _greedy_sampling(*, logits, rng): del rng return jnp.argmax(logits, axis=-1) @registry.Registry.register("paligemma_sampler.temperature") def _temperature_sampling(t, *, logits, rng): return jax.random.categorical(rng, logits / t) @registry.Registry.register("paligemma_sampler.nucleus") def _nucleus_sampling(p: float, t: float = 1.0, *, logits, rng): logits = logits / t neg_inf = np.array(-1.0e7) # Effective negative infinity. logits_sorted = jnp.sort(logits, axis=-1, descending=True) sorted_cum_probs = jnp.cumsum( jax.nn.softmax(logits_sorted, axis=-1), axis=-1) cutoff_index = jnp.sum(sorted_cum_probs < p, axis=-1, keepdims=True) cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) logits = jnp.where(logits < cutoff_logit, jnp.full_like(logits, neg_inf), logits) return jax.random.categorical(rng, logits) def _beam_decode(train_state, batch, *, model, devices, max_decode_len, eos_token, beam_size, replicate_out=False): """Beam search (greedy/top-k exploration).""" mesh = jax.sharding.Mesh(devices, ("devices",)) replicate_sharding = jax.sharding.NamedSharding(mesh, P()) out_sharding = jax.sharding.NamedSharding( mesh, P() if replicate_out else P("devices") ) # Prefill the model cache and generate logits for first token. logits, cache = jax.jit( _prefill_cache, out_shardings=out_sharding, static_argnames=("model", "max_decode_len"), )( train_state["params"], { "image": batch["image"], "text": batch["text"], "mask_input": batch["mask_input"], "mask_ar": batch["mask_ar"], }, model=model, max_decode_len=max_decode_len, ) # Mask indicating real examples. False if example is used to pad the batch. mask = batch["_mask"] beam_sample_output = jax.jit( _beam_sample_output, static_argnames=("max_decode_len", "beam_size", "eos_token"), ) beam_early_stop = jax.jit( _beam_early_stop, out_shardings=replicate_sharding, static_argnames=("eos_token",), ) extend_cache = jax.jit( _extend_cache, donate_argnums=1, static_argnames=("model",), ) # Keep sampling tokens from last logits until EOS or max_decode_len. state = None for idx in range(max_decode_len): tokens, state, cache = beam_sample_output( state, logits, cache, max_decode_len=max_decode_len, beam_size=beam_size, eos_token=eos_token) early_stop = beam_early_stop(state, mask, eos_token=eos_token) if jax.device_get(early_stop) or (idx + 1 >= max_decode_len): break # Compute logits for next token logits, cache = extend_cache( train_state["params"], cache, tokens, model=model) return jax.jit(_beam_make_output, out_shardings=out_sharding)(state) def _beam_early_stop(state, mask, eos_token): (best_tokens, best_logp, seqlen, unused_tokens, logp) = state # Scores of finalized sequences. best_scores = _compute_score(best_tokens, best_logp, eos_token) # Scores of live sequences. live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen live_scores = jnp.sum(logp * live_mask, axis=-1) live_scores = jnp.max(live_scores, axis=1) done = live_scores < best_scores return jnp.all(jnp.logical_or(done, jnp.logical_not(mask))) def _beam_make_output(state): (best_tokens, *_) = state return best_tokens[:, 0, ...] def _beam_sample_output(state, logits, cache, *, beam_size, max_decode_len, eos_token): assert logits.shape[1] == 1 logits = jax.nn.log_softmax(logits[:, 0, :]) # Normalize logits if state is None: bs = logits.shape[0] # Beam decode state keeps track of: # A) Best sampled output for each example. At initialization these have # shape[1]=0, but end up with shape[1]=1 after first call. best_tokens = jnp.zeros((bs, 0, max_decode_len), dtype=jnp.int32) best_logp = jnp.zeros((bs, 0, max_decode_len), dtype=logits.dtype) # B) N candidate sequences for each example. At initialization these have # beam_size=1, but end up with correct beam_size when expanded. seqlen = jnp.zeros((bs, 1, 1), dtype=jnp.int32) tokens = jnp.zeros((bs, 1, max_decode_len), dtype=jnp.int32) logp = jnp.zeros((bs, 1, max_decode_len), dtype=logits.dtype) else: (best_tokens, best_logp, seqlen, tokens, logp) = state bs = logits.shape[0] // beam_size assert best_tokens.shape[0] == bs # Reshape cache to [example, candidate, ...]. # Note: on first call the number of candidates is 1. Later it is beam_size. cache, logits = jax.tree.map( lambda x: einops.rearrange(x, "(b n) ... -> b n ...", b=bs), (cache, logits)) # Consider a live sequence could end now and update the best finished # sequences so far for each example. This strategy is found in some beam # implementations such as in praxis. # The code below also adjusts the best shape[1]=0 -> 1 during first call. eos_tokens = jnp.array(eos_token)[None, None, None] new_tokens = _put_along_last_axis(tokens, seqlen, eos_tokens) new_logp = _put_along_last_axis(logp, seqlen, logits[:, :, eos_token, None]) best_tokens = jnp.concatenate([best_tokens, new_tokens], axis=1) best_logp = jnp.concatenate([best_logp, new_logp], axis=1) best_scores = _compute_score(best_tokens, best_logp, eos_token=eos_token) _, top_indices = jax.lax.top_k(best_scores, k=1) best_tokens = jnp.take_along_axis(best_tokens, top_indices[..., None], axis=1) best_logp = jnp.take_along_axis(best_logp, top_indices[..., None], axis=1) # To find the next best N live candidates we expand each candidate and keep # the best N (ignoring EOS tokens). In this case we expand into (N+1) # candidates and set their likelihood to "-inf" (if EOS) after the fact. live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen live_scores = jnp.sum(logp * live_mask, axis=-1) topk_logits, topk_tokens = jax.lax.top_k(logits, beam_size+1) scores = live_scores[..., None] + topk_logits scores = jnp.where( topk_tokens != eos_token, scores, jnp.finfo(scores.dtype).min) # From the N*(N+1) candidates find the top N for each example. topk_logits, topk_tokens, scores = jax.tree.map( lambda x: einops.rearrange(x, "b n np1 -> b (n np1)"), (topk_logits, topk_tokens, scores)) _, topk_indices = jax.lax.top_k(scores, k=beam_size) sampled_indices = topk_indices // (beam_size+1) sampled_tokens = jnp.take_along_axis( topk_tokens, topk_indices, axis=-1)[..., None] sampled_logits = jnp.take_along_axis( topk_logits, topk_indices, axis=-1)[..., None] # Adjust cache and state so it matches the selected top N input candidates. # This also adjusts the beam_size=1->n during first call. def take_candidates(x): one_hot_matrix = jax.nn.one_hot(sampled_indices, x.shape[1], dtype=x.dtype) return jnp.einsum("bi...,boi->bo...", x, one_hot_matrix) cache, seqlen, tokens, logp = jax.tree.map( take_candidates, (cache, seqlen, tokens, logp)) # Write the sampled tokens/logits on the reshuffled state. tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens) logp = _put_along_last_axis(logp, seqlen, sampled_logits) seqlen = seqlen + 1 state = (best_tokens, best_logp, seqlen, tokens, logp) # Reshape to [(example, candidate), ...]. sampled_tokens, cache = jax.tree.map( lambda x: einops.rearrange(x, "b n ... -> (b n) ..."), (sampled_tokens, cache)) return sampled_tokens, state, cache