Molbap HF Staff commited on
Commit
d793afd
·
verified ·
1 Parent(s): e221c2d

Upload 2 files

Browse files
Files changed (2) hide show
  1. dependencies.py +0 -0
  2. main_code.py +253 -253
dependencies.py CHANGED
The diff for this file is too large to render. See raw diff
 
main_code.py CHANGED
@@ -8,14 +8,14 @@
8
  # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
9
  #
10
  #
11
- # Licensed under the Apache License, Version 2.0 (the "License");
12
  # you may not use this file except in compliance with the License.
13
  # You may obtain a copy of the License at
14
  #
15
  # http://www.apache.org/licenses/LICENSE-2.0
16
  #
17
  # Unless required by applicable law or agreed to in writing, software
18
- # distributed under the License is distributed on an "AS IS" BASIS,
19
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
  # See the License for the specific language governing permissions and
21
  # limitations under the License.
@@ -56,25 +56,25 @@ logger = logging.get_logger(__name__)
56
 
57
  @dataclass
58
  @auto_docstring(
59
- custom_intro="""
60
  Base class for Gemma3n outputs, with hidden states and attentions.
61
- """
62
  )
63
  class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
64
- r"""
65
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
66
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
67
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
68
 
69
  Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
70
- `past_key_values` input) to speed up sequential decoding.
71
- image_hidden_states (`torch.FloatTensor`, *optional*):
72
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
73
  image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
74
- audio_hidden_states (`torch.FloatTensor`, *optional*):
75
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
76
  audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
77
- """
78
 
79
  image_hidden_states: Optional[torch.FloatTensor] = None
80
 
@@ -83,29 +83,29 @@ class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
83
 
84
  @dataclass
85
  @auto_docstring(
86
- custom_intro="""
87
  Base class for Gemma3n causal language model (or autoregressive) outputs.
88
- """
89
  )
90
  class Gemma3nCausalLMOutputWithPast(ModelOutput):
91
- r"""
92
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
93
  Language modeling loss (for next-token prediction).
94
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
95
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
96
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
97
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
98
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
99
 
100
  Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
101
- `past_key_values` input) to speed up sequential decoding.
102
- image_hidden_states (`torch.FloatTensor`, *optional*):
103
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
104
  image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
105
- audio_hidden_states (`torch.FloatTensor`, *optional*):
106
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
107
  audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
108
- """
109
 
110
  loss: Optional[torch.FloatTensor] = None
111
  logits: Optional[torch.FloatTensor] = None
@@ -126,7 +126,7 @@ class Gemma3nRMSNorm(nn.Module):
126
  if self.with_scale:
127
  self.weight = nn.Parameter(torch.ones(dim))
128
  else:
129
- self.register_buffer("weight", torch.tensor(1.0), persistent=False)
130
 
131
  def _norm(self, x):
132
  return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@@ -138,7 +138,7 @@ class Gemma3nRMSNorm(nn.Module):
138
  return output.type_as(x)
139
 
140
  def extra_repr(self):
141
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
142
 
143
 
144
  # ==== Audio Encoder ====
@@ -163,7 +163,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
163
  log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
164
  inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
165
  self.register_buffer(
166
- "inv_timescales",
167
  inv_timescales.float().unsqueeze(0).unsqueeze(0),
168
  persistent=False,
169
  )
@@ -184,7 +184,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
184
  key_context_size: int,
185
  max_span_plus_1: int,
186
  ) -> torch.Tensor:
187
- """Performs the relative shift.
188
 
189
  Args:
190
  term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
@@ -193,7 +193,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
193
 
194
  Returns:
195
  Tensor of shape [B, N, U, W, C].
196
- """
197
  # term_bd_before_shift shape: [B, N, U, W, F_span]
198
  # Target shape after shift: [B, N, U, W, C]
199
 
@@ -209,7 +209,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
209
  term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
210
  # Shape after pad: [B, N, U, W, C+1]
211
 
212
- # Reshape for slicing (emulating JAX's behavior)
213
  # [B, N, U, W * (C+1)]
