zifei9 commited on
Commit
4ea6b44
·
verified ·
1 Parent(s): 62f1e9b

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +1 -34
modeling_llama.py CHANGED
@@ -50,6 +50,7 @@ from transformers.utils import (
50
  )
51
  from .configuration_llama import LlamaConfig
52
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
 
53
 
54
 
55
  if is_flash_attn_2_available():
@@ -158,40 +159,6 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
158
  return cos, sin
159
 
160
 
161
- def rotate_half(x):
162
- """Rotates half the hidden dims of the input."""
163
- x1 = x[..., : x.shape[-1] // 2]
164
- x2 = x[..., x.shape[-1] // 2 :]
165
- return torch.cat((-x2, x1), dim=-1)
166
-
167
-
168
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
- """Applies Rotary Position Embedding to the query and key tensors.
170
-
171
- Args:
172
- q (`torch.Tensor`): The query tensor.
173
- k (`torch.Tensor`): The key tensor.
174
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
175
- sin (`torch.Tensor`): The sine part of the rotary embedding.
176
- position_ids (`torch.Tensor`, *optional*):
177
- Deprecated and unused.
178
- unsqueeze_dim (`int`, *optional*, defaults to 1):
179
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
180
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
181
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
182
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
183
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
184
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
185
- Returns:
186
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
187
- """
188
- cos = cos.unsqueeze(unsqueeze_dim)
189
- sin = sin.unsqueeze(unsqueeze_dim)
190
- q_embed = (q * cos) + (rotate_half(q) * sin)
191
- k_embed = (k * cos) + (rotate_half(k) * sin)
192
- return q_embed, k_embed
193
-
194
-
195
  class LlamaMLP(nn.Module):
196
  def __init__(self, config):
197
  super().__init__()
 
50
  )
51
  from .configuration_llama import LlamaConfig
52
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
53
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
54
 
55
 
56
  if is_flash_attn_2_available():
 
159
  return cos, sin
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  class LlamaMLP(nn.Module):
163
  def __init__(self, config):
164
  super().__init__()