code-bloat-lite / main_code.py
Molbap's picture
Molbap HF Staff
Upload 2 files
d793afd verified
raw
history blame
115 kB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3n.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the \"License\");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an \"AS IS\" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
logging,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel
from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro=\"\"\"
Base class for Gemma3n outputs, with hidden states and attentions.
\"\"\"
)
class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
r\"\"\"
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
\`past_key_values\` input) to speed up sequential decoding.
image_hidden_states (\`torch.FloatTensor\`, *optional*):
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
\"\"\"
image_hidden_states: Optional[torch.FloatTensor] = None
audio_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring(
custom_intro=\"\"\"
Base class for Gemma3n causal language model (or autoregressive) outputs.
\"\"\"
)
class Gemma3nCausalLMOutputWithPast(ModelOutput):
r\"\"\"
loss (\`torch.FloatTensor\` of shape \`(1,)\`, *optional*, returned when \`labels\` is provided):
Language modeling loss (for next-token prediction).
logits (\`torch.FloatTensor\` of shape \`(batch_size, sequence_length, config.text_config.vocab_size)\`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
\`past_key_values\` input) to speed up sequential decoding.
image_hidden_states (\`torch.FloatTensor\`, *optional*):
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
\"\"\"
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
audio_hidden_states: Optional[torch.FloatTensor] = None
class Gemma3nRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if self.with_scale:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_buffer(\"weight\", torch.tensor(1.0), persistent=False)
def _norm(self, x):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = self._norm(x.float()) * self.weight.float()
return output.type_as(x)
def extra_repr(self):
return f\"{tuple(self.weight.shape)}, eps={self.eps}\"
# ==== Audio Encoder ====
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
min_timescale = 1.0
max_timescale = 1.0e4
num_timescales = self.channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
self.register_buffer(
\"inv_timescales\",
inv_timescales.float().unsqueeze(0).unsqueeze(0),
persistent=False,
)
def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
position = position.float().unsqueeze(-1)
scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return timing_signal.type(dtype)
def _relative_shift(
self,
term_bd_before_shift: torch.Tensor,
batch_size: int,
num_heads: int,
num_query_blocks: int,
query_block_size: int,
key_context_size: int,
max_span_plus_1: int,
) -> torch.Tensor:
\"\"\"Performs the relative shift.
Args:
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
(B), num_heads (N), num_query_blocks (U), query_block_size (W),
key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
Returns:
Tensor of shape [B, N, U, W, C].
\"\"\"
# term_bd_before_shift shape: [B, N, U, W, F_span]
# Target shape after shift: [B, N, U, W, C]
# Padding amount for the last dimension (F_span) to become (C + 1)
# C = key_context_size
# F_span = max_span_plus_1
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
# PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
# We only pad the last dimension on the right.
padding_tuple = (0, pad_amount_last_dim)
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
# Shape after pad: [B, N, U, W, C+1]
# Reshape for slicing (emulating JAX\'s behavior)
# [B, N, U, W * (C+1)]
term_bd_reshaped = term_bd_padded.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size * (key_context_size + 1),
)
)
# Slice to effective [B, N, U, W * C]
term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
# Reshape back to [B, N, U, W, C]
term_bd_shifted = term_bd_sliced.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
)
)
return term_bd_shifted
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
# queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
# keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
# C = W + L + R (key_context_size)
# F_span = L + R + 1 (max_span + 1)
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
_, _, key_context_size, _, _ = keys.shape
# Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
# Length is L+R+1 = self.max_span + 1
pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
0
) # Shape [1, F_span]
max_span_plus_1 = pos_indices.shape[1] # F_span
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
pos_indices, dtype=queries.dtype
) # Shape [1, F_span, self.channels]
# Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
# Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
0
) # Shape [F, N, H]
# term_ac: Query-Key content interaction
# queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
# keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
# term_bd: Query-Position interaction
# Original einsum: term_bd_unshifed = torch.einsum(\'buwnh,fnh->bnuwf\', queries, sin_emb)
# queries shape: [B, U, W, N, H]
# sin_emb shape: [F, N, H]
# Target output shape: [B, N, U, W, F]
# Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
q_permuted = queries.permute(0, 3, 1, 2, 4)
# Permute sin_emb to [N, H, F] to prepare for matmul
# sin_emb original is [F, N, H]
s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
# Reshape queries for matmul: [B, N, U*W, H]
q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
# Result: [B, N, U*W, F]
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
# Reshape to target [B, N, U, W, F]
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
max_span_plus_1,
)
# Apply relative shift to term_bd_unshifed
term_bd_shifted = self._relative_shift(
term_bd_unshifed,
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
max_span_plus_1,
) # Shape [B, N, U, W, C]
return term_ac + term_bd_shifted
class Gemma3nAudioAttention(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.chunk_size = self.config.conf_attention_chunk_size
self.max_future_horizon = self.config.conf_attention_context_right
self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
q_scale = self.head_dim**-0.5
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
self.register_buffer(\"q_scale\", (q_scale * r_softplus_0).clone().detach(), persistent=False)
lower_causal_mask = torch.tril(
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
diagonal=0,
).T
upper_causal_mask = torch.tril(
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
diagonal=self.max_past_horizon + self.max_future_horizon,
)
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
self.register_buffer(\"local_causal_valid_mask\", local_causal_valid_mask, persistent=False)
self.register_buffer(
\"softcap\",
torch.tensor(self.attention_logits_soft_cap).float(),
persistent=False,
)
def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
batch, _, *tail_shape = x.shape
left = x.new_zeros((batch, pad_left, *tail_shape))
right = x.new_zeros((batch, pad_right, *tail_shape))
x = torch.cat([left, x, right], dim=1)
return x
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
\"\"\"Turns a sequence to non overlapping blocks.
Args:
hidden_states: a tensor of [batch, time, ...].
Returns:
A tensor of [batch, num_blocks, block_size, ...], with necessary
paddings,
where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
\"\"\"
shape = hidden_states.shape
b, t = shape[:2]
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
if (padding_len := num_blocks * self.chunk_size - t) > 0:
hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
hidden_states = hidden_states.reshape(permute_dims).contiguous()
return hidden_states
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
\"\"\"Extracts temporal context for every block.
Args:
hidden_states: a tensor of [batch, time, ...].
Returns:
A tensor of [batch, num_blocks, context_size, ...], with necessary
paddings,
where context_size = block_size + left_context + right_context,
and output[:, i, ...] are x[:, start-left_context:end+right_context,
...],
start = i * block_size, end = (i + 1) * block_size.
\"\"\"
pad_left = self.max_past_horizon
# The JAX equivalent padding for signal.frame with pad_mode=\'valid\' is
# (left_context, right_context + block_size - 1) on the time dimension.
# PyTorch\'s _pad_dim1 applies padding symmetrically if only one value is given,
# or (pad_dim_start, pad_dim_end) if two are given.
# Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
# or dim 1 (time for [B,T]).
# The current pad_right calculation matches the JAX effective padding.
pad_right = self.max_future_horizon + self.chunk_size - 1
hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
frame_len = self.context_size
frame_step = self.chunk_size
# Directly use unfold without the subframe_factor logic
# x.unfold(dimension, size, step)
# dimension=1 (time dimension, assuming x is [B, T_padded, ...])
# size=frame_len (context_size)
# step=frame_step (chunk_size)
x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
# If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
# If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
# We want to match JAX\'s typical output for such operations which might be
# [B, num_blocks, frame_len, N, H] if N, H are present.
# The relative_position_embedding expects keys as [B, U, C, N, H].
# If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
# Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
# Target shape for keys in RPE: [B, U, C, N, H]
x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
return x_unfolded.contiguous()
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
# sl.Dense uses jax.numpy.einsum(\"...a,abcd->...bcd\") and jax.numpy.select()
qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
broadcast_shape = (1, 1, 1, self.head_dim)
per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
batch_size, q_time = query_states.shape[:2]
query_blocks = self._convert_to_block(query_states)
key_blocks = self._extract_block_context(key_states)
value_blocks = self._extract_block_context(value_states)
num_query_blocks = query_blocks.shape[1]
# 1. Create a mask indicating originally valid positions.
original_valid_mask = ~mask # True for valid, False for padded
# 2. Extract blocks from this validity mask.
extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
# If subframe_factor was used in _extract_block_context for a [B, T] input mask,
# the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
# batch_size and num_query_blocks are known from query_blocks.
# self.context_size is C.
if (
extracted_valid_mask_blocks.ndim == 4
and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
):
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
batch_size, num_query_blocks, self.context_size
)
# After potential reshape, ensure it\'s [B, U, C] if it was from a [B,T] mask.
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
# but for the mask case, this should hold.
if extracted_valid_mask_blocks.shape != (
batch_size,
num_query_blocks,
self.context_size,
):
raise ValueError(
\"Shape of extracted_valid_mask_blocks\"
f\" {extracted_valid_mask_blocks.shape} is not ({batch_size},\"
f\" {num_query_blocks}, {self.context_size}) after potential reshape.\"
)
# 3. Expand dimensions for broadcasting with logits and causal mask.
# Target shape for broadcasting with logits [B,N,U,W,C]
# extracted_valid_mask_blocks to [B, 1, U, 1, C]
condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
# self.local_causal_valid_mask is [W, C], True where allowed by local window.
# Expand to [1, 1, 1, W, C]
condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
# 4. Combine the two conditions.
# final_condition will be True where a key is *both* originally valid *and* causally accessible.
# Broadcasts to [B, 1, U, W, C]
final_condition_for_where = torch.logical_and(
condition_from_input_validity,
condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
)
# Embed queries and keys
logits = self.relative_position_embedding(query_blocks, key_blocks)
# Apply attention logit softcap
# Ensure softcap is on the same device as logits
softcap_val = self.softcap.to(logits.device)
logits = logits / softcap_val
logits = torch.tanh(logits)
logits = logits * softcap_val
# Apply the combined mask.
# final_condition_for_where will broadcast with logits [B,N,U,W,C]
logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
# context_vectors is adapted from jax.numpy.einsum(\"BNuwc,BucNH->BuwNH\", ...)
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
h_dim = value_blocks.shape[-1]
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
result_bmm = torch.bmm(prob_bun, v_bun)
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
context_vectors = context_vectors.reshape(
(
batch_size,
num_query_blocks * self.chunk_size,
self.num_heads,
self.head_dim,
)
)
context_vectors = context_vectors[:, :q_time]
return context_vectors
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
\"\"\"Applies Group Normalization cumulatively over the time dimension.
This layer normalizes the input by calculating the mean and variance
cumulatively over the time dimension (dim 1). The statistics are computed
over all feature dimensions (specified by \`feature_dims\` and \`num_channels\`)
for elements marked as valid by the optional \`mask\`.
If a \`mask\` is provided (True for valid, False for invalid/padded),
invalid time steps do not contribute to the statistics calculation, and
their corresponding output values are zeroed out.
Scale and bias, if enabled, are applied per-channel (last dimension).
This behavior is similar to JAX\'s \`GroupNormalization\` with \`num_groups=1\`
and \`cumulative=True\`.
\"\"\"
def __init__(
self,
num_channels: int, # Number of channels (size of the last dimension)
feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
eps: float = 1e-3,
):
super().__init__()
self.num_channels = num_channels
self.feature_dims = tuple(feature_dims)
self.eps = eps
# Scale parameter depends only on the channel dimension
self.weight = nn.Parameter(torch.ones(num_channels))
# Axes for normalization: all dimensions except Batch (0) and Time (1).
# For input [B, T, *feature_dims, C], these are dims from 2 onwards.
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
\"\"\"Applies cumulative group norm, optionally using a mask.
Args:
hidden_states: Input tensor, shape [B, T, *feature_dims, C].
Returns:
Normalized tensor with the same shape as x.
\"\"\"
expected_input_suffix = self.feature_dims + (self.num_channels,)
if hidden_states.shape[2:] != expected_input_suffix:
raise ValueError(
f\"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected\"
f\" suffix (feature_dims + num_channels) {expected_input_suffix}\"
)
input_dtype = hidden_states.dtype
# Calculations are performed in float32 for numerical stability.
calc_dtype = torch.float32
x_calc = hidden_states.to(calc_dtype)
# Prepare a broadcastable mask (\`mask_calc\`).
# If no mask is provided, treat all elements as valid
# (mask_calc is all ones).
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
# Cumulative Statistics Calculation
# 1. Sum of values over reduction axes at each time step.
sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
# 2. Cumulative sum of values over time.
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
# 3. Count of valid elements in the normalization group at each time step.
# (A \"group\" here consists of all features at a given Batch, Time).
elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
# 4. Cumulative count of valid elements over time.
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
# Avoid division by zero if all preceding elements were masked.
safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
# 5. Cumulative mean.
cum_mean = cum_sum_values / safe_cum_count_elements
# 6. Sum of squared differences from the cumulative mean.
# Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
# Using x_calc here for the difference, as cum_mean already accounts for masking.
squared_diff_from_mean = (x_calc - cum_mean).pow(2)
sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
# 7. Cumulative sum of squared differences over time.
cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
# 8. Cumulative variance.
cum_variance = cum_sum_sq_diff / safe_cum_count_elements
# Normalize the input using the calculated cumulative statistics:
# (x - E[x]) / sqrt(Var[x] + eps)
normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
# Apply affine transformation (scale and bias) if enabled.
# Scale and bias are applied per-channel (last dimension).
scale = self.weight.to(calc_dtype)
# Reshape for broadcasting: [C] -> [1, ..., 1, C]
scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
normalized_x = normalized_x * scale.view(scale_view_shape)
# Zero out outputs for time steps that were originally masked (where mask_calc is 0).
# This ensures padded/invalid positions in the input result in zero output.
final_output = normalized_x * mask_calc
return final_output.to(input_dtype)
class Gemma3nAudioSSCPConvBlock(nn.Module):
\"\"\"A single convolution block for the SubSampleConvProjection.
This block consists of a 2D convolution, followed by CumulativeGroupNorm,
and a ReLU activation. It handles manual padding for the convolution.
\"\"\"
def __init__(
self,
config: Gemma3nAudioConfig,
idx: int,
input_freq_dim: int, # Changed from input_spatial_dim
manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
):
super().__init__()
self.config = config
self.manual_padding = manual_padding
# in_channels is 1 for the first block, or C_out from previous block\'s conv
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
out_channels = self.config.sscp_conv_channel_size[idx]
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(
kernel_h,
kernel_w,
), # Kernel (kH, kW) operates on (Time, Freq_dim)
stride=(stride_h, stride_w),
padding=(0, 0), # Manual padding is used
bias=False,
)
# Calculate output frequency dimension (f_out_conv) after this convolution.
# input_freq_dim is the unpadded width (feature dimension).
# self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
self.norm = Gemma3nAudioCumulativeGroupNorm(
num_channels=out_channels, # Channels of the conv output
feature_dims=(f_out_conv,), # The frequency dimension size after conv
eps=self.config.sscp_conv_group_norm_eps,
)
self.activation = nn.ReLU()
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
# F.pad applies to last two dims: F_in then T_in
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode=\"constant\", value=0.0)
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
audio_encodings_conv = self.conv(audio_encodings_padded)
# Expected conv output shape: [B, C_out, T_out, F_out]
# Input to norm is [B, T_out, F_out, C_out]
x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
x_normed = self.norm(x_for_norm)
# Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
return self.activation(audio_encodings_normed)
class Gemma3nAudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
current_f_for_block_input = config.input_feat_size # Start with original feature dim
calculated_block_padding = []
calculated_f_out_dims = [] # Tracking frequency dimension output sizes
for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
stride_h, stride_w = config.sscp_conv_stride_size[i]
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
# JAX \'reverse_causal\' padding is (0, kernel_size - 1)
pad_t_top = 0
pad_t_bottom = kernel_h - 1
# Frequency Padding (Width for Conv2d)
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
# and the successful test configuration.
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
# to match generic JAX \'SAME\' behavior if it differs.
pad_f_left = 1
pad_f_right = 1
manual_padding_tuple = (
pad_f_left,
pad_f_right,
pad_t_top,
pad_t_bottom,
)
calculated_block_padding.append(manual_padding_tuple)
# Calculate output frequency dimension after this convolution
# This uses the actual padding applied and kernel/stride.
f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
calculated_f_out_dims.append(f_out_after_conv)
current_f_for_block_input = f_out_after_conv
self.conv_0 = Gemma3nAudioSSCPConvBlock(
idx=0,
input_freq_dim=config.input_feat_size, # Pass original feature dim
config=config,
manual_padding=calculated_block_padding[0],
)
self.conv_1 = Gemma3nAudioSSCPConvBlock(
idx=1,
input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
config=config,
manual_padding=calculated_block_padding[1],
)
final_c_out = config.sscp_conv_channel_size[-1]
final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
self.input_proj_in_features = final_c_out * final_f_out
self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
# audio_encodings is [B, T, F_in]
# Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
audio_encodings_reshaped = audio_encodings.unsqueeze(1)
x = self.conv_0(audio_encodings_reshaped)
x = self.conv_1(x)
# x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
b, c_out, t_out, f_out = x.shape
# Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
x_permuted = x.permute(0, 2, 3, 1).contiguous()
output_flattened = x_permuted.view(b, t_out, f_out * c_out)
output = self.input_proj_linear(output_flattened)
return output
class Gemma3nAudioConformerAttention(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.post_in_features = self.config.hidden_size
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.attn = Gemma3nAudioAttention(config)
self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
audio_encodings_input_to_attn = audio_encodings
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
audio_encodings_norm = self.pre_attn_norm(audio_encodings)
# Output of self.attn is [B, T, NumHeads, HeadDim]
audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
# Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
# NumHeads * HeadDim = hidden_size
b, t, num_heads, head_dim = audio_encodings_attn_out.shape
audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
audio_encodings = self.post(audio_encodings_reshaped)
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
class Gemma3nAudioConformerFeedForward(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
audio_encodings = self.pre_layer_norm(audio_encodings)
audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
audio_encodings = nn.functional.silu(audio_encodings)
audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
audio_encodings = self.post_layer_norm(audio_encodings)
return residual + (audio_encodings * self.post_layer_scale)
class Gemma3nAudioConformerLightConv1d(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
bias=False,
)
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
self.causal_padding = self.config.conf_conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
audio_encodings = self.pre_layer_norm(audio_encodings)
audio_encodings = self.linear_start(audio_encodings)
audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
# Permute for Conv1d: [B, T, D] -> [B, D, T]
audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
# Apply manual causal padding
audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
# Permute back: [B, D, T_out] -> [B, T_out, D]
audio_encodings = audio_encodings.permute(0, 2, 1)
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
audio_encodings = self.conv_norm(audio_encodings)
audio_encodings = nn.functional.silu(audio_encodings)
audio_encodings = self.linear_end(audio_encodings)
output = audio_encodings + audio_encodings_residual
return output
class Gemma3nAudioConformerBlock(nn.Module):
def __init__(self, config: Gemma3nAudioConfig):
super().__init__()
self.config = config
self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
self.attention = Gemma3nAudioConformerAttention(self.config)
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
audio_encodings = self.ffw_layer_start(audio_encodings)
audio_encodings = self.attention(audio_encodings, audio_mel_mask)
validity_mask_for_lconv = ~audio_mel_mask # True for valid
audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
audio_encodings.dtype
)
audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
audio_encodings = self.ffw_layer_end(audio_encodings)
audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
output = self.norm(audio_encodings)
return output
class Gemma3nAudioEncoder(PreTrainedModel):
\"\"\"An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture.\"\"\"
config_class = Gemma3nAudioConfig
main_input_name = \"audio_mel\"
def __init__(self, config: Gemma3nAudioConfig):
super().__init__(config)
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
) -> tuple[torch.Tensor, torch.BoolTensor]:
\"\"\"Encodes a batch of MELs.
Args:
audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
mel_bins].
Returns:
audio_encodings: a torch.Tensor of shape
\`[batch_size, self.config.audio_soft_tokens_per_image,
self.config.audio_config.hidden_size]\`
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
\"\"\"
audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
t_sub = audio_encodings.shape[1]
time_stride_product = 1
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
# Create indices for gathering from the original mask.
# These indices map to original time steps corresponding to the start of each
# receptive field in the subsampled output.
indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
# Expand indices for batch compatibility if B > 1 and indices is 1D.
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
elif (
audio_mel_mask.ndim == indices.ndim
and audio_mel_mask.shape[0] == 1
and indices.shape[0] != 1
and t_sub == indices.shape[0]
):
# Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
indices = indices.unsqueeze(0)
current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
\"\"\"
This module overrides nn.Embeddings\' forward by multiplying with embeddings scale.
\"\"\"
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_buffer(\"embed_scale\", torch.tensor(embed_scale), persistent=False)
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
class Gemma3nTextLaurelBlock(nn.Module):
\"\"\"Learned Augmented Residual Layer\"\"\"
def __init__(self, config: Gemma3nTextConfig):
super().__init__()
self.config = config
self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
return hidden_states + normed_laurel_hidden_states
class Gemma3nTextMLP(nn.Module):
def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size[layer_idx]
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
activations = self.act_fn(gate_proj)
up_proj = self.up_proj(hidden_states)
down_proj = self.down_proj(activations * up_proj)
return down_proj
def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
# normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
#
# References:
# * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
# * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
# * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
std_multiplier = std_multiplier.type(inputs.dtype)
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
cutoff_x = inputs_mean + inputs_std * std_multiplier
return nn.functional.relu(inputs - cutoff_x)
class Gemma3nTextAltUp(nn.Module):
\"\"\"Alternating Updates (AltUp)
The AltUp module wraps transformer layers. The \`predict\` step modifies the
input to the transformer layer, and the \`correct\` step propagates the output
of the transformer layer to the sparsely updated dimensions.
See more in the research paper:
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
\"\"\"
def __init__(self, config: Gemma3nTextConfig):
super().__init__()
self.config = config
self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.register_buffer(\"router_input_scale\", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
router_inputs = self.router_norm(x) * self.router_input_scale
routed = self.modality_router(router_inputs)
return torch.tanh(routed.float()).type_as(x)
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
\"\"\"Predicts the output of a layer using a trainable map.
Args:
hidden_states: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
Returns:
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` containing the predictions.
\"\"\"
modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
if self.training and self.config.altup_coef_clip is not None:
self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
# Project and then transpose all 2D matrices contained so that mulmat gives the correct result
all_coefs: torch.Tensor = (
self.prediction_coefs(modalities)
.reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
.permute(0, 1, 3, 2)
)
# permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
predictions = predictions.permute(3, 0, 1, 2) # undo the permute
predictions += hidden_states # add the original input
return predictions.contiguous().type_as(hidden_states)
def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
\"\"\"Corrects the predictions relative to the
Args:
predictions: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
activated: A 3D tensor of shape \`[batch_size, num_tokens, hidden_size]\` containing the activated inputs.
Returns:
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` correcting the original
predictions relative to the activated input embeddings.
\"\"\"
modalities = self.compute_router_modalities(activated)
innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
if self.config.altup_coef_clip is not None:
self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
# all_coefs adapted from jax.numpy.einsum(\"...p,pi->...i\", ...)
# Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
# and expand on dim1 for broadcastability
all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
corrected = torch.mul(innovation, all_coefs)
corrected += predictions # add the original input
return corrected.contiguous().type_as(activated)
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
\"\"\"
This is only defined as the \`forward\` so that accelerate hooks can move correctly \`correct_output_scale\`
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
\`scale_corrected_output\`
\"\"\"
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
\"\"\"Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].\"\"\"
return self.forward(corrected)
class Gemma3nTextRotaryEmbedding(nn.Module):
def __init__(self, config: Gemma3nTextConfig, device=None):
super().__init__()
# BC: \"rope_type\" was originally \"type\"
if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))
else:
self.rope_type = \"default\"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer(\"inv_freq\", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
\"\"\"Rotates half the hidden dims of the input.\"\"\"
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
\"\"\"
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
\"\"\"
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if scaling is None:
scaling = module.head_dim**-0.5
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if softcap is not None:
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
unsqueeze_dim: int = 1,
):
\"\"\"Applies Rotary Position Embedding to the query and key tensors.
Args:
x (\`torch.Tensor\`): The tensor to embed.
cos (\`torch.Tensor\`): The cosine part of the rotary embedding.
sin (\`torch.Tensor\`): The sine part of the rotary embedding.
position_ids (\`torch.Tensor\`, *optional*):
Deprecated and unused.
unsqueeze_dim (\`int\`, *optional*, defaults to 1):
The \'unsqueeze_dim\' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
\`tuple(torch.Tensor)\` comprising of the query and key tensors rotated using the Rotary Position Embedding.
\"\"\"
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)
class Gemma3nTextAttention(nn.Module):
\"\"\"Multi-headed attention from \'Attention Is All You Need\' paper\"\"\"
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
super().__init__()
self.is_sliding = config.layer_types[layer_idx] == \"sliding_attention\"
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.attention_dropout = self.config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.sliding_window = config.sliding_window if self.is_sliding else None
self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
if self.is_kv_shared_layer
else None
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.config.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# Device of past layer may be different from current one
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding:
max_length = past_key_value.sliding_window
indices = (
slice(0, max_length)
if cache_position.shape[0] > max_length
else cache_position.clamp(min=0, max=max_length - 1)
)
# Device of past layer may be different from current one
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
query_states.device
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
key_states = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
\"sin\": sin,
\"cos\": cos,
\"cache_position\": cache_position,
\"sliding_window\": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != \"eager\":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=1.0,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3nTextAttention(config, layer_idx)
self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.act_fn = ACT2FN[config.hidden_activation]
self.altup = Gemma3nTextAltUp(config)
self.laurel = Gemma3nTextLaurelBlock(config)
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
@deprecate_kwarg(\"last_cache_position\", version=\"4.53.0\")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
per_layer_input: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
predictions = self.altup.predict(hidden_states)
active_prediction = predictions[self.config.altup_active_idx]
active_prediction_normed = self.input_layernorm(active_prediction)
laurel_output = self.laurel(active_prediction_normed)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
attn, self_attn_weights = self.self_attn(
hidden_states=active_prediction_normed,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
attn = self.post_attention_layernorm(attn)
attn_gated = active_prediction + attn
attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
attn_norm = self.pre_feedforward_layernorm(attn_laurel)
attn_ffw = self.mlp(attn_norm)
attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate adapted from jax.numpy.einsum(\"btd,dp->btp\", ...)
first_prediction = self.per_layer_input_gate(first_prediction)
first_prediction = self.act_fn(first_prediction)
first_prediction = torch.multiply(first_prediction, per_layer_input)
# per_layer_projection adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
first_prediction = self.per_layer_projection(first_prediction)
first_prediction = self.post_per_layer_input_norm(first_prediction)
corrected_predictions[1:] += first_prediction
outputs = (corrected_predictions,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@auto_docstring
class Gemma3nPreTrainedModel(PreTrainedModel):
config_class = Gemma3nConfig
base_model_prefix = \"\"
supports_gradient_checkpointing = True
_no_split_modules = [\"Gemma3nTextDecoderLayer\"]
_skip_keys_device_placement = [\"past_key_values\"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of Gemma2 isn\'t meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed
std = getattr(self.config, \"initializer_range\", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Gemma3nRMSNorm):
if module.with_scale:
module.weight.data.fill_(1.0)
elif isinstance(module, Gemma3nAudioCumulativeGroupNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Gemma3nAudioAttention):
module.per_dim_scale.data.zero_()
elif isinstance(module, Gemma3nTextAltUp):
module.correct_output_scale.data.zero_()
@auto_docstring(custom_intro=\"The base Gemma 3n language model without a language modeling head.\")
class Gemma3nTextModel(Gemma3nPreTrainedModel):
config_class = Gemma3nTextConfig
def __init__(self, config: Gemma3nTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
self.embed_tokens = Gemma3nTextScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
)
self.layers = nn.ModuleList(
[Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# TODO (raushan): Fix this after RoPE refactor. For now we hack it by
# reassigning thetas when we want to create a local RoPE layer. Config
# defaults should hold values for global RoPE.
config = copy.deepcopy(config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {\"rope_type\": \"default\"}
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
self.padding_idx,
embed_scale=config.hidden_size_per_layer_input**0.5,
)
self.per_layer_model_projection = nn.Linear(
self.hidden_size,
config.num_hidden_layers * config.hidden_size_per_layer_input,
bias=False,
)
self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
self.altup_projections = nn.ModuleList(
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
)
self.altup_unembed_projections = nn.ModuleList(
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
)
self.register_buffer(\"per_layer_projection_scale\", torch.tensor(self.hidden_size**-0.5), persistent=False)
self.register_buffer(\"per_layer_input_scale\", torch.rsqrt(torch.tensor(2.0)), persistent=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
r\"\"\"
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
\"\"\"
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
\"\`use_cache=True\` is incompatible with gradient checkpointing. Setting \`use_cache=False\`.\"
)
use_cache = False
if input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
per_layer_inputs = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. \`generate\`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
\"config\": self.config,
\"input_embeds\": inputs_embeds,
\"attention_mask\": attention_mask,
\"cache_position\": cache_position,
\"past_key_values\": past_key_values,
}
# Create the masks
causal_mask_mapping = {
\"full_attention\": create_causal_mask(**mask_kwargs),
\"sliding_attention\": create_sliding_window_causal_mask(**mask_kwargs),
}
# embed positions
hidden_states_0 = inputs_embeds
# Initialize RoPE embeddings
position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
# Expand hidden_states to support per-layer inputs
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(1e-5)
temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs):
# altup_proj adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
causal_mask = causal_mask_mapping[decoder_layer.attention_type]
per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global,
position_embeddings_local,
per_layer_input,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Per-layer inputs to single output
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
temp_hidden_states = [hidden_states[0]]
for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states)
hidden_states = torch.mean(hidden_states, dim=0)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.embed_tokens_per_layer(input_ids).reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
if per_layer_projection.shape != per_layer_inputs.shape:
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
@auto_docstring(custom_intro=\"The base Gemma 3n language model with a language modeling head.\")
class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
_tied_weights_keys = [\"lm_head.weight\"]
_tp_plan = {\"lm_head\": \"colwise_rep\"}
_pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}
config_class = Gemma3nTextConfig
base_model_prefix = \"model\"
_checkpoint_conversion_mapping = {\"model.language_model\": \"model\"}
def __init__(self, config: Gemma3nTextConfig):
super().__init__(config)
self.model = Gemma3nTextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> CausalLMOutputWithPast:
r\"\"\"
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.vocab_size]\`.
Example:
\`\`\`python
>>> from transformers import AutoTokenizer, Gemma3nForCausalLM
>>> model = Gemma3nForCausalLM.from_pretrained(\"google/gemma-2-9b\")
>>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b\")
>>> prompt = \"What is your favorite condiment?\"
>>> inputs = tokenizer(prompt, return_tensors=\"pt\")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
\"What is your favorite condiment?\"
\`\`\`\"\"\"
if self.training and self.config._attn_implementation != \"eager\":
logger.warning_once(
\"It is strongly recommended to train Gemma3n models with the \`eager\` attention implementation \"
f\"instead of \`{self.config._attn_implementation}\`. Use \`eager\` with \`AutoModelForCausalLM.from_pretrained(\'<path-to-checkpoint>\', attn_implementation=\'eager\')\`.\"
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**loss_kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class Gemma3nMultimodalEmbedder(nn.Module):
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.\"\"\"
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
):
super().__init__()
self.multimodal_hidden_size = multimodal_config.hidden_size
self.eps = multimodal_config.rms_norm_eps
self.vocab_offset = multimodal_config.vocab_offset
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.
Args:
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
\`[vocab_offset, vocab_offset + vocab_size)\`.
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
Returns:
A torch.Tensor of embeddings with shape \`[batch_size, seq_len, self.config.text_config.hidden_size]\`.
\"\"\"
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
hard_emb = self.embedding(input_ids - self.vocab_offset)
emb_norm = self.hard_embedding_norm(hard_emb)
emb_norm_proj = self.embedding_projection(emb_norm)
return self.embedding_post_projection_norm(emb_norm_proj)
@auto_docstring(
custom_intro=\"\"\"
The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
language modeling head.
\"\"\"
)
class Gemma3nModel(Gemma3nPreTrainedModel):
_checkpoint_conversion_mapping = {}
# we are filtering the logits/labels so we shouldn\'t divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: Gemma3nConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
self.audio_tower = AutoModel.from_config(config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
\"\"\"
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (\`torch.FloatTensor]\` of shape \`(batch_size, channels, height, width)\`)
The tensors corresponding to the input images.
Returns:
image_features (\`torch.Tensor\`): Image feature tensor of shape \`(num_images, image_length, embed_dim)\`).
\"\"\"
vision_outputs = self.vision_tower(
pixel_values=pixel_values, do_pooling=False, return_dict=True
).last_hidden_state
# Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
# height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1)
# Normalize and embed the soft tokens into language model space.
vision_outputs *= self.config.vision_config.hidden_size**0.5
return self.embed_vision(inputs_embeds=vision_outputs)
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # text inputs
pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
input_features: Optional[torch.FloatTensor] = None, # audio inputs
attention_mask: Optional[torch.Tensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**lm_kwargs,
) -> Gemma3nCausalLMOutputWithPast:
r\"\"\"
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.text_config.vocab_size]\`.
Example:
\`\`\`python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
>>> model = Gemma3nForConditionalGeneration.from_pretrained(\"google/gemma3n2-3b-mix-224\")
>>> processor = AutoProcessor.from_pretrained(\"google/gemma3n2-3b-mix-224\")
>>> prompt = \"Where is the cat standing?\"
>>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors=\"pt\")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
\"Where is the cat standing?\nsnow\"
\`\`\`
\"\"\"
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
# Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
vision_mask = torch.logical_and(
input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
)
dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
vision_embeds = self.embed_vision(input_ids=vision_input_ids)
expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
# Handle audio tokens (>= embed_audio.vocab_offset)
audio_mask = input_ids >= self.embed_audio.vocab_offset
dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
audio_embeds = self.embed_audio(input_ids=audio_input_ids)
expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
else:
per_layer_inputs = None
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f\"Number of images does not match number of special image tokens in the input text. \"
f\"Got {image_tokens_in_text} image tokens in the text and \"
f\"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings.\"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Merge text and audio
if input_features is not None and input_features_mask is not None:
audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
if input_ids is None:
special_audio_mask = inputs_embeds == self.embed_audio(
input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f\"Number of audio input features does not match number of special audio tokens in the input text. \"
f\"Got {audio_tokens_in_text} audio tokens in the text and \"
f\"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings.\"
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
outputs = self.language_model(
input_ids=None,
per_layer_inputs=per_layer_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3nModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
audio_hidden_states=audio_features if input_features is not None else None,
)
def get_audio_features(
self, input_features: torch.Tensor, input_features_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
\"\"\"
Projects the last hidden state from the audio encoder into language model space.
Args:
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length, num_features)\`):
The tensors corresponding to the input audio.
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length)\`):
The attention mask for the input audio.
Returns:
audio_features (\`torch.Tensor\`): Audio feature tensor of shape \`(num_images, audio_length, embed_dim)\`).
\"\"\"
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
@auto_docstring(
custom_intro=\"\"\"
The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
head.
\"\"\"
)
class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = [\"lm_head.weight\"]
base_model_prefix = \"model\"
def __init__(self, config: Gemma3nConfig):
super().__init__(config)
self.model = Gemma3nModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
raise AttributeError(\"Use embed_vision instead of multi_modal_projector.\")
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # text inputs
pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
input_features: Optional[torch.FloatTensor] = None, # audio inputs
attention_mask: Optional[torch.Tensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Gemma3nCausalLMOutputWithPast:
r\"\"\"
input_features (torch.Tensor, *optional*, defaults to None):
The audio inputs to be encoded.
input_features_mask (torch.Tensor, *optional*, defaults to None):
The attention mask for the input audio.
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are
ignored (masked), the loss is only computed for the tokens with labels in
\`[0, ..., config.text_config.vocab_size]\`.
Example:
\`\`\`python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained(\"google/gemma-3-4b-it\")
>>> processor = AutoProcessor.from_pretrained(\"google/gemma-3-4b-it\")
>>> messages = [
... {
... \"role\": \"system\",
... \"content\": [
... {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}
... ]
... },
... {
... \"role\": \"user\", \"content\": [
... {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"},
... {\"type\": \"text\", \"text\": \"Where is the cat standing?\"},
... ]
... },
... ]
>>> inputs = processor.apply_chat_template(
... messages,
... tokenizer=True,
... return_dict=True,
... return_tensors=\"pt\",
... add_generation_prompt=True
... )
>>> # Generate
>>> generate_ids = model.generate(**inputs)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
\"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to\"
\`\`\`
\"\"\"
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
cache_position=cache_position,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
**lm_kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
logits = logits / final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * final_logit_softcapping
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
return Gemma3nCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
input_features=None,
attention_mask=None,
input_features_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom \`position_ids\` and \`pixel_values\` handling
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# If we\'re in cached decoding stage, multimodal inputs should be None because input ids do not contain special
# tokens anymore. Otherwise multimodal inputs should be passed to model.
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
if cache_position[0] == 0:
model_inputs[\"pixel_values\"] = pixel_values
model_inputs[\"input_features\"] = input_features
model_inputs[\"input_features_mask\"] = input_features_mask
return model_inputs
@property
def audio_tower(self):
return self.model.audio_tower
__all__ = [
\"Gemma3nAudioEncoder\",
\"Gemma3nForCausalLM\",
\"Gemma3nForConditionalGeneration\",
\"Gemma3nModel\",
\"Gemma3nPreTrainedModel\",
\"Gemma3nTextModel\",
]