# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """gemma reimplementation for big_vision. We follow this einsum axis naming convention: B: batch T: query length S: k/v length N: num query heads K: num k/v heads G: num query heads per k/v head H: head dim D: d_model ("features") Example Colab using the models via the PaliGemma decoding logic: (internal link) Doc locating the variable initializers in the original code and validating them: (internal link) """ from big_vision.models import common import big_vision.utils as u import einops import flax import flax.linen as nn import jax import jax.numpy as jnp import ml_collections import numpy as np import orbax.checkpoint def get_config(variant): """Returns config for specified gemma variant.""" if variant == "gemma_2b": return ml_collections.ConfigDict( dict( variant=variant, width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256, norm_eps=1e-6, vocab_size=256_128, scan=True, remat_policy="nothing_saveable", ) ) if variant == "gemma_7b": return ml_collections.ConfigDict( dict( variant=variant, width=3072, depth=28, mlp_dim=24_576, num_heads=16, num_kv_heads=16, head_dim=256, norm_eps=1e-6, vocab_size=256_128, scan=True, remat_policy="nothing_saveable", ) ) raise ValueError(f"Unknown variant: {variant}") def _apply_rope(x, *, positions, max_wavelength=10_000): """Applies RoPE positions [B, L] to x [B, L, H, D].""" freq_exponents = (2. / x.shape[-1]) * jnp.arange(x.shape[-1] // 2) timescale = (max_wavelength ** freq_exponents) radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] # radians.shape = [...,L,1,d=D/2] sin, cos = jnp.sin(radians), jnp.cos(radians) x1, x2 = jnp.split(x, 2, axis=-1) res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) return res def _update_kv_cache(module, k, v, cache_size, cache_dtype): """Updates KV cache and returns its current contents.""" initialized = module.has_variable("cache", "idx") batch_size, update_len, num_heads, head_dim = k.shape cache_dtype = cache_dtype or k.dtype # Idx of which cache row to update next is the same for all examples, so that # it allows to update with dynamic_update_slice. But in order to keep things # nicely partitioned we store it with leading batch dimension and use only # the first entry. idx = module.variable("cache", "idx", jnp.zeros, (batch_size,), jnp.int32) kv_shape = (batch_size, cache_size, num_heads, head_dim) k_cache = module.variable( "cache", "k_cache", jnp.zeros, kv_shape, cache_dtype) v_cache = module.variable( "cache", "v_cache", jnp.zeros, kv_shape, cache_dtype) if initialized: # write k, v in the next cache position. assert update_len == 1, update_len # Note: idx is the same for all examples. Use value from example 0. indices = (0, idx.value[0], 0, 0) k_cache.value = jax.lax.dynamic_update_slice( k_cache.value, k.astype(cache_dtype), indices) v_cache.value = jax.lax.dynamic_update_slice( v_cache.value, v.astype(cache_dtype), indices) idx.value = idx.value + 1 else: # init cache with k, v after padding to cache_size. prefill_len = k.shape[1] pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) k_cache.value = jnp.pad(k.astype(cache_dtype), pad_width) v_cache.value = jnp.pad(v.astype(cache_dtype), pad_width) idx.value = idx.value + prefill_len return k_cache.value.astype(k.dtype), v_cache.value.astype(v.dtype) def trunc_norm_init(in_axis, out_axis, batch_axis): return nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", in_axis=in_axis, out_axis=out_axis, batch_axis=batch_axis) class Einsum(nn.Module): shape: tuple[int, ...] w_init: nn.initializers.Initializer = nn.initializers.zeros_init() @nn.compact def __call__(self, eqn, x): w = self.param("w", self.w_init, self.shape) return jnp.einsum(eqn, x, w) class RMSNorm(nn.Module): @nn.compact def __call__(self, x): scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) normed_inputs = normed_inputs * (1 + scale) return normed_inputs class Embedder(nn.Module): """Embedder module.""" vocab_size: int embed_dim: int def setup(self): self.input_embedding_table = self.param( "input_embedding", nn.initializers.variance_scaling( scale=1.0, mode="fan_in", distribution="normal", in_axis=1, out_axis=0,), (self.vocab_size, self.embed_dim), ) def encode(self, x): x = self.input_embedding_table[(x,)] x *= jnp.sqrt(self.embed_dim).astype(x.dtype) return x def decode(self, x): return jnp.dot(x, self.input_embedding_table.T) class Attention(nn.Module): """Attention module.""" num_heads: int num_kv_heads: int features: int head_dim: int cache_dtype: str | None = None def setup(self): if self.num_kv_heads == self.num_heads: self.qkv_einsum = Einsum( shape=(3, self.num_heads, self.features, self.head_dim), w_init=trunc_norm_init( in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()), ) else: # MQA self.q_einsum = Einsum( shape=(self.num_heads, self.features, self.head_dim), w_init=trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()), ) self.kv_einsum = Einsum( shape=(2, self.num_kv_heads, self.features, self.head_dim), w_init=trunc_norm_init( in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()), ) self.attn_vec_einsum = Einsum( shape=(self.num_heads, self.head_dim, self.features), w_init=trunc_norm_init(in_axis=(0, 1), out_axis=(2,), batch_axis=()), ) @nn.compact def __call__(self, x, positions, attn_mask, decode, deterministic=True): if self.num_kv_heads == self.num_heads: q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) else: q = self.q_einsum("BTD,NDH->BTNH", x) k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) q = _apply_rope(q, positions=positions) q *= self.head_dim**-0.5 k = _apply_rope(k, positions=positions) if decode: k, v = _update_kv_cache(self, k, v, cache_size=attn_mask.shape[-1], cache_dtype=self.cache_dtype) q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k) logits = logits.astype(jnp.float32) if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): raise ValueError( f"Attention mask with shape {attn_mask.shape} but shapes for q and k " f"are: {q.shape} and {k.shape}" ) # big_neg = jnp.finfo(logits.dtype).min big_neg = -2.3819763e38 # See gemma/modules.py masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) probs = jax.nn.softmax(masked_logits, axis=-1).astype(k.dtype) encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") attn_output = self.attn_vec_einsum("BTNH,NHD->BTD", encoded) return attn_output class FeedForward(nn.Module): """Feed forward module.""" features: int hidden_dim: int @nn.compact def __call__(self, x): w_gating = self.param( "gating_einsum", trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()), ((2, self.features, self.hidden_dim)), ) ff_gate = jnp.dot(x, w_gating[0]) gate_value = nn.gelu(ff_gate) ff1 = jnp.dot(x, w_gating[1]) activations = gate_value * ff1 w_linear = self.param( "linear", trunc_norm_init(in_axis=(0,), out_axis=(1,), batch_axis=()), (self.hidden_dim, self.features), ) outputs = jnp.dot(activations, w_linear) return outputs class Block(nn.Module): """Transformer block.""" num_heads: int num_kv_heads: int embed_dim: int head_dim: int hidden_dim: int dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () cache_dtype: str | None = None def setup(self): self.pre_attention_norm = RMSNorm() self.attn = Attention( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, features=self.embed_dim, head_dim=self.head_dim, cache_dtype=self.cache_dtype, ) self.pre_ffw_norm = RMSNorm() self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim) if self.dropout: self.drop = nn.Dropout(self.dropout, self.dropout_bdims) else: self.drop = lambda x, _: x def __call__(self, x, unused_scan_arg, positions, attn_mask, decode, deterministic=True): x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) inputs_normalized = self.pre_attention_norm(x) attn_output = self.attn(inputs_normalized, positions, attn_mask, decode, deterministic) attn_output = self.drop(attn_output, deterministic) attn_output += x residual = attn_output attn_output = self.pre_ffw_norm(attn_output) outputs = self.mlp(attn_output) outputs = self.drop(outputs, deterministic) outputs = residual + outputs return outputs, unused_scan_arg class Model(nn.Module): """gemma model.""" variant: str width: int depth: int mlp_dim: int num_heads: int num_kv_heads: int head_dim: int norm_eps: float vocab_size: int dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. cache_dtype: str | None = None # TODO: Wire this in all places needed so that the model can be # run with different activation dtype. For now only float32 runs. embed_dtype: str = "float32" scan: bool = False remat_policy: str = "none" @nn.compact def __call__( self, tokens, *, embedded_prefix=None, embed_only=False, pre_logits=None, positions=None, mask=None, decode=False, deterministic=True, ): """Embed only, or complete forward pass. Args: tokens: Embedded, then and appended to `embedded_prefix`. Can be None. embedded_prefix: Optional prefix that is already embedded. embed_only: Whether to compute embeddings only. pre_logits: If present computes logits from pre_logits and returns. positions: Optional `[B, T]` allows to specify the absolute position of the tokens. mask: Optional attention mask `[B, T, S]`. decode: Whether to use kv-cache. Caller must pass masks and positions. deterministic: Forwarded to all dropout layers. Returns: If `embed_only=False`, then `(logits, out)` will be returned. If `embed_only=True`, then the embeddings will be returned. """ out = {} embedder = Embedder( vocab_size=self.vocab_size, embed_dim=self.width, name="embedder") if pre_logits is not None: x = out["pre_logits"] = pre_logits logits = out["logits"] = embedder.decode(x) return logits, out x = [] if embedded_prefix is not None: x.append(embedded_prefix) if tokens is not None: x.append(embedder.encode(tokens)) x = jnp.concatenate(x, axis=-2) x = x.astype(self.embed_dtype) batch_size, seq_len, width = x.shape if embed_only: return x if decode: assert positions is not None and mask is not None, ( "Must explicitly pass positions and mask for decoding.") if positions is None: positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) if mask is None: mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) if mask.ndim == 3: mask = mask[:, None, :, :] cache_size = max(seq_len, mask.shape[-1]) assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape if self.remat_policy == "none": block_cls = Block else: block_cls = nn.remat( Block, prevent_cse=not self.scan, static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic policy=getattr(jax.checkpoint_policies, self.remat_policy), ) block_kw = dict( num_heads=self.num_heads, head_dim=self.head_dim, num_kv_heads=self.num_kv_heads, embed_dim=width, hidden_dim=self.mlp_dim, dropout=self.dropout, dropout_bdims=self.dropout_bdims, cache_dtype=self.cache_dtype, ) layers = self.scope.push("layers") if self.scan: blocks = [nn.scan( block_cls, # cache has axis 1 since we want leading dimension to be batch size. variable_axes={"params": 0, "cache": 1}, split_rngs={"params": True, "dropout": True}, in_axes=nn.broadcast, length=self.depth, )( parent=layers, **block_kw )] else: blocks = [ block_cls( parent=layers.push(str(layer)), **block_kw, ) for layer in range(self.depth) ] unused_scan_arg = () for block in blocks: x, unused_scan_arg = block( x, unused_scan_arg, positions, mask, decode, deterministic) assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. out["encoded"] = x x = RMSNorm(name="final_norm")(x) out["pre_logits"] = x x = embedder.decode(x) out["logits"] = x return x, out _ORBAX_INITS = {} _BV_INITS = {} def _load_orbax(path): """Loads and coverts Orbax gemma checkpoint.""" checkpointer = orbax.checkpoint.PyTreeCheckpointer() params = checkpointer.restore(path) params = flax.traverse_util.unflatten_dict(params, sep="/")["transformer"] n = sum(1 for k in params if k.startswith("layer_")) params["layers"] = jax.tree.map( lambda *xs: np.stack(xs), *(params.pop(f"layer_{i}") for i in range(n)) ) mlp = params["layers"]["mlp"] mlp["gating_einsum"] = mlp["gating_einsum"].pop("w") mlp["linear"] = mlp["linear"].pop("w") return params def _del_pad_rows(params): """Some checkpoints have 128 unused padding tokens.""" emb = params["embedder"]["input_embedding"] assert emb.shape[0] == 256_128 params["embedder"]["input_embedding"] = np.asarray(emb)[:256_000] return params def load(init_params, init_file, model_cfg=None, dont_load=()): """Loads existing weights.""" model_cfg = model_cfg or {} variant = model_cfg.get("variant", "gemma_2b") init_variant = f"{init_file} {variant}" if init_variant in _ORBAX_INITS: params = _del_pad_rows(_load_orbax(_ORBAX_INITS[init_variant])) elif init_variant in _BV_INITS: params = _del_pad_rows(u.load_params(_BV_INITS[init_variant])) else: params = u.load_params(init_file) def extend_rows(emb1, target_rows): if (missing_rows := target_rows - emb1.shape[0]) == 0: return emb1 assert missing_rows > 0, "You're asking to shrink vocab?!" new_rows = np.random.randn(missing_rows, emb1.shape[1]) new_rows = (new_rows * 0.02).astype(emb1.dtype) return np.r_[np.asarray(emb1), new_rows] if "vocab_size" in model_cfg: params["embedder"]["input_embedding"] = extend_rows( params["embedder"]["input_embedding"], model_cfg["vocab_size"], ) return common.merge_params(params, init_params, dont_load)