zifei9 commited on
Commit
d07dd25
·
verified ·
1 Parent(s): 1b9b030

Update modeling_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_mistral.py +1 -35
modeling_mistral.py CHANGED
@@ -34,7 +34,7 @@ from transformers.utils import (
34
  )
35
  from .configuration_mistral import MistralConfig
36
  from transformers.models.mistral.modeling_mistral import MistralRMSNorm
37
-
38
 
39
  logger = logging.get_logger(__name__)
40
 
@@ -58,40 +58,6 @@ class MistralMLP(nn.Module):
58
  return down_proj
59
 
60
 
61
- def rotate_half(x):
62
- """Rotates half the hidden dims of the input."""
63
- x1 = x[..., : x.shape[-1] // 2]
64
- x2 = x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
-
67
-
68
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
69
- """Applies Rotary Position Embedding to the query and key tensors.
70
-
71
- Args:
72
- q (`torch.Tensor`): The query tensor.
73
- k (`torch.Tensor`): The key tensor.
74
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
75
- sin (`torch.Tensor`): The sine part of the rotary embedding.
76
- position_ids (`torch.Tensor`, *optional*):
77
- Deprecated and unused.
78
- unsqueeze_dim (`int`, *optional*, defaults to 1):
79
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
80
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
81
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
82
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
83
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
84
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
85
- Returns:
86
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
87
- """
88
- cos = cos.unsqueeze(unsqueeze_dim)
89
- sin = sin.unsqueeze(unsqueeze_dim)
90
- q_embed = (q * cos) + (rotate_half(q) * sin)
91
- k_embed = (k * cos) + (rotate_half(k) * sin)
92
- return q_embed, k_embed
93
-
94
-
95
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
96
  """
97
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
34
  )
35
  from .configuration_mistral import MistralConfig
36
  from transformers.models.mistral.modeling_mistral import MistralRMSNorm
37
+ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb
38
 
39
  logger = logging.get_logger(__name__)
40
 
 
58
  return down_proj
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
62
  """
63
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,