zifei9 commited on
Commit
2e86aea
·
verified ·
1 Parent(s): 9dd3e90

Update modeling_gemma.py

Browse files
Files changed (1) hide show
  1. modeling_gemma.py +1 -35
modeling_gemma.py CHANGED
@@ -48,6 +48,7 @@ from transformers.utils import (
48
  )
49
  from .configuration_gemma import GemmaConfig
50
  from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
 
51
 
52
  logger = logging.get_logger(__name__)
53
 
@@ -188,41 +189,6 @@ class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
188
  cos, sin = super().forward(x, position_ids)
189
  return cos, sin
190
 
191
-
192
- def rotate_half(x):
193
- """Rotates half the hidden dims of the input."""
194
- x1 = x[..., : x.shape[-1] // 2]
195
- x2 = x[..., x.shape[-1] // 2 :]
196
- return torch.cat((-x2, x1), dim=-1)
197
-
198
-
199
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
200
- """Applies Rotary Position Embedding to the query and key tensors.
201
-
202
- Args:
203
- q (`torch.Tensor`): The query tensor.
204
- k (`torch.Tensor`): The key tensor.
205
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
206
- sin (`torch.Tensor`): The sine part of the rotary embedding.
207
- position_ids (`torch.Tensor`, *optional*):
208
- Deprecated and unused.
209
- unsqueeze_dim (`int`, *optional*, defaults to 1):
210
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
211
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
212
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
213
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
214
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
215
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
216
- Returns:
217
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
218
- """
219
- cos = cos.unsqueeze(unsqueeze_dim)
220
- sin = sin.unsqueeze(unsqueeze_dim)
221
- q_embed = (q * cos) + (rotate_half(q) * sin)
222
- k_embed = (k * cos) + (rotate_half(k) * sin)
223
- return q_embed, k_embed
224
-
225
-
226
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
227
  """
228
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
48
  )
49
  from .configuration_gemma import GemmaConfig
50
  from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
51
+ from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb
52
 
53
  logger = logging.get_logger(__name__)
54
 
 
189
  cos, sin = super().forward(x, position_ids)
190
  return cos, sin
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
193
  """
194
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,