Spaces:
Running
Running
Upload 2 files
Browse files- dependencies.py +0 -0
- main_code.py +253 -253
dependencies.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
main_code.py
CHANGED
@@ -8,14 +8,14 @@
|
|
8 |
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
9 |
#
|
10 |
#
|
11 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
12 |
# you may not use this file except in compliance with the License.
|
13 |
# You may obtain a copy of the License at
|
14 |
#
|
15 |
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
#
|
17 |
# Unless required by applicable law or agreed to in writing, software
|
18 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
19 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
# See the License for the specific language governing permissions and
|
21 |
# limitations under the License.
|
@@ -56,25 +56,25 @@ logger = logging.get_logger(__name__)
|
|
56 |
|
57 |
@dataclass
|
58 |
@auto_docstring(
|
59 |
-
custom_intro
|
60 |
Base class for Gemma3n outputs, with hidden states and attentions.
|
61 |
-
"""
|
62 |
)
|
63 |
class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
64 |
-
r"""
|
65 |
-
past_key_values (
|
66 |
-
Tuple of
|
67 |
-
|
68 |
|
69 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
70 |
-
|
71 |
-
image_hidden_states (
|
72 |
-
A
|
73 |
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
74 |
-
audio_hidden_states (
|
75 |
-
A
|
76 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
77 |
-
"""
|
78 |
|
79 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
80 |
|
@@ -83,29 +83,29 @@ class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
|
83 |
|
84 |
@dataclass
|
85 |
@auto_docstring(
|
86 |
-
custom_intro
|
87 |
Base class for Gemma3n causal language model (or autoregressive) outputs.
|
88 |
-
"""
|
89 |
)
|
90 |
class Gemma3nCausalLMOutputWithPast(ModelOutput):
|
91 |
-
r"""
|
92 |
-
loss (
|
93 |
Language modeling loss (for next-token prediction).
|
94 |
-
logits (
|
95 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
96 |
-
past_key_values (
|
97 |
-
Tuple of
|
98 |
-
|
99 |
|
100 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
101 |
-
|
102 |
-
image_hidden_states (
|
103 |
-
A
|
104 |
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
105 |
-
audio_hidden_states (
|
106 |
-
A
|
107 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
108 |
-
"""
|
109 |
|
110 |
loss: Optional[torch.FloatTensor] = None
|
111 |
logits: Optional[torch.FloatTensor] = None
|
@@ -126,7 +126,7 @@ class Gemma3nRMSNorm(nn.Module):
|
|
126 |
if self.with_scale:
|
127 |
self.weight = nn.Parameter(torch.ones(dim))
|
128 |
else:
|
129 |
-
self.register_buffer("weight", torch.tensor(1.0), persistent=False)
|
130 |
|
131 |
def _norm(self, x):
|
132 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
@@ -138,7 +138,7 @@ class Gemma3nRMSNorm(nn.Module):
|
|
138 |
return output.type_as(x)
|
139 |
|
140 |
def extra_repr(self):
|
141 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
142 |
|
143 |
|
144 |
# ==== Audio Encoder ====
|
@@ -163,7 +163,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
163 |
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
164 |
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
165 |
self.register_buffer(
|
166 |
-
"inv_timescales",
|
167 |
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
168 |
persistent=False,
|
169 |
)
|
@@ -184,7 +184,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
184 |
key_context_size: int,
|
185 |
max_span_plus_1: int,
|
186 |
) -> torch.Tensor:
|
187 |
-
"""Performs the relative shift.
|
188 |
|
189 |
Args:
|
190 |
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
@@ -193,7 +193,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
193 |
|
194 |
Returns:
|
195 |
Tensor of shape [B, N, U, W, C].
|
196 |
-
"""
|
197 |
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
198 |
# Target shape after shift: [B, N, U, W, C]
|
199 |
|
@@ -209,7 +209,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
209 |
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
210 |
# Shape after pad: [B, N, U, W, C+1]
|
211 |
|
212 |
-
# Reshape for slicing (emulating JAX's behavior)
|
213 |
# [B, N, U, W * (C+1)]
|
214 |
term_bd_reshaped = term_bd_padded.reshape(
|
215 |
(
|
@@ -271,7 +271,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
271 |
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
272 |
|
273 |
# term_bd: Query-Position interaction
|
274 |
-
# Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
|
275 |
# queries shape: [B, U, W, N, H]
|
276 |
# sin_emb shape: [F, N, H]
|
277 |
# Target output shape: [B, N, U, W, F]
|
@@ -338,7 +338,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
338 |
|
339 |
q_scale = self.head_dim**-0.5
|
340 |
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
341 |
-
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
342 |
|
343 |
lower_causal_mask = torch.tril(
|
344 |
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
@@ -350,10 +350,10 @@ class Gemma3nAudioAttention(nn.Module):
|
|
350 |
)
|
351 |
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
352 |
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
353 |
-
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
|
354 |
|
355 |
self.register_buffer(
|
356 |
-
"softcap",
|
357 |
torch.tensor(self.attention_logits_soft_cap).float(),
|
358 |
persistent=False,
|
359 |
)
|
@@ -366,7 +366,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
366 |
return x
|
367 |
|
368 |
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
369 |
-
"""Turns a sequence to non overlapping blocks.
|
370 |
|
371 |
Args:
|
372 |
hidden_states: a tensor of [batch, time, ...].
|
@@ -375,7 +375,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
375 |
A tensor of [batch, num_blocks, block_size, ...], with necessary
|
376 |
paddings,
|
377 |
where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
|
378 |
-
"""
|
379 |
shape = hidden_states.shape
|
380 |
b, t = shape[:2]
|
381 |
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
@@ -388,7 +388,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
388 |
return hidden_states
|
389 |
|
390 |
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
391 |
-
"""Extracts temporal context for every block.
|
392 |
|
393 |
Args:
|
394 |
hidden_states: a tensor of [batch, time, ...].
|
@@ -400,11 +400,11 @@ class Gemma3nAudioAttention(nn.Module):
|
|
400 |
and output[:, i, ...] are x[:, start-left_context:end+right_context,
|
401 |
...],
|
402 |
start = i * block_size, end = (i + 1) * block_size.
|
403 |
-
"""
|
404 |
pad_left = self.max_past_horizon
|
405 |
-
# The JAX equivalent padding for signal.frame with pad_mode
|
406 |
# (left_context, right_context + block_size - 1) on the time dimension.
|
407 |
-
# PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
|
408 |
# or (pad_dim_start, pad_dim_end) if two are given.
|
409 |
# Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
|
410 |
# or dim 1 (time for [B,T]).
|
@@ -424,7 +424,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
424 |
|
425 |
# If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
|
426 |
# If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
|
427 |
-
# We want to match JAX's typical output for such operations which might be
|
428 |
# [B, num_blocks, frame_len, N, H] if N, H are present.
|
429 |
# The relative_position_embedding expects keys as [B, U, C, N, H].
|
430 |
# If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
|
@@ -436,7 +436,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
436 |
return x_unfolded.contiguous()
|
437 |
|
438 |
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
439 |
-
# sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
|
440 |
qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
|
441 |
query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
|
442 |
key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
|
@@ -472,7 +472,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
472 |
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
473 |
batch_size, num_query_blocks, self.context_size
|
474 |
)
|
475 |
-
# After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
|
476 |
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
|
477 |
# but for the mask case, this should hold.
|
478 |
if extracted_valid_mask_blocks.shape != (
|
@@ -481,9 +481,9 @@ class Gemma3nAudioAttention(nn.Module):
|
|
481 |
self.context_size,
|
482 |
):
|
483 |
raise ValueError(
|
484 |
-
"Shape of extracted_valid_mask_blocks"
|
485 |
-
f" {extracted_valid_mask_blocks.shape} is not ({batch_size}
|
486 |
-
f" {num_query_blocks}, {self.context_size}) after potential reshape
|
487 |
)
|
488 |
|
489 |
# 3. Expand dimensions for broadcasting with logits and causal mask.
|
@@ -518,7 +518,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
518 |
logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
|
519 |
probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
|
520 |
|
521 |
-
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
|
522 |
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
523 |
h_dim = value_blocks.shape[-1]
|
524 |
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
@@ -539,21 +539,21 @@ class Gemma3nAudioAttention(nn.Module):
|
|
539 |
|
540 |
|
541 |
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
542 |
-
"""Applies Group Normalization cumulatively over the time dimension.
|
543 |
|
544 |
This layer normalizes the input by calculating the mean and variance
|
545 |
cumulatively over the time dimension (dim 1). The statistics are computed
|
546 |
-
over all feature dimensions (specified by
|
547 |
-
for elements marked as valid by the optional
|
548 |
|
549 |
-
If a
|
550 |
invalid time steps do not contribute to the statistics calculation, and
|
551 |
their corresponding output values are zeroed out.
|
552 |
|
553 |
Scale and bias, if enabled, are applied per-channel (last dimension).
|
554 |
-
This behavior is similar to JAX's
|
555 |
-
and
|
556 |
-
"""
|
557 |
|
558 |
def __init__(
|
559 |
self,
|
@@ -574,19 +574,19 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
574 |
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
575 |
|
576 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
577 |
-
"""Applies cumulative group norm, optionally using a mask.
|
578 |
|
579 |
Args:
|
580 |
hidden_states: Input tensor, shape [B, T, *feature_dims, C].
|
581 |
|
582 |
Returns:
|
583 |
Normalized tensor with the same shape as x.
|
584 |
-
"""
|
585 |
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
586 |
if hidden_states.shape[2:] != expected_input_suffix:
|
587 |
raise ValueError(
|
588 |
-
f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
|
589 |
-
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
|
590 |
)
|
591 |
|
592 |
input_dtype = hidden_states.dtype
|
@@ -594,7 +594,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
594 |
calc_dtype = torch.float32
|
595 |
x_calc = hidden_states.to(calc_dtype)
|
596 |
|
597 |
-
# Prepare a broadcastable mask (
|
598 |
# If no mask is provided, treat all elements as valid
|
599 |
# (mask_calc is all ones).
|
600 |
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
@@ -607,7 +607,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
607 |
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
608 |
|
609 |
# 3. Count of valid elements in the normalization group at each time step.
|
610 |
-
# (A "group" here consists of all features at a given Batch, Time).
|
611 |
elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
|
612 |
# 4. Cumulative count of valid elements over time.
|
613 |
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
@@ -648,11 +648,11 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
648 |
|
649 |
|
650 |
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
651 |
-
"""A single convolution block for the SubSampleConvProjection.
|
652 |
|
653 |
This block consists of a 2D convolution, followed by CumulativeGroupNorm,
|
654 |
and a ReLU activation. It handles manual padding for the convolution.
|
655 |
-
"""
|
656 |
|
657 |
def __init__(
|
658 |
self,
|
@@ -665,7 +665,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
|
|
665 |
self.config = config
|
666 |
self.manual_padding = manual_padding
|
667 |
|
668 |
-
# in_channels is 1 for the first block, or C_out from previous block's conv
|
669 |
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
670 |
out_channels = self.config.sscp_conv_channel_size[idx]
|
671 |
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
@@ -701,7 +701,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
|
|
701 |
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
|
702 |
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
703 |
# F.pad applies to last two dims: F_in then T_in
|
704 |
-
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode
|
705 |
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
|
706 |
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
|
707 |
audio_encodings_conv = self.conv(audio_encodings_padded)
|
@@ -728,7 +728,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
|
728 |
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
729 |
|
730 |
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
731 |
-
# JAX 'reverse_causal' padding is (0, kernel_size - 1)
|
732 |
pad_t_top = 0
|
733 |
pad_t_bottom = kernel_h - 1
|
734 |
|
@@ -736,7 +736,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
|
736 |
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
|
737 |
# and the successful test configuration.
|
738 |
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
|
739 |
-
# to match generic JAX 'SAME' behavior if it differs.
|
740 |
pad_f_left = 1
|
741 |
pad_f_right = 1
|
742 |
|
@@ -792,7 +792,7 @@ class Gemma3nAudioConformerAttention(nn.Module):
|
|
792 |
super().__init__()
|
793 |
self.config = config
|
794 |
self.post_in_features = self.config.hidden_size
|
795 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
796 |
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
797 |
self.attn = Gemma3nAudioAttention(config)
|
798 |
self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
|
@@ -820,7 +820,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
|
|
820 |
super().__init__()
|
821 |
self.config = config
|
822 |
|
823 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
824 |
|
825 |
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
826 |
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
@@ -856,7 +856,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
|
856 |
groups=self.config.hidden_size, # Depthwise
|
857 |
bias=False,
|
858 |
)
|
859 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
860 |
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
861 |
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
862 |
|
@@ -892,7 +892,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
|
892 |
self.attention = Gemma3nAudioConformerAttention(self.config)
|
893 |
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
|
894 |
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
|
895 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
896 |
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
897 |
|
898 |
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
|
@@ -911,11 +911,11 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
|
911 |
|
912 |
|
913 |
class Gemma3nAudioEncoder(PreTrainedModel):
|
914 |
-
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture
|
915 |
|
916 |
config_class = Gemma3nAudioConfig
|
917 |
|
918 |
-
main_input_name = "audio_mel"
|
919 |
|
920 |
def __init__(self, config: Gemma3nAudioConfig):
|
921 |
super().__init__(config)
|
@@ -929,7 +929,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
929 |
def forward(
|
930 |
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
931 |
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
932 |
-
"""Encodes a batch of MELs.
|
933 |
|
934 |
Args:
|
935 |
audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
|
@@ -937,10 +937,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
937 |
|
938 |
Returns:
|
939 |
audio_encodings: a torch.Tensor of shape
|
940 |
-
|
941 |
-
self.config.audio_config.hidden_size]
|
942 |
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
943 |
-
"""
|
944 |
audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
|
945 |
|
946 |
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
@@ -983,20 +983,20 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
983 |
|
984 |
|
985 |
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
|
986 |
-
"""
|
987 |
-
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
988 |
-
"""
|
989 |
|
990 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
991 |
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
992 |
-
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
993 |
|
994 |
def forward(self, input_ids: torch.Tensor):
|
995 |
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
996 |
|
997 |
|
998 |
class Gemma3nTextLaurelBlock(nn.Module):
|
999 |
-
"""Learned Augmented Residual Layer"""
|
1000 |
|
1001 |
def __init__(self, config: Gemma3nTextConfig):
|
1002 |
super().__init__()
|
@@ -1052,16 +1052,16 @@ class Gemma3nTextMLP(nn.Module):
|
|
1052 |
|
1053 |
|
1054 |
class Gemma3nTextAltUp(nn.Module):
|
1055 |
-
"""Alternating Updates (AltUp)
|
1056 |
|
1057 |
-
The AltUp module wraps transformer layers. The
|
1058 |
-
input to the transformer layer, and the
|
1059 |
of the transformer layer to the sparsely updated dimensions.
|
1060 |
|
1061 |
See more in the research paper:
|
1062 |
|
1063 |
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
|
1064 |
-
"""
|
1065 |
|
1066 |
def __init__(self, config: Gemma3nTextConfig):
|
1067 |
super().__init__()
|
@@ -1071,7 +1071,7 @@ class Gemma3nTextAltUp(nn.Module):
|
|
1071 |
self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
|
1072 |
self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
|
1073 |
self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
1074 |
-
self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
|
1075 |
|
1076 |
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
1077 |
router_inputs = self.router_norm(x) * self.router_input_scale
|
@@ -1079,15 +1079,15 @@ class Gemma3nTextAltUp(nn.Module):
|
|
1079 |
return torch.tanh(routed.float()).type_as(x)
|
1080 |
|
1081 |
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
1082 |
-
"""Predicts the output of a layer using a trainable map.
|
1083 |
|
1084 |
Args:
|
1085 |
-
hidden_states: A 4D tensor of shape
|
1086 |
-
stacking the input embeddings and preprocessing the last
|
1087 |
|
1088 |
Returns:
|
1089 |
-
A 4D tensor of shape
|
1090 |
-
"""
|
1091 |
modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
|
1092 |
|
1093 |
if self.training and self.config.altup_coef_clip is not None:
|
@@ -1107,17 +1107,17 @@ class Gemma3nTextAltUp(nn.Module):
|
|
1107 |
return predictions.contiguous().type_as(hidden_states)
|
1108 |
|
1109 |
def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
|
1110 |
-
"""Corrects the predictions relative to the
|
1111 |
|
1112 |
Args:
|
1113 |
-
predictions: A 4D tensor of shape
|
1114 |
-
stacking the input embeddings and preprocessing the last
|
1115 |
-
activated: A 3D tensor of shape
|
1116 |
|
1117 |
Returns:
|
1118 |
-
A 4D tensor of shape
|
1119 |
predictions relative to the activated input embeddings.
|
1120 |
-
"""
|
1121 |
modalities = self.compute_router_modalities(activated)
|
1122 |
innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
|
1123 |
innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
|
@@ -1125,7 +1125,7 @@ class Gemma3nTextAltUp(nn.Module):
|
|
1125 |
if self.config.altup_coef_clip is not None:
|
1126 |
self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
|
1127 |
|
1128 |
-
# all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
|
1129 |
# Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
|
1130 |
# and expand on dim1 for broadcastability
|
1131 |
all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
|
@@ -1136,26 +1136,26 @@ class Gemma3nTextAltUp(nn.Module):
|
|
1136 |
return corrected.contiguous().type_as(activated)
|
1137 |
|
1138 |
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
1139 |
-
"""
|
1140 |
-
This is only defined as the
|
1141 |
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
1142 |
-
|
1143 |
-
"""
|
1144 |
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
1145 |
|
1146 |
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
1147 |
-
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]
|
1148 |
return self.forward(corrected)
|
1149 |
|
1150 |
|
1151 |
class Gemma3nTextRotaryEmbedding(nn.Module):
|
1152 |
def __init__(self, config: Gemma3nTextConfig, device=None):
|
1153 |
super().__init__()
|
1154 |
-
# BC: "rope_type" was originally "type"
|
1155 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
1156 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
1157 |
else:
|
1158 |
-
self.rope_type = "default"
|
1159 |
self.max_seq_len_cached = config.max_position_embeddings
|
1160 |
self.original_max_seq_len = config.max_position_embeddings
|
1161 |
|
@@ -1163,7 +1163,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
1163 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
1164 |
|
1165 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
1166 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
1167 |
self.original_inv_freq = self.inv_freq
|
1168 |
|
1169 |
@torch.no_grad()
|
@@ -1172,7 +1172,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
1172 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
1173 |
position_ids_expanded = position_ids[:, None, :].float()
|
1174 |
|
1175 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
1176 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
1177 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
1178 |
emb = torch.cat((freqs, freqs), dim=-1)
|
@@ -1183,17 +1183,17 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
1183 |
|
1184 |
|
1185 |
def rotate_half(x):
|
1186 |
-
"""Rotates half the hidden dims of the input
|
1187 |
x1 = x[..., : x.shape[-1] // 2]
|
1188 |
x2 = x[..., x.shape[-1] // 2 :]
|
1189 |
return torch.cat((-x2, x1), dim=-1)
|
1190 |
|
1191 |
|
1192 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
1193 |
-
"""
|
1194 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
1195 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
1196 |
-
"""
|
1197 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
1198 |
if n_rep == 1:
|
1199 |
return hidden_states
|
@@ -1243,38 +1243,38 @@ def apply_rotary_pos_emb(
|
|
1243 |
position_ids: Optional[torch.Tensor] = None,
|
1244 |
unsqueeze_dim: int = 1,
|
1245 |
):
|
1246 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
1247 |
|
1248 |
Args:
|
1249 |
-
x (
|
1250 |
-
cos (
|
1251 |
-
sin (
|
1252 |
-
position_ids (
|
1253 |
Deprecated and unused.
|
1254 |
-
unsqueeze_dim (
|
1255 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
1256 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
1257 |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
1258 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
1259 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
1260 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
1261 |
Returns:
|
1262 |
-
|
1263 |
-
"""
|
1264 |
cos = cos.unsqueeze(unsqueeze_dim)
|
1265 |
sin = sin.unsqueeze(unsqueeze_dim)
|
1266 |
return (x * cos) + (rotate_half(x) * sin)
|
1267 |
|
1268 |
|
1269 |
class Gemma3nTextAttention(nn.Module):
|
1270 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
1271 |
|
1272 |
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
|
1273 |
super().__init__()
|
1274 |
-
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
1275 |
self.config = config
|
1276 |
self.layer_idx = layer_idx
|
1277 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
1278 |
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
1279 |
self.attention_dropout = self.config.attention_dropout
|
1280 |
self.is_causal = True
|
@@ -1356,15 +1356,15 @@ class Gemma3nTextAttention(nn.Module):
|
|
1356 |
if past_key_value is not None:
|
1357 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
1358 |
cache_kwargs = {
|
1359 |
-
"sin": sin,
|
1360 |
-
"cos": cos,
|
1361 |
-
"cache_position": cache_position,
|
1362 |
-
"sliding_window": self.sliding_window,
|
1363 |
}
|
1364 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
1365 |
|
1366 |
attention_interface: Callable = eager_attention_forward
|
1367 |
-
if self.config._attn_implementation != "eager":
|
1368 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
1369 |
|
1370 |
attn_output, attn_weights = attention_interface(
|
@@ -1407,7 +1407,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
1407 |
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
1408 |
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
1409 |
|
1410 |
-
@deprecate_kwarg("last_cache_position", version
|
1411 |
def forward(
|
1412 |
self,
|
1413 |
hidden_states: torch.Tensor,
|
@@ -1460,12 +1460,12 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
1460 |
if self.config.altup_correct_scale:
|
1461 |
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
1462 |
|
1463 |
-
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
1464 |
first_prediction = self.per_layer_input_gate(first_prediction)
|
1465 |
first_prediction = self.act_fn(first_prediction)
|
1466 |
first_prediction = torch.multiply(first_prediction, per_layer_input)
|
1467 |
|
1468 |
-
# per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
1469 |
first_prediction = self.per_layer_projection(first_prediction)
|
1470 |
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
1471 |
corrected_predictions[1:] += first_prediction
|
@@ -1481,10 +1481,10 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
1481 |
@auto_docstring
|
1482 |
class Gemma3nPreTrainedModel(PreTrainedModel):
|
1483 |
config_class = Gemma3nConfig
|
1484 |
-
base_model_prefix = ""
|
1485 |
supports_gradient_checkpointing = True
|
1486 |
-
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
1487 |
-
_skip_keys_device_placement = ["past_key_values"]
|
1488 |
_supports_flash_attn_3 = True
|
1489 |
_supports_flash_attn_2 = True
|
1490 |
_supports_sdpa = True
|
@@ -1495,9 +1495,9 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
1495 |
_supports_attention_backend = True
|
1496 |
|
1497 |
def _init_weights(self, module):
|
1498 |
-
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
1499 |
# inference and fine-tuning - so the proper init weights code has been removed
|
1500 |
-
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
1501 |
|
1502 |
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
|
1503 |
module.weight.data.normal_(mean=0.0, std=std)
|
@@ -1518,7 +1518,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
1518 |
module.correct_output_scale.data.zero_()
|
1519 |
|
1520 |
|
1521 |
-
@auto_docstring(custom_intro
|
1522 |
class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
1523 |
config_class = Gemma3nTextConfig
|
1524 |
|
@@ -1544,7 +1544,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1544 |
# defaults should hold values for global RoPE.
|
1545 |
config = copy.deepcopy(config)
|
1546 |
config.rope_theta = config.rope_local_base_freq
|
1547 |
-
config.rope_scaling = {"rope_type": "default"}
|
1548 |
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
|
1549 |
|
1550 |
self.hidden_size = config.hidden_size
|
@@ -1573,8 +1573,8 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1573 |
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
|
1574 |
)
|
1575 |
|
1576 |
-
self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
|
1577 |
-
self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
|
1578 |
|
1579 |
# Initialize weights and apply final processing
|
1580 |
self.post_init()
|
@@ -1601,10 +1601,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1601 |
cache_position: Optional[torch.LongTensor] = None,
|
1602 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
1603 |
) -> BaseModelOutputWithPast:
|
1604 |
-
r"""
|
1605 |
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
1606 |
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
1607 |
-
"""
|
1608 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1609 |
output_hidden_states = (
|
1610 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -1612,11 +1612,11 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1612 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1613 |
|
1614 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
1615 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
1616 |
|
1617 |
if self.gradient_checkpointing and self.training and use_cache:
|
1618 |
logger.warning_once(
|
1619 |
-
"
|
1620 |
)
|
1621 |
use_cache = False
|
1622 |
|
@@ -1640,20 +1640,20 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1640 |
if position_ids is None:
|
1641 |
position_ids = cache_position.unsqueeze(0)
|
1642 |
|
1643 |
-
# It may already have been prepared by e.g.
|
1644 |
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
1645 |
# Prepare mask arguments
|
1646 |
mask_kwargs = {
|
1647 |
-
"config": self.config,
|
1648 |
-
"input_embeds": inputs_embeds,
|
1649 |
-
"attention_mask": attention_mask,
|
1650 |
-
"cache_position": cache_position,
|
1651 |
-
"past_key_values": past_key_values,
|
1652 |
}
|
1653 |
# Create the masks
|
1654 |
causal_mask_mapping = {
|
1655 |
-
"full_attention": create_causal_mask(**mask_kwargs),
|
1656 |
-
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
1657 |
}
|
1658 |
|
1659 |
# embed positions
|
@@ -1669,7 +1669,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1669 |
|
1670 |
temp_hidden_states = [hidden_states_0]
|
1671 |
for i in range(1, self.config.altup_num_inputs):
|
1672 |
-
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
1673 |
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
1674 |
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
1675 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
@@ -1717,7 +1717,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1717 |
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
1718 |
temp_hidden_states = [hidden_states[0]]
|
1719 |
for i in range(1, self.config.altup_num_inputs):
|
1720 |
-
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
1721 |
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
1722 |
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
1723 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
@@ -1771,14 +1771,14 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
1771 |
)
|
1772 |
|
1773 |
|
1774 |
-
@auto_docstring(custom_intro
|
1775 |
class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
1776 |
-
_tied_weights_keys = ["lm_head.weight"]
|
1777 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
1778 |
-
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
1779 |
config_class = Gemma3nTextConfig
|
1780 |
-
base_model_prefix = "model"
|
1781 |
-
_checkpoint_conversion_mapping = {"model.language_model": "model"}
|
1782 |
|
1783 |
def __init__(self, config: Gemma3nTextConfig):
|
1784 |
super().__init__(config)
|
@@ -1824,33 +1824,33 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
|
1824 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
1825 |
**loss_kwargs,
|
1826 |
) -> CausalLMOutputWithPast:
|
1827 |
-
r"""
|
1828 |
-
labels (
|
1829 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
1830 |
-
config.vocab_size]
|
1831 |
-
(masked), the loss is only computed for the tokens with labels in
|
1832 |
|
1833 |
Example:
|
1834 |
|
1835 |
-
|
1836 |
>>> from transformers import AutoTokenizer, Gemma3nForCausalLM
|
1837 |
|
1838 |
-
>>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
|
1839 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
1840 |
|
1841 |
-
>>> prompt = "What is your favorite condiment
|
1842 |
-
>>> inputs = tokenizer(prompt, return_tensors
|
1843 |
|
1844 |
>>> # Generate
|
1845 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1846 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1847 |
-
"What is your favorite condiment
|
1848 |
-
|
1849 |
|
1850 |
-
if self.training and self.config._attn_implementation != "eager":
|
1851 |
logger.warning_once(
|
1852 |
-
"It is strongly recommended to train Gemma3n models with the
|
1853 |
-
f"instead of
|
1854 |
)
|
1855 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1856 |
output_hidden_states = (
|
@@ -1893,7 +1893,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
|
1893 |
|
1894 |
|
1895 |
class Gemma3nMultimodalEmbedder(nn.Module):
|
1896 |
-
"""Embeds token ids or soft tokens for multimodal content into language model space
|
1897 |
|
1898 |
def __init__(
|
1899 |
self,
|
@@ -1919,18 +1919,18 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
|
1919 |
input_ids: Optional[torch.LongTensor] = None,
|
1920 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1921 |
) -> torch.Tensor:
|
1922 |
-
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
1923 |
|
1924 |
Args:
|
1925 |
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
1926 |
-
|
1927 |
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
1928 |
|
1929 |
Returns:
|
1930 |
-
A torch.Tensor of embeddings with shape
|
1931 |
-
"""
|
1932 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
1933 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
1934 |
|
1935 |
if inputs_embeds is not None:
|
1936 |
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
@@ -1943,14 +1943,14 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
|
1943 |
|
1944 |
|
1945 |
@auto_docstring(
|
1946 |
-
custom_intro
|
1947 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
|
1948 |
language modeling head.
|
1949 |
-
"""
|
1950 |
)
|
1951 |
class Gemma3nModel(Gemma3nPreTrainedModel):
|
1952 |
_checkpoint_conversion_mapping = {}
|
1953 |
-
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
1954 |
accepts_loss_kwargs = False
|
1955 |
|
1956 |
def __init__(self, config: Gemma3nConfig):
|
@@ -1981,16 +1981,16 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
1981 |
return self.language_model
|
1982 |
|
1983 |
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1984 |
-
"""
|
1985 |
Projects the last hidden state from the vision model into language model space.
|
1986 |
|
1987 |
Args:
|
1988 |
-
pixel_values (
|
1989 |
The tensors corresponding to the input images.
|
1990 |
|
1991 |
Returns:
|
1992 |
-
image_features (
|
1993 |
-
"""
|
1994 |
vision_outputs = self.vision_tower(
|
1995 |
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
1996 |
).last_hidden_state
|
@@ -2024,36 +2024,36 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
2024 |
output_hidden_states: Optional[bool] = None,
|
2025 |
**lm_kwargs,
|
2026 |
) -> Gemma3nCausalLMOutputWithPast:
|
2027 |
-
r"""
|
2028 |
-
labels (
|
2029 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
2030 |
-
config.text_config.vocab_size]
|
2031 |
-
(masked), the loss is only computed for the tokens with labels in
|
2032 |
|
2033 |
Example:
|
2034 |
|
2035 |
-
|
2036 |
>>> from PIL import Image
|
2037 |
>>> import requests
|
2038 |
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
2039 |
|
2040 |
-
>>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
|
2041 |
-
>>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
|
2042 |
|
2043 |
-
>>> prompt = "Where is the cat standing
|
2044 |
-
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
2045 |
>>> image = Image.open(requests.get(url, stream=True).raw)
|
2046 |
|
2047 |
-
>>> inputs = processor(images=image, text=prompt, return_tensors
|
2048 |
|
2049 |
>>> # Generate
|
2050 |
>>> generate_ids = model.generate(**inputs,)
|
2051 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2052 |
-
"Where is the cat standing?\nsnow"
|
2053 |
-
|
2054 |
-
"""
|
2055 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
2056 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
2057 |
|
2058 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2059 |
output_hidden_states = (
|
@@ -2103,9 +2103,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
2103 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
2104 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
2105 |
raise ValueError(
|
2106 |
-
f"Number of images does not match number of special image tokens in the input text. "
|
2107 |
-
f"Got {image_tokens_in_text} image tokens in the text and "
|
2108 |
-
f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings
|
2109 |
)
|
2110 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
2111 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
@@ -2140,9 +2140,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
2140 |
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
|
2141 |
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
2142 |
raise ValueError(
|
2143 |
-
f"Number of audio input features does not match number of special audio tokens in the input text. "
|
2144 |
-
f"Got {audio_tokens_in_text} audio tokens in the text and "
|
2145 |
-
f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings
|
2146 |
)
|
2147 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
2148 |
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
@@ -2174,32 +2174,32 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
2174 |
def get_audio_features(
|
2175 |
self, input_features: torch.Tensor, input_features_mask: torch.Tensor
|
2176 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
2177 |
-
"""
|
2178 |
Projects the last hidden state from the audio encoder into language model space.
|
2179 |
|
2180 |
Args:
|
2181 |
-
input_features (
|
2182 |
The tensors corresponding to the input audio.
|
2183 |
-
input_features (
|
2184 |
The attention mask for the input audio.
|
2185 |
|
2186 |
Returns:
|
2187 |
-
audio_features (
|
2188 |
-
"""
|
2189 |
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
|
2190 |
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
2191 |
|
2192 |
|
2193 |
@auto_docstring(
|
2194 |
-
custom_intro
|
2195 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
2196 |
head.
|
2197 |
-
"""
|
2198 |
)
|
2199 |
class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
2200 |
_checkpoint_conversion_mapping = {}
|
2201 |
-
_tied_weights_keys = ["lm_head.weight"]
|
2202 |
-
base_model_prefix = "model"
|
2203 |
|
2204 |
def __init__(self, config: Gemma3nConfig):
|
2205 |
super().__init__(config)
|
@@ -2239,7 +2239,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2239 |
|
2240 |
@property
|
2241 |
def multi_modal_projector(self):
|
2242 |
-
raise AttributeError("Use embed_vision instead of multi_modal_projector
|
2243 |
|
2244 |
@can_return_tuple
|
2245 |
@auto_docstring
|
@@ -2262,38 +2262,38 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2262 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
2263 |
**lm_kwargs,
|
2264 |
) -> Gemma3nCausalLMOutputWithPast:
|
2265 |
-
r"""
|
2266 |
input_features (torch.Tensor, *optional*, defaults to None):
|
2267 |
The audio inputs to be encoded.
|
2268 |
input_features_mask (torch.Tensor, *optional*, defaults to None):
|
2269 |
The attention mask for the input audio.
|
2270 |
-
labels (
|
2271 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
2272 |
-
config.text_config.vocab_size]
|
2273 |
ignored (masked), the loss is only computed for the tokens with labels in
|
2274 |
-
|
2275 |
|
2276 |
Example:
|
2277 |
|
2278 |
-
|
2279 |
>>> from PIL import Image
|
2280 |
>>> import requests
|
2281 |
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
2282 |
|
2283 |
-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
2284 |
-
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
2285 |
|
2286 |
>>> messages = [
|
2287 |
... {
|
2288 |
-
... "role": "system",
|
2289 |
-
... "content": [
|
2290 |
-
... {"type": "text", "text": "You are a helpful assistant
|
2291 |
... ]
|
2292 |
... },
|
2293 |
... {
|
2294 |
-
... "role": "user", "content": [
|
2295 |
-
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
2296 |
-
... {"type": "text", "text": "Where is the cat standing
|
2297 |
... ]
|
2298 |
... },
|
2299 |
... ]
|
@@ -2302,15 +2302,15 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2302 |
... messages,
|
2303 |
... tokenizer=True,
|
2304 |
... return_dict=True,
|
2305 |
-
... return_tensors
|
2306 |
... add_generation_prompt=True
|
2307 |
... )
|
2308 |
>>> # Generate
|
2309 |
>>> generate_ids = model.generate(**inputs)
|
2310 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2311 |
-
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
2312 |
-
|
2313 |
-
"""
|
2314 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2315 |
output_hidden_states = (
|
2316 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -2393,7 +2393,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2393 |
labels=None,
|
2394 |
**kwargs,
|
2395 |
):
|
2396 |
-
# Overwritten -- custom
|
2397 |
model_inputs = super().prepare_inputs_for_generation(
|
2398 |
input_ids,
|
2399 |
past_key_values=past_key_values,
|
@@ -2407,13 +2407,13 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2407 |
**kwargs,
|
2408 |
)
|
2409 |
|
2410 |
-
# If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
2411 |
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
2412 |
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
2413 |
if cache_position[0] == 0:
|
2414 |
-
model_inputs["pixel_values"] = pixel_values
|
2415 |
-
model_inputs["input_features"] = input_features
|
2416 |
-
model_inputs["input_features_mask"] = input_features_mask
|
2417 |
|
2418 |
return model_inputs
|
2419 |
|
@@ -2423,10 +2423,10 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
2423 |
|
2424 |
|
2425 |
__all__ = [
|
2426 |
-
"Gemma3nAudioEncoder",
|
2427 |
-
"Gemma3nForCausalLM",
|
2428 |
-
"Gemma3nForConditionalGeneration",
|
2429 |
-
"Gemma3nModel",
|
2430 |
-
"Gemma3nPreTrainedModel",
|
2431 |
-
"Gemma3nTextModel",
|
2432 |
]
|
|
|
8 |
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
9 |
#
|
10 |
#
|
11 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
12 |
# you may not use this file except in compliance with the License.
|
13 |
# You may obtain a copy of the License at
|
14 |
#
|
15 |
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
#
|
17 |
# Unless required by applicable law or agreed to in writing, software
|
18 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
19 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
# See the License for the specific language governing permissions and
|
21 |
# limitations under the License.
|
|
|
56 |
|
57 |
@dataclass
|
58 |
@auto_docstring(
|
59 |
+
custom_intro=\"\"\"
|
60 |
Base class for Gemma3n outputs, with hidden states and attentions.
|
61 |
+
\"\"\"
|
62 |
)
|
63 |
class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
64 |
+
r\"\"\"
|
65 |
+
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
|
66 |
+
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
|
67 |
+
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
|
68 |
|
69 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
70 |
+
\`past_key_values\` input) to speed up sequential decoding.
|
71 |
+
image_hidden_states (\`torch.FloatTensor\`, *optional*):
|
72 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
73 |
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
74 |
+
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
|
75 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
76 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
77 |
+
\"\"\"
|
78 |
|
79 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
80 |
|
|
|
83 |
|
84 |
@dataclass
|
85 |
@auto_docstring(
|
86 |
+
custom_intro=\"\"\"
|
87 |
Base class for Gemma3n causal language model (or autoregressive) outputs.
|
88 |
+
\"\"\"
|
89 |
)
|
90 |
class Gemma3nCausalLMOutputWithPast(ModelOutput):
|
91 |
+
r\"\"\"
|
92 |
+
loss (\`torch.FloatTensor\` of shape \`(1,)\`, *optional*, returned when \`labels\` is provided):
|
93 |
Language modeling loss (for next-token prediction).
|
94 |
+
logits (\`torch.FloatTensor\` of shape \`(batch_size, sequence_length, config.text_config.vocab_size)\`):
|
95 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
96 |
+
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
|
97 |
+
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
|
98 |
+
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
|
99 |
|
100 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
101 |
+
\`past_key_values\` input) to speed up sequential decoding.
|
102 |
+
image_hidden_states (\`torch.FloatTensor\`, *optional*):
|
103 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
104 |
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
105 |
+
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
|
106 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
107 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
108 |
+
\"\"\"
|
109 |
|
110 |
loss: Optional[torch.FloatTensor] = None
|
111 |
logits: Optional[torch.FloatTensor] = None
|
|
|
126 |
if self.with_scale:
|
127 |
self.weight = nn.Parameter(torch.ones(dim))
|
128 |
else:
|
129 |
+
self.register_buffer(\"weight\", torch.tensor(1.0), persistent=False)
|
130 |
|
131 |
def _norm(self, x):
|
132 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
138 |
return output.type_as(x)
|
139 |
|
140 |
def extra_repr(self):
|
141 |
+
return f\"{tuple(self.weight.shape)}, eps={self.eps}\"
|
142 |
|
143 |
|
144 |
# ==== Audio Encoder ====
|
|
|
163 |
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
164 |
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
165 |
self.register_buffer(
|
166 |
+
\"inv_timescales\",
|
167 |
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
168 |
persistent=False,
|
169 |
)
|
|
|
184 |
key_context_size: int,
|
185 |
max_span_plus_1: int,
|
186 |
) -> torch.Tensor:
|
187 |
+
\"\"\"Performs the relative shift.
|
188 |
|
189 |
Args:
|
190 |
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
|
|
193 |
|
194 |
Returns:
|
195 |
Tensor of shape [B, N, U, W, C].
|
196 |
+
\"\"\"
|
197 |
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
198 |
# Target shape after shift: [B, N, U, W, C]
|
199 |
|
|
|
209 |
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
210 |
# Shape after pad: [B, N, U, W, C+1]
|
211 |
|
212 |
+
# Reshape for slicing (emulating JAX\'s behavior)
|
213 |
# [B, N, U, W * (C+1)]
|
214 |
term_bd_reshaped = term_bd_padded.reshape(
|
215 |
(
|
|
|
271 |
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
272 |
|
273 |
# term_bd: Query-Position interaction
|
274 |
+
# Original einsum: term_bd_unshifed = torch.einsum(\'buwnh,fnh->bnuwf\', queries, sin_emb)
|
275 |
# queries shape: [B, U, W, N, H]
|
276 |
# sin_emb shape: [F, N, H]
|
277 |
# Target output shape: [B, N, U, W, F]
|
|
|
338 |
|
339 |
q_scale = self.head_dim**-0.5
|
340 |
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
341 |
+
self.register_buffer(\"q_scale\", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
342 |
|
343 |
lower_causal_mask = torch.tril(
|
344 |
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
|
|
350 |
)
|
351 |
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
352 |
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
353 |
+
self.register_buffer(\"local_causal_valid_mask\", local_causal_valid_mask, persistent=False)
|
354 |
|
355 |
self.register_buffer(
|
356 |
+
\"softcap\",
|
357 |
torch.tensor(self.attention_logits_soft_cap).float(),
|
358 |
persistent=False,
|
359 |
)
|
|
|
366 |
return x
|
367 |
|
368 |
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
369 |
+
\"\"\"Turns a sequence to non overlapping blocks.
|
370 |
|
371 |
Args:
|
372 |
hidden_states: a tensor of [batch, time, ...].
|
|
|
375 |
A tensor of [batch, num_blocks, block_size, ...], with necessary
|
376 |
paddings,
|
377 |
where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
|
378 |
+
\"\"\"
|
379 |
shape = hidden_states.shape
|
380 |
b, t = shape[:2]
|
381 |
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
|
|
388 |
return hidden_states
|
389 |
|
390 |
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
391 |
+
\"\"\"Extracts temporal context for every block.
|
392 |
|
393 |
Args:
|
394 |
hidden_states: a tensor of [batch, time, ...].
|
|
|
400 |
and output[:, i, ...] are x[:, start-left_context:end+right_context,
|
401 |
...],
|
402 |
start = i * block_size, end = (i + 1) * block_size.
|
403 |
+
\"\"\"
|
404 |
pad_left = self.max_past_horizon
|
405 |
+
# The JAX equivalent padding for signal.frame with pad_mode=\'valid\' is
|
406 |
# (left_context, right_context + block_size - 1) on the time dimension.
|
407 |
+
# PyTorch\'s _pad_dim1 applies padding symmetrically if only one value is given,
|
408 |
# or (pad_dim_start, pad_dim_end) if two are given.
|
409 |
# Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
|
410 |
# or dim 1 (time for [B,T]).
|
|
|
424 |
|
425 |
# If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
|
426 |
# If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
|
427 |
+
# We want to match JAX\'s typical output for such operations which might be
|
428 |
# [B, num_blocks, frame_len, N, H] if N, H are present.
|
429 |
# The relative_position_embedding expects keys as [B, U, C, N, H].
|
430 |
# If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
|
|
|
436 |
return x_unfolded.contiguous()
|
437 |
|
438 |
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
439 |
+
# sl.Dense uses jax.numpy.einsum(\"...a,abcd->...bcd\") and jax.numpy.select()
|
440 |
qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
|
441 |
query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
|
442 |
key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
|
|
|
472 |
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
473 |
batch_size, num_query_blocks, self.context_size
|
474 |
)
|
475 |
+
# After potential reshape, ensure it\'s [B, U, C] if it was from a [B,T] mask.
|
476 |
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
|
477 |
# but for the mask case, this should hold.
|
478 |
if extracted_valid_mask_blocks.shape != (
|
|
|
481 |
self.context_size,
|
482 |
):
|
483 |
raise ValueError(
|
484 |
+
\"Shape of extracted_valid_mask_blocks\"
|
485 |
+
f\" {extracted_valid_mask_blocks.shape} is not ({batch_size},\"
|
486 |
+
f\" {num_query_blocks}, {self.context_size}) after potential reshape.\"
|
487 |
)
|
488 |
|
489 |
# 3. Expand dimensions for broadcasting with logits and causal mask.
|
|
|
518 |
logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
|
519 |
probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
|
520 |
|
521 |
+
# context_vectors is adapted from jax.numpy.einsum(\"BNuwc,BucNH->BuwNH\", ...)
|
522 |
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
523 |
h_dim = value_blocks.shape[-1]
|
524 |
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
|
|
539 |
|
540 |
|
541 |
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
542 |
+
\"\"\"Applies Group Normalization cumulatively over the time dimension.
|
543 |
|
544 |
This layer normalizes the input by calculating the mean and variance
|
545 |
cumulatively over the time dimension (dim 1). The statistics are computed
|
546 |
+
over all feature dimensions (specified by \`feature_dims\` and \`num_channels\`)
|
547 |
+
for elements marked as valid by the optional \`mask\`.
|
548 |
|
549 |
+
If a \`mask\` is provided (True for valid, False for invalid/padded),
|
550 |
invalid time steps do not contribute to the statistics calculation, and
|
551 |
their corresponding output values are zeroed out.
|
552 |
|
553 |
Scale and bias, if enabled, are applied per-channel (last dimension).
|
554 |
+
This behavior is similar to JAX\'s \`GroupNormalization\` with \`num_groups=1\`
|
555 |
+
and \`cumulative=True\`.
|
556 |
+
\"\"\"
|
557 |
|
558 |
def __init__(
|
559 |
self,
|
|
|
574 |
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
575 |
|
576 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
577 |
+
\"\"\"Applies cumulative group norm, optionally using a mask.
|
578 |
|
579 |
Args:
|
580 |
hidden_states: Input tensor, shape [B, T, *feature_dims, C].
|
581 |
|
582 |
Returns:
|
583 |
Normalized tensor with the same shape as x.
|
584 |
+
\"\"\"
|
585 |
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
586 |
if hidden_states.shape[2:] != expected_input_suffix:
|
587 |
raise ValueError(
|
588 |
+
f\"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected\"
|
589 |
+
f\" suffix (feature_dims + num_channels) {expected_input_suffix}\"
|
590 |
)
|
591 |
|
592 |
input_dtype = hidden_states.dtype
|
|
|
594 |
calc_dtype = torch.float32
|
595 |
x_calc = hidden_states.to(calc_dtype)
|
596 |
|
597 |
+
# Prepare a broadcastable mask (\`mask_calc\`).
|
598 |
# If no mask is provided, treat all elements as valid
|
599 |
# (mask_calc is all ones).
|
600 |
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
|
|
607 |
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
608 |
|
609 |
# 3. Count of valid elements in the normalization group at each time step.
|
610 |
+
# (A \"group\" here consists of all features at a given Batch, Time).
|
611 |
elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
|
612 |
# 4. Cumulative count of valid elements over time.
|
613 |
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
|
|
648 |
|
649 |
|
650 |
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
651 |
+
\"\"\"A single convolution block for the SubSampleConvProjection.
|
652 |
|
653 |
This block consists of a 2D convolution, followed by CumulativeGroupNorm,
|
654 |
and a ReLU activation. It handles manual padding for the convolution.
|
655 |
+
\"\"\"
|
656 |
|
657 |
def __init__(
|
658 |
self,
|
|
|
665 |
self.config = config
|
666 |
self.manual_padding = manual_padding
|
667 |
|
668 |
+
# in_channels is 1 for the first block, or C_out from previous block\'s conv
|
669 |
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
670 |
out_channels = self.config.sscp_conv_channel_size[idx]
|
671 |
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
|
|
701 |
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
|
702 |
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
703 |
# F.pad applies to last two dims: F_in then T_in
|
704 |
+
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode=\"constant\", value=0.0)
|
705 |
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
|
706 |
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
|
707 |
audio_encodings_conv = self.conv(audio_encodings_padded)
|
|
|
728 |
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
729 |
|
730 |
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
731 |
+
# JAX \'reverse_causal\' padding is (0, kernel_size - 1)
|
732 |
pad_t_top = 0
|
733 |
pad_t_bottom = kernel_h - 1
|
734 |
|
|
|
736 |
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
|
737 |
# and the successful test configuration.
|
738 |
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
|
739 |
+
# to match generic JAX \'SAME\' behavior if it differs.
|
740 |
pad_f_left = 1
|
741 |
pad_f_right = 1
|
742 |
|
|
|
792 |
super().__init__()
|
793 |
self.config = config
|
794 |
self.post_in_features = self.config.hidden_size
|
795 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
796 |
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
797 |
self.attn = Gemma3nAudioAttention(config)
|
798 |
self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
|
|
|
820 |
super().__init__()
|
821 |
self.config = config
|
822 |
|
823 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
824 |
|
825 |
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
826 |
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
|
|
856 |
groups=self.config.hidden_size, # Depthwise
|
857 |
bias=False,
|
858 |
)
|
859 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
860 |
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
861 |
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
862 |
|
|
|
892 |
self.attention = Gemma3nAudioConformerAttention(self.config)
|
893 |
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
|
894 |
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
|
895 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
896 |
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
897 |
|
898 |
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
|
|
|
911 |
|
912 |
|
913 |
class Gemma3nAudioEncoder(PreTrainedModel):
|
914 |
+
\"\"\"An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture.\"\"\"
|
915 |
|
916 |
config_class = Gemma3nAudioConfig
|
917 |
|
918 |
+
main_input_name = \"audio_mel\"
|
919 |
|
920 |
def __init__(self, config: Gemma3nAudioConfig):
|
921 |
super().__init__(config)
|
|
|
929 |
def forward(
|
930 |
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
931 |
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
932 |
+
\"\"\"Encodes a batch of MELs.
|
933 |
|
934 |
Args:
|
935 |
audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
|
|
|
937 |
|
938 |
Returns:
|
939 |
audio_encodings: a torch.Tensor of shape
|
940 |
+
\`[batch_size, self.config.audio_soft_tokens_per_image,
|
941 |
+
self.config.audio_config.hidden_size]\`
|
942 |
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
943 |
+
\"\"\"
|
944 |
audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
|
945 |
|
946 |
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
|
|
983 |
|
984 |
|
985 |
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
|
986 |
+
\"\"\"
|
987 |
+
This module overrides nn.Embeddings\' forward by multiplying with embeddings scale.
|
988 |
+
\"\"\"
|
989 |
|
990 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
991 |
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
992 |
+
self.register_buffer(\"embed_scale\", torch.tensor(embed_scale), persistent=False)
|
993 |
|
994 |
def forward(self, input_ids: torch.Tensor):
|
995 |
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
996 |
|
997 |
|
998 |
class Gemma3nTextLaurelBlock(nn.Module):
|
999 |
+
\"\"\"Learned Augmented Residual Layer\"\"\"
|
1000 |
|
1001 |
def __init__(self, config: Gemma3nTextConfig):
|
1002 |
super().__init__()
|
|
|
1052 |
|
1053 |
|
1054 |
class Gemma3nTextAltUp(nn.Module):
|
1055 |
+
\"\"\"Alternating Updates (AltUp)
|
1056 |
|
1057 |
+
The AltUp module wraps transformer layers. The \`predict\` step modifies the
|
1058 |
+
input to the transformer layer, and the \`correct\` step propagates the output
|
1059 |
of the transformer layer to the sparsely updated dimensions.
|
1060 |
|
1061 |
See more in the research paper:
|
1062 |
|
1063 |
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
|
1064 |
+
\"\"\"
|
1065 |
|
1066 |
def __init__(self, config: Gemma3nTextConfig):
|
1067 |
super().__init__()
|
|
|
1071 |
self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
|
1072 |
self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
|
1073 |
self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
1074 |
+
self.register_buffer(\"router_input_scale\", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
|
1075 |
|
1076 |
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
1077 |
router_inputs = self.router_norm(x) * self.router_input_scale
|
|
|
1079 |
return torch.tanh(routed.float()).type_as(x)
|
1080 |
|
1081 |
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
1082 |
+
\"\"\"Predicts the output of a layer using a trainable map.
|
1083 |
|
1084 |
Args:
|
1085 |
+
hidden_states: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
|
1086 |
+
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
|
1087 |
|
1088 |
Returns:
|
1089 |
+
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` containing the predictions.
|
1090 |
+
\"\"\"
|
1091 |
modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
|
1092 |
|
1093 |
if self.training and self.config.altup_coef_clip is not None:
|
|
|
1107 |
return predictions.contiguous().type_as(hidden_states)
|
1108 |
|
1109 |
def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
|
1110 |
+
\"\"\"Corrects the predictions relative to the
|
1111 |
|
1112 |
Args:
|
1113 |
+
predictions: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
|
1114 |
+
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
|
1115 |
+
activated: A 3D tensor of shape \`[batch_size, num_tokens, hidden_size]\` containing the activated inputs.
|
1116 |
|
1117 |
Returns:
|
1118 |
+
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` correcting the original
|
1119 |
predictions relative to the activated input embeddings.
|
1120 |
+
\"\"\"
|
1121 |
modalities = self.compute_router_modalities(activated)
|
1122 |
innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
|
1123 |
innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
|
|
|
1125 |
if self.config.altup_coef_clip is not None:
|
1126 |
self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
|
1127 |
|
1128 |
+
# all_coefs adapted from jax.numpy.einsum(\"...p,pi->...i\", ...)
|
1129 |
# Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
|
1130 |
# and expand on dim1 for broadcastability
|
1131 |
all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
|
|
|
1136 |
return corrected.contiguous().type_as(activated)
|
1137 |
|
1138 |
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
1139 |
+
\"\"\"
|
1140 |
+
This is only defined as the \`forward\` so that accelerate hooks can move correctly \`correct_output_scale\`
|
1141 |
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
1142 |
+
\`scale_corrected_output\`
|
1143 |
+
\"\"\"
|
1144 |
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
1145 |
|
1146 |
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
1147 |
+
\"\"\"Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].\"\"\"
|
1148 |
return self.forward(corrected)
|
1149 |
|
1150 |
|
1151 |
class Gemma3nTextRotaryEmbedding(nn.Module):
|
1152 |
def __init__(self, config: Gemma3nTextConfig, device=None):
|
1153 |
super().__init__()
|
1154 |
+
# BC: \"rope_type\" was originally \"type\"
|
1155 |
+
if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:
|
1156 |
+
self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))
|
1157 |
else:
|
1158 |
+
self.rope_type = \"default\"
|
1159 |
self.max_seq_len_cached = config.max_position_embeddings
|
1160 |
self.original_max_seq_len = config.max_position_embeddings
|
1161 |
|
|
|
1163 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
1164 |
|
1165 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
1166 |
+
self.register_buffer(\"inv_freq\", inv_freq, persistent=False)
|
1167 |
self.original_inv_freq = self.inv_freq
|
1168 |
|
1169 |
@torch.no_grad()
|
|
|
1172 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
1173 |
position_ids_expanded = position_ids[:, None, :].float()
|
1174 |
|
1175 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"
|
1176 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
1177 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
1178 |
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
1183 |
|
1184 |
|
1185 |
def rotate_half(x):
|
1186 |
+
\"\"\"Rotates half the hidden dims of the input.\"\"\"
|
1187 |
x1 = x[..., : x.shape[-1] // 2]
|
1188 |
x2 = x[..., x.shape[-1] // 2 :]
|
1189 |
return torch.cat((-x2, x1), dim=-1)
|
1190 |
|
1191 |
|
1192 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
1193 |
+
\"\"\"
|
1194 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
1195 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
1196 |
+
\"\"\"
|
1197 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
1198 |
if n_rep == 1:
|
1199 |
return hidden_states
|
|
|
1243 |
position_ids: Optional[torch.Tensor] = None,
|
1244 |
unsqueeze_dim: int = 1,
|
1245 |
):
|
1246 |
+
\"\"\"Applies Rotary Position Embedding to the query and key tensors.
|
1247 |
|
1248 |
Args:
|
1249 |
+
x (\`torch.Tensor\`): The tensor to embed.
|
1250 |
+
cos (\`torch.Tensor\`): The cosine part of the rotary embedding.
|
1251 |
+
sin (\`torch.Tensor\`): The sine part of the rotary embedding.
|
1252 |
+
position_ids (\`torch.Tensor\`, *optional*):
|
1253 |
Deprecated and unused.
|
1254 |
+
unsqueeze_dim (\`int\`, *optional*, defaults to 1):
|
1255 |
+
The \'unsqueeze_dim\' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
1256 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
1257 |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
1258 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
1259 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
1260 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
1261 |
Returns:
|
1262 |
+
\`tuple(torch.Tensor)\` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
1263 |
+
\"\"\"
|
1264 |
cos = cos.unsqueeze(unsqueeze_dim)
|
1265 |
sin = sin.unsqueeze(unsqueeze_dim)
|
1266 |
return (x * cos) + (rotate_half(x) * sin)
|
1267 |
|
1268 |
|
1269 |
class Gemma3nTextAttention(nn.Module):
|
1270 |
+
\"\"\"Multi-headed attention from \'Attention Is All You Need\' paper\"\"\"
|
1271 |
|
1272 |
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
|
1273 |
super().__init__()
|
1274 |
+
self.is_sliding = config.layer_types[layer_idx] == \"sliding_attention\"
|
1275 |
self.config = config
|
1276 |
self.layer_idx = layer_idx
|
1277 |
+
self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)
|
1278 |
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
1279 |
self.attention_dropout = self.config.attention_dropout
|
1280 |
self.is_causal = True
|
|
|
1356 |
if past_key_value is not None:
|
1357 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
1358 |
cache_kwargs = {
|
1359 |
+
\"sin\": sin,
|
1360 |
+
\"cos\": cos,
|
1361 |
+
\"cache_position\": cache_position,
|
1362 |
+
\"sliding_window\": self.sliding_window,
|
1363 |
}
|
1364 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
1365 |
|
1366 |
attention_interface: Callable = eager_attention_forward
|
1367 |
+
if self.config._attn_implementation != \"eager\":
|
1368 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
1369 |
|
1370 |
attn_output, attn_weights = attention_interface(
|
|
|
1407 |
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
1408 |
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
1409 |
|
1410 |
+
@deprecate_kwarg(\"last_cache_position\", version=\"4.53.0\")
|
1411 |
def forward(
|
1412 |
self,
|
1413 |
hidden_states: torch.Tensor,
|
|
|
1460 |
if self.config.altup_correct_scale:
|
1461 |
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
1462 |
|
1463 |
+
# per_layer_input_gate adapted from jax.numpy.einsum(\"btd,dp->btp\", ...)
|
1464 |
first_prediction = self.per_layer_input_gate(first_prediction)
|
1465 |
first_prediction = self.act_fn(first_prediction)
|
1466 |
first_prediction = torch.multiply(first_prediction, per_layer_input)
|
1467 |
|
1468 |
+
# per_layer_projection adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
1469 |
first_prediction = self.per_layer_projection(first_prediction)
|
1470 |
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
1471 |
corrected_predictions[1:] += first_prediction
|
|
|
1481 |
@auto_docstring
|
1482 |
class Gemma3nPreTrainedModel(PreTrainedModel):
|
1483 |
config_class = Gemma3nConfig
|
1484 |
+
base_model_prefix = \"\"
|
1485 |
supports_gradient_checkpointing = True
|
1486 |
+
_no_split_modules = [\"Gemma3nTextDecoderLayer\"]
|
1487 |
+
_skip_keys_device_placement = [\"past_key_values\"]
|
1488 |
_supports_flash_attn_3 = True
|
1489 |
_supports_flash_attn_2 = True
|
1490 |
_supports_sdpa = True
|
|
|
1495 |
_supports_attention_backend = True
|
1496 |
|
1497 |
def _init_weights(self, module):
|
1498 |
+
# important: this ported version of Gemma2 isn\'t meant for training from scratch - only
|
1499 |
# inference and fine-tuning - so the proper init weights code has been removed
|
1500 |
+
std = getattr(self.config, \"initializer_range\", self.config.get_text_config().initializer_range)
|
1501 |
|
1502 |
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
|
1503 |
module.weight.data.normal_(mean=0.0, std=std)
|
|
|
1518 |
module.correct_output_scale.data.zero_()
|
1519 |
|
1520 |
|
1521 |
+
@auto_docstring(custom_intro=\"The base Gemma 3n language model without a language modeling head.\")
|
1522 |
class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
1523 |
config_class = Gemma3nTextConfig
|
1524 |
|
|
|
1544 |
# defaults should hold values for global RoPE.
|
1545 |
config = copy.deepcopy(config)
|
1546 |
config.rope_theta = config.rope_local_base_freq
|
1547 |
+
config.rope_scaling = {\"rope_type\": \"default\"}
|
1548 |
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
|
1549 |
|
1550 |
self.hidden_size = config.hidden_size
|
|
|
1573 |
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
|
1574 |
)
|
1575 |
|
1576 |
+
self.register_buffer(\"per_layer_projection_scale\", torch.tensor(self.hidden_size**-0.5), persistent=False)
|
1577 |
+
self.register_buffer(\"per_layer_input_scale\", torch.rsqrt(torch.tensor(2.0)), persistent=False)
|
1578 |
|
1579 |
# Initialize weights and apply final processing
|
1580 |
self.post_init()
|
|
|
1601 |
cache_position: Optional[torch.LongTensor] = None,
|
1602 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
1603 |
) -> BaseModelOutputWithPast:
|
1604 |
+
r\"\"\"
|
1605 |
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
1606 |
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
1607 |
+
\"\"\"
|
1608 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1609 |
output_hidden_states = (
|
1610 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
1612 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1613 |
|
1614 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
1615 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
1616 |
|
1617 |
if self.gradient_checkpointing and self.training and use_cache:
|
1618 |
logger.warning_once(
|
1619 |
+
\"\`use_cache=True\` is incompatible with gradient checkpointing. Setting \`use_cache=False\`.\"
|
1620 |
)
|
1621 |
use_cache = False
|
1622 |
|
|
|
1640 |
if position_ids is None:
|
1641 |
position_ids = cache_position.unsqueeze(0)
|
1642 |
|
1643 |
+
# It may already have been prepared by e.g. \`generate\`
|
1644 |
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
1645 |
# Prepare mask arguments
|
1646 |
mask_kwargs = {
|
1647 |
+
\"config\": self.config,
|
1648 |
+
\"input_embeds\": inputs_embeds,
|
1649 |
+
\"attention_mask\": attention_mask,
|
1650 |
+
\"cache_position\": cache_position,
|
1651 |
+
\"past_key_values\": past_key_values,
|
1652 |
}
|
1653 |
# Create the masks
|
1654 |
causal_mask_mapping = {
|
1655 |
+
\"full_attention\": create_causal_mask(**mask_kwargs),
|
1656 |
+
\"sliding_attention\": create_sliding_window_causal_mask(**mask_kwargs),
|
1657 |
}
|
1658 |
|
1659 |
# embed positions
|
|
|
1669 |
|
1670 |
temp_hidden_states = [hidden_states_0]
|
1671 |
for i in range(1, self.config.altup_num_inputs):
|
1672 |
+
# altup_proj adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
1673 |
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
1674 |
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
1675 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
|
1717 |
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
1718 |
temp_hidden_states = [hidden_states[0]]
|
1719 |
for i in range(1, self.config.altup_num_inputs):
|
1720 |
+
# altup_unembed_projections adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
1721 |
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
1722 |
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
1723 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
|
1771 |
)
|
1772 |
|
1773 |
|
1774 |
+
@auto_docstring(custom_intro=\"The base Gemma 3n language model with a language modeling head.\")
|
1775 |
class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
1776 |
+
_tied_weights_keys = [\"lm_head.weight\"]
|
1777 |
+
_tp_plan = {\"lm_head\": \"colwise_rep\"}
|
1778 |
+
_pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}
|
1779 |
config_class = Gemma3nTextConfig
|
1780 |
+
base_model_prefix = \"model\"
|
1781 |
+
_checkpoint_conversion_mapping = {\"model.language_model\": \"model\"}
|
1782 |
|
1783 |
def __init__(self, config: Gemma3nTextConfig):
|
1784 |
super().__init__(config)
|
|
|
1824 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
1825 |
**loss_kwargs,
|
1826 |
) -> CausalLMOutputWithPast:
|
1827 |
+
r\"\"\"
|
1828 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
1829 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
1830 |
+
config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
|
1831 |
+
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.vocab_size]\`.
|
1832 |
|
1833 |
Example:
|
1834 |
|
1835 |
+
\`\`\`python
|
1836 |
>>> from transformers import AutoTokenizer, Gemma3nForCausalLM
|
1837 |
|
1838 |
+
>>> model = Gemma3nForCausalLM.from_pretrained(\"google/gemma-2-9b\")
|
1839 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b\")
|
1840 |
|
1841 |
+
>>> prompt = \"What is your favorite condiment?\"
|
1842 |
+
>>> inputs = tokenizer(prompt, return_tensors=\"pt\")
|
1843 |
|
1844 |
>>> # Generate
|
1845 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1846 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1847 |
+
\"What is your favorite condiment?\"
|
1848 |
+
\`\`\`\"\"\"
|
1849 |
|
1850 |
+
if self.training and self.config._attn_implementation != \"eager\":
|
1851 |
logger.warning_once(
|
1852 |
+
\"It is strongly recommended to train Gemma3n models with the \`eager\` attention implementation \"
|
1853 |
+
f\"instead of \`{self.config._attn_implementation}\`. Use \`eager\` with \`AutoModelForCausalLM.from_pretrained(\'<path-to-checkpoint>\', attn_implementation=\'eager\')\`.\"
|
1854 |
)
|
1855 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1856 |
output_hidden_states = (
|
|
|
1893 |
|
1894 |
|
1895 |
class Gemma3nMultimodalEmbedder(nn.Module):
|
1896 |
+
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.\"\"\"
|
1897 |
|
1898 |
def __init__(
|
1899 |
self,
|
|
|
1919 |
input_ids: Optional[torch.LongTensor] = None,
|
1920 |
inputs_embeds: Optional[torch.Tensor] = None,
|
1921 |
) -> torch.Tensor:
|
1922 |
+
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.
|
1923 |
|
1924 |
Args:
|
1925 |
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
1926 |
+
\`[vocab_offset, vocab_offset + vocab_size)\`.
|
1927 |
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
1928 |
|
1929 |
Returns:
|
1930 |
+
A torch.Tensor of embeddings with shape \`[batch_size, seq_len, self.config.text_config.hidden_size]\`.
|
1931 |
+
\"\"\"
|
1932 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
1933 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
1934 |
|
1935 |
if inputs_embeds is not None:
|
1936 |
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
|
|
1943 |
|
1944 |
|
1945 |
@auto_docstring(
|
1946 |
+
custom_intro=\"\"\"
|
1947 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
|
1948 |
language modeling head.
|
1949 |
+
\"\"\"
|
1950 |
)
|
1951 |
class Gemma3nModel(Gemma3nPreTrainedModel):
|
1952 |
_checkpoint_conversion_mapping = {}
|
1953 |
+
# we are filtering the logits/labels so we shouldn\'t divide the loss based on num_items_in_batch
|
1954 |
accepts_loss_kwargs = False
|
1955 |
|
1956 |
def __init__(self, config: Gemma3nConfig):
|
|
|
1981 |
return self.language_model
|
1982 |
|
1983 |
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1984 |
+
\"\"\"
|
1985 |
Projects the last hidden state from the vision model into language model space.
|
1986 |
|
1987 |
Args:
|
1988 |
+
pixel_values (\`torch.FloatTensor]\` of shape \`(batch_size, channels, height, width)\`)
|
1989 |
The tensors corresponding to the input images.
|
1990 |
|
1991 |
Returns:
|
1992 |
+
image_features (\`torch.Tensor\`): Image feature tensor of shape \`(num_images, image_length, embed_dim)\`).
|
1993 |
+
\"\"\"
|
1994 |
vision_outputs = self.vision_tower(
|
1995 |
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
1996 |
).last_hidden_state
|
|
|
2024 |
output_hidden_states: Optional[bool] = None,
|
2025 |
**lm_kwargs,
|
2026 |
) -> Gemma3nCausalLMOutputWithPast:
|
2027 |
+
r\"\"\"
|
2028 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
2029 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
2030 |
+
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
|
2031 |
+
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.text_config.vocab_size]\`.
|
2032 |
|
2033 |
Example:
|
2034 |
|
2035 |
+
\`\`\`python
|
2036 |
>>> from PIL import Image
|
2037 |
>>> import requests
|
2038 |
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
2039 |
|
2040 |
+
>>> model = Gemma3nForConditionalGeneration.from_pretrained(\"google/gemma3n2-3b-mix-224\")
|
2041 |
+
>>> processor = AutoProcessor.from_pretrained(\"google/gemma3n2-3b-mix-224\")
|
2042 |
|
2043 |
+
>>> prompt = \"Where is the cat standing?\"
|
2044 |
+
>>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"
|
2045 |
>>> image = Image.open(requests.get(url, stream=True).raw)
|
2046 |
|
2047 |
+
>>> inputs = processor(images=image, text=prompt, return_tensors=\"pt\")
|
2048 |
|
2049 |
>>> # Generate
|
2050 |
>>> generate_ids = model.generate(**inputs,)
|
2051 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2052 |
+
\"Where is the cat standing?\nsnow\"
|
2053 |
+
\`\`\`
|
2054 |
+
\"\"\"
|
2055 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
2056 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
2057 |
|
2058 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2059 |
output_hidden_states = (
|
|
|
2103 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
2104 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
2105 |
raise ValueError(
|
2106 |
+
f\"Number of images does not match number of special image tokens in the input text. \"
|
2107 |
+
f\"Got {image_tokens_in_text} image tokens in the text and \"
|
2108 |
+
f\"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings.\"
|
2109 |
)
|
2110 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
2111 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
|
2140 |
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
|
2141 |
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
2142 |
raise ValueError(
|
2143 |
+
f\"Number of audio input features does not match number of special audio tokens in the input text. \"
|
2144 |
+
f\"Got {audio_tokens_in_text} audio tokens in the text and \"
|
2145 |
+
f\"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings.\"
|
2146 |
)
|
2147 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
2148 |
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
|
|
2174 |
def get_audio_features(
|
2175 |
self, input_features: torch.Tensor, input_features_mask: torch.Tensor
|
2176 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
2177 |
+
\"\"\"
|
2178 |
Projects the last hidden state from the audio encoder into language model space.
|
2179 |
|
2180 |
Args:
|
2181 |
+
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length, num_features)\`):
|
2182 |
The tensors corresponding to the input audio.
|
2183 |
+
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length)\`):
|
2184 |
The attention mask for the input audio.
|
2185 |
|
2186 |
Returns:
|
2187 |
+
audio_features (\`torch.Tensor\`): Audio feature tensor of shape \`(num_images, audio_length, embed_dim)\`).
|
2188 |
+
\"\"\"
|
2189 |
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
|
2190 |
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
2191 |
|
2192 |
|
2193 |
@auto_docstring(
|
2194 |
+
custom_intro=\"\"\"
|
2195 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
2196 |
head.
|
2197 |
+
\"\"\"
|
2198 |
)
|
2199 |
class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
2200 |
_checkpoint_conversion_mapping = {}
|
2201 |
+
_tied_weights_keys = [\"lm_head.weight\"]
|
2202 |
+
base_model_prefix = \"model\"
|
2203 |
|
2204 |
def __init__(self, config: Gemma3nConfig):
|
2205 |
super().__init__(config)
|
|
|
2239 |
|
2240 |
@property
|
2241 |
def multi_modal_projector(self):
|
2242 |
+
raise AttributeError(\"Use embed_vision instead of multi_modal_projector.\")
|
2243 |
|
2244 |
@can_return_tuple
|
2245 |
@auto_docstring
|
|
|
2262 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
2263 |
**lm_kwargs,
|
2264 |
) -> Gemma3nCausalLMOutputWithPast:
|
2265 |
+
r\"\"\"
|
2266 |
input_features (torch.Tensor, *optional*, defaults to None):
|
2267 |
The audio inputs to be encoded.
|
2268 |
input_features_mask (torch.Tensor, *optional*, defaults to None):
|
2269 |
The attention mask for the input audio.
|
2270 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
2271 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
2272 |
+
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are
|
2273 |
ignored (masked), the loss is only computed for the tokens with labels in
|
2274 |
+
\`[0, ..., config.text_config.vocab_size]\`.
|
2275 |
|
2276 |
Example:
|
2277 |
|
2278 |
+
\`\`\`python
|
2279 |
>>> from PIL import Image
|
2280 |
>>> import requests
|
2281 |
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
2282 |
|
2283 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained(\"google/gemma-3-4b-it\")
|
2284 |
+
>>> processor = AutoProcessor.from_pretrained(\"google/gemma-3-4b-it\")
|
2285 |
|
2286 |
>>> messages = [
|
2287 |
... {
|
2288 |
+
... \"role\": \"system\",
|
2289 |
+
... \"content\": [
|
2290 |
+
... {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}
|
2291 |
... ]
|
2292 |
... },
|
2293 |
... {
|
2294 |
+
... \"role\": \"user\", \"content\": [
|
2295 |
+
... {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"},
|
2296 |
+
... {\"type\": \"text\", \"text\": \"Where is the cat standing?\"},
|
2297 |
... ]
|
2298 |
... },
|
2299 |
... ]
|
|
|
2302 |
... messages,
|
2303 |
... tokenizer=True,
|
2304 |
... return_dict=True,
|
2305 |
+
... return_tensors=\"pt\",
|
2306 |
... add_generation_prompt=True
|
2307 |
... )
|
2308 |
>>> # Generate
|
2309 |
>>> generate_ids = model.generate(**inputs)
|
2310 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2311 |
+
\"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to\"
|
2312 |
+
\`\`\`
|
2313 |
+
\"\"\"
|
2314 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2315 |
output_hidden_states = (
|
2316 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
2393 |
labels=None,
|
2394 |
**kwargs,
|
2395 |
):
|
2396 |
+
# Overwritten -- custom \`position_ids\` and \`pixel_values\` handling
|
2397 |
model_inputs = super().prepare_inputs_for_generation(
|
2398 |
input_ids,
|
2399 |
past_key_values=past_key_values,
|
|
|
2407 |
**kwargs,
|
2408 |
)
|
2409 |
|
2410 |
+
# If we\'re in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
2411 |
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
2412 |
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
2413 |
if cache_position[0] == 0:
|
2414 |
+
model_inputs[\"pixel_values\"] = pixel_values
|
2415 |
+
model_inputs[\"input_features\"] = input_features
|
2416 |
+
model_inputs[\"input_features_mask\"] = input_features_mask
|
2417 |
|
2418 |
return model_inputs
|
2419 |
|
|
|
2423 |
|
2424 |
|
2425 |
__all__ = [
|
2426 |
+
\"Gemma3nAudioEncoder\",
|
2427 |
+
\"Gemma3nForCausalLM\",
|
2428 |
+
\"Gemma3nForConditionalGeneration\",
|
2429 |
+
\"Gemma3nModel\",
|
2430 |
+
\"Gemma3nPreTrainedModel\",
|
2431 |
+
\"Gemma3nTextModel\",
|
2432 |
]
|