BK-Lee commited on
Commit
7e1e0aa
1 Parent(s): 15b745f
Files changed (1) hide show
  1. meteor/arch/modeling_internlm2.py +2 -2
meteor/arch/modeling_internlm2.py CHANGED
@@ -277,8 +277,8 @@ def rotate_half(x):
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
280
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
281
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
282
  q_embed = (q * cos) + (rotate_half(q) * sin)
283
  k_embed = (k * cos) + (rotate_half(k) * sin)
284
  return q_embed, k_embed
 
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
280
+ cos = cos.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
281
+ sin = sin.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
282
  q_embed = (q * cos) + (rotate_half(q) * sin)
283
  k_embed = (k * cos) + (rotate_half(k) * sin)
284
  return q_embed, k_embed