Spaces:
Sleeping
Sleeping
v1
Browse files
meteor/arch/modeling_internlm2.py
CHANGED
@@ -277,8 +277,8 @@ def rotate_half(x):
|
|
277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
280 |
-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
281 |
-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
284 |
return q_embed, k_embed
|
|
|
277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
280 |
+
cos = cos.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
|
281 |
+
sin = sin.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
|
282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
284 |
return q_embed, k_embed
|