214
  term_bd_reshaped = term_bd_padded.reshape(
215
  (
@@ -271,7 +271,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
271
  term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
272
 
273
  # term_bd: Query-Position interaction
274
- # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
275
  # queries shape: [B, U, W, N, H]
276
  # sin_emb shape: [F, N, H]
277
  # Target output shape: [B, N, U, W, F]
@@ -338,7 +338,7 @@ class Gemma3nAudioAttention(nn.Module):
338
 
339
  q_scale = self.head_dim**-0.5
340
  r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
341
- self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
342
 
343
  lower_causal_mask = torch.tril(
344
  torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
@@ -350,10 +350,10 @@ class Gemma3nAudioAttention(nn.Module):
350
  )
351
  local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
352
  local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
353
- self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
354
 
355
  self.register_buffer(
356
- "softcap",
357
  torch.tensor(self.attention_logits_soft_cap).float(),
358
  persistent=False,
359
  )
@@ -366,7 +366,7 @@ class Gemma3nAudioAttention(nn.Module):
366
  return x
367
 
368
  def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
369
- """Turns a sequence to non overlapping blocks.
370
 
371
  Args:
372
  hidden_states: a tensor of [batch, time, ...].
@@ -375,7 +375,7 @@ class Gemma3nAudioAttention(nn.Module):
375
  A tensor of [batch, num_blocks, block_size, ...], with necessary
376
  paddings,
377
  where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
378
- """
379
  shape = hidden_states.shape
380
  b, t = shape[:2]
381
  num_blocks = (t + self.chunk_size - 1) // self.chunk_size
@@ -388,7 +388,7 @@ class Gemma3nAudioAttention(nn.Module):
388
  return hidden_states
389
 
390
  def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
391
- """Extracts temporal context for every block.
392
 
393
  Args:
394
  hidden_states: a tensor of [batch, time, ...].
@@ -400,11 +400,11 @@ class Gemma3nAudioAttention(nn.Module):
400
  and output[:, i, ...] are x[:, start-left_context:end+right_context,
401
  ...],
402
  start = i * block_size, end = (i + 1) * block_size.
403
- """
404
  pad_left = self.max_past_horizon
405
- # The JAX equivalent padding for signal.frame with pad_mode='valid' is
406
  # (left_context, right_context + block_size - 1) on the time dimension.
407
- # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
408
  # or (pad_dim_start, pad_dim_end) if two are given.
409
  # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
410
  # or dim 1 (time for [B,T]).
@@ -424,7 +424,7 @@ class Gemma3nAudioAttention(nn.Module):
424
 
425
  # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
426
  # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
427
- # We want to match JAX's typical output for such operations which might be
428
  # [B, num_blocks, frame_len, N, H] if N, H are present.
429
  # The relative_position_embedding expects keys as [B, U, C, N, H].
430
  # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
@@ -436,7 +436,7 @@ class Gemma3nAudioAttention(nn.Module):
436
  return x_unfolded.contiguous()
437
 
438
  def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
439
- # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
440
  qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
441
  query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
442
  key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
@@ -472,7 +472,7 @@ class Gemma3nAudioAttention(nn.Module):
472
  extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
473
  batch_size, num_query_blocks, self.context_size
474
  )
475
- # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
476
  # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
477
  # but for the mask case, this should hold.
478
  if extracted_valid_mask_blocks.shape != (
@@ -481,9 +481,9 @@ class Gemma3nAudioAttention(nn.Module):
481
  self.context_size,
482
  ):
483
  raise ValueError(
484
- "Shape of extracted_valid_mask_blocks"
485
- f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
486
- f" {num_query_blocks}, {self.context_size}) after potential reshape."
487
  )
488
 
489
  # 3. Expand dimensions for broadcasting with logits and causal mask.
@@ -518,7 +518,7 @@ class Gemma3nAudioAttention(nn.Module):
518
  logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
519
  probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
520
 
521
- # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
522
  b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
523
  h_dim = value_blocks.shape[-1]
524
  prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
@@ -539,21 +539,21 @@ class Gemma3nAudioAttention(nn.Module):
539
 
540
 
541
  class Gemma3nAudioCumulativeGroupNorm(nn.Module):
542
- """Applies Group Normalization cumulatively over the time dimension.
543
 
544
  This layer normalizes the input by calculating the mean and variance
545
  cumulatively over the time dimension (dim 1). The statistics are computed
546
- over all feature dimensions (specified by `feature_dims` and `num_channels`)
547
- for elements marked as valid by the optional `mask`.
548
 
549
- If a `mask` is provided (True for valid, False for invalid/padded),
550
  invalid time steps do not contribute to the statistics calculation, and
551
  their corresponding output values are zeroed out.
552
 
553
  Scale and bias, if enabled, are applied per-channel (last dimension).
554
- This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
555
- and `cumulative=True`.
556
- """
557
 
558
  def __init__(
559
  self,
@@ -574,19 +574,19 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
574
  self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
575
 
576
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
- """Applies cumulative group norm, optionally using a mask.
578
 
579
  Args:
580
  hidden_states: Input tensor, shape [B, T, *feature_dims, C].
581
 
582
  Returns:
583
  Normalized tensor with the same shape as x.
584
- """
585
  expected_input_suffix = self.feature_dims + (self.num_channels,)
586
  if hidden_states.shape[2:] != expected_input_suffix:
587
  raise ValueError(
588
- f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
589
- f" suffix (feature_dims + num_channels) {expected_input_suffix}"
590
  )
591
 
592
  input_dtype = hidden_states.dtype
@@ -594,7 +594,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
594
  calc_dtype = torch.float32
595
  x_calc = hidden_states.to(calc_dtype)
596
 
597
- # Prepare a broadcastable mask (`mask_calc`).
598
  # If no mask is provided, treat all elements as valid
599
  # (mask_calc is all ones).
600
  # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
@@ -607,7 +607,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
607
  cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
608
 
609
  # 3. Count of valid elements in the normalization group at each time step.
610
- # (A "group" here consists of all features at a given Batch, Time).
611
  elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
612
  # 4. Cumulative count of valid elements over time.
613
  cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
@@ -648,11 +648,11 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
648
 
649
 
650
  class Gemma3nAudioSSCPConvBlock(nn.Module):
651
- """A single convolution block for the SubSampleConvProjection.
652
 
653
  This block consists of a 2D convolution, followed by CumulativeGroupNorm,
654
  and a ReLU activation. It handles manual padding for the convolution.
655
- """
656
 
657
  def __init__(
658
  self,
@@ -665,7 +665,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
665
  self.config = config
666
  self.manual_padding = manual_padding
667
 
668
- # in_channels is 1 for the first block, or C_out from previous block's conv
669
  in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
670
  out_channels = self.config.sscp_conv_channel_size[idx]
671
  kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
@@ -701,7 +701,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
701
  # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
702
  # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
703
  # F.pad applies to last two dims: F_in then T_in
704
- audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0)
705
  # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
706
  # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
707
  audio_encodings_conv = self.conv(audio_encodings_padded)
@@ -728,7 +728,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
728
  stride_h, stride_w = config.sscp_conv_stride_size[i]
729
 
730
  # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
731
- # JAX 'reverse_causal' padding is (0, kernel_size - 1)
732
  pad_t_top = 0
733
  pad_t_bottom = kernel_h - 1
734
 
@@ -736,7 +736,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
736
  # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
737
  # and the successful test configuration.
738
  # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
739
- # to match generic JAX 'SAME' behavior if it differs.
740
  pad_f_left = 1
741
  pad_f_right = 1
742
 
@@ -792,7 +792,7 @@ class Gemma3nAudioConformerAttention(nn.Module):
792
  super().__init__()
793
  self.config = config
794
  self.post_in_features = self.config.hidden_size
795
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
796
  self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
797
  self.attn = Gemma3nAudioAttention(config)
798
  self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
@@ -820,7 +820,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
820
  super().__init__()
821
  self.config = config
822
 
823
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
824
 
825
  self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
826
  self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
@@ -856,7 +856,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
856
  groups=self.config.hidden_size, # Depthwise
857
  bias=False,
858
  )
859
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
860
  self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
861
  self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
862
 
@@ -892,7 +892,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
892
  self.attention = Gemma3nAudioConformerAttention(self.config)
893
  self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
894
  self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
895
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
896
  self.norm = Gemma3nRMSNorm(self.config.hidden_size)
897
 
898
  def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
@@ -911,11 +911,11 @@ class Gemma3nAudioConformerBlock(nn.Module):
911
 
912
 
913
  class Gemma3nAudioEncoder(PreTrainedModel):
914
- """An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture."""
915
 
916
  config_class = Gemma3nAudioConfig
917
 
918
- main_input_name = "audio_mel"
919
 
920
  def __init__(self, config: Gemma3nAudioConfig):
921
  super().__init__(config)
@@ -929,7 +929,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
929
  def forward(
930
  self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
931
  ) -> tuple[torch.Tensor, torch.BoolTensor]:
932
- """Encodes a batch of MELs.
933
 
934
  Args:
935
  audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
@@ -937,10 +937,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
937
 
938
  Returns:
939
  audio_encodings: a torch.Tensor of shape
940
- `[batch_size, self.config.audio_soft_tokens_per_image,
941
- self.config.audio_config.hidden_size]`
942
  audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
943
- """
944
  audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
945
 
946
  # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
@@ -983,20 +983,20 @@ class Gemma3nAudioEncoder(PreTrainedModel):
983
 
984
 
985
  class Gemma3nTextScaledWordEmbedding(nn.Embedding):
986
- """
987
- This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
988
- """
989
 
990
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
991
  super().__init__(num_embeddings, embedding_dim, padding_idx)
992
- self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
993
 
994
  def forward(self, input_ids: torch.Tensor):
995
  return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
996
 
997
 
998
  class Gemma3nTextLaurelBlock(nn.Module):
999
- """Learned Augmented Residual Layer"""
1000
 
1001
  def __init__(self, config: Gemma3nTextConfig):
1002
  super().__init__()
@@ -1052,16 +1052,16 @@ class Gemma3nTextMLP(nn.Module):
1052
 
1053
 
1054
  class Gemma3nTextAltUp(nn.Module):
1055
- """Alternating Updates (AltUp)
1056
 
1057
- The AltUp module wraps transformer layers. The `predict` step modifies the
1058
- input to the transformer layer, and the `correct` step propagates the output
1059
  of the transformer layer to the sparsely updated dimensions.
1060
 
1061
  See more in the research paper:
1062
 
1063
  https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
1064
- """
1065
 
1066
  def __init__(self, config: Gemma3nTextConfig):
1067
  super().__init__()
@@ -1071,7 +1071,7 @@ class Gemma3nTextAltUp(nn.Module):
1071
  self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
1072
  self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
1073
  self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
1074
- self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
1075
 
1076
  def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
1077
  router_inputs = self.router_norm(x) * self.router_input_scale
@@ -1079,15 +1079,15 @@ class Gemma3nTextAltUp(nn.Module):
1079
  return torch.tanh(routed.float()).type_as(x)
1080
 
1081
  def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
1082
- """Predicts the output of a layer using a trainable map.
1083
 
1084
  Args:
1085
- hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
1086
- stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
1087
 
1088
  Returns:
1089
- A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
1090
- """
1091
  modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
1092
 
1093
  if self.training and self.config.altup_coef_clip is not None:
@@ -1107,17 +1107,17 @@ class Gemma3nTextAltUp(nn.Module):
1107
  return predictions.contiguous().type_as(hidden_states)
1108
 
1109
  def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
1110
- """Corrects the predictions relative to the
1111
 
1112
  Args:
1113
- predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
1114
- stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
1115
- activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
1116
 
1117
  Returns:
1118
- A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
1119
  predictions relative to the activated input embeddings.
1120
- """
1121
  modalities = self.compute_router_modalities(activated)
1122
  innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
1123
  innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
@@ -1125,7 +1125,7 @@ class Gemma3nTextAltUp(nn.Module):
1125
  if self.config.altup_coef_clip is not None:
1126
  self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
1127
 
1128
- # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
1129
  # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
1130
  # and expand on dim1 for broadcastability
1131
  all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
@@ -1136,26 +1136,26 @@ class Gemma3nTextAltUp(nn.Module):
1136
  return corrected.contiguous().type_as(activated)
1137
 
1138
  def forward(self, corrected: torch.Tensor) -> torch.Tensor:
1139
- """
1140
- This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
1141
  (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
1142
- `scale_corrected_output`
1143
- """
1144
  return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1145
 
1146
  def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
1147
- """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
1148
  return self.forward(corrected)
1149
 
1150
 
1151
  class Gemma3nTextRotaryEmbedding(nn.Module):
1152
  def __init__(self, config: Gemma3nTextConfig, device=None):
1153
  super().__init__()
1154
- # BC: "rope_type" was originally "type"
1155
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1156
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1157
  else:
1158
- self.rope_type = "default"
1159
  self.max_seq_len_cached = config.max_position_embeddings
1160
  self.original_max_seq_len = config.max_position_embeddings
1161
 
@@ -1163,7 +1163,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
1163
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
1164
 
1165
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1166
- self.register_buffer("inv_freq", inv_freq, persistent=False)
1167
  self.original_inv_freq = self.inv_freq
1168
 
1169
  @torch.no_grad()
@@ -1172,7 +1172,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
1172
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1173
  position_ids_expanded = position_ids[:, None, :].float()
1174
 
1175
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1176
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
1177
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1178
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1183,17 +1183,17 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
1183
 
1184
 
1185
  def rotate_half(x):
1186
- """Rotates half the hidden dims of the input."""
1187
  x1 = x[..., : x.shape[-1] // 2]
1188
  x2 = x[..., x.shape[-1] // 2 :]
1189
  return torch.cat((-x2, x1), dim=-1)
1190
 
1191
 
1192
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1193
- """
1194
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
1195
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1196
- """
1197
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1198
  if n_rep == 1:
1199
  return hidden_states
@@ -1243,38 +1243,38 @@ def apply_rotary_pos_emb(
1243
  position_ids: Optional[torch.Tensor] = None,
1244
  unsqueeze_dim: int = 1,
1245
  ):
1246
- """Applies Rotary Position Embedding to the query and key tensors.
1247
 
1248
  Args:
1249
- x (`torch.Tensor`): The tensor to embed.
1250
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
1251
- sin (`torch.Tensor`): The sine part of the rotary embedding.
1252
- position_ids (`torch.Tensor`, *optional*):
1253
  Deprecated and unused.
1254
- unsqueeze_dim (`int`, *optional*, defaults to 1):
1255
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
1256
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
1257
  that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
1258
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
1259
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
1260
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
1261
  Returns:
1262
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
1263
- """
1264
  cos = cos.unsqueeze(unsqueeze_dim)
1265
  sin = sin.unsqueeze(unsqueeze_dim)
1266
  return (x * cos) + (rotate_half(x) * sin)
1267
 
1268
 
1269
  class Gemma3nTextAttention(nn.Module):
1270
- """Multi-headed attention from 'Attention Is All You Need' paper"""
1271
 
1272
  def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
1273
  super().__init__()
1274
- self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
1275
  self.config = config
1276
  self.layer_idx = layer_idx
1277
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
1278
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
1279
  self.attention_dropout = self.config.attention_dropout
1280
  self.is_causal = True
@@ -1356,15 +1356,15 @@ class Gemma3nTextAttention(nn.Module):
1356
  if past_key_value is not None:
1357
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
1358
  cache_kwargs = {
1359
- "sin": sin,
1360
- "cos": cos,
1361
- "cache_position": cache_position,
1362
- "sliding_window": self.sliding_window,
1363
  }
1364
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1365
 
1366
  attention_interface: Callable = eager_attention_forward
1367
- if self.config._attn_implementation != "eager":
1368
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
1369
 
1370
  attn_output, attn_weights = attention_interface(
@@ -1407,7 +1407,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1407
  self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
1408
  self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
1409
 
1410
- @deprecate_kwarg("last_cache_position", version="4.53.0")
1411
  def forward(
1412
  self,
1413
  hidden_states: torch.Tensor,
@@ -1460,12 +1460,12 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1460
  if self.config.altup_correct_scale:
1461
  first_prediction = self.altup.scale_corrected_output(first_prediction)
1462
 
1463
- # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
1464
  first_prediction = self.per_layer_input_gate(first_prediction)
1465
  first_prediction = self.act_fn(first_prediction)
1466
  first_prediction = torch.multiply(first_prediction, per_layer_input)
1467
 
1468
- # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
1469
  first_prediction = self.per_layer_projection(first_prediction)
1470
  first_prediction = self.post_per_layer_input_norm(first_prediction)
1471
  corrected_predictions[1:] += first_prediction
@@ -1481,10 +1481,10 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1481
  @auto_docstring
1482
  class Gemma3nPreTrainedModel(PreTrainedModel):
1483
  config_class = Gemma3nConfig
1484
- base_model_prefix = ""
1485
  supports_gradient_checkpointing = True
1486
- _no_split_modules = ["Gemma3nTextDecoderLayer"]
1487
- _skip_keys_device_placement = ["past_key_values"]
1488
  _supports_flash_attn_3 = True
1489
  _supports_flash_attn_2 = True
1490
  _supports_sdpa = True
@@ -1495,9 +1495,9 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
1495
  _supports_attention_backend = True
1496
 
1497
  def _init_weights(self, module):
1498
- # important: this ported version of Gemma2 isn't meant for training from scratch - only
1499
  # inference and fine-tuning - so the proper init weights code has been removed
1500
- std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
1501
 
1502
  if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
1503
  module.weight.data.normal_(mean=0.0, std=std)
@@ -1518,7 +1518,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
1518
  module.correct_output_scale.data.zero_()
1519
 
1520
 
1521
- @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
1522
  class Gemma3nTextModel(Gemma3nPreTrainedModel):
1523
  config_class = Gemma3nTextConfig
1524
 
@@ -1544,7 +1544,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1544
  # defaults should hold values for global RoPE.
1545
  config = copy.deepcopy(config)
1546
  config.rope_theta = config.rope_local_base_freq
1547
- config.rope_scaling = {"rope_type": "default"}
1548
  self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
1549
 
1550
  self.hidden_size = config.hidden_size
@@ -1573,8 +1573,8 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1573
  [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
1574
  )
1575
 
1576
- self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
1577
- self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
1578
 
1579
  # Initialize weights and apply final processing
1580
  self.post_init()
@@ -1601,10 +1601,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1601
  cache_position: Optional[torch.LongTensor] = None,
1602
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1603
  ) -> BaseModelOutputWithPast:
1604
- r"""
1605
  per_layer_inputs (torch.Tensor, *optional*, defaults to None):
1606
  Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
1607
- """
1608
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1609
  output_hidden_states = (
1610
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1612,11 +1612,11 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1612
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1613
 
1614
  if (input_ids is None) ^ (inputs_embeds is not None):
1615
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1616
 
1617
  if self.gradient_checkpointing and self.training and use_cache:
1618
  logger.warning_once(
1619
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1620
  )
1621
  use_cache = False
1622
 
@@ -1640,20 +1640,20 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1640
  if position_ids is None:
1641
  position_ids = cache_position.unsqueeze(0)
1642
 
1643
- # It may already have been prepared by e.g. `generate`
1644
  if not isinstance(causal_mask_mapping := attention_mask, dict):
1645
  # Prepare mask arguments
1646
  mask_kwargs = {
1647
- "config": self.config,
1648
- "input_embeds": inputs_embeds,
1649
- "attention_mask": attention_mask,
1650
- "cache_position": cache_position,
1651
- "past_key_values": past_key_values,
1652
  }
1653
  # Create the masks
1654
  causal_mask_mapping = {
1655
- "full_attention": create_causal_mask(**mask_kwargs),
1656
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
1657
  }
1658
 
1659
  # embed positions
@@ -1669,7 +1669,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1669
 
1670
  temp_hidden_states = [hidden_states_0]
1671
  for i in range(1, self.config.altup_num_inputs):
1672
- # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
1673
  altup_proj = self.altup_projections[i - 1](hidden_states_0)
1674
  current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1675
  new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
@@ -1717,7 +1717,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1717
  target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
1718
  temp_hidden_states = [hidden_states[0]]
1719
  for i in range(1, self.config.altup_num_inputs):
1720
- # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
1721
  altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
1722
  current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1723
  new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
@@ -1771,14 +1771,14 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1771
  )
1772
 
1773
 
1774
- @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
1775
  class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
1776
- _tied_weights_keys = ["lm_head.weight"]
1777
- _tp_plan = {"lm_head": "colwise_rep"}
1778
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1779
  config_class = Gemma3nTextConfig
1780
- base_model_prefix = "model"
1781
- _checkpoint_conversion_mapping = {"model.language_model": "model"}
1782
 
1783
  def __init__(self, config: Gemma3nTextConfig):
1784
  super().__init__(config)
@@ -1824,33 +1824,33 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
1824
  logits_to_keep: Union[int, torch.Tensor] = 0,
1825
  **loss_kwargs,
1826
  ) -> CausalLMOutputWithPast:
1827
- r"""
1828
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1829
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1830
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1831
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1832
 
1833
  Example:
1834
 
1835
- ```python
1836
  >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
1837
 
1838
- >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
1839
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
1840
 
1841
- >>> prompt = "What is your favorite condiment?"
1842
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1843
 
1844
  >>> # Generate
1845
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1846
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1847
- "What is your favorite condiment?"
1848
- ```"""
1849
 
1850
- if self.training and self.config._attn_implementation != "eager":
1851
  logger.warning_once(
1852
- "It is strongly recommended to train Gemma3n models with the `eager` attention implementation "
1853
- f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
1854
  )
1855
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1856
  output_hidden_states = (
@@ -1893,7 +1893,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
1893
 
1894
 
1895
  class Gemma3nMultimodalEmbedder(nn.Module):
1896
- """Embeds token ids or soft tokens for multimodal content into language model space."""
1897
 
1898
  def __init__(
1899
  self,
@@ -1919,18 +1919,18 @@ class Gemma3nMultimodalEmbedder(nn.Module):
1919
  input_ids: Optional[torch.LongTensor] = None,
1920
  inputs_embeds: Optional[torch.Tensor] = None,
1921
  ) -> torch.Tensor:
1922
- """Embeds token ids or soft tokens for multimodal content into language model space.
1923
 
1924
  Args:
1925
  input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
1926
- `[vocab_offset, vocab_offset + vocab_size)`.
1927
  inputs_embeds: A torch.Tensor containing the soft tokens to embed.
1928
 
1929
  Returns:
1930
- A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
1931
- """
1932
  if (input_ids is None) ^ (inputs_embeds is not None):
1933
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1934
 
1935
  if inputs_embeds is not None:
1936
  emb_norm = self.soft_embedding_norm(inputs_embeds)
@@ -1943,14 +1943,14 @@ class Gemma3nMultimodalEmbedder(nn.Module):
1943
 
1944
 
1945
  @auto_docstring(
1946
- custom_intro="""
1947
  The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
1948
  language modeling head.
1949
- """
1950
  )
1951
  class Gemma3nModel(Gemma3nPreTrainedModel):
1952
  _checkpoint_conversion_mapping = {}
1953
- # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
1954
  accepts_loss_kwargs = False
1955
 
1956
  def __init__(self, config: Gemma3nConfig):
@@ -1981,16 +1981,16 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
1981
  return self.language_model
1982
 
1983
  def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
1984
- """
1985
  Projects the last hidden state from the vision model into language model space.
1986
 
1987
  Args:
1988
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
1989
  The tensors corresponding to the input images.
1990
 
1991
  Returns:
1992
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
1993
- """
1994
  vision_outputs = self.vision_tower(
1995
  pixel_values=pixel_values, do_pooling=False, return_dict=True
1996
  ).last_hidden_state
@@ -2024,36 +2024,36 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
2024
  output_hidden_states: Optional[bool] = None,
2025
  **lm_kwargs,
2026
  ) -> Gemma3nCausalLMOutputWithPast:
2027
- r"""
2028
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
2029
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
2030
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
2031
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
2032
 
2033
  Example:
2034
 
2035
- ```python
2036
  >>> from PIL import Image
2037
  >>> import requests
2038
  >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
2039
 
2040
- >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
2041
- >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
2042
 
2043
- >>> prompt = "Where is the cat standing?"
2044
- >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
2045
  >>> image = Image.open(requests.get(url, stream=True).raw)
2046
 
2047
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
2048
 
2049
  >>> # Generate
2050
  >>> generate_ids = model.generate(**inputs,)
2051
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2052
- "Where is the cat standing?\nsnow"
2053
- ```
2054
- """
2055
  if (input_ids is None) ^ (inputs_embeds is not None):
2056
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
2057
 
2058
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2059
  output_hidden_states = (
@@ -2103,9 +2103,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
2103
  if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
2104
  image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
2105
  raise ValueError(
2106
- f"Number of images does not match number of special image tokens in the input text. "
2107
- f"Got {image_tokens_in_text} image tokens in the text and "
2108
- f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings."
2109
  )
2110
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
2111
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -2140,9 +2140,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
2140
  if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
2141
  audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
2142
  raise ValueError(
2143
- f"Number of audio input features does not match number of special audio tokens in the input text. "
2144
- f"Got {audio_tokens_in_text} audio tokens in the text and "
2145
- f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings."
2146
  )
2147
  audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
2148
  inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
@@ -2174,32 +2174,32 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
2174
  def get_audio_features(
2175
  self, input_features: torch.Tensor, input_features_mask: torch.Tensor
2176
  ) -> tuple[torch.Tensor, torch.Tensor]:
2177
- """
2178
  Projects the last hidden state from the audio encoder into language model space.
2179
 
2180
  Args:
2181
- input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
2182
  The tensors corresponding to the input audio.
2183
- input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
2184
  The attention mask for the input audio.
2185
 
2186
  Returns:
2187
- audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
2188
- """
2189
  audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
2190
  return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
2191
 
2192
 
2193
  @auto_docstring(
2194
- custom_intro="""
2195
  The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
2196
  head.
2197
- """
2198
  )
2199
  class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2200
  _checkpoint_conversion_mapping = {}
2201
- _tied_weights_keys = ["lm_head.weight"]
2202
- base_model_prefix = "model"
2203
 
2204
  def __init__(self, config: Gemma3nConfig):
2205
  super().__init__(config)
@@ -2239,7 +2239,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2239
 
2240
  @property
2241
  def multi_modal_projector(self):
2242
- raise AttributeError("Use embed_vision instead of multi_modal_projector.")
2243
 
2244
  @can_return_tuple
2245
  @auto_docstring
@@ -2262,38 +2262,38 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2262
  logits_to_keep: Union[int, torch.Tensor] = 0,
2263
  **lm_kwargs,
2264
  ) -> Gemma3nCausalLMOutputWithPast:
2265
- r"""
2266
  input_features (torch.Tensor, *optional*, defaults to None):
2267
  The audio inputs to be encoded.
2268
  input_features_mask (torch.Tensor, *optional*, defaults to None):
2269
  The attention mask for the input audio.
2270
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
2271
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
2272
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
2273
  ignored (masked), the loss is only computed for the tokens with labels in
2274
- `[0, ..., config.text_config.vocab_size]`.
2275
 
2276
  Example:
2277
 
2278
- ```python
2279
  >>> from PIL import Image
2280
  >>> import requests
2281
  >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
2282
 
2283
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
2284
- >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
2285
 
2286
  >>> messages = [
2287
  ... {
2288
- ... "role": "system",
2289
- ... "content": [
2290
- ... {"type": "text", "text": "You are a helpful assistant."}
2291
  ... ]
2292
  ... },
2293
  ... {
2294
- ... "role": "user", "content": [
2295
- ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
2296
- ... {"type": "text", "text": "Where is the cat standing?"},
2297
  ... ]
2298
  ... },
2299
  ... ]
@@ -2302,15 +2302,15 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2302
  ... messages,
2303
  ... tokenizer=True,
2304
  ... return_dict=True,
2305
- ... return_tensors="pt",
2306
  ... add_generation_prompt=True
2307
  ... )
2308
  >>> # Generate
2309
  >>> generate_ids = model.generate(**inputs)
2310
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2311
- "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
2312
- ```
2313
- """
2314
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2315
  output_hidden_states = (
2316
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -2393,7 +2393,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2393
  labels=None,
2394
  **kwargs,
2395
  ):
2396
- # Overwritten -- custom `position_ids` and `pixel_values` handling
2397
  model_inputs = super().prepare_inputs_for_generation(
2398
  input_ids,
2399
  past_key_values=past_key_values,
@@ -2407,13 +2407,13 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2407
  **kwargs,
2408
  )
2409
 
2410
- # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
2411
  # tokens anymore. Otherwise multimodal inputs should be passed to model.
2412
  # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
2413
  if cache_position[0] == 0:
2414
- model_inputs["pixel_values"] = pixel_values
2415
- model_inputs["input_features"] = input_features
2416
- model_inputs["input_features_mask"] = input_features_mask
2417
 
2418
  return model_inputs
2419
 
@@ -2423,10 +2423,10 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2423
 
2424
 
2425
  __all__ = [
2426
- "Gemma3nAudioEncoder",
2427
- "Gemma3nForCausalLM",
2428
- "Gemma3nForConditionalGeneration",
2429
- "Gemma3nModel",
2430
- "Gemma3nPreTrainedModel",
2431
- "Gemma3nTextModel",
2432
  ]
 
8
  # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
9
  #
10
  #
11
+ # Licensed under the Apache License, Version 2.0 (the \"License\");
12
  # you may not use this file except in compliance with the License.
13
  # You may obtain a copy of the License at
14
  #
15
  # http://www.apache.org/licenses/LICENSE-2.0
16
  #
17
  # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an \"AS IS\" BASIS,
19
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
  # See the License for the specific language governing permissions and
21
  # limitations under the License.
 
56
 
57
  @dataclass
58
  @auto_docstring(
59
+ custom_intro=\"\"\"
60
  Base class for Gemma3n outputs, with hidden states and attentions.
61
+ \"\"\"
62
  )
63
  class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
64
+ r\"\"\"
65
+ past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
66
+ Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
67
+ \`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
68
 
69
  Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
70
+ \`past_key_values\` input) to speed up sequential decoding.
71
+ image_hidden_states (\`torch.FloatTensor\`, *optional*):
72
+ A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
73
  image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
74
+ audio_hidden_states (\`torch.FloatTensor\`, *optional*):
75
+ A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
76
  audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
77
+ \"\"\"
78
 
79
  image_hidden_states: Optional[torch.FloatTensor] = None
80
 
 
83
 
84
  @dataclass
85
  @auto_docstring(
86
+ custom_intro=\"\"\"
87
  Base class for Gemma3n causal language model (or autoregressive) outputs.
88
+ \"\"\"
89
  )
90
  class Gemma3nCausalLMOutputWithPast(ModelOutput):
91
+ r\"\"\"
92
+ loss (\`torch.FloatTensor\` of shape \`(1,)\`, *optional*, returned when \`labels\` is provided):
93
  Language modeling loss (for next-token prediction).
94
+ logits (\`torch.FloatTensor\` of shape \`(batch_size, sequence_length, config.text_config.vocab_size)\`):
95
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
96
+ past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
97
+ Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
98
+ \`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
99
 
100
  Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
101
+ \`past_key_values\` input) to speed up sequential decoding.
102
+ image_hidden_states (\`torch.FloatTensor\`, *optional*):
103
+ A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
104
  image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
105
+ audio_hidden_states (\`torch.FloatTensor\`, *optional*):
106
+ A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
107
  audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
108
+ \"\"\"
109
 
110
  loss: Optional[torch.FloatTensor] = None
111
  logits: Optional[torch.FloatTensor] = None
 
126
  if self.with_scale:
127
  self.weight = nn.Parameter(torch.ones(dim))
128
  else:
129
+ self.register_buffer(\"weight\", torch.tensor(1.0), persistent=False)
130
 
131
  def _norm(self, x):
132
  return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
138
  return output.type_as(x)
139
 
140
  def extra_repr(self):
141
+ return f\"{tuple(self.weight.shape)}, eps={self.eps}\"
142
 
143
 
144
  # ==== Audio Encoder ====
 
163
  log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
164
  inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
165
  self.register_buffer(
166
+ \"inv_timescales\",
167
  inv_timescales.float().unsqueeze(0).unsqueeze(0),
168
  persistent=False,
169
  )
 
184
  key_context_size: int,
185
  max_span_plus_1: int,
186
  ) -> torch.Tensor:
187
+ \"\"\"Performs the relative shift.
188
 
189
  Args:
190
  term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
 
193
 
194
  Returns:
195
  Tensor of shape [B, N, U, W, C].
196
+ \"\"\"
197
  # term_bd_before_shift shape: [B, N, U, W, F_span]
198
  # Target shape after shift: [B, N, U, W, C]
199
 
 
209
  term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
210
  # Shape after pad: [B, N, U, W, C+1]
211
 
212
+ # Reshape for slicing (emulating JAX\'s behavior)
213
  # [B, N, U, W * (C+1)]
214
  term_bd_reshaped = term_bd_padded.reshape(
215
  (
 
271
  term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
272
 
273
  # term_bd: Query-Position interaction
274
+ # Original einsum: term_bd_unshifed = torch.einsum(\'buwnh,fnh->bnuwf\', queries, sin_emb)
275
  # queries shape: [B, U, W, N, H]
276
  # sin_emb shape: [F, N, H]
277
  # Target output shape: [B, N, U, W, F]
 
338
 
339
  q_scale = self.head_dim**-0.5
340
  r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
341
+ self.register_buffer(\"q_scale\", (q_scale * r_softplus_0).clone().detach(), persistent=False)
342
 
343
  lower_causal_mask = torch.tril(
344
  torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
 
350
  )
351
  local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
352
  local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
353
+ self.register_buffer(\"local_causal_valid_mask\", local_causal_valid_mask, persistent=False)
354
 
355
  self.register_buffer(
356
+ \"softcap\",
357
  torch.tensor(self.attention_logits_soft_cap).float(),
358
  persistent=False,
359
  )
 
366
  return x
367
 
368
  def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
369
+ \"\"\"Turns a sequence to non overlapping blocks.
370
 
371
  Args:
372
  hidden_states: a tensor of [batch, time, ...].
 
375
  A tensor of [batch, num_blocks, block_size, ...], with necessary
376
  paddings,
377
  where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
378
+ \"\"\"
379
  shape = hidden_states.shape
380
  b, t = shape[:2]
381
  num_blocks = (t + self.chunk_size - 1) // self.chunk_size
 
388
  return hidden_states
389
 
390
  def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
391
+ \"\"\"Extracts temporal context for every block.
392
 
393
  Args:
394
  hidden_states: a tensor of [batch, time, ...].
 
400
  and output[:, i, ...] are x[:, start-left_context:end+right_context,
401
  ...],
402
  start = i * block_size, end = (i + 1) * block_size.
403
+ \"\"\"
404
  pad_left = self.max_past_horizon
405
+ # The JAX equivalent padding for signal.frame with pad_mode=\'valid\' is
406
  # (left_context, right_context + block_size - 1) on the time dimension.
407
+ # PyTorch\'s _pad_dim1 applies padding symmetrically if only one value is given,
408
  # or (pad_dim_start, pad_dim_end) if two are given.
409
  # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
410
  # or dim 1 (time for [B,T]).
 
424
 
425
  # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
426
  # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
427
+ # We want to match JAX\'s typical output for such operations which might be
428
  # [B, num_blocks, frame_len, N, H] if N, H are present.
429
  # The relative_position_embedding expects keys as [B, U, C, N, H].
430
  # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
 
436
  return x_unfolded.contiguous()
437
 
438
  def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
439
+ # sl.Dense uses jax.numpy.einsum(\"...a,abcd->...bcd\") and jax.numpy.select()
440
  qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
441
  query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
442
  key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
 
472
  extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
473
  batch_size, num_query_blocks, self.context_size
474
  )
475
+ # After potential reshape, ensure it\'s [B, U, C] if it was from a [B,T] mask.
476
  # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
477
  # but for the mask case, this should hold.
478
  if extracted_valid_mask_blocks.shape != (
 
481
  self.context_size,
482
  ):
483
  raise ValueError(
484
+ \"Shape of extracted_valid_mask_blocks\"
485
+ f\" {extracted_valid_mask_blocks.shape} is not ({batch_size},\"
486
+ f\" {num_query_blocks}, {self.context_size}) after potential reshape.\"
487
  )
488
 
489
  # 3. Expand dimensions for broadcasting with logits and causal mask.
 
518
  logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
519
  probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
520
 
521
+ # context_vectors is adapted from jax.numpy.einsum(\"BNuwc,BucNH->BuwNH\", ...)
522
  b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
523
  h_dim = value_blocks.shape[-1]
524
  prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
 
539
 
540
 
541
  class Gemma3nAudioCumulativeGroupNorm(nn.Module):
542
+ \"\"\"Applies Group Normalization cumulatively over the time dimension.
543
 
544
  This layer normalizes the input by calculating the mean and variance
545
  cumulatively over the time dimension (dim 1). The statistics are computed
546
+ over all feature dimensions (specified by \`feature_dims\` and \`num_channels\`)
547
+ for elements marked as valid by the optional \`mask\`.
548
 
549
+ If a \`mask\` is provided (True for valid, False for invalid/padded),
550
  invalid time steps do not contribute to the statistics calculation, and
551
  their corresponding output values are zeroed out.
552
 
553
  Scale and bias, if enabled, are applied per-channel (last dimension).
554
+ This behavior is similar to JAX\'s \`GroupNormalization\` with \`num_groups=1\`
555
+ and \`cumulative=True\`.
556
+ \"\"\"
557
 
558
  def __init__(
559
  self,
 
574
  self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
575
 
576
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
+ \"\"\"Applies cumulative group norm, optionally using a mask.
578
 
579
  Args:
580
  hidden_states: Input tensor, shape [B, T, *feature_dims, C].
581
 
582
  Returns:
583
  Normalized tensor with the same shape as x.
584
+ \"\"\"
585
  expected_input_suffix = self.feature_dims + (self.num_channels,)
586
  if hidden_states.shape[2:] != expected_input_suffix:
587
  raise ValueError(
588
+ f\"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected\"
589
+ f\" suffix (feature_dims + num_channels) {expected_input_suffix}\"
590
  )
591
 
592
  input_dtype = hidden_states.dtype
 
594
  calc_dtype = torch.float32
595
  x_calc = hidden_states.to(calc_dtype)
596
 
597
+ # Prepare a broadcastable mask (\`mask_calc\`).
598
  # If no mask is provided, treat all elements as valid
599
  # (mask_calc is all ones).
600
  # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
 
607
  cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
608
 
609
  # 3. Count of valid elements in the normalization group at each time step.
610
+ # (A \"group\" here consists of all features at a given Batch, Time).
611
  elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
612
  # 4. Cumulative count of valid elements over time.
613
  cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
 
648
 
649
 
650
  class Gemma3nAudioSSCPConvBlock(nn.Module):
651
+ \"\"\"A single convolution block for the SubSampleConvProjection.
652
 
653
  This block consists of a 2D convolution, followed by CumulativeGroupNorm,
654
  and a ReLU activation. It handles manual padding for the convolution.
655
+ \"\"\"
656
 
657
  def __init__(
658
  self,
 
665
  self.config = config
666
  self.manual_padding = manual_padding
667
 
668
+ # in_channels is 1 for the first block, or C_out from previous block\'s conv
669
  in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
670
  out_channels = self.config.sscp_conv_channel_size[idx]
671
  kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
 
701
  # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
702
  # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
703
  # F.pad applies to last two dims: F_in then T_in
704
+ audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode=\"constant\", value=0.0)
705
  # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
706
  # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
707
  audio_encodings_conv = self.conv(audio_encodings_padded)
 
728
  stride_h, stride_w = config.sscp_conv_stride_size[i]
729
 
730
  # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
731
+ # JAX \'reverse_causal\' padding is (0, kernel_size - 1)
732
  pad_t_top = 0
733
  pad_t_bottom = kernel_h - 1
734
 
 
736
  # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
737
  # and the successful test configuration.
738
  # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
739
+ # to match generic JAX \'SAME\' behavior if it differs.
740
  pad_f_left = 1
741
  pad_f_right = 1
742
 
 
792
  super().__init__()
793
  self.config = config
794
  self.post_in_features = self.config.hidden_size
795
+ self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
796
  self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
797
  self.attn = Gemma3nAudioAttention(config)
798
  self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
 
820
  super().__init__()
821
  self.config = config
822
 
823
+ self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
824
 
825
  self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
826
  self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
 
856
  groups=self.config.hidden_size, # Depthwise
857
  bias=False,
858
  )
859
+ self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
860
  self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
861
  self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
862
 
 
892
  self.attention = Gemma3nAudioConformerAttention(self.config)
893
  self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
894
  self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
895
+ self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
896
  self.norm = Gemma3nRMSNorm(self.config.hidden_size)
897
 
898
  def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
 
911
 
912
 
913
  class Gemma3nAudioEncoder(PreTrainedModel):
914
+ \"\"\"An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture.\"\"\"
915
 
916
  config_class = Gemma3nAudioConfig
917
 
918
+ main_input_name = \"audio_mel\"
919
 
920
  def __init__(self, config: Gemma3nAudioConfig):
921
  super().__init__(config)
 
929
  def forward(
930
  self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
931
  ) -> tuple[torch.Tensor, torch.BoolTensor]:
932
+ \"\"\"Encodes a batch of MELs.
933
 
934
  Args:
935
  audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
 
937
 
938
  Returns:
939
  audio_encodings: a torch.Tensor of shape
940
+ \`[batch_size, self.config.audio_soft_tokens_per_image,
941
+ self.config.audio_config.hidden_size]\`
942
  audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
943
+ \"\"\"
944
  audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
945
 
946
  # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
 
983
 
984
 
985
  class Gemma3nTextScaledWordEmbedding(nn.Embedding):
986
+ \"\"\"
987
+ This module overrides nn.Embeddings\' forward by multiplying with embeddings scale.
988
+ \"\"\"
989
 
990
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
991
  super().__init__(num_embeddings, embedding_dim, padding_idx)
992
+ self.register_buffer(\"embed_scale\", torch.tensor(embed_scale), persistent=False)
993
 
994
  def forward(self, input_ids: torch.Tensor):
995
  return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
996
 
997
 
998
  class Gemma3nTextLaurelBlock(nn.Module):
999
+ \"\"\"Learned Augmented Residual Layer\"\"\"
1000
 
1001
  def __init__(self, config: Gemma3nTextConfig):
1002
  super().__init__()
 
1052
 
1053
 
1054
  class Gemma3nTextAltUp(nn.Module):
1055
+ \"\"\"Alternating Updates (AltUp)
1056
 
1057
+ The AltUp module wraps transformer layers. The \`predict\` step modifies the
1058
+ input to the transformer layer, and the \`correct\` step propagates the output
1059
  of the transformer layer to the sparsely updated dimensions.
1060
 
1061
  See more in the research paper:
1062
 
1063
  https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
1064
+ \"\"\"
1065
 
1066
  def __init__(self, config: Gemma3nTextConfig):
1067
  super().__init__()
 
1071
  self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
1072
  self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
1073
  self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
1074
+ self.register_buffer(\"router_input_scale\", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
1075
 
1076
  def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
1077
  router_inputs = self.router_norm(x) * self.router_input_scale
 
1079
  return torch.tanh(routed.float()).type_as(x)
1080
 
1081
  def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
1082
+ \"\"\"Predicts the output of a layer using a trainable map.
1083
 
1084
  Args:
1085
+ hidden_states: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
1086
+ stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
1087
 
1088
  Returns:
1089
+ A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` containing the predictions.
1090
+ \"\"\"
1091
  modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
1092
 
1093
  if self.training and self.config.altup_coef_clip is not None:
 
1107
  return predictions.contiguous().type_as(hidden_states)
1108
 
1109
  def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
1110
+ \"\"\"Corrects the predictions relative to the
1111
 
1112
  Args:
1113
+ predictions: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
1114
+ stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
1115
+ activated: A 3D tensor of shape \`[batch_size, num_tokens, hidden_size]\` containing the activated inputs.
1116
 
1117
  Returns:
1118
+ A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` correcting the original
1119
  predictions relative to the activated input embeddings.
1120
+ \"\"\"
1121
  modalities = self.compute_router_modalities(activated)
1122
  innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
1123
  innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
 
1125
  if self.config.altup_coef_clip is not None:
1126
  self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
1127
 
1128
+ # all_coefs adapted from jax.numpy.einsum(\"...p,pi->...i\", ...)
1129
  # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
1130
  # and expand on dim1 for broadcastability
1131
  all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
 
1136
  return corrected.contiguous().type_as(activated)
1137
 
1138
  def forward(self, corrected: torch.Tensor) -> torch.Tensor:
1139
+ \"\"\"
1140
+ This is only defined as the \`forward\` so that accelerate hooks can move correctly \`correct_output_scale\`
1141
  (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
1142
+ \`scale_corrected_output\`
1143
+ \"\"\"
1144
  return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
1145
 
1146
  def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
1147
+ \"\"\"Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].\"\"\"
1148
  return self.forward(corrected)
1149
 
1150
 
1151
  class Gemma3nTextRotaryEmbedding(nn.Module):
1152
  def __init__(self, config: Gemma3nTextConfig, device=None):
1153
  super().__init__()
1154
+ # BC: \"rope_type\" was originally \"type\"
1155
+ if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:
1156
+ self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))
1157
  else:
1158
+ self.rope_type = \"default\"
1159
  self.max_seq_len_cached = config.max_position_embeddings
1160
  self.original_max_seq_len = config.max_position_embeddings
1161
 
 
1163
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
1164
 
1165
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1166
+ self.register_buffer(\"inv_freq\", inv_freq, persistent=False)
1167
  self.original_inv_freq = self.inv_freq
1168
 
1169
  @torch.no_grad()
 
1172
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1173
  position_ids_expanded = position_ids[:, None, :].float()
1174
 
1175
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"
1176
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
1177
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1178
  emb = torch.cat((freqs, freqs), dim=-1)
 
1183
 
1184
 
1185
  def rotate_half(x):
1186
+ \"\"\"Rotates half the hidden dims of the input.\"\"\"
1187
  x1 = x[..., : x.shape[-1] // 2]
1188
  x2 = x[..., x.shape[-1] // 2 :]
1189
  return torch.cat((-x2, x1), dim=-1)
1190
 
1191
 
1192
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1193
+ \"\"\"
1194
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
1195
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1196
+ \"\"\"
1197
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1198
  if n_rep == 1:
1199
  return hidden_states
 
1243
  position_ids: Optional[torch.Tensor] = None,
1244
  unsqueeze_dim: int = 1,
1245
  ):
1246
+ \"\"\"Applies Rotary Position Embedding to the query and key tensors.
1247
 
1248
  Args:
1249
+ x (\`torch.Tensor\`): The tensor to embed.
1250
+ cos (\`torch.Tensor\`): The cosine part of the rotary embedding.
1251
+ sin (\`torch.Tensor\`): The sine part of the rotary embedding.
1252
+ position_ids (\`torch.Tensor\`, *optional*):
1253
  Deprecated and unused.
1254
+ unsqueeze_dim (\`int\`, *optional*, defaults to 1):
1255
+ The \'unsqueeze_dim\' argument specifies the dimension along which to unsqueeze cos[position_ids] and
1256
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
1257
  that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
1258
  k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
1259
  cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
1260
  the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
1261
  Returns:
1262
+ \`tuple(torch.Tensor)\` comprising of the query and key tensors rotated using the Rotary Position Embedding.
1263
+ \"\"\"
1264
  cos = cos.unsqueeze(unsqueeze_dim)
1265
  sin = sin.unsqueeze(unsqueeze_dim)
1266
  return (x * cos) + (rotate_half(x) * sin)
1267
 
1268
 
1269
  class Gemma3nTextAttention(nn.Module):
1270
+ \"\"\"Multi-headed attention from \'Attention Is All You Need\' paper\"\"\"
1271
 
1272
  def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
1273
  super().__init__()
1274
+ self.is_sliding = config.layer_types[layer_idx] == \"sliding_attention\"
1275
  self.config = config
1276
  self.layer_idx = layer_idx
1277
+ self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)
1278
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
1279
  self.attention_dropout = self.config.attention_dropout
1280
  self.is_causal = True
 
1356
  if past_key_value is not None:
1357
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
1358
  cache_kwargs = {
1359
+ \"sin\": sin,
1360
+ \"cos\": cos,
1361
+ \"cache_position\": cache_position,
1362
+ \"sliding_window\": self.sliding_window,
1363
  }
1364
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1365
 
1366
  attention_interface: Callable = eager_attention_forward
1367
+ if self.config._attn_implementation != \"eager\":
1368
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
1369
 
1370
  attn_output, attn_weights = attention_interface(
 
1407
  self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
1408
  self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
1409
 
1410
+ @deprecate_kwarg(\"last_cache_position\", version=\"4.53.0\")
1411
  def forward(
1412
  self,
1413
  hidden_states: torch.Tensor,
 
1460
  if self.config.altup_correct_scale:
1461
  first_prediction = self.altup.scale_corrected_output(first_prediction)
1462
 
1463
+ # per_layer_input_gate adapted from jax.numpy.einsum(\"btd,dp->btp\", ...)
1464
  first_prediction = self.per_layer_input_gate(first_prediction)
1465
  first_prediction = self.act_fn(first_prediction)
1466
  first_prediction = torch.multiply(first_prediction, per_layer_input)
1467
 
1468
+ # per_layer_projection adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
1469
  first_prediction = self.per_layer_projection(first_prediction)
1470
  first_prediction = self.post_per_layer_input_norm(first_prediction)
1471
  corrected_predictions[1:] += first_prediction
 
1481
  @auto_docstring
1482
  class Gemma3nPreTrainedModel(PreTrainedModel):
1483
  config_class = Gemma3nConfig
1484
+ base_model_prefix = \"\"
1485
  supports_gradient_checkpointing = True
1486
+ _no_split_modules = [\"Gemma3nTextDecoderLayer\"]
1487
+ _skip_keys_device_placement = [\"past_key_values\"]
1488
  _supports_flash_attn_3 = True
1489
  _supports_flash_attn_2 = True
1490
  _supports_sdpa = True
 
1495
  _supports_attention_backend = True
1496
 
1497
  def _init_weights(self, module):
1498
+ # important: this ported version of Gemma2 isn\'t meant for training from scratch - only
1499
  # inference and fine-tuning - so the proper init weights code has been removed
1500
+ std = getattr(self.config, \"initializer_range\", self.config.get_text_config().initializer_range)
1501
 
1502
  if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
1503
  module.weight.data.normal_(mean=0.0, std=std)
 
1518
  module.correct_output_scale.data.zero_()
1519
 
1520
 
1521
+ @auto_docstring(custom_intro=\"The base Gemma 3n language model without a language modeling head.\")
1522
  class Gemma3nTextModel(Gemma3nPreTrainedModel):
1523
  config_class = Gemma3nTextConfig
1524
 
 
1544
  # defaults should hold values for global RoPE.
1545
  config = copy.deepcopy(config)
1546
  config.rope_theta = config.rope_local_base_freq
1547
+ config.rope_scaling = {\"rope_type\": \"default\"}
1548
  self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
1549
 
1550
  self.hidden_size = config.hidden_size
 
1573
  [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
1574
  )
1575
 
1576
+ self.register_buffer(\"per_layer_projection_scale\", torch.tensor(self.hidden_size**-0.5), persistent=False)
1577
+ self.register_buffer(\"per_layer_input_scale\", torch.rsqrt(torch.tensor(2.0)), persistent=False)
1578
 
1579
  # Initialize weights and apply final processing
1580
  self.post_init()
 
1601
  cache_position: Optional[torch.LongTensor] = None,
1602
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1603
  ) -> BaseModelOutputWithPast:
1604
+ r\"\"\"
1605
  per_layer_inputs (torch.Tensor, *optional*, defaults to None):
1606
  Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
1607
+ \"\"\"
1608
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1609
  output_hidden_states = (
1610
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1612
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1613
 
1614
  if (input_ids is None) ^ (inputs_embeds is not None):
1615
+ raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
1616
 
1617
  if self.gradient_checkpointing and self.training and use_cache:
1618
  logger.warning_once(
1619
+ \"\`use_cache=True\` is incompatible with gradient checkpointing. Setting \`use_cache=False\`.\"
1620
  )
1621
  use_cache = False
1622
 
 
1640
  if position_ids is None:
1641
  position_ids = cache_position.unsqueeze(0)
1642
 
1643
+ # It may already have been prepared by e.g. \`generate\`
1644
  if not isinstance(causal_mask_mapping := attention_mask, dict):
1645
  # Prepare mask arguments
1646
  mask_kwargs = {
1647
+ \"config\": self.config,
1648
+ \"input_embeds\": inputs_embeds,
1649
+ \"attention_mask\": attention_mask,
1650
+ \"cache_position\": cache_position,
1651
+ \"past_key_values\": past_key_values,
1652
  }
1653
  # Create the masks
1654
  causal_mask_mapping = {
1655
+ \"full_attention\": create_causal_mask(**mask_kwargs),
1656
+ \"sliding_attention\": create_sliding_window_causal_mask(**mask_kwargs),
1657
  }
1658
 
1659
  # embed positions
 
1669
 
1670
  temp_hidden_states = [hidden_states_0]
1671
  for i in range(1, self.config.altup_num_inputs):
1672
+ # altup_proj adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
1673
  altup_proj = self.altup_projections[i - 1](hidden_states_0)
1674
  current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1675
  new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
 
1717
  target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
1718
  temp_hidden_states = [hidden_states[0]]
1719
  for i in range(1, self.config.altup_num_inputs):
1720
+ # altup_unembed_projections adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
1721
  altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
1722
  current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1723
  new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
 
1771
  )
1772
 
1773
 
1774
+ @auto_docstring(custom_intro=\"The base Gemma 3n language model with a language modeling head.\")
1775
  class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
1776
+ _tied_weights_keys = [\"lm_head.weight\"]
1777
+ _tp_plan = {\"lm_head\": \"colwise_rep\"}
1778
+ _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}
1779
  config_class = Gemma3nTextConfig
1780
+ base_model_prefix = \"model\"
1781
+ _checkpoint_conversion_mapping = {\"model.language_model\": \"model\"}
1782
 
1783
  def __init__(self, config: Gemma3nTextConfig):
1784
  super().__init__(config)
 
1824
  logits_to_keep: Union[int, torch.Tensor] = 0,
1825
  **loss_kwargs,
1826
  ) -> CausalLMOutputWithPast:
1827
+ r\"\"\"
1828
+ labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
1829
+ Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
1830
+ config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
1831
+ (masked), the loss is only computed for the tokens with labels in \`[0, ..., config.vocab_size]\`.
1832
 
1833
  Example:
1834
 
1835
+ \`\`\`python
1836
  >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
1837
 
1838
+ >>> model = Gemma3nForCausalLM.from_pretrained(\"google/gemma-2-9b\")
1839
+ >>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b\")
1840
 
1841
+ >>> prompt = \"What is your favorite condiment?\"
1842
+ >>> inputs = tokenizer(prompt, return_tensors=\"pt\")
1843
 
1844
  >>> # Generate
1845
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1846
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1847
+ \"What is your favorite condiment?\"
1848
+ \`\`\`\"\"\"
1849
 
1850
+ if self.training and self.config._attn_implementation != \"eager\":
1851
  logger.warning_once(
1852
+ \"It is strongly recommended to train Gemma3n models with the \`eager\` attention implementation \"
1853
+ f\"instead of \`{self.config._attn_implementation}\`. Use \`eager\` with \`AutoModelForCausalLM.from_pretrained(\'<path-to-checkpoint>\', attn_implementation=\'eager\')\`.\"
1854
  )
1855
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1856
  output_hidden_states = (
 
1893
 
1894
 
1895
  class Gemma3nMultimodalEmbedder(nn.Module):
1896
+ \"\"\"Embeds token ids or soft tokens for multimodal content into language model space.\"\"\"
1897
 
1898
  def __init__(
1899
  self,
 
1919
  input_ids: Optional[torch.LongTensor] = None,
1920
  inputs_embeds: Optional[torch.Tensor] = None,
1921
  ) -> torch.Tensor:
1922
+ \"\"\"Embeds token ids or soft tokens for multimodal content into language model space.
1923
 
1924
  Args:
1925
  input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
1926
+ \`[vocab_offset, vocab_offset + vocab_size)\`.
1927
  inputs_embeds: A torch.Tensor containing the soft tokens to embed.
1928
 
1929
  Returns:
1930
+ A torch.Tensor of embeddings with shape \`[batch_size, seq_len, self.config.text_config.hidden_size]\`.
1931
+ \"\"\"
1932
  if (input_ids is None) ^ (inputs_embeds is not None):
1933
+ raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
1934
 
1935
  if inputs_embeds is not None:
1936
  emb_norm = self.soft_embedding_norm(inputs_embeds)
 
1943
 
1944
 
1945
  @auto_docstring(
1946
+ custom_intro=\"\"\"
1947
  The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
1948
  language modeling head.
1949
+ \"\"\"
1950
  )
1951
  class Gemma3nModel(Gemma3nPreTrainedModel):
1952
  _checkpoint_conversion_mapping = {}
1953
+ # we are filtering the logits/labels so we shouldn\'t divide the loss based on num_items_in_batch
1954
  accepts_loss_kwargs = False
1955
 
1956
  def __init__(self, config: Gemma3nConfig):
 
1981
  return self.language_model
1982
 
1983
  def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
1984
+ \"\"\"
1985
  Projects the last hidden state from the vision model into language model space.
1986
 
1987
  Args:
1988
+ pixel_values (\`torch.FloatTensor]\` of shape \`(batch_size, channels, height, width)\`)
1989
  The tensors corresponding to the input images.
1990
 
1991
  Returns:
1992
+ image_features (\`torch.Tensor\`): Image feature tensor of shape \`(num_images, image_length, embed_dim)\`).
1993
+ \"\"\"
1994
  vision_outputs = self.vision_tower(
1995
  pixel_values=pixel_values, do_pooling=False, return_dict=True
1996
  ).last_hidden_state
 
2024
  output_hidden_states: Optional[bool] = None,
2025
  **lm_kwargs,
2026
  ) -> Gemma3nCausalLMOutputWithPast:
2027
+ r\"\"\"
2028
+ labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
2029
+ Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
2030
+ config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
2031
+ (masked), the loss is only computed for the tokens with labels in \`[0, ..., config.text_config.vocab_size]\`.
2032
 
2033
  Example:
2034
 
2035
+ \`\`\`python
2036
  >>> from PIL import Image
2037
  >>> import requests
2038
  >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
2039
 
2040
+ >>> model = Gemma3nForConditionalGeneration.from_pretrained(\"google/gemma3n2-3b-mix-224\")
2041
+ >>> processor = AutoProcessor.from_pretrained(\"google/gemma3n2-3b-mix-224\")
2042
 
2043
+ >>> prompt = \"Where is the cat standing?\"
2044
+ >>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"
2045
  >>> image = Image.open(requests.get(url, stream=True).raw)
2046
 
2047
+ >>> inputs = processor(images=image, text=prompt, return_tensors=\"pt\")
2048
 
2049
  >>> # Generate
2050
  >>> generate_ids = model.generate(**inputs,)
2051
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2052
+ \"Where is the cat standing?\nsnow\"
2053
+ \`\`\`
2054
+ \"\"\"
2055
  if (input_ids is None) ^ (inputs_embeds is not None):
2056
+ raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
2057
 
2058
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2059
  output_hidden_states = (
 
2103
  if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
2104
  image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
2105
  raise ValueError(
2106
+ f\"Number of images does not match number of special image tokens in the input text. \"
2107
+ f\"Got {image_tokens_in_text} image tokens in the text and \"
2108
+ f\"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings.\"
2109
  )
2110
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
2111
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 
2140
  if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
2141
  audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
2142
  raise ValueError(
2143
+ f\"Number of audio input features does not match number of special audio tokens in the input text. \"
2144
+ f\"Got {audio_tokens_in_text} audio tokens in the text and \"
2145
+ f\"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings.\"
2146
  )
2147
  audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
2148
  inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
 
2174
  def get_audio_features(
2175
  self, input_features: torch.Tensor, input_features_mask: torch.Tensor
2176
  ) -> tuple[torch.Tensor, torch.Tensor]:
2177
+ \"\"\"
2178
  Projects the last hidden state from the audio encoder into language model space.
2179
 
2180
  Args:
2181
+ input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length, num_features)\`):
2182
  The tensors corresponding to the input audio.
2183
+ input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length)\`):
2184
  The attention mask for the input audio.
2185
 
2186
  Returns:
2187
+ audio_features (\`torch.Tensor\`): Audio feature tensor of shape \`(num_images, audio_length, embed_dim)\`).
2188
+ \"\"\"
2189
  audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
2190
  return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
2191
 
2192
 
2193
  @auto_docstring(
2194
+ custom_intro=\"\"\"
2195
  The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
2196
  head.
2197
+ \"\"\"
2198
  )
2199
  class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2200
  _checkpoint_conversion_mapping = {}
2201
+ _tied_weights_keys = [\"lm_head.weight\"]
2202
+ base_model_prefix = \"model\"
2203
 
2204
  def __init__(self, config: Gemma3nConfig):
2205
  super().__init__(config)
 
2239
 
2240
  @property
2241
  def multi_modal_projector(self):
2242
+ raise AttributeError(\"Use embed_vision instead of multi_modal_projector.\")
2243
 
2244
  @can_return_tuple
2245
  @auto_docstring
 
2262
  logits_to_keep: Union[int, torch.Tensor] = 0,
2263
  **lm_kwargs,
2264
  ) -> Gemma3nCausalLMOutputWithPast:
2265
+ r\"\"\"
2266
  input_features (torch.Tensor, *optional*, defaults to None):
2267
  The audio inputs to be encoded.
2268
  input_features_mask (torch.Tensor, *optional*, defaults to None):
2269
  The attention mask for the input audio.
2270
+ labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
2271
+ Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
2272
+ config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are
2273
  ignored (masked), the loss is only computed for the tokens with labels in
2274
+ \`[0, ..., config.text_config.vocab_size]\`.
2275
 
2276
  Example:
2277
 
2278
+ \`\`\`python
2279
  >>> from PIL import Image
2280
  >>> import requests
2281
  >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
2282
 
2283
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained(\"google/gemma-3-4b-it\")
2284
+ >>> processor = AutoProcessor.from_pretrained(\"google/gemma-3-4b-it\")
2285
 
2286
  >>> messages = [
2287
  ... {
2288
+ ... \"role\": \"system\",
2289
+ ... \"content\": [
2290
+ ... {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}
2291
  ... ]
2292
  ... },
2293
  ... {
2294
+ ... \"role\": \"user\", \"content\": [
2295
+ ... {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"},
2296
+ ... {\"type\": \"text\", \"text\": \"Where is the cat standing?\"},
2297
  ... ]
2298
  ... },
2299
  ... ]
 
2302
  ... messages,
2303
  ... tokenizer=True,
2304
  ... return_dict=True,
2305
+ ... return_tensors=\"pt\",
2306
  ... add_generation_prompt=True
2307
  ... )
2308
  >>> # Generate
2309
  >>> generate_ids = model.generate(**inputs)
2310
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2311
+ \"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to\"
2312
+ \`\`\`
2313
+ \"\"\"
2314
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2315
  output_hidden_states = (
2316
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
2393
  labels=None,
2394
  **kwargs,
2395
  ):
2396
+ # Overwritten -- custom \`position_ids\` and \`pixel_values\` handling
2397
  model_inputs = super().prepare_inputs_for_generation(
2398
  input_ids,
2399
  past_key_values=past_key_values,
 
2407
  **kwargs,
2408
  )
2409
 
2410
+ # If we\'re in cached decoding stage, multimodal inputs should be None because input ids do not contain special
2411
  # tokens anymore. Otherwise multimodal inputs should be passed to model.
2412
  # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
2413
  if cache_position[0] == 0:
2414
+ model_inputs[\"pixel_values\"] = pixel_values
2415
+ model_inputs[\"input_features\"] = input_features
2416
+ model_inputs[\"input_features_mask\"] = input_features_mask
2417
 
2418
  return model_inputs
2419
 
 
2423
 
2424
 
2425
  __all__ = [
2426
+ \"Gemma3nAudioEncoder\",
2427
+ \"Gemma3nForCausalLM\",
2428
+ \"Gemma3nForConditionalGeneration\",
2429
+ \"Gemma3nModel\",
2430
+ \"Gemma3nPreTrainedModel\",
2431
+ \"Gemma3nTextModel\",
2432
  ]