LeroyDyer commited on
Commit
8a153c2
·
verified ·
1 Parent(s): a88942a

Update modeling_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_mistral.py +32 -4
modeling_mistral.py CHANGED
@@ -475,14 +475,42 @@ class MistralAttention(nn.Module):
475
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
476
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
477
 
478
- self.rotary_emb = MistralRotaryEmbedding(
479
- self.head_dim,
480
- max_position_embeddings=self.max_position_embeddings,
481
- base=self.rope_theta,
482
  )
483
 
484
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
485
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  def forward(
488
  self,
 
475
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
476
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
477
 
478
+ self._init_rope()
 
 
 
479
  )
480
 
481
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
482
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
483
+ def _init_rope(self):
484
+ if self.config.rope_scaling is None:
485
+ self.rotary_emb = MistralRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)
486
+ else:
487
+ scaling_type = self.config.rope_scaling["type"]
488
+ scaling_factor = self.config.rope_scaling["factor"]
489
+ if scaling_type == "linear":
490
+ self.rotary_emb = MistralLinearScalingRotaryEmbedding(
491
+ self.head_dim, max_position_embeddings=self.max_position_embeddings,
492
+ scaling_factor=scaling_factor, base=self.rope_theta,
493
+ )
494
+ elif scaling_type == "dynamic":
495
+ self.rotary_emb = MistralDynamicNTKScalingRotaryEmbedding(
496
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor,
497
+ base=self.rope_theta,
498
+ )
499
+ elif scaling_type == "yarn":
500
+ original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
501
+ self.rotary_emb = MistralYaRNScaledRotaryEmbedding(
502
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scale=scaling_factor,
503
+ original_max_position_embeddings=original_max_position_embeddings, base=self.rope_theta,
504
+ )
505
+ elif scaling_type == "dynamic-yarn":
506
+ original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
507
+ self.rotary_emb = MistralDynamicYaRNScaledRotaryEmbedding(
508
+ self.head_dim, max_position_embeddings=self.max_position_embeddings,
509
+ original_max_position_embeddings=original_max_position_embeddings, base=self.rope_theta,
510
+ )
511
+ else:
512
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
513
+
514
 
515
  def forward(
516
  self